diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-04 18:13:13 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-04 18:13:13 +0300 |
commit | 4736e43631ed21fd14741daa1dde746687d330fa (patch) | |
tree | c5dcee2930e8681397a369ae7869671bdcf568dc /storage | |
parent | 461d19aa2512fea7ac07e50c3178609850ef07c3 (diff) |
Feat (RAG): tying tui calls to rag funcs [WIP; skip-ci]
RAG itself is annoying to properly implement, plucking sentences with no
context is useless. Also it should not be a part of main package, same
for goes for tui. The number of global vars is absurd.
Diffstat (limited to 'storage')
-rw-r--r-- | storage/migrations/002_add_vector.up.sql | 7 | ||||
-rw-r--r-- | storage/vector.go | 64 |
2 files changed, 57 insertions, 14 deletions
diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql index 4fcc9aa..f64aecb 100644 --- a/storage/migrations/002_add_vector.up.sql +++ b/storage/migrations/002_add_vector.up.sql @@ -4,3 +4,10 @@ CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0( slug TEXT NOT NULL, raw_text TEXT NOT NULL ); + +CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embedding FLOAT[384], + slug TEXT NOT NULL, + raw_text TEXT NOT NULL +); diff --git a/storage/vector.go b/storage/vector.go index bc46734..23a72e9 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -2,6 +2,7 @@ package storage import ( "elefant/models" + "errors" "fmt" "log" "unsafe" @@ -11,29 +12,61 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q [5120]float32) (*models.VectorRow, error) + SearchClosest(q []float32) (*models.VectorRow, error) } -var vecTableName = "embeddings" +var ( + vecTableName = "embeddings" + vecTableName384 = "embeddings_384" +) + +func fetchTableName(emb []float32) (string, error) { + switch len(emb) { + case 5120: + return vecTableName, nil + case 384: + return vecTableName384, nil + default: + return "", fmt.Errorf("no table for the size of %d", len(emb)) + } +} func (p ProviderSQL) WriteVector(row *models.VectorRow) error { + tableName, err := fetchTableName(row.Embeddings) + if err != nil { + return err + } stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", vecTableName)) - defer stmt.Close() + fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName)) if err != nil { p.logger.Error("failed to prep a stmt", "error", err) return err } + defer stmt.Close() v, err := sqlite_vec.SerializeFloat32(row.Embeddings) if err != nil { p.logger.Error("failed to serialize vector", "emb-len", len(row.Embeddings), "error", err) return err } - stmt.BindInt(1, int(row.ID)) - stmt.BindBlob(2, v) - stmt.BindText(3, row.Slug) - stmt.BindText(4, row.RawText) + if v == nil { + err = errors.New("empty vector after serialization") + p.logger.Error("empty vector after serialization", + "emb-len", len(row.Embeddings), "text", row.RawText, "error", err) + return err + } + if err := stmt.BindBlob(1, v); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(2, row.Slug); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(3, row.RawText); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } err = stmt.Exec() if err != nil { p.logger.Error("failed exec a stmt", "error", err) @@ -46,19 +79,19 @@ func decodeUnsafe(bs []byte) []float32 { return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) } -func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) { - stmt, _, err := p.s3Conn.Prepare(` - SELECT +func (p ProviderSQL) SearchClosest(q []float32) (*models.VectorRow, error) { + stmt, _, err := p.s3Conn.Prepare( + fmt.Sprintf(`SELECT id, distance, embedding, slug, raw_text - FROM vec_items + FROM %s WHERE embedding MATCH ? ORDER BY distance LIMIT 4 - `) + `, vecTableName)) if err != nil { log.Fatal(err) } @@ -66,7 +99,10 @@ func (p ProviderSQL) SearchClosest(q [5120]float32) (*models.VectorRow, error) { if err != nil { log.Fatal(err) } - stmt.BindBlob(1, query) + if err := stmt.BindBlob(1, query); err != nil { + p.logger.Error("failed to bind", "error", err) + return nil, err + } resp := make([]models.VectorRow, 4) i := 0 for stmt.Step() { |