diff options
Diffstat (limited to 'rag.go')
-rw-r--r-- | rag.go | 222 |
1 files changed, 0 insertions, 222 deletions
@@ -1,222 +0,0 @@ -package main - -import ( - "bytes" - "context" - "elefant/models" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - - "github.com/neurosnap/sentences/english" -) - -func loadRAG(fpath string) error { - data, err := os.ReadFile(fpath) - if err != nil { - return err - } - fileText := string(data) - tokenizer, err := english.NewSentenceTokenizer(nil) - if err != nil { - return err - } - sentences := tokenizer.Tokenize(fileText) - sents := make([]string, len(sentences)) - for i, s := range sentences { - sents[i] = s.Text - } - var ( - // TODO: to config - workers = 5 - batchSize = 200 - // - left = 0 - right = batchSize - batchCh = make(chan map[int][]string) - vectorCh = make(chan []models.VectorRow) - errCh = make(chan error) - ) - if len(sents) < batchSize { - batchSize = len(sents) - } - // fill input channel - for { - if right > len(sents) { - batchCh <- map[int][]string{left: sents[left:]} - break - } - batchCh <- map[int][]string{left: sents[left:right]} - left, right = right, right+batchSize - } - // TODO: cancel complains, replace ctx with done chan - ctx, cancel := context.WithCancel(context.Background()) - for w := 0; w < workers; w++ { - go batchToVectorHFAsync(ctx, cancel, len(sents), batchCh, vectorCh, errCh) - } - // write to db - return writeVectors(vectorCh) -} - -func writeVectors(vectorCh <-chan []models.VectorRow) error { - for batch := range vectorCh { - for _, vector := range batch { - if err := store.WriteVector(&vector); err != nil { - return err - } - } - } - return nil -} - -func batchToVectorHFAsync(ctx context.Context, close context.CancelFunc, limit int, - inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error) { - for { - select { - case linesMap := <-inputCh: - for leftI, v := range linesMap { - FecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) - if leftI+200 >= limit { // last batch - close() - return - } - } - case <-ctx.Done(): - logger.Error("got ctx done") - return - case err := <-errCh: - logger.Error("got an error", "error", err) - close() - return - } - } -} - -func FecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - errCh <- err - return - } - req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) - resp, err := httpClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - errCh <- err - return - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - if len(emb) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - errCh <- err - return - } - vectors := make([]models.VectorRow, len(emb)) - for i, e := range emb { - vector := models.VectorRow{ - Embeddings: e, - RawText: lines[i], - Slug: slug, - } - vectors[i] = vector - } - vectorCh <- vectors -} - -func batchToVectorHF(lines []string) ([][]float32, error) { - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - return nil, err - } - req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) - resp, err := httpClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - return nil, err - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - if len(emb) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - return nil, err - } - return emb, nil -} - -func lineToVector(line string) (*models.EmbeddingResp, error) { - payload, err := json.Marshal(map[string]string{"content": line}) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - return nil, err - } - // nolint - resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - return nil, err - } - emb := models.EmbeddingResp{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - if len(emb.Embedding) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - return nil, err - } - return &emb, nil -} - -func saveLine(topic, line string, emb *models.EmbeddingResp) error { - row := &models.VectorRow{ - Embeddings: emb.Embedding, - Slug: topic, - RawText: line, - } - return store.WriteVector(row) -} - -func searchEmb(emb *models.EmbeddingResp) (*models.VectorRow, error) { - return store.SearchClosest(emb.Embedding) -} |