diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-09 15:49:59 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-09 15:49:59 +0300 |
commit | 363bbae2c756f448d8cdac50305902d68d45c26c (patch) | |
tree | 4206136d7342006c32e7dbaf2891bce608969427 /rag | |
parent | 7bbedd93cf078fc7496a6779cf9eda6e588e64c0 (diff) |
Fix: RAG updates
Diffstat (limited to 'rag')
-rw-r--r-- | rag/main.go | 70 |
1 files changed, 47 insertions, 23 deletions
diff --git a/rag/main.go b/rag/main.go index da919b4..58ff448 100644 --- a/rag/main.go +++ b/rag/main.go @@ -11,6 +11,8 @@ import ( "log/slog" "net/http" "os" + "path" + "strings" "github.com/neurosnap/sentences/english" ) @@ -29,6 +31,10 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { } } +func wordCounter(sentence string) int { + return len(strings.Split(sentence, " ")) +} + func (r *RAG) LoadRAG(fpath string) error { data, err := os.ReadFile(fpath) if err != nil { @@ -53,6 +59,9 @@ func (r *RAG) LoadRAG(fpath string) error { batchSize = 200 maxChSize = 1000 // + // psize = 3 + wordLimit = 80 + // left = 0 right = batchSize batchCh = make(chan map[int][]string, maxChSize) @@ -60,23 +69,40 @@ func (r *RAG) LoadRAG(fpath string) error { errCh = make(chan error, 1) doneCh = make(chan bool, 1) ) - if len(sents) < batchSize { - batchSize = len(sents) + // group sentences + paragraphs := []string{} + par := strings.Builder{} + for i := 0; i < len(sents); i++ { + par.WriteString(sents[i]) + if wordCounter(par.String()) > wordLimit { + paragraphs = append(paragraphs, par.String()) + par.Reset() + } + } + // for i := 0; i < len(sents); i += psize { + // if len(sents) < i+psize { + // paragraphs = append(paragraphs, strings.Join(sents[i:], " ")) + // break + // } + // paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " ")) + // } + if len(paragraphs) < batchSize { + batchSize = len(paragraphs) } // fill input channel ctn := 0 for { - if right > len(sents) { - batchCh <- map[int][]string{left: sents[left:]} + if right > len(paragraphs) { + batchCh <- map[int][]string{left: paragraphs[left:]} break } - batchCh <- map[int][]string{left: sents[left:right]} + batchCh <- map[int][]string{left: paragraphs[left:right]} left, right = right, right+batchSize ctn++ } r.logger.Info("finished batching", "batches#", len(batchCh)) for w := 0; w < workers; w++ { - go r.batchToVectorHFAsync(len(sents), batchCh, vectorCh, errCh, doneCh) + go r.batchToVectorHFAsync(len(paragraphs), batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) } // write to db return r.writeVectors(vectorCh, doneCh) @@ -102,20 +128,17 @@ func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan boo } func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, - vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool) { + vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { r.logger.Info("to vector batches", "batches#", len(inputCh)) for { select { case linesMap := <-inputCh: - // r.logger.Info("batch from ch") for leftI, v := range linesMap { - // r.logger.Info("fetching", "index", leftI) - r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) + r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) if leftI+200 >= limit { // last batch doneCh <- true return } - // r.logger.Info("done feitching", "index", leftI) } case <-doneCh: r.logger.Info("got done") @@ -129,7 +152,7 @@ func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, } } -func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { +func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) { payload, err := json.Marshal( map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, ) @@ -138,13 +161,14 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod errCh <- err return } + // nolint req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) if err != nil { r.logger.Error("failed to create new req", "err:", err.Error()) errCh <- err return } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) // nolint // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) @@ -179,6 +203,7 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod Embeddings: e, RawText: lines[i], Slug: slug, + FileName: filename, } vectors[i] = vector } @@ -201,7 +226,7 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { r.logger.Error("failed to create new req", "err:", err.Error()) return nil, err } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) // resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { @@ -228,15 +253,14 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { return emb[0], nil } -// func (r *RAG) saveLine(topic, line string, emb *models.EmbeddingResp) error { -// row := &models.VectorRow{ -// Embeddings: emb.Embedding, -// Slug: topic, -// RawText: line, -// } -// return r.store.WriteVector(row) -// } - func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { return r.store.SearchClosest(emb.Embedding) } + +func (r *RAG) ListLoaded() ([]string, error) { + return r.store.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { + return r.store.RemoveEmbByFileName(filename) +} |