diff options
Diffstat (limited to 'rag_new/rag.go')
-rw-r--r-- | rag_new/rag.go | 260 |
1 files changed, 260 insertions, 0 deletions
diff --git a/rag_new/rag.go b/rag_new/rag.go new file mode 100644 index 0000000..d012087 --- /dev/null +++ b/rag_new/rag.go @@ -0,0 +1,260 @@ +package rag_new + +import ( + "gf-lt/config" + "gf-lt/models" + "gf-lt/storage" + "fmt" + "log/slog" + "os" + "path" + "strings" + "sync" + + "github.com/neurosnap/sentences/english" +) + +var ( + // Status messages for TUI integration + LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking + FinishedRAGStatus = "finished loading RAG file; press Enter" + LoadedFileRAGStatus = "loaded file" + ErrRAGStatus = "some error occurred; failed to transfer data to vector db" +) + +type RAG struct { + logger *slog.Logger + store storage.FullRepo + cfg *config.Config + embedder Embedder + storage *VectorStorage +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { + // Initialize with API embedder by default, could be configurable later + embedder := NewAPIEmbedder(l, cfg) + + rag := &RAG{ + logger: l, + store: s, + cfg: cfg, + embedder: embedder, + storage: NewVectorStorage(l, s), + } + + // Create the necessary tables + if err := rag.storage.CreateTables(); err != nil { + l.Error("failed to create vector tables", "error", err) + } + + return rag +} + +func wordCounter(sentence string) int { + return len(strings.Split(strings.TrimSpace(sentence), " ")) +} + +func (r *RAG) LoadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + r.logger.Debug("rag: loaded file", "fp", fpath) + LongJobStatusCh <- LoadedFileRAGStatus + + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + for i, s := range sentences { + sents[i] = s.Text + } + + // Group sentences into paragraphs based on word limit + paragraphs := []string{} + par := strings.Builder{} + for i := 0; i < len(sents); i++ { + // Only add sentences that aren't empty + if strings.TrimSpace(sents[i]) != "" { + if par.Len() > 0 { + par.WriteString(" ") // Add space between sentences + } + par.WriteString(sents[i]) + } + + if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + par.Reset() + } + } + + // Handle any remaining content in the paragraph buffer + if par.Len() > 0 { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + } + + // Adjust batch size if needed + if len(paragraphs) < int(r.cfg.RAGBatchSize) && len(paragraphs) > 0 { + r.cfg.RAGBatchSize = len(paragraphs) + } + + if len(paragraphs) == 0 { + return fmt.Errorf("no valid paragraphs found in file") + } + + var ( + maxChSize = 100 + left = 0 + right = r.cfg.RAGBatchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + doneCh = make(chan bool, 1) + lock = new(sync.Mutex) + ) + + defer close(doneCh) + defer close(errCh) + defer close(batchCh) + + // Fill input channel with batches + ctn := 0 + totalParagraphs := len(paragraphs) + for { + if int(right) > totalParagraphs { + batchCh <- map[int][]string{left: paragraphs[left:]} + break + } + batchCh <- map[int][]string{left: paragraphs[left:right]} + left, right = right, right+r.cfg.RAGBatchSize + ctn++ + } + + finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) + r.logger.Debug(finishedBatchesMsg) + LongJobStatusCh <- finishedBatchesMsg + + // Start worker goroutines + for w := 0; w < int(r.cfg.RAGWorkers); w++ { + go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) + } + + // Wait for embedding to be done + <-doneCh + + // Write vectors to storage + return r.writeVectors(vectorCh) +} + +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { + for { + for batch := range vectorCh { + for _, vector := range batch { + if err := r.storage.WriteVector(&vector); err != nil { + r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) + LongJobStatusCh <- ErrRAGStatus + continue // a duplicate is not critical + } + } + r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) + if len(vectorCh) == 0 { + r.logger.Debug("finished writing vectors") + LongJobStatusCh <- FinishedRAGStatus + return nil + } + } + } +} + +func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { + defer func() { + if len(doneCh) == 0 { + doneCh <- true + } + }() + + for { + lock.Lock() + if len(inputCh) == 0 { + lock.Unlock() + return + } + + select { + case linesMap := <-inputCh: + for leftI, lines := range linesMap { + if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { + r.logger.Error("error fetching embeddings", "error", err, "worker", id) + lock.Unlock() + return + } + } + lock.Unlock() + case err := <-errCh: + r.logger.Error("got an error from error channel", "error", err) + lock.Unlock() + return + default: + lock.Unlock() + } + + r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) + LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) + } +} + +func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { + embeddings, err := r.embedder.Embed(lines) + if err != nil { + r.logger.Error("failed to embed lines", "err", err.Error()) + errCh <- err + return err + } + + if len(embeddings) == 0 { + err := fmt.Errorf("no embeddings returned") + r.logger.Error("empty embeddings") + errCh <- err + return err + } + + vectors := make([]models.VectorRow, len(embeddings)) + for i, emb := range embeddings { + vector := models.VectorRow{ + Embeddings: emb, + RawText: lines[i], + Slug: fmt.Sprintf("%s_%d", slug, i), + FileName: filename, + } + vectors[i] = vector + } + + vectorCh <- vectors + return nil +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { + return r.embedder.EmbedSingle(line) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + return r.storage.SearchClosest(emb.Embedding) +} + +func (r *RAG) ListLoaded() ([]string, error) { + return r.storage.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { + return r.storage.RemoveEmbByFileName(filename) +}
\ No newline at end of file |