diff options
-rw-r--r-- | bot.go | 21 | ||||
-rw-r--r-- | main.go | 2 | ||||
-rw-r--r-- | rag.go | 222 | ||||
-rw-r--r-- | rag/main.go | 240 | ||||
-rw-r--r-- | storage/vector.go | 37 | ||||
-rw-r--r-- | tui.go | 8 |
6 files changed, 282 insertions, 248 deletions
@@ -5,6 +5,7 @@ import ( "bytes" "elefant/config" "elefant/models" + "elefant/rag" "elefant/storage" "encoding/json" "fmt" @@ -33,6 +34,7 @@ var ( defaultStarter = []models.RoleMsg{} defaultStarterBytes = []byte{} interruptResp = false + ragger *rag.RAG ) // ==== @@ -129,26 +131,34 @@ func chatRagUse(qText string) (string, error) { for i, q := range questionsS { questions[i] = q.Text } - respVecs := []*models.VectorRow{} + respVecs := []models.VectorRow{} for i, q := range questions { - emb, err := lineToVector(q) + emb, err := ragger.LineToVector(q) if err != nil { logger.Error("failed to get embs", "error", err, "index", i, "question", q) continue } - vec, err := searchEmb(emb) + // e := &models.EmbeddingResp{ + // Embedding: emb, + // } + // vecs, err := ragger.SearchEmb(e) + vecs, err := store.SearchClosest(emb) if err != nil { - logger.Error("failed to get embs", "error", err, "index", i, "question", q) + logger.Error("failed to query embs", "error", err, "index", i, "question", q) continue } - respVecs = append(respVecs, vec) + respVecs = append(respVecs, vecs...) // logger.Info("returned vector from query search", "question", q, "vec", vec) } // get raw text resps := []string{} + logger.Info("sqlvec resp", "vecs", respVecs) for _, rv := range respVecs { resps = append(resps, rv.RawText) } + if len(resps) == 0 { + return "No related results from vector storage.", nil + } return strings.Join(resps, "\n"), nil } @@ -326,6 +336,7 @@ func init() { if store == nil { os.Exit(1) } + ragger = rag.New(logger, store, cfg) // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md // load all chats in memory if _, err := loadHistoryChats(); err != nil { @@ -10,7 +10,7 @@ var ( botRespMode = false editMode = false selectedIndex = int(-1) - indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s; RAGEnabled: %v" + indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s; RAGEnabled: %v; EmbedURL: %s" focusSwitcher = map[tview.Primitive]tview.Primitive{} ) @@ -1,222 +0,0 @@ -package main - -import ( - "bytes" - "context" - "elefant/models" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - - "github.com/neurosnap/sentences/english" -) - -func loadRAG(fpath string) error { - data, err := os.ReadFile(fpath) - if err != nil { - return err - } - fileText := string(data) - tokenizer, err := english.NewSentenceTokenizer(nil) - if err != nil { - return err - } - sentences := tokenizer.Tokenize(fileText) - sents := make([]string, len(sentences)) - for i, s := range sentences { - sents[i] = s.Text - } - var ( - // TODO: to config - workers = 5 - batchSize = 200 - // - left = 0 - right = batchSize - batchCh = make(chan map[int][]string) - vectorCh = make(chan []models.VectorRow) - errCh = make(chan error) - ) - if len(sents) < batchSize { - batchSize = len(sents) - } - // fill input channel - for { - if right > len(sents) { - batchCh <- map[int][]string{left: sents[left:]} - break - } - batchCh <- map[int][]string{left: sents[left:right]} - left, right = right, right+batchSize - } - // TODO: cancel complains, replace ctx with done chan - ctx, cancel := context.WithCancel(context.Background()) - for w := 0; w < workers; w++ { - go batchToVectorHFAsync(ctx, cancel, len(sents), batchCh, vectorCh, errCh) - } - // write to db - return writeVectors(vectorCh) -} - -func writeVectors(vectorCh <-chan []models.VectorRow) error { - for batch := range vectorCh { - for _, vector := range batch { - if err := store.WriteVector(&vector); err != nil { - return err - } - } - } - return nil -} - -func batchToVectorHFAsync(ctx context.Context, close context.CancelFunc, limit int, - inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error) { - for { - select { - case linesMap := <-inputCh: - for leftI, v := range linesMap { - FecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) - if leftI+200 >= limit { // last batch - close() - return - } - } - case <-ctx.Done(): - logger.Error("got ctx done") - return - case err := <-errCh: - logger.Error("got an error", "error", err) - close() - return - } - } -} - -func FecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - errCh <- err - return - } - req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) - resp, err := httpClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - errCh <- err - return - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - if len(emb) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - errCh <- err - return - } - vectors := make([]models.VectorRow, len(emb)) - for i, e := range emb { - vector := models.VectorRow{ - Embeddings: e, - RawText: lines[i], - Slug: slug, - } - vectors[i] = vector - } - vectorCh <- vectors -} - -func batchToVectorHF(lines []string) ([][]float32, error) { - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - return nil, err - } - req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) - resp, err := httpClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - return nil, err - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - if len(emb) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - return nil, err - } - return emb, nil -} - -func lineToVector(line string) (*models.EmbeddingResp, error) { - payload, err := json.Marshal(map[string]string{"content": line}) - if err != nil { - logger.Error("failed to marshal payload", "err:", err.Error()) - return nil, err - } - // nolint - resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) - if err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - logger.Error("non 200 resp", "code", resp.StatusCode) - return nil, err - } - emb := models.EmbeddingResp{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - if len(emb.Embedding) == 0 { - logger.Error("empty emb") - err = errors.New("empty emb") - return nil, err - } - return &emb, nil -} - -func saveLine(topic, line string, emb *models.EmbeddingResp) error { - row := &models.VectorRow{ - Embeddings: emb.Embedding, - Slug: topic, - RawText: line, - } - return store.WriteVector(row) -} - -func searchEmb(emb *models.EmbeddingResp) (*models.VectorRow, error) { - return store.SearchClosest(emb.Embedding) -} diff --git a/rag/main.go b/rag/main.go new file mode 100644 index 0000000..a7084bf --- /dev/null +++ b/rag/main.go @@ -0,0 +1,240 @@ +package rag + +import ( + "bytes" + "elefant/config" + "elefant/models" + "elefant/storage" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/neurosnap/sentences/english" +) + +type RAG struct { + logger *slog.Logger + store storage.FullRepo + cfg *config.Config +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { + return &RAG{ + logger: l, + store: s, + cfg: cfg, + } +} + +func (r *RAG) LoadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + r.logger.Info("rag: loaded file", "fp", fpath) + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + r.logger.Info("rag: sentences", "#", len(sents)) + for i, s := range sentences { + sents[i] = s.Text + } + // TODO: maybe better to decide batch size based on sentences len + var ( + // TODO: to config + workers = 5 + batchSize = 200 + maxChSize = 1000 + // + left = 0 + right = batchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + doneCh = make(chan bool, 1) + ) + if len(sents) < batchSize { + batchSize = len(sents) + } + // fill input channel + ctn := 0 + for { + if right > len(sents) { + batchCh <- map[int][]string{left: sents[left:]} + break + } + batchCh <- map[int][]string{left: sents[left:right]} + left, right = right, right+batchSize + ctn++ + } + r.logger.Info("finished batching", "batches#", len(batchCh)) + for w := 0; w < workers; w++ { + go r.batchToVectorHFAsync(len(sents), batchCh, vectorCh, errCh, doneCh) + } + // write to db + return r.writeVectors(vectorCh, doneCh) +} + +func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan bool) error { + for { + select { + case batch := <-vectorCh: + for _, vector := range batch { + if err := r.store.WriteVector(&vector); err != nil { + return err + } + } + r.logger.Info("wrote batch to db", "size", len(batch)) + case <-doneCh: + r.logger.Info("rag finished") + return nil + } + } +} + +func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool) { + r.logger.Info("to vector batches", "batches#", len(inputCh)) + for { + select { + case linesMap := <-inputCh: + // r.logger.Info("batch from ch") + for leftI, v := range linesMap { + // r.logger.Info("fetching", "index", leftI) + r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) + if leftI+200 >= limit { // last batch + doneCh <- true + return + } + // r.logger.Info("done feitching", "index", leftI) + } + case <-doneCh: + r.logger.Info("got done") + close(errCh) + close(doneCh) + return + case err := <-errCh: + r.logger.Error("got an error", "error", err) + return + } + } +} + +func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + r.logger.Error("failed to marshal payload", "err:", err.Error()) + errCh <- err + return + } + req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to create new req", "err:", err.Error()) + errCh <- err + return + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + resp, err := http.DefaultClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + r.logger.Error("non 200 resp", "code", resp.StatusCode) + return + // err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode) + // errCh <- err + // return + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + if len(emb) == 0 { + r.logger.Error("empty emb") + err = errors.New("empty emb") + errCh <- err + return + } + vectors := make([]models.VectorRow, len(emb)) + for i, e := range emb { + vector := models.VectorRow{ + Embeddings: e, + RawText: lines[i], + Slug: slug, + } + vectors[i] = vector + } + vectorCh <- vectors +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { + // payload, err := json.Marshal(map[string]string{"content": line}) + lines := []string{line} + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + r.logger.Error("failed to marshal payload", "err:", err.Error()) + return nil, err + } + // nolint + req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to create new req", "err:", err.Error()) + return nil, err + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) + resp, err := http.DefaultClient.Do(req) + // resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 resp; code: %v\n", resp.StatusCode) + r.logger.Error(err.Error()) + return nil, err + } + // emb := models.EmbeddingResp{} + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + r.logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + if len(emb) == 0 || len(emb[0]) == 0 { + r.logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } + return emb[0], nil +} + +func (r *RAG) saveLine(topic, line string, emb *models.EmbeddingResp) error { + row := &models.VectorRow{ + Embeddings: emb.Embedding, + Slug: topic, + RawText: line, + } + return r.store.WriteVector(row) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + return r.store.SearchClosest(emb.Embedding) +} diff --git a/storage/vector.go b/storage/vector.go index 23a72e9..1579686 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -4,7 +4,6 @@ import ( "elefant/models" "errors" "fmt" - "log" "unsafe" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" @@ -12,7 +11,7 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q []float32) (*models.VectorRow, error) + SearchClosest(q []float32) ([]models.VectorRow, error) } var ( @@ -79,7 +78,11 @@ func decodeUnsafe(bs []byte) []float32 { return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) } -func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { +func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } stmt, _, err := p.s3Conn.Prepare( fmt.Sprintf(`SELECT id, @@ -91,35 +94,35 @@ func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { WHERE embedding MATCH ? ORDER BY distance LIMIT 4 - `, vecTableName)) + `, tableName)) if err != nil { - log.Fatal(err) + return nil, err } query, err := sqlite_vec.SerializeFloat32(q[:]) if err != nil { - log.Fatal(err) + return nil, err } if err := stmt.BindBlob(1, query); err != nil { p.logger.Error("failed to bind", "error", err) return nil, err } - resp := make([]models.VectorRow, 4) - i := 0 + resp := []models.VectorRow{} for stmt.Step() { - resp[i].ID = uint32(stmt.ColumnInt64(0)) - resp[i].Distance = float32(stmt.ColumnFloat(1)) + res := models.VectorRow{} + res.ID = uint32(stmt.ColumnInt64(0)) + res.Distance = float32(stmt.ColumnFloat(1)) emb := stmt.ColumnRawText(2) - resp[i].Embeddings = decodeUnsafe(emb) - resp[i].Slug = stmt.ColumnText(3) - resp[i].RawText = stmt.ColumnText(4) - i++ + res.Embeddings = decodeUnsafe(emb) + res.Slug = stmt.ColumnText(3) + res.RawText = stmt.ColumnText(4) + resp = append(resp, res) } if err := stmt.Err(); err != nil { - log.Fatal(err) + return nil, err } err = stmt.Close() if err != nil { - log.Fatal(err) + return nil, err } - return nil, nil + return resp, nil } @@ -5,6 +5,7 @@ import ( "elefant/pngmeta" "fmt" "os" + "path" "strconv" "strings" "time" @@ -169,8 +170,9 @@ func makeRAGTable(fileList []string) *tview.Table { // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text) switch tc.Text { case "load": - if err := loadRAG(fpath); err != nil { - logger.Error("failed to read history file", "chat", fpath) + fpath = path.Join(cfg.RAGDir, fpath) + if err := ragger.LoadRAG(fpath); err != nil { + logger.Error("failed to embed file", "chat", fpath, "error", err) pages.RemovePage(RAGPage) return } @@ -228,7 +230,7 @@ func colorText() { } func updateStatusLine() { - position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled)) + position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.EmbedURL)) } func initSysCards() ([]string, error) { |