diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/main.go | 52 |
1 files changed, 22 insertions, 30 deletions
diff --git a/rag/main.go b/rag/main.go index f1be167..d4065e5 100644 --- a/rag/main.go +++ b/rag/main.go @@ -18,6 +18,14 @@ import ( "github.com/neurosnap/sentences/english" ) +var ( + LongJobStatusCh = make(chan string, 1) + // messages + FinishedRAGStatus = "finished loading RAG file; press Enter" + LoadedFileRAGStatus = "loaded file" + ErrRAGStatus = "some error occured; failed to transfer data to vector db" +) + type RAG struct { logger *slog.Logger store storage.FullRepo @@ -42,6 +50,7 @@ func (r *RAG) LoadRAG(fpath string) error { return err } r.logger.Info("rag: loaded file", "fp", fpath) + LongJobStatusCh <- LoadedFileRAGStatus fileText := string(data) tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { @@ -49,7 +58,6 @@ func (r *RAG) LoadRAG(fpath string) error { } sentences := tokenizer.Tokenize(fileText) sents := make([]string, len(sentences)) - r.logger.Info("rag: sentences", "#", len(sents)) for i, s := range sentences { sents[i] = s.Text } @@ -60,16 +68,14 @@ func (r *RAG) LoadRAG(fpath string) error { batchSize = 100 maxChSize = 1000 // - // psize = 3 wordLimit = 80 - // - left = 0 - right = batchSize - batchCh = make(chan map[int][]string, maxChSize) - vectorCh = make(chan []models.VectorRow, maxChSize) - errCh = make(chan error, 1) - doneCh = make(chan bool, 1) - lock = new(sync.Mutex) + left = 0 + right = batchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + doneCh = make(chan bool, 1) + lock = new(sync.Mutex) ) defer close(doneCh) defer close(errCh) @@ -84,13 +90,6 @@ func (r *RAG) LoadRAG(fpath string) error { par.Reset() } } - // for i := 0; i < len(sents); i += psize { - // if len(sents) < i+psize { - // paragraphs = append(paragraphs, strings.Join(sents[i:], " ")) - // break - // } - // paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " ")) - // } if len(paragraphs) < batchSize { batchSize = len(paragraphs) } @@ -105,7 +104,9 @@ func (r *RAG) LoadRAG(fpath string) error { left, right = right, right+batchSize ctn++ } - r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents)) + finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents)) + r.logger.Info(finishedBatchesMsg) + LongJobStatusCh <- finishedBatchesMsg for w := 0; w < workers; w++ { go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) } @@ -121,6 +122,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { for _, vector := range batch { if err := r.store.WriteVector(&vector); err != nil { r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) + LongJobStatusCh <- ErrRAGStatus continue // a duplicate is not critical // return err } @@ -128,6 +130,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) if len(vectorCh) == 0 { r.logger.Info("finished writing vectors") + LongJobStatusCh <- FinishedRAGStatus defer close(vectorCh) return nil } @@ -150,10 +153,6 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[ case linesMap := <-inputCh: for leftI, v := range linesMap { r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) - // if leftI+200 >= limit { // last batch - // // doneCh <- true - // return - // } } lock.Unlock() case err := <-errCh: @@ -162,6 +161,7 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[ return } r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id) + LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) } } @@ -183,8 +183,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod } req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) errCh <- err @@ -194,9 +192,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod if resp.StatusCode != 200 { r.logger.Error("non 200 resp", "code", resp.StatusCode) return - // err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode) - // errCh <- err - // return } emb := [][]float32{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { @@ -224,7 +219,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod } func (r *RAG) LineToVector(line string) ([]float32, error) { - // payload, err := json.Marshal(map[string]string{"content": line}) lines := []string{line} payload, err := json.Marshal( map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, @@ -241,7 +235,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { } req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) - // resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) return nil, err @@ -252,7 +245,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { r.logger.Error(err.Error()) return nil, err } - // emb := models.EmbeddingResp{} emb := [][]float32{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) |