summaryrefslogtreecommitdiff
path: root/storage/vector.go
diff options
context:
space:
mode:
Diffstat (limited to 'storage/vector.go')
-rw-r--r--storage/vector.go62
1 files changed, 20 insertions, 42 deletions
diff --git a/storage/vector.go b/storage/vector.go
index b3e5654..6958634 100644
--- a/storage/vector.go
+++ b/storage/vector.go
@@ -66,35 +66,13 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
if err != nil {
return err
}
- stmt, _, err := p.s3Conn.Prepare(
- fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName))
- if err != nil {
- p.logger.Error("failed to prep a stmt", "error", err)
- return err
- }
- defer stmt.Close()
+
serializedEmbeddings := SerializeVector(row.Embeddings)
- if err := stmt.BindBlob(1, serializedEmbeddings); err != nil {
- p.logger.Error("failed to bind", "error", err)
- return err
- }
- if err := stmt.BindText(2, row.Slug); err != nil {
- p.logger.Error("failed to bind", "error", err)
- return err
- }
- if err := stmt.BindText(3, row.RawText); err != nil {
- p.logger.Error("failed to bind", "error", err)
- return err
- }
- if err := stmt.BindText(4, row.FileName); err != nil {
- p.logger.Error("failed to bind", "error", err)
- return err
- }
- err = stmt.Exec()
- if err != nil {
- return err
- }
- return nil
+
+ query := fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
+ _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
+
+ return err
}
func decodeUnsafe(bs []byte) []float32 {
@@ -110,30 +88,30 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
func (p ProviderSQL) ListFiles() ([]string, error) {
q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384)
- stmt, _, err := p.s3Conn.Prepare(q)
+ rows, err := p.db.Query(q)
if err != nil {
return nil, err
}
- defer stmt.Close()
+ defer rows.Close()
+
resp := []string{}
- for stmt.Step() {
- resp = append(resp, stmt.ColumnText(0))
+ for rows.Next() {
+ var filename string
+ if err := rows.Scan(&filename); err != nil {
+ return nil, err
+ }
+ resp = append(resp, filename)
}
- if err := stmt.Err(); err != nil {
+
+ if err := rows.Err(); err != nil {
return nil, err
}
+
return resp, nil
}
func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384)
- stmt, _, err := p.s3Conn.Prepare(q)
- if err != nil {
- return err
- }
- defer stmt.Close()
- if err := stmt.BindText(1, filename); err != nil {
- return err
- }
- return stmt.Exec()
+ _, err := p.db.Exec(q, filename)
+ return err
}