summaryrefslogtreecommitdiff
path: root/storage
diff options
context:
space:
mode:
Diffstat (limited to 'storage')
-rw-r--r--storage/memory.go2
-rw-r--r--storage/storage.go2
-rw-r--r--storage/storage_test.go173
-rw-r--r--storage/vector.go2
4 files changed, 88 insertions, 91 deletions
diff --git a/storage/memory.go b/storage/memory.go
index c9fc853..406182f 100644
--- a/storage/memory.go
+++ b/storage/memory.go
@@ -1,6 +1,6 @@
package storage
-import "elefant/models"
+import "gf-lt/models"
type Memories interface {
Memorise(m *models.Memory) (*models.Memory, error)
diff --git a/storage/storage.go b/storage/storage.go
index f759700..7911e13 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -1,7 +1,7 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"log/slog"
_ "github.com/glebarez/go-sqlite"
diff --git a/storage/storage_test.go b/storage/storage_test.go
index ff3b5e6..a1c4cf4 100644
--- a/storage/storage_test.go
+++ b/storage/storage_test.go
@@ -1,18 +1,15 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"fmt"
- "log"
"log/slog"
"os"
"testing"
"time"
- sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
- "github.com/ncruces/go-sqlite3"
)
func TestMemories(t *testing.T) {
@@ -177,87 +174,87 @@ func TestChatHistory(t *testing.T) {
}
}
-func TestVecTable(t *testing.T) {
- // healthcheck
- db, err := sqlite3.Open(":memory:")
- if err != nil {
- t.Fatal(err)
- }
- stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
- if err != nil {
- t.Fatal(err)
- }
- stmt.Step()
- log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
- stmt.Close()
- // migration
- err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
- if err != nil {
- t.Fatal(err)
- }
- // data prep and insert
- items := map[int][]float32{
- 1: {0.1, 0.1, 0.1, 0.1},
- 2: {0.2, 0.2, 0.2, 0.2},
- 3: {0.3, 0.3, 0.3, 0.3},
- 4: {0.4, 0.4, 0.4, 0.4},
- 5: {0.5, 0.5, 0.5, 0.5},
- }
- q := []float32{0.28, 0.3, 0.3, 0.3}
- stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
- if err != nil {
- t.Fatal(err)
- }
- for id, values := range items {
- v, err := sqlite_vec.SerializeFloat32(values)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindInt(1, id)
- stmt.BindBlob(2, v)
- stmt.BindText(3, "some_chat")
- err = stmt.Exec()
- if err != nil {
- t.Fatal(err)
- }
- stmt.Reset()
- }
- stmt.Close()
- // select | vec search
- stmt, _, err = db.Prepare(`
- SELECT
- rowid,
- distance,
- embedding
- FROM vec_items
- WHERE embedding MATCH ?
- ORDER BY distance
- LIMIT 3
- `)
- if err != nil {
- t.Fatal(err)
- }
- query, err := sqlite_vec.SerializeFloat32(q)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindBlob(1, query)
- for stmt.Step() {
- rowid := stmt.ColumnInt64(0)
- distance := stmt.ColumnFloat(1)
- emb := stmt.ColumnRawText(2)
- floats := decodeUnsafe(emb)
- log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
- }
- if err := stmt.Err(); err != nil {
- t.Fatal(err)
- }
- err = stmt.Close()
- if err != nil {
- t.Fatal(err)
- }
- err = db.Close()
- if err != nil {
- t.Fatal(err)
- }
-}
+// func TestVecTable(t *testing.T) {
+// // healthcheck
+// db, err := sqlite3.Open(":memory:")
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Step()
+// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
+// stmt.Close()
+// // migration
+// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// // data prep and insert
+// items := map[int][]float32{
+// 1: {0.1, 0.1, 0.1, 0.1},
+// 2: {0.2, 0.2, 0.2, 0.2},
+// 3: {0.3, 0.3, 0.3, 0.3},
+// 4: {0.4, 0.4, 0.4, 0.4},
+// 5: {0.5, 0.5, 0.5, 0.5},
+// }
+// q := []float32{0.4, 0.3, 0.3, 0.3}
+// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// for id, values := range items {
+// v, err := sqlite_vec.SerializeFloat32(values)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindInt(1, id)
+// stmt.BindBlob(2, v)
+// stmt.BindText(3, "some_chat")
+// err = stmt.Exec()
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Reset()
+// }
+// stmt.Close()
+// // select | vec search
+// stmt, _, err = db.Prepare(`
+// SELECT
+// rowid,
+// distance,
+// embedding
+// FROM vec_items
+// WHERE embedding MATCH ?
+// ORDER BY distance
+// LIMIT 3
+// `)
+// if err != nil {
+// t.Fatal(err)
+// }
+// query, err := sqlite_vec.SerializeFloat32(q)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindBlob(1, query)
+// for stmt.Step() {
+// rowid := stmt.ColumnInt64(0)
+// distance := stmt.ColumnFloat(1)
+// emb := stmt.ColumnRawText(2)
+// floats := decodeUnsafe(emb)
+// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
+// }
+// if err := stmt.Err(); err != nil {
+// t.Fatal(err)
+// }
+// err = stmt.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// err = db.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// }
diff --git a/storage/vector.go b/storage/vector.go
index 5e9069c..71005e4 100644
--- a/storage/vector.go
+++ b/storage/vector.go
@@ -1,7 +1,7 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"errors"
"fmt"
"unsafe"