diff options
Diffstat (limited to 'storage')
-rw-r--r-- | storage/vector.go | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/storage/vector.go b/storage/vector.go index 23a72e9..1579686 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -4,7 +4,6 @@ import ( "elefant/models" "errors" "fmt" - "log" "unsafe" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" @@ -12,7 +11,7 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q []float32) (*models.VectorRow, error) + SearchClosest(q []float32) ([]models.VectorRow, error) } var ( @@ -79,7 +78,11 @@ func decodeUnsafe(bs []byte) []float32 { return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) } -func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { +func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } stmt, _, err := p.s3Conn.Prepare( fmt.Sprintf(`SELECT id, @@ -91,35 +94,35 @@ func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { WHERE embedding MATCH ? ORDER BY distance LIMIT 4 - `, vecTableName)) + `, tableName)) if err != nil { - log.Fatal(err) + return nil, err } query, err := sqlite_vec.SerializeFloat32(q[:]) if err != nil { - log.Fatal(err) + return nil, err } if err := stmt.BindBlob(1, query); err != nil { p.logger.Error("failed to bind", "error", err) return nil, err } - resp := make([]models.VectorRow, 4) - i := 0 + resp := []models.VectorRow{} for stmt.Step() { - resp[i].ID = uint32(stmt.ColumnInt64(0)) - resp[i].Distance = float32(stmt.ColumnFloat(1)) + res := models.VectorRow{} + res.ID = uint32(stmt.ColumnInt64(0)) + res.Distance = float32(stmt.ColumnFloat(1)) emb := stmt.ColumnRawText(2) - resp[i].Embeddings = decodeUnsafe(emb) - resp[i].Slug = stmt.ColumnText(3) - resp[i].RawText = stmt.ColumnText(4) - i++ + res.Embeddings = decodeUnsafe(emb) + res.Slug = stmt.ColumnText(3) + res.RawText = stmt.ColumnText(4) + resp = append(resp, res) } if err := stmt.Err(); err != nil { - log.Fatal(err) + return nil, err } err = stmt.Close() if err != nil { - log.Fatal(err) + return nil, err } - return nil, nil + return resp, nil } |