diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-04 18:13:13 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-04 18:13:13 +0300 |
commit | 4736e43631ed21fd14741daa1dde746687d330fa (patch) | |
tree | c5dcee2930e8681397a369ae7869671bdcf568dc /rag.go | |
parent | 461d19aa2512fea7ac07e50c3178609850ef07c3 (diff) |
Feat (RAG): tying tui calls to rag funcs [WIP; skip-ci]
RAG itself is annoying to properly implement, plucking sentences with no
context is useless. Also it should not be a part of main package, same
for goes for tui. The number of global vars is absurd.
Diffstat (limited to 'rag.go')
-rw-r--r-- | rag.go | 184 |
1 files changed, 183 insertions, 1 deletions
@@ -2,27 +2,209 @@ package main import ( "bytes" + "context" "elefant/models" "encoding/json" + "errors" + "fmt" + "net/http" + "os" + + "github.com/neurosnap/sentences/english" ) +func loadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + for i, s := range sentences { + sents[i] = s.Text + } + var ( + // TODO: to config + workers = 5 + batchSize = 200 + // + left = 0 + right = batchSize + batchCh = make(chan map[int][]string) + vectorCh = make(chan []models.VectorRow) + errCh = make(chan error) + ) + if len(sents) < batchSize { + batchSize = len(sents) + } + // fill input channel + for { + if right > len(sents) { + batchCh <- map[int][]string{left: sents[left:]} + break + } + batchCh <- map[int][]string{left: sents[left:right]} + left, right = right, right+batchSize + } + // TODO: cancel complains, replace ctx with done chan + ctx, cancel := context.WithCancel(context.Background()) + for w := 0; w < workers; w++ { + go batchToVectorHFAsync(ctx, cancel, len(sents), batchCh, vectorCh, errCh) + } + // write to db + return writeVectors(vectorCh) +} + +func writeVectors(vectorCh <-chan []models.VectorRow) error { + for batch := range vectorCh { + for _, vector := range batch { + if err := store.WriteVector(&vector); err != nil { + return err + } + } + } + return nil +} + +func batchToVectorHFAsync(ctx context.Context, close context.CancelFunc, limit int, + inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error) { + for { + select { + case linesMap := <-inputCh: + for leftI, v := range linesMap { + FecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) + if leftI+200 >= limit { // last batch + close() + return + } + } + case <-ctx.Done(): + logger.Error("got ctx done") + return + case err := <-errCh: + logger.Error("got an error", "error", err) + close() + return + } + } +} + +func FecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + logger.Error("failed to marshal payload", "err:", err.Error()) + errCh <- err + return + } + req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) + resp, err := httpClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + errCh <- err + return + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + if len(emb) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + errCh <- err + return + } + vectors := make([]models.VectorRow, len(emb)) + for i, e := range emb { + vector := models.VectorRow{ + Embeddings: e, + RawText: lines[i], + Slug: slug, + } + vectors[i] = vector + } + vectorCh <- vectors +} + +func batchToVectorHF(lines []string) ([][]float32, error) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + logger.Error("failed to marshal payload", "err:", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) + resp, err := httpClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + return nil, err + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + if len(emb) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } + return emb, nil +} + func lineToVector(line string) (*models.EmbeddingResp, error) { payload, err := json.Marshal(map[string]string{"content": line}) if err != nil { logger.Error("failed to marshal payload", "err:", err.Error()) return nil, err } + // nolint resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { logger.Error("failed to embedd line", "err:", err.Error()) return nil, err } defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + return nil, err + } emb := models.EmbeddingResp{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { logger.Error("failed to embedd line", "err:", err.Error()) return nil, err } + if len(emb.Embedding) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } return &emb, nil } @@ -36,5 +218,5 @@ func saveLine(topic, line string, emb *models.EmbeddingResp) error { } func searchEmb(emb *models.EmbeddingResp) (*models.VectorRow, error) { - return store.SearchClosest([5120]float32(emb.Embedding)) + return store.SearchClosest(emb.Embedding) } |