diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-05 20:51:31 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-05 20:51:31 +0300 |
commit | b822b3a1613ef7f1c9ed8fa5aaddfaffbfc513a4 (patch) | |
tree | afa89e8eb8916e5e970cac9fb70eaddd97c12ae6 /rag | |
parent | 4736e43631ed21fd14741daa1dde746687d330fa (diff) |
Refactor: rag to sep package
Diffstat (limited to 'rag')
-rw-r--r-- | rag/main.go | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/rag/main.go b/rag/main.go new file mode 100644 index 0000000..a7084bf --- /dev/null +++ b/rag/main.go @@ -0,0 +1,240 @@ +package rag + +import ( + "bytes" + "elefant/config" + "elefant/models" + "elefant/storage" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/neurosnap/sentences/english" +) + +type RAG struct { + logger *slog.Logger + store storage.FullRepo + cfg *config.Config +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { + return &RAG{ + logger: l, + store: s, + cfg: cfg, + } +} + +func (r *RAG) LoadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + r.logger.Info("rag: loaded file", "fp", fpath) + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + r.logger.Info("rag: sentences", "#", len(sents)) + for i, s := range sentences { + sents[i] = s.Text + } + // TODO: maybe better to decide batch size based on sentences len + var ( + // TODO: to config + workers = 5 + batchSize = 200 + maxChSize = 1000 + // + left = 0 + right = batchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + doneCh = make(chan bool, 1) + ) + if len(sents) < batchSize { + batchSize = len(sents) + } + // fill input channel + ctn := 0 + for { + if right > len(sents) { + batchCh <- map[int][]string{left: sents[left:]} + break + } + batchCh <- map[int][]string{left: sents[left:right]} + left, right = right, right+batchSize + ctn++ + } + r.logger.Info("finished batching", "batches#", len(batchCh)) + for w := 0; w < workers; w++ { + go r.batchToVectorHFAsync(len(sents), batchCh, vectorCh, errCh, doneCh) + } + // write to db + return r.writeVectors(vectorCh, doneCh) +} + +func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan bool) error { + for { + select { + case batch := <-vectorCh: + for _, vector := range batch { + if err := r.store.WriteVector(&vector); err != nil { + return err + } + } + r.logger.Info("wrote batch to db", "size", len(batch)) + case <-doneCh: + r.logger.Info("rag finished") + return nil + } + } +} + +func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool) { + r.logger.Info("to vector batches", "batches#", len(inputCh)) + for { + select { + case linesMap := <-inputCh: + // r.logger.Info("batch from ch") + for leftI, v := range linesMap { + // r.logger.Info("fetching", "index", leftI) + r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) + if leftI+200 >= limit { // last batch + doneCh <- true + return + } + // r.logger.Info("done feitching", "index", leftI) + } + case <-doneCh: + r.logger.Info("got done") + close(errCh) + close(doneCh) + return + case err := <-errCh: + r.logger.Error("got an error", "error", err) + return + } + } +} + +func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + r.logger.Error("failed to marshal payload", "err:", err.Error()) + errCh <- err + return + } + req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to create new req", "err:", err.Error()) + errCh <- err + return + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + resp, err := http.DefaultClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + r.logger.Error("non 200 resp", "code", resp.StatusCode) + return + // err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode) + // errCh <- err + // return + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + if len(emb) == 0 { + r.logger.Error("empty emb") + err = errors.New("empty emb") + errCh <- err + return + } + vectors := make([]models.VectorRow, len(emb)) + for i, e := range emb { + vector := models.VectorRow{ + Embeddings: e, + RawText: lines[i], + Slug: slug, + } + vectors[i] = vector + } + vectorCh <- vectors +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { + // payload, err := json.Marshal(map[string]string{"content": line}) + lines := []string{line} + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + r.logger.Error("failed to marshal payload", "err:", err.Error()) + return nil, err + } + // nolint + req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to create new req", "err:", err.Error()) + return nil, err + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + resp, err := http.DefaultClient.Do(req) + // resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 resp; code: %v\n", resp.StatusCode) + r.logger.Error(err.Error()) + return nil, err + } + // emb := models.EmbeddingResp{} + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + if len(emb) == 0 || len(emb[0]) == 0 { + r.logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } + return emb[0], nil +} + +func (r *RAG) saveLine(topic, line string, emb *models.EmbeddingResp) error { + row := &models.VectorRow{ + Embeddings: emb.Embedding, + Slug: topic, + RawText: line, + } + return r.store.WriteVector(row) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + return r.store.SearchClosest(emb.Embedding) +} |