diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | bot.go | 46 | ||||
-rw-r--r-- | config.example.toml | 1 | ||||
-rw-r--r-- | config/config.go | 7 | ||||
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | main.go | 2 | ||||
-rw-r--r-- | pngmeta/metareader.go | 3 | ||||
-rw-r--r-- | rag.go | 184 | ||||
-rw-r--r-- | storage/migrations/002_add_vector.up.sql | 7 | ||||
-rw-r--r-- | storage/vector.go | 64 | ||||
-rw-r--r-- | tui.go | 103 |
13 files changed, 404 insertions, 18 deletions
@@ -6,3 +6,4 @@ history/ *.db config.toml sysprompts/* +history_bak/ @@ -32,6 +32,7 @@ - it is a bit clumsy to mix chats in db and chars from the external files, maybe load external files in db on startup? - lets say we have two (or more) agents with the same name across multiple chats. These agents go and ask db for topics they memorised. Now they can access topics that aren't meant for them. (so memory should have an option: shareable; that indicates if that memory can be shared across chats); - delete chat option; +- server mode: no tui but api calls with the func calling, rag, other middleware; ### FIX: - bot responding (or hanging) blocks everything; + @@ -15,6 +15,7 @@ import ( "strings" "time" + "github.com/neurosnap/sentences/english" "github.com/rivo/tview" ) @@ -40,6 +41,16 @@ func formMsg(chatBody *models.ChatBody, newMsg, role string) io.Reader { if newMsg != "" { // otherwise let the bot continue newMsg := models.RoleMsg{Role: role, Content: newMsg} chatBody.Messages = append(chatBody.Messages, newMsg) + // if rag + if cfg.RAGEnabled { + ragResp, err := chatRagUse(newMsg.Content) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil + } + ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + } } data, err := json.Marshal(chatBody) if err != nil { @@ -107,6 +118,40 @@ func sendMsgToLLM(body io.Reader) { } } +func chatRagUse(qText string) (string, error) { + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return "", err + } + // TODO: this where llm should find the questions in text and ask them + questionsS := tokenizer.Tokenize(qText) + questions := make([]string, len(questionsS)) + for i, q := range questionsS { + questions[i] = q.Text + } + respVecs := []*models.VectorRow{} + for i, q := range questions { + emb, err := lineToVector(q) + if err != nil { + logger.Error("failed to get embs", "error", err, "index", i, "question", q) + continue + } + vec, err := searchEmb(emb) + if err != nil { + logger.Error("failed to get embs", "error", err, "index", i, "question", q) + continue + } + respVecs = append(respVecs, vec) + // logger.Info("returned vector from query search", "question", q, "vec", vec) + } + // get raw text + resps := []string{} + for _, rv := range respVecs { + resps = append(resps, rv.RawText) + } + return strings.Join(resps, "\n"), nil +} + func chatRound(userMsg, role string, tv *tview.TextView, regen bool) { botRespMode = true reader := formMsg(chatBody, userMsg, role) @@ -294,4 +339,5 @@ func init() { Stream: true, Messages: lastChat, } + // tempLoad() } diff --git a/config.example.toml b/config.example.toml index c52f267..d0a9841 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,4 +1,5 @@ APIURL = "http://localhost:8080/v1/chat/completions" +EmbedURL = "http://localhost:8080/v1/embeddings" ShowSys = true LogFile = "log.txt" UserRole = "user" diff --git a/config/config.go b/config/config.go index ce1b877..3c79564 100644 --- a/config/config.go +++ b/config/config.go @@ -8,7 +8,6 @@ import ( type Config struct { APIURL string `toml:"APIURL"` - EmbedURL string `toml:"EmbedURL"` ShowSys bool `toml:"ShowSys"` LogFile string `toml:"LogFile"` UserRole string `toml:"UserRole"` @@ -19,6 +18,11 @@ type Config struct { ToolIcon string `toml:"ToolIcon"` SysDir string `toml:"SysDir"` ChunkLimit uint32 `toml:"ChunkLimit"` + // embeddings + RAGEnabled bool `toml:"RAGEnabled"` + EmbedURL string `toml:"EmbedURL"` + HFToken string `toml:"HFToken"` + RAGDir string `toml:"RAGDir"` } func LoadConfigOrDefault(fn string) *Config { @@ -30,6 +34,7 @@ func LoadConfigOrDefault(fn string) *Config { if err != nil { fmt.Println("failed to read config from file, loading default") config.APIURL = "http://localhost:8080/v1/chat/completions" + config.RAGEnabled = false config.EmbedURL = "http://localhost:8080/v1/embiddings" config.ShowSys = true config.LogFile = "log.txt" @@ -9,6 +9,7 @@ require ( github.com/glebarez/go-sqlite v1.22.0 github.com/jmoiron/sqlx v1.4.0 github.com/ncruces/go-sqlite3 v0.21.3 + github.com/neurosnap/sentences v1.1.2 github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 ) @@ -34,6 +34,8 @@ github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1 github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= +github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw= +github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 h1:YIJ+B1hePP6AgynC5TcqpO0H9k3SSoZa2BGyL6vDUzM= @@ -10,7 +10,7 @@ var ( botRespMode = false editMode = false selectedIndex = int(-1) - indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s" + indexLine = "F12 to show keys help; bot resp mode: %v; char: %s; chat: %s; RAGEnabled: %v" focusSwitcher = map[tview.Primitive]tview.Primitive{} ) diff --git a/pngmeta/metareader.go b/pngmeta/metareader.go index 5542e86..44b8ca4 100644 --- a/pngmeta/metareader.go +++ b/pngmeta/metareader.go @@ -103,6 +103,9 @@ func ReadDirCards(dirname, uname string) ([]*models.CharCard, error) { } resp := []*models.CharCard{} for _, f := range files { + if f.IsDir() { + continue + } if strings.HasSuffix(f.Name(), ".png") { fpath := path.Join(dirname, f.Name()) cc, err := ReadCard(fpath, uname) @@ -2,27 +2,209 @@ package main import ( "bytes" + "context" "elefant/models" "encoding/json" + "errors" + "fmt" + "net/http" + "os" + + "github.com/neurosnap/sentences/english" ) +func loadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + for i, s := range sentences { + sents[i] = s.Text + } + var ( + // TODO: to config + workers = 5 + batchSize = 200 + // + left = 0 + right = batchSize + batchCh = make(chan map[int][]string) + vectorCh = make(chan []models.VectorRow) + errCh = make(chan error) + ) + if len(sents) < batchSize { + batchSize = len(sents) + } + // fill input channel + for { + if right > len(sents) { + batchCh <- map[int][]string{left: sents[left:]} + break + } + batchCh <- map[int][]string{left: sents[left:right]} + left, right = right, right+batchSize + } + // TODO: cancel complains, replace ctx with done chan + ctx, cancel := context.WithCancel(context.Background()) + for w := 0; w < workers; w++ { + go batchToVectorHFAsync(ctx, cancel, len(sents), batchCh, vectorCh, errCh) + } + // write to db + return writeVectors(vectorCh) +} + +func writeVectors(vectorCh <-chan []models.VectorRow) error { + for batch := range vectorCh { + for _, vector := range batch { + if err := store.WriteVector(&vector); err != nil { + return err + } + } + } + return nil +} + +func batchToVectorHFAsync(ctx context.Context, close context.CancelFunc, limit int, + inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error) { + for { + select { + case linesMap := <-inputCh: + for leftI, v := range linesMap { + FecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) + if leftI+200 >= limit { // last batch + close() + return + } + } + case <-ctx.Done(): + logger.Error("got ctx done") + return + case err := <-errCh: + logger.Error("got an error", "error", err) + close() + return + } + } +} + +func FecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + logger.Error("failed to marshal payload", "err:", err.Error()) + errCh <- err + return + } + req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) + resp, err := httpClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + errCh <- err + return + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + errCh <- err + return + } + if len(emb) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + errCh <- err + return + } + vectors := make([]models.VectorRow, len(emb)) + for i, e := range emb { + vector := models.VectorRow{ + Embeddings: e, + RawText: lines[i], + Slug: slug, + } + vectors[i] = vector + } + vectorCh <- vectors +} + +func batchToVectorHF(lines []string) ([][]float32, error) { + payload, err := json.Marshal( + map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, + ) + if err != nil { + logger.Error("failed to marshal payload", "err:", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", cfg.EmbedURL, bytes.NewReader(payload)) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.HFToken)) + resp, err := httpClient.Do(req) + // nolint + // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + return nil, err + } + emb := [][]float32{} + if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + logger.Error("failed to embedd line", "err:", err.Error()) + return nil, err + } + if len(emb) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } + return emb, nil +} + func lineToVector(line string) (*models.EmbeddingResp, error) { payload, err := json.Marshal(map[string]string{"content": line}) if err != nil { logger.Error("failed to marshal payload", "err:", err.Error()) return nil, err } + // nolint resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { logger.Error("failed to embedd line", "err:", err.Error()) return nil, err } defer resp.Body.Close() + if resp.StatusCode != 200 { + logger.Error("non 200 resp", "code", resp.StatusCode) + return nil, err + } emb := models.EmbeddingResp{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { logger.Error("failed to embedd line", "err:", err.Error()) return nil, err } + if len(emb.Embedding) == 0 { + logger.Error("empty emb") + err = errors.New("empty emb") + return nil, err + } return &emb, nil } @@ -36,5 +218,5 @@ func saveLine(topic, line string, emb *models.EmbeddingResp) error { } func searchEmb(emb *models.EmbeddingResp) (*models.VectorRow, error) { - return store.SearchClosest([5120]float32(emb.Embedding)) + return store.SearchClosest(emb.Embedding) } diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql index 4fcc9aa..f64aecb 100644 --- a/storage/migrations/002_add_vector.up.sql +++ b/storage/migrations/002_add_vector.up.sql @@ -4,3 +4,10 @@ CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0( slug TEXT NOT NULL, raw_text TEXT NOT NULL ); + +CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embedding FLOAT[384], + slug TEXT NOT NULL, + raw_text TEXT NOT NULL +); diff --git a/storage/vector.go b/storage/vector.go index bc46734..23a72e9 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -2,6 +2,7 @@ package storage import ( "elefant/models" + "errors" "fmt" "log" "unsafe" @@ -11,29 +12,61 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q [5120]float32) (*models.VectorRow, error) + SearchClosest(q []float32) (*models.VectorRow, error) } -var vecTableName = "embeddings" +var ( + vecTableName = "embeddings" + vecTableName384 = "embeddings_384" +) + +func fetchTableName(emb []float32) (string, error) { + switch len(emb) { + case 5120: + return vecTableName, nil + case 384: + return vecTableName384, nil + default: + return "", fmt.Errorf("no table for the size of %d", len(emb)) + } +} func (p ProviderSQL) WriteVector(row *models.VectorRow) error { + tableName, err := fetchTableName(row.Embeddings) + if err != nil { + return err + } stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", vecTableName)) - defer stmt.Close() + fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName)) if err != nil { p.logger.Error("failed to prep a stmt", "error", err) return err } + defer stmt.Close() v, err := sqlite_vec.SerializeFloat32(row.Embeddings) if err != nil { p.logger.Error("failed to serialize vector", "emb-len", len(row.Embeddings), "error", err) return err } - stmt.BindInt(1, int(row.ID)) - stmt.BindBlob(2, v) - stmt.BindText(3, row.Slug) - stmt.BindText(4, row.RawText) + if v == nil { + err = errors.New("empty vector after serialization") + p.logger.Error("empty vector after serialization", + "emb-len", len(row.Embeddings), "text", row.RawText, "error", err) + return err + } + if err := stmt.BindBlob(1, v); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(2, row.Slug); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(3, row.RawText); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } err = stmt.Exec() if err != nil { p.logger.Error("failed exec a stmt", "error", err) @@ -46,19 +79,19 @@ func decodeUnsafe(bs []byte) []float32 { return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) } -func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) { - stmt, _, err := p.s3Conn.Prepare(` - SELECT +func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { + stmt, _, err := p.s3Conn.Prepare( + fmt.Sprintf(`SELECT id, distance, embedding, slug, raw_text - FROM vec_items + FROM %s WHERE embedding MATCH ? ORDER BY distance LIMIT 4 - `) + `, vecTableName)) if err != nil { log.Fatal(err) } @@ -66,7 +99,10 @@ func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) { if err != nil { log.Fatal(err) } - stmt.BindBlob(1, query) + if err := stmt.BindBlob(1, query); err != nil { + p.logger.Error("failed to bind", "error", err) + return nil, err + } resp := make([]models.VectorRow, 4) i := 0 for stmt.Step() { @@ -4,6 +4,7 @@ import ( "elefant/models" "elefant/pngmeta" "fmt" + "os" "strconv" "strings" "time" @@ -32,6 +33,7 @@ var ( indexPage = "indexPage" helpPage = "helpPage" renamePage = "renamePage" + RAGPage = "RAGPage " // help text helpText = ` [yellow]Esc[white]: send msg @@ -130,6 +132,79 @@ func makeChatTable(chatList []string) *tview.Table { return chatActTable } +func makeRAGTable(fileList []string) *tview.Table { + actions := []string{"load", "rename", "delete"} + rows, cols := len(fileList), len(actions)+1 + chatActTable := tview.NewTable(). + SetBorders(true) + for r := 0; r < rows; r++ { + for c := 0; c < cols; c++ { + color := tcell.ColorWhite + if c < 1 { + chatActTable.SetCell(r, c, + tview.NewTableCell(fileList[r]). + SetTextColor(color). + SetAlign(tview.AlignCenter)) + } else { + chatActTable.SetCell(r, c, + tview.NewTableCell(actions[c-1]). + SetTextColor(color). + SetAlign(tview.AlignCenter)) + } + } + } + chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 { + pages.RemovePage(RAGPage) + return + } + if key == tcell.KeyEnter { + chatActTable.SetSelectable(true, true) + } + }).SetSelectedFunc(func(row int, column int) { + tc := chatActTable.GetCell(row, column) + tc.SetTextColor(tcell.ColorRed) + chatActTable.SetSelectable(false, false) + fpath := fileList[row] + // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text) + switch tc.Text { + case "load": + if err := loadRAG(fpath); err != nil { + logger.Error("failed to read history file", "chat", fpath) + pages.RemovePage(RAGPage) + return + } + pages.RemovePage(RAGPage) + colorText() + updateStatusLine() + return + case "rename": + pages.RemovePage(RAGPage) + pages.AddPage(renamePage, renameWindow, true, true) + return + case "delete": + sc, ok := chatMap[fpath] + if !ok { + // no chat found + pages.RemovePage(RAGPage) + return + } + if err := store.RemoveChat(sc.ID); err != nil { + logger.Error("failed to remove chat from db", "chat_id", sc.ID, "chat_name", sc.Name) + } + if err := notifyUser("chat deleted", fpath+" was deleted"); err != nil { + logger.Error("failed to send notification", "error", err) + } + pages.RemovePage(RAGPage) + return + default: + pages.RemovePage(RAGPage) + return + } + }) + return chatActTable +} + // // code block colors get interrupted by " & * // func codeBlockColor(text string) string { // fi := strings.Index(text, "```") @@ -153,7 +228,7 @@ func colorText() { } func updateStatusLine() { - position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName)) + position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled)) } func initSysCards() ([]string, error) { @@ -379,6 +454,7 @@ func init() { textView.SetText(chatToText(cfg.ShowSys)) colorText() textView.ScrollToEnd() + // init sysmap _, err := initSysCards() if err != nil { logger.Error("failed to init sys cards", "error", err) @@ -456,6 +532,12 @@ func init() { pages.AddPage(indexPage, indexPickWindow, true, true) return nil } + if event.Key() == tcell.KeyF11 { + // xor + cfg.RAGEnabled = cfg.RAGEnabled != true + updateStatusLine() + return nil + } if event.Key() == tcell.KeyF12 { // help window cheatsheet pages.AddPage(helpPage, helpView, true, true) @@ -496,6 +578,25 @@ func init() { updateStatusLine() return nil } + if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" { + // rag load + // menu of the text files from defined rag directory + files, err := os.ReadDir(cfg.RAGDir) + if err != nil { + logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err) + return nil + } + fileList := []string{} + for _, f := range files { + if f.IsDir() { + continue + } + fileList = append(fileList, f.Name()) + } + chatRAGTable := makeRAGTable(fileList) + pages.AddPage(RAGPage, chatRAGTable, true, true) + return nil + } // cannot send msg in editMode or botRespMode if event.Key() == tcell.KeyEscape && !editMode && !botRespMode { position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName)) |