diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-03-06 18:58:23 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-03-06 18:58:23 +0300 |
| commit | 17b68bc21fae99c17ec48e046e67a643b9d159bb (patch) | |
| tree | 00b2da2f55876e720aecccc10dbc59232da768db /rag/storage.go | |
| parent | edfd43c52ae3f2fa16f6ab5d64cb48218a2c0a64 (diff) | |
Enha (rag): async writes
Diffstat (limited to 'rag/storage.go')
| -rw-r--r-- | rag/storage.go | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/rag/storage.go b/rag/storage.go index 110cea2..1e6b013 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -102,6 +102,92 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { return nil } +// WriteVectors stores multiple embedding vectors in a single transaction +func (vs *VectorStorage) WriteVectors(rows []*models.VectorRow) error { + if len(rows) == 0 { + return nil + } + // SQLite has limit of 999 parameters per statement, each row uses 4 parameters + const maxBatchSize = 200 // 200 * 4 = 800 < 999 + if len(rows) > maxBatchSize { + // Process in chunks + for i := 0; i < len(rows); i += maxBatchSize { + end := i + maxBatchSize + if end > len(rows) { + end = len(rows) + } + if err := vs.WriteVectors(rows[i:end]); err != nil { + return err + } + } + return nil + } + // All rows should have same embedding size (same model) + firstSize := len(rows[0].Embeddings) + for i, row := range rows { + if len(row.Embeddings) != firstSize { + return fmt.Errorf("embedding size mismatch: row %d has size %d, expected %d", i, len(row.Embeddings), firstSize) + } + } + tableName, err := vs.getTableName(rows[0].Embeddings) + if err != nil { + return err + } + + // Start transaction + tx, err := vs.sqlxDB.Beginx() + if err != nil { + return err + } + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // Build batch insert for embeddings table + embeddingPlaceholders := make([]string, 0, len(rows)) + embeddingArgs := make([]any, 0, len(rows)*4) + for _, row := range rows { + embeddingPlaceholders = append(embeddingPlaceholders, "(?, ?, ?, ?)") + embeddingArgs = append(embeddingArgs, SerializeVector(row.Embeddings), row.Slug, row.RawText, row.FileName) + } + embeddingQuery := fmt.Sprintf( + "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES %s", + tableName, + strings.Join(embeddingPlaceholders, ", "), + ) + if _, err := tx.Exec(embeddingQuery, embeddingArgs...); err != nil { + vs.logger.Error("failed to write vectors batch", "error", err, "batch_size", len(rows)) + return err + } + + // Build batch insert for FTS table + ftsPlaceholders := make([]string, 0, len(rows)) + ftsArgs := make([]any, 0, len(rows)*4) + embeddingSize := len(rows[0].Embeddings) + for _, row := range rows { + ftsPlaceholders = append(ftsPlaceholders, "(?, ?, ?, ?)") + ftsArgs = append(ftsArgs, row.Slug, row.RawText, row.FileName, embeddingSize) + } + ftsQuery := fmt.Sprintf( + "INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) VALUES %s", + strings.Join(ftsPlaceholders, ", "), + ) + if _, err := tx.Exec(ftsQuery, ftsArgs...); err != nil { + vs.logger.Error("failed to write FTS batch", "error", err, "batch_size", len(rows)) + return err + } + + err = tx.Commit() + if err != nil { + vs.logger.Error("failed to commit transaction", "error", err) + return err + } + vs.logger.Debug("wrote vectors batch", "batch_size", len(rows)) + return nil +} + // getTableName determines which table to use based on embedding size func (vs *VectorStorage) getTableName(emb []float32) (string, error) { size := len(emb) |
