diff options
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/embedder.go | 145 | ||||
| -rw-r--r-- | rag/rag.go | 334 | ||||
| -rw-r--r-- | rag/storage.go | 278 |
3 files changed, 757 insertions, 0 deletions
diff --git a/rag/embedder.go b/rag/embedder.go new file mode 100644 index 0000000..bed1b41 --- /dev/null +++ b/rag/embedder.go @@ -0,0 +1,145 @@ +package rag + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "gf-lt/config" + "gf-lt/models" + "log/slog" + "net/http" +) + +// Embedder defines the interface for embedding text +type Embedder interface { + Embed(text string) ([]float32, error) + EmbedSlice(lines []string) ([][]float32, error) +} + +// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.) +type APIEmbedder struct { + logger *slog.Logger + client *http.Client + cfg *config.Config +} + +func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { + return &APIEmbedder{ + logger: l, + client: &http.Client{}, + cfg: cfg, + } +} + +func (a *APIEmbedder) Embed(text string) ([]float32, error) { + payload, err := json.Marshal( + map[string]any{"input": text, "encoding_format": "float"}, + ) + if err != nil { + a.logger.Error("failed to marshal payload", "err", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + a.logger.Error("failed to create new req", "err", err.Error()) + return nil, err + } + if a.cfg.HFToken != "" { + req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) + } + resp, err := a.client.Do(req) + if err != nil { + a.logger.Error("failed to embed text", "err", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) + a.logger.Error(err.Error()) + return nil, err + } + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + a.logger.Error("failed to decode embedding response", "err", err.Error()) + return nil, err + } + if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 { + err = errors.New("empty embedding response") + a.logger.Error("empty embedding response") + return nil, err + } + return embResp.Data[0].Embedding, nil +} + +func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { + payload, err := json.Marshal( + map[string]any{"input": lines, "encoding_format": "float"}, + ) + if err != nil { + a.logger.Error("failed to marshal payload", "err", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + a.logger.Error("failed to create new req", "err", err.Error()) + return nil, err + } + if a.cfg.HFToken != "" { + req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) + } + resp, err := a.client.Do(req) + if err != nil { + a.logger.Error("failed to embed text", "err", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) + a.logger.Error(err.Error()) + return nil, err + } + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + a.logger.Error("failed to decode embedding response", "err", err.Error()) + return nil, err + } + if len(embResp.Data) == 0 { + err = errors.New("empty embedding response") + a.logger.Error("empty embedding response") + return nil, err + } + + // Collect all embeddings from the response + embeddings := make([][]float32, len(embResp.Data)) + for i := range embResp.Data { + if len(embResp.Data[i].Embedding) == 0 { + err = fmt.Errorf("empty embedding at index %d", i) + a.logger.Error("empty embedding", "index", i) + return nil, err + } + embeddings[i] = embResp.Data[i].Embedding + } + + // Sort embeddings by index to match the order of input lines + // API responses may not be in order + for _, data := range embResp.Data { + if data.Index >= len(embeddings) || data.Index < 0 { + err = fmt.Errorf("invalid embedding index %d", data.Index) + a.logger.Error("invalid embedding index", "index", data.Index) + return nil, err + } + embeddings[data.Index] = data.Embedding + } + + return embeddings, nil +} + +// TODO: ONNXEmbedder implementation would go here +// This would require: +// 1. Loading ONNX models locally +// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) +// 3. Converting text to embeddings without external API calls +// +// For now, we'll focus on the API implementation which is already working in the current system, +// and can be extended later when we have ONNX runtime integration diff --git a/rag/rag.go b/rag/rag.go new file mode 100644 index 0000000..b29b9eb --- /dev/null +++ b/rag/rag.go @@ -0,0 +1,334 @@ +package rag + +import ( + "errors" + "fmt" + "gf-lt/config" + "gf-lt/models" + "gf-lt/storage" + "log/slog" + "os" + "path" + "strings" + "sync" + + "github.com/neurosnap/sentences/english" +) + +var ( + // Status messages for TUI integration + LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking + FinishedRAGStatus = "finished loading RAG file; press Enter" + LoadedFileRAGStatus = "loaded file" + ErrRAGStatus = "some error occurred; failed to transfer data to vector db" +) + + +type RAG struct { + logger *slog.Logger + store storage.FullRepo + cfg *config.Config + embedder Embedder + storage *VectorStorage +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { + // Initialize with API embedder by default, could be configurable later + embedder := NewAPIEmbedder(l, cfg) + + rag := &RAG{ + logger: l, + store: s, + cfg: cfg, + embedder: embedder, + storage: NewVectorStorage(l, s), + } + + // Note: Vector tables are created via database migrations, not at runtime + + return rag +} + +func wordCounter(sentence string) int { + return len(strings.Split(strings.TrimSpace(sentence), " ")) +} + +func (r *RAG) LoadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + r.logger.Debug("rag: loaded file", "fp", fpath) + select { + case LongJobStatusCh <- LoadedFileRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + for i, s := range sentences { + sents[i] = s.Text + } + + // Group sentences into paragraphs based on word limit + paragraphs := []string{} + par := strings.Builder{} + for i := 0; i < len(sents); i++ { + // Only add sentences that aren't empty + if strings.TrimSpace(sents[i]) != "" { + if par.Len() > 0 { + par.WriteString(" ") // Add space between sentences + } + par.WriteString(sents[i]) + } + + if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + par.Reset() + } + } + + // Handle any remaining content in the paragraph buffer + if par.Len() > 0 { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + } + + // Adjust batch size if needed + if len(paragraphs) < int(r.cfg.RAGBatchSize) && len(paragraphs) > 0 { + r.cfg.RAGBatchSize = len(paragraphs) + } + + if len(paragraphs) == 0 { + return errors.New("no valid paragraphs found in file") + } + + var ( + maxChSize = 100 + left = 0 + right = r.cfg.RAGBatchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + wg = new(sync.WaitGroup) + lock = new(sync.Mutex) + ) + + defer close(errCh) + defer close(batchCh) + + // Fill input channel with batches + ctn := 0 + totalParagraphs := len(paragraphs) + for { + if int(right) > totalParagraphs { + batchCh <- map[int][]string{left: paragraphs[left:]} + break + } + batchCh <- map[int][]string{left: paragraphs[left:right]} + left, right = right, right+r.cfg.RAGBatchSize + ctn++ + } + + finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) + r.logger.Debug(finishedBatchesMsg) + select { + case LongJobStatusCh <- finishedBatchesMsg: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg) + // Channel is full or closed, ignore the message to prevent panic + } + + // Start worker goroutines with WaitGroup + wg.Add(int(r.cfg.RAGWorkers)) + for w := 0; w < int(r.cfg.RAGWorkers); w++ { + go func(workerID int) { + defer wg.Done() + r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath)) + }(w) + } + + // Use a goroutine to close the batchCh when all batches are sent + go func() { + wg.Wait() + close(vectorCh) // Close vectorCh when all workers are done + }() + + // Check for errors from workers + // Use a non-blocking check for errors + select { + case err := <-errCh: + if err != nil { + r.logger.Error("error during RAG processing", "error", err) + return err + } + default: + // No immediate error, continue + } + + // Write vectors to storage - this will block until vectorCh is closed + return r.writeVectors(vectorCh) +} + +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { + for { + for batch := range vectorCh { + for _, vector := range batch { + if err := r.storage.WriteVector(&vector); err != nil { + r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) + select { + case LongJobStatusCh <- ErrRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + return err // Stop the entire RAG operation on DB error + } + } + r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) + if len(vectorCh) == 0 { + r.logger.Debug("finished writing vectors") + select { + case LongJobStatusCh <- FinishedRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + return nil + } + } + } +} + +func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, filename string) { + var err error + + defer func() { + // For errCh, make sure we only send if there's actually an error and the channel can accept it + if err != nil { + select { + case errCh <- err: + default: + // errCh might be full or closed, log but don't panic + r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err) + } + } + }() + + for { + lock.Lock() + if len(inputCh) == 0 { + lock.Unlock() + return + } + + select { + case linesMap := <-inputCh: + for leftI, lines := range linesMap { + if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { + r.logger.Error("error fetching embeddings", "error", err, "worker", id) + lock.Unlock() + return + } + } + lock.Unlock() + case err = <-errCh: + r.logger.Error("got an error from error channel", "error", err) + lock.Unlock() + return + default: + lock.Unlock() + } + + r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) + statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) + select { + case LongJobStatusCh <- statusMsg: + default: + r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) + // Channel is full or closed, ignore the message to prevent panic + } + } +} + +func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { + // Filter out empty lines before sending to embedder + nonEmptyLines := make([]string, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" { + nonEmptyLines = append(nonEmptyLines, trimmed) + } + } + + // Skip if no non-empty lines + if len(nonEmptyLines) == 0 { + // Send empty result but don't error + vectorCh <- []models.VectorRow{} + return nil + } + + embeddings, err := r.embedder.EmbedSlice(nonEmptyLines) + if err != nil { + r.logger.Error("failed to embed lines", "err", err.Error()) + errCh <- err + return err + } + + if len(embeddings) == 0 { + err := errors.New("no embeddings returned") + r.logger.Error("empty embeddings") + errCh <- err + return err + } + + if len(embeddings) != len(nonEmptyLines) { + err := errors.New("mismatch between number of lines and embeddings returned") + r.logger.Error("embedding mismatch", "err", err.Error()) + errCh <- err + return err + } + + // Create a VectorRow for each line in the batch + vectors := make([]models.VectorRow, len(nonEmptyLines)) + for i, line := range nonEmptyLines { + vectors[i] = models.VectorRow{ + Embeddings: embeddings[i], + RawText: line, + Slug: fmt.Sprintf("%s_%d", slug, i), + FileName: filename, + } + } + + vectorCh <- vectors + return nil +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { + return r.embedder.Embed(line) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + return r.storage.SearchClosest(emb.Embedding) +} + +func (r *RAG) ListLoaded() ([]string, error) { + return r.storage.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { + return r.storage.RemoveEmbByFileName(filename) +} diff --git a/rag/storage.go b/rag/storage.go new file mode 100644 index 0000000..782c504 --- /dev/null +++ b/rag/storage.go @@ -0,0 +1,278 @@ +package rag + +import ( + "encoding/binary" + "fmt" + "gf-lt/models" + "gf-lt/storage" + "log/slog" + "sort" + "strings" + "unsafe" + + "github.com/jmoiron/sqlx" +) + +// VectorStorage handles storing and retrieving vectors from SQLite +type VectorStorage struct { + logger *slog.Logger + sqlxDB *sqlx.DB + store storage.FullRepo +} + +func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorage { + return &VectorStorage{ + logger: logger, + sqlxDB: store.DB(), // Use the new DB() method + store: store, + } +} + + +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} + +// WriteVector stores an embedding vector in the database +func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { + tableName, err := vs.getTableName(row.Embeddings) + if err != nil { + return err + } + + // Serialize the embeddings to binary + serializedEmbeddings := SerializeVector(row.Embeddings) + + query := fmt.Sprintf( + "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", + tableName, + ) + + if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { + vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) + return err + } + + return nil +} + +// getTableName determines which table to use based on embedding size +func (vs *VectorStorage) getTableName(emb []float32) (string, error) { + size := len(emb) + + // Check if we support this embedding size + supportedSizes := map[int]bool{ + 384: true, + 768: true, + 1024: true, + 1536: true, + 2048: true, + 3072: true, + 4096: true, + 5120: true, + } + + if supportedSizes[size] { + return fmt.Sprintf("embeddings_%d", size), nil + } + + return "", fmt.Errorf("no table for embedding size of %d", size) +} + +// SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation +func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) { + tableName, err := vs.getTableName(query) + if err != nil { + return nil, err + } + + // For better performance, instead of loading all vectors at once, + // we'll implement batching and potentially add L2 distance-based pre-filtering + // since cosine similarity is related to L2 distance for normalized vectors + + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := vs.sqlxDB.Query(querySQL) + if err != nil { + return nil, err + } + defer rows.Close() + + // Use a min-heap or simple slice to keep track of top 3 closest vectors + type SearchResult struct { + vector models.VectorRow + distance float32 + } + + var topResults []SearchResult + + // Process vectors one by one to avoid loading everything into memory + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + vs.logger.Error("failed to scan row", "error", err) + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(query, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top 3 + topResults = append(topResults, result) + + // Sort and keep only top 3 + sort.Slice(topResults, func(i, j int) bool { + return topResults[i].distance < topResults[j].distance + }) + + if len(topResults) > 3 { + topResults = topResults[:3] // Keep only closest 3 + } + } + + // Convert back to VectorRow slice + results := make([]models.VectorRow, 0, len(topResults)) + for _, result := range topResults { + result.vector.Distance = result.distance + results = append(results, result.vector) + } + + return results, nil +} + +// ListFiles returns a list of all loaded files +func (vs *VectorStorage) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query all supported tables and combine results + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := "SELECT DISTINCT filename FROM " + table + rows, err := vs.sqlxDB.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + + return allFiles, nil +} + +// RemoveEmbByFileName removes all embeddings associated with a specific filename +func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { + var errors []string + + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := vs.sqlxDB.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; ")) + } + + return nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + |
