diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-11 17:29:21 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-11 17:29:21 +0300 |
commit | 85f96aa4013f9cedaf333c6d1027fe6d901cf561 (patch) | |
tree | 8aa04f6eb5d50c6a97e84632a3ce270444e84e27 | |
parent | f40d8afe08c524fc7f9df0dfa0802342af2d2c3d (diff) |
Feat: RAG file loading status textviewfeat/rag
-rw-r--r-- | rag/main.go | 52 | ||||
-rw-r--r-- | tables.go | 67 | ||||
-rw-r--r-- | tui.go | 22 |
3 files changed, 92 insertions, 49 deletions
diff --git a/rag/main.go b/rag/main.go index f1be167..d4065e5 100644 --- a/rag/main.go +++ b/rag/main.go @@ -18,6 +18,14 @@ import ( "github.com/neurosnap/sentences/english" ) +var ( + LongJobStatusCh = make(chan string, 1) + // messages + FinishedRAGStatus = "finished loading RAG file; press Enter" + LoadedFileRAGStatus = "loaded file" + ErrRAGStatus = "some error occured; failed to transfer data to vector db" +) + type RAG struct { logger *slog.Logger store storage.FullRepo @@ -42,6 +50,7 @@ func (r *RAG) LoadRAG(fpath string) error { return err } r.logger.Info("rag: loaded file", "fp", fpath) + LongJobStatusCh <- LoadedFileRAGStatus fileText := string(data) tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { @@ -49,7 +58,6 @@ func (r *RAG) LoadRAG(fpath string) error { } sentences := tokenizer.Tokenize(fileText) sents := make([]string, len(sentences)) - r.logger.Info("rag: sentences", "#", len(sents)) for i, s := range sentences { sents[i] = s.Text } @@ -60,16 +68,14 @@ func (r *RAG) LoadRAG(fpath string) error { batchSize = 100 maxChSize = 1000 // - // psize = 3 wordLimit = 80 - // - left = 0 - right = batchSize - batchCh = make(chan map[int][]string, maxChSize) - vectorCh = make(chan []models.VectorRow, maxChSize) - errCh = make(chan error, 1) - doneCh = make(chan bool, 1) - lock = new(sync.Mutex) + left = 0 + right = batchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + doneCh = make(chan bool, 1) + lock = new(sync.Mutex) ) defer close(doneCh) defer close(errCh) @@ -84,13 +90,6 @@ func (r *RAG) LoadRAG(fpath string) error { par.Reset() } } - // for i := 0; i < len(sents); i += psize { - // if len(sents) < i+psize { - // paragraphs = append(paragraphs, strings.Join(sents[i:], " ")) - // break - // } - // paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " ")) - // } if len(paragraphs) < batchSize { batchSize = len(paragraphs) } @@ -105,7 +104,9 @@ func (r *RAG) LoadRAG(fpath string) error { left, right = right, right+batchSize ctn++ } - r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents)) + finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents)) + r.logger.Info(finishedBatchesMsg) + LongJobStatusCh <- finishedBatchesMsg for w := 0; w < workers; w++ { go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) } @@ -121,6 +122,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { for _, vector := range batch { if err := r.store.WriteVector(&vector); err != nil { r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) + LongJobStatusCh <- ErrRAGStatus continue // a duplicate is not critical // return err } @@ -128,6 +130,7 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) if len(vectorCh) == 0 { r.logger.Info("finished writing vectors") + LongJobStatusCh <- FinishedRAGStatus defer close(vectorCh) return nil } @@ -150,10 +153,6 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[ case linesMap := <-inputCh: for leftI, v := range linesMap { r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) - // if leftI+200 >= limit { // last batch - // // doneCh <- true - // return - // } } lock.Unlock() case err := <-errCh: @@ -162,6 +161,7 @@ func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[ return } r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id) + LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) } } @@ -183,8 +183,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod } req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) - // nolint - // resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) errCh <- err @@ -194,9 +192,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod if resp.StatusCode != 200 { r.logger.Error("non 200 resp", "code", resp.StatusCode) return - // err = fmt.Errorf("non 200 resp; url: %s; code %d", r.cfg.EmbedURL, resp.StatusCode) - // errCh <- err - // return } emb := [][]float32{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { @@ -224,7 +219,6 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod } func (r *RAG) LineToVector(line string) ([]float32, error) { - // payload, err := json.Marshal(map[string]string{"content": line}) lines := []string{line} payload, err := json.Marshal( map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, @@ -241,7 +235,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { } req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) resp, err := http.DefaultClient.Do(req) - // resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload)) if err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) return nil, err @@ -252,7 +245,6 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { r.logger.Error(err.Error()) return nil, err } - // emb := models.EmbeddingResp{} emb := [][]float32{} if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { r.logger.Error("failed to embedd line", "err:", err.Error()) @@ -1,8 +1,12 @@ package main import ( + "fmt" "os" "path" + "time" + + "elefant/rag" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -85,11 +89,21 @@ func makeChatTable(chatList []string) *tview.Table { return chatActTable } -func makeRAGTable(fileList []string) *tview.Table { +// func makeRAGTable(fileList []string) *tview.Table { +func makeRAGTable(fileList []string) *tview.Flex { actions := []string{"load", "delete"} rows, cols := len(fileList), len(actions)+1 fileTable := tview.NewTable(). SetBorders(true) + longStatusView := tview.NewTextView() + longStatusView.SetText("status text") + longStatusView.SetBorder(true).SetTitle("status") + longStatusView.SetChangedFunc(func() { + app.Draw() + }) + ragflex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(longStatusView, 0, 10, false). + AddItem(fileTable, 0, 60, true) for r := 0; r < rows; r++ { for c := 0; c < cols; c++ { color := tcell.ColorWhite @@ -106,6 +120,33 @@ func makeRAGTable(fileList []string) *tview.Table { } } } + errCh := make(chan error, 1) + go func() { + defer pages.RemovePage(RAGPage) + for { + select { + case err := <-errCh: + if err == nil { + logger.Error("somehow got a nil err", "error", err) + continue + } + logger.Error("got an err in rag status", "error", err, "textview", longStatusView) + longStatusView.SetText(fmt.Sprintf("%v", err)) + close(errCh) + return + case status := <-rag.LongJobStatusCh: + logger.Info("reading status channel", "status", status) + longStatusView.SetText(status) + // fmt.Fprintln(longStatusView, status) + // app.Sync() + if status == rag.FinishedRAGStatus { + close(errCh) + time.Sleep(2 * time.Second) + return + } + } + } + }() fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { if key == tcell.KeyEsc || key == tcell.KeyF1 { pages.RemovePage(RAGPage) @@ -115,7 +156,7 @@ func makeRAGTable(fileList []string) *tview.Table { fileTable.SetSelectable(true, true) } }).SetSelectedFunc(func(row int, column int) { - defer pages.RemovePage(RAGPage) + // defer pages.RemovePage(RAGPage) tc := fileTable.GetCell(row, column) tc.SetTextColor(tcell.ColorRed) fileTable.SetSelectable(false, false) @@ -124,14 +165,18 @@ func makeRAGTable(fileList []string) *tview.Table { switch tc.Text { case "load": fpath = path.Join(cfg.RAGDir, fpath) - if err := ragger.LoadRAG(fpath); err != nil { - logger.Error("failed to embed file", "chat", fpath, "error", err) - // pages.RemovePage(RAGPage) - return - } - pages.RemovePage(RAGPage) - colorText() - updateStatusLine() + longStatusView.SetText("clicked load") + go func() { + if err := ragger.LoadRAG(fpath); err != nil { + logger.Error("failed to embed file", "chat", fpath, "error", err) + errCh <- err + // pages.RemovePage(RAGPage) + return + } + }() + // make new page and write status updates to it + // colorText() + // updateStatusLine() return case "delete": fpath = path.Join(cfg.RAGDir, fpath) @@ -148,7 +193,7 @@ func makeRAGTable(fileList []string) *tview.Table { return } }) - return fileTable + return ragflex } func makeLoadedRAGTable(fileList []string) *tview.Table { @@ -3,6 +3,7 @@ package main import ( "elefant/models" "elefant/pngmeta" + "elefant/rag" "fmt" "os" "strconv" @@ -26,14 +27,17 @@ var ( sysModal *tview.Modal indexPickWindow *tview.InputField renameWindow *tview.InputField + // + longJobStatusCh = make(chan string, 1) // pages - historyPage = "historyPage" - agentPage = "agentPage" - editMsgPage = "editMsgPage" - indexPage = "indexPage" - helpPage = "helpPage" - renamePage = "renamePage" - RAGPage = "RAGPage " + historyPage = "historyPage" + agentPage = "agentPage" + editMsgPage = "editMsgPage" + indexPage = "indexPage" + helpPage = "helpPage" + renamePage = "renamePage" + RAGPage = "RAGPage " + longStatusPage = "longStatusPage" // help text helpText = ` [yellow]Esc[white]: send msg @@ -155,6 +159,7 @@ func init() { position = tview.NewTextView(). SetDynamicColors(true). SetTextAlign(tview.AlignCenter) + flex = tview.NewFlex().SetDirection(tview.FlexRow). AddItem(textView, 0, 40, false). AddItem(textArea, 0, 10, true). @@ -466,6 +471,7 @@ func init() { } fileList = append(fileList, f.Name()) } + rag.LongJobStatusCh <- "first msg" chatRAGTable := makeRAGTable(fileList) pages.AddPage(RAGPage, chatRAGTable, true, true) return nil @@ -482,7 +488,7 @@ func init() { if strings.HasSuffix(prevText, nl) { nl = "" } - if msgText != "" { + if msgText != "" { // continue fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", nl, len(chatBody.Messages), cfg.UserRole, msgText) textArea.SetText("", true) |