diff options
Diffstat (limited to 'storage/vector.go')
-rw-r--r-- | storage/vector.go | 50 |
1 files changed, 44 insertions, 6 deletions
diff --git a/storage/vector.go b/storage/vector.go index fe479d8..5e9069c 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -12,17 +12,19 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error SearchClosest(q []float32) ([]models.VectorRow, error) + ListFiles() ([]string, error) + RemoveEmbByFileName(filename string) error } var ( - vecTableName = "embeddings" - vecTableName384 = "embeddings_384" + vecTableName5120 = "embeddings_5120" + vecTableName384 = "embeddings_384" ) func fetchTableName(emb []float32) (string, error) { switch len(emb) { case 5120: - return vecTableName, nil + return vecTableName5120, nil case 384: return vecTableName384, nil default: @@ -36,7 +38,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { return err } stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName)) + fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)) if err != nil { p.logger.Error("failed to prep a stmt", "error", err) return err @@ -66,6 +68,10 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { p.logger.Error("failed to bind", "error", err) return err } + if err := stmt.BindText(4, row.FileName); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } err = stmt.Exec() if err != nil { return err @@ -87,11 +93,12 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { distance, embedding, slug, - raw_text + raw_text, + filename FROM %s WHERE embedding MATCH ? ORDER BY distance - LIMIT 4 + LIMIT 3 `, tableName)) if err != nil { return nil, err @@ -112,6 +119,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { res.Embeddings = decodeUnsafe(emb) res.Slug = stmt.ColumnText(2) res.RawText = stmt.ColumnText(3) + res.FileName = stmt.ColumnText(4) resp = append(resp, res) } if err := stmt.Err(); err != nil { @@ -123,3 +131,33 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { } return resp, nil } + +func (p ProviderSQL) ListFiles() ([]string, error) { + q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) + stmt, _, err := p.s3Conn.Prepare(q) + if err != nil { + return nil, err + } + defer stmt.Close() + resp := []string{} + for stmt.Step() { + resp = append(resp, stmt.ColumnText(0)) + } + if err := stmt.Err(); err != nil { + return nil, err + } + return resp, nil +} + +func (p ProviderSQL) RemoveEmbByFileName(filename string) error { + q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) + stmt, _, err := p.s3Conn.Prepare(q) + if err != nil { + return err + } + defer stmt.Close() + if err := stmt.BindText(1, filename); err != nil { + return err + } + return stmt.Exec() +} |