diff options
Diffstat (limited to 'rag/storage.go')
| -rw-r--r-- | rag/storage.go | 230 |
1 files changed, 208 insertions, 22 deletions
diff --git a/rag/storage.go b/rag/storage.go index 52f6859..a53f767 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -1,6 +1,7 @@ package rag import ( + "database/sql" "encoding/binary" "fmt" "gf-lt/models" @@ -62,6 +63,17 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { if err != nil { return err } + embeddingSize := len(row.Embeddings) + // Start transaction + tx, err := vs.sqlxDB.Beginx() + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() // Serialize the embeddings to binary serializedEmbeddings := SerializeVector(row.Embeddings) @@ -69,10 +81,102 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName, ) - if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { + if _, err := tx.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) return err } + // Insert into FTS table + ftsQuery := `INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES (?, ?, ?, ?)` + if _, err := tx.Exec(ftsQuery, row.Slug, row.RawText, row.FileName, embeddingSize); err != nil { + vs.logger.Error("failed to write to FTS table", "error", err, "slug", row.Slug) + return err + } + err = tx.Commit() + if err != nil { + vs.logger.Error("failed to commit transaction", "error", err) + return err + } + return nil +} + +// WriteVectors stores multiple embedding vectors in a single transaction +func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { + if len(rows) == 0 { + return nil + } + // SQLite has limit of 999 parameters per statement, each row uses 4 parameters + const maxBatchSize = 200 // 200 * 4 = 800 < 999 + if len(rows) > maxBatchSize { + // Process in chunks + for i := 0; i < len(rows); i += maxBatchSize { + end := i + maxBatchSize + if end > len(rows) { + end = len(rows) + } + if err := vs.WriteVectors(rows[i:end]); err != nil { + return err + } + } + return nil + } + // All rows should have same embedding size (same model) + firstSize := len(rows[0].Embeddings) + for i, row := range rows { + if len(row.Embeddings) != firstSize { + return fmt.Errorf("embedding size mismatch: row %d has size %d, expected %d", i, len(row.Embeddings), firstSize) + } + } + tableName, err := vs.getTableName(rows[0].Embeddings) + if err != nil { + return err + } + // Start transaction + tx, err := vs.sqlxDB.Beginx() + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + // Build batch insert for embeddings table + embeddingPlaceholders := make([]string, 0, len(rows)) + embeddingArgs := make([]any, 0, len(rows)*4) + for _, row := range rows { + embeddingPlaceholders = append(embeddingPlaceholders, "(?, ?, ?, ?)") + embeddingArgs = append(embeddingArgs, SerializeVector(row.Embeddings), row.Slug, row.RawText, row.FileName) + } + embeddingQuery := fmt.Sprintf( + "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES %s", + tableName, + strings.Join(embeddingPlaceholders, ", "), + ) + if _, err := tx.Exec(embeddingQuery, embeddingArgs...); err != nil { + vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows)) + return err + } + // Build batch insert for FTS table + ftsPlaceholders := make([]string, 0, len(rows)) + ftsArgs := make([]any, 0, len(rows)*4) + embeddingSize := len(rows[0].Embeddings) + for _, row := range rows { + ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)") + ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize) + } + ftsQuery := "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES " + + strings.Join(ftsPlaceholders, ", ") + if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil { + vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows)) + return err + } + err = tx.Commit() + if err != nil { + vs.logger.Error("failed to commit transaction", "error", err) + return err + } + vs.logger.Debug("wrote vectors batch", "batch_size", len(rows)) return nil } @@ -98,30 +202,25 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) { } // SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation -func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) { +func (vs *VectorStorage) SearchClosest(query []float32, limit int) ([]models.VectorRow, error) { + if limit <= 0 { + limit = 10 + } tableName, err := vs.getTableName(query) if err != nil { return nil, err } - - // For better performance, instead of loading all vectors at once, - // we'll implement batching and potentially add L2 distance-based pre-filtering - // since cosine similarity is related to L2 distance for normalized vectors - querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName rows, err := vs.sqlxDB.Query(querySQL) if err != nil { return nil, err } defer rows.Close() - - // Use a min-heap or simple slice to keep track of top 3 closest vectors type SearchResult struct { vector models.VectorRow distance float32 } var topResults []SearchResult - // Process vectors one by one to avoid loading everything into memory for rows.Next() { var ( embeddingsBlob []byte @@ -132,12 +231,9 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err vs.logger.Error("failed to scan row", "error", err) continue } - storedEmbeddings := DeserializeVector(embeddingsBlob) - - // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) similarity := cosineSimilarity(query, storedEmbeddings) - distance := 1 - similarity // Convert to distance where 0 is most similar + distance := 1 - similarity result := SearchResult{ vector: models.VectorRow{ @@ -149,20 +245,14 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err distance: distance, } - // Add to top results and maintain only top 3 topResults = append(topResults, result) - - // Sort and keep only top 3 sort.Slice(topResults, func(i, j int) bool { return topResults[i].distance < topResults[j].distance }) - - if len(topResults) > 3 { - topResults = topResults[:3] // Keep only closest 3 + if len(topResults) > limit { + topResults = topResults[:limit] } } - - // Convert back to VectorRow slice results := make([]models.VectorRow, 0, len(topResults)) for _, result := range topResults { result.vector.Distance = result.distance @@ -171,6 +261,98 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err return results, nil } +// GetVectorBySlug retrieves a vector row by its slug +func (vs *VectorStorage) GetVectorBySlug(slug string) (*models.VectorRow, error) { + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := fmt.Sprintf("SELECT embeddings, slug, raw_text, filename FROM %s WHERE slug = ?", table) + row := vs.sqlxDB.QueryRow(query, slug) + var ( + embeddingsBlob []byte + retrievedSlug, rawText, fileName string + ) + if err := row.Scan(&embeddingsBlob, &retrievedSlug, &rawText, &fileName); err != nil { + // No row in this table, continue to next size + continue + } + storedEmbeddings := DeserializeVector(embeddingsBlob) + return &models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: retrievedSlug, + RawText: rawText, + FileName: fileName, + }, nil + } + return nil, fmt.Errorf("vector with slug %s not found", slug) +} + +// SearchKeyword performs full-text search using FTS5 +func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.VectorRow, error) { + // Use FTS5 bm25 ranking. bm25 returns negative values where more negative is better. + // We'll order by bm25 (ascending) and limit. + ftsQuery := `SELECT slug, raw_text, filename, bm25(fts_embeddings) as score + FROM fts_embeddings + WHERE fts_embeddings MATCH ? + ORDER BY score + LIMIT ?` + + // Try original query first + rows, err := vs.sqlxDB.Query(ftsQuery, query, limit) + if err != nil { + return nil, fmt.Errorf("FTS search failed: %w", err) + } + results, err := vs.scanRows(rows) + rows.Close() + if err != nil { + return nil, err + } + + // If no results and query contains multiple terms, try OR fallback + if len(results) == 0 && strings.Contains(query, " ") && !strings.Contains(strings.ToUpper(query), " OR ") { + // Build OR query: term1 OR term2 OR term3 + terms := strings.Fields(query) + if len(terms) > 1 { + orQuery := strings.Join(terms, " OR ") + rows, err := vs.sqlxDB.Query(ftsQuery, orQuery, limit) + if err != nil { + // Return original empty results rather than error + return results, nil + } + orResults, err := vs.scanRows(rows) + rows.Close() + if err == nil { + results = orResults + } + } + } + return results, nil +} + +// scanRows converts SQL rows to VectorRow slice +func (vs *VectorStorage) scanRows(rows *sql.Rows) ([]models.VectorRow, error) { + var results []models.VectorRow + for rows.Next() { + var slug, rawText, fileName string + var score float64 + if err := rows.Scan(&slug, &rawText, &fileName, &score); err != nil { + vs.logger.Error("failed to scan FTS row", "error", err) + continue + } + // Convert BM25 score to distance-like metric (lower is better) + // BM25 is negative, more negative is better. Keep as negative. + distance := float32(score) // Keep negative, more negative is better + // No clamping needed; negative distances are fine + results = append(results, models.VectorRow{ + Slug: slug, + RawText: rawText, + FileName: fileName, + Distance: distance, + }) + } + return results, nil +} + // ListFiles returns a list of all loaded files func (vs *VectorStorage) ListFiles() ([]string, error) { fileLists := make([][]string, 0) @@ -215,6 +397,10 @@ func (vs *VectorStorage) ListFiles() ([]string, error) { // RemoveEmbByFileName removes all embeddings associated with a specific filename func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { var errors []string + // Delete from FTS table first + if _, err := vs.sqlxDB.Exec("DELETE FROM fts_embeddings WHERE filename = ?", filename); err != nil { + errors = append(errors, err.Error()) + } embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} for _, size := range embeddingSizes { table := fmt.Sprintf("embeddings_%d", size) |
