summaryrefslogtreecommitdiff
path: root/storage/vector.go
blob: 15796862e5789cb8a91be8ef08dfe12a795a680f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package storage

import (
	"elefant/models"
	"errors"
	"fmt"
	"unsafe"

	sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
)

type VectorRepo interface {
	WriteVector(*models.VectorRow) error
	SearchClosest(q []float32) ([]models.VectorRow, error)
}

var (
	vecTableName    = "embeddings"
	vecTableName384 = "embeddings_384"
)

func fetchTableName(emb []float32) (string, error) {
	switch len(emb) {
	case 5120:
		return vecTableName, nil
	case 384:
		return vecTableName384, nil
	default:
		return "", fmt.Errorf("no table for the size of %d", len(emb))
	}
}

func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
	tableName, err := fetchTableName(row.Embeddings)
	if err != nil {
		return err
	}
	stmt, _, err := p.s3Conn.Prepare(
		fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text) VALUES (?, ?, ?)", tableName))
	if err != nil {
		p.logger.Error("failed to prep a stmt", "error", err)
		return err
	}
	defer stmt.Close()
	v, err := sqlite_vec.SerializeFloat32(row.Embeddings)
	if err != nil {
		p.logger.Error("failed to serialize vector",
			"emb-len", len(row.Embeddings), "error", err)
		return err
	}
	if v == nil {
		err = errors.New("empty vector after serialization")
		p.logger.Error("empty vector after serialization",
			"emb-len", len(row.Embeddings), "text", row.RawText, "error", err)
		return err
	}
	if err := stmt.BindBlob(1, v); err != nil {
		p.logger.Error("failed to bind", "error", err)
		return err
	}
	if err := stmt.BindText(2, row.Slug); err != nil {
		p.logger.Error("failed to bind", "error", err)
		return err
	}
	if err := stmt.BindText(3, row.RawText); err != nil {
		p.logger.Error("failed to bind", "error", err)
		return err
	}
	err = stmt.Exec()
	if err != nil {
		p.logger.Error("failed exec a stmt", "error", err)
		return err
	}
	return nil
}

func decodeUnsafe(bs []byte) []float32 {
	return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4)
}

func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
	tableName, err := fetchTableName(q)
	if err != nil {
		return nil, err
	}
	stmt, _, err := p.s3Conn.Prepare(
		fmt.Sprintf(`SELECT
			id,
			distance,
			embedding,
			slug,
			raw_text
		FROM %s
		WHERE embedding MATCH ?
		ORDER BY distance
		LIMIT 4
	`, tableName))
	if err != nil {
		return nil, err
	}
	query, err := sqlite_vec.SerializeFloat32(q[:])
	if err != nil {
		return nil, err
	}
	if err := stmt.BindBlob(1, query); err != nil {
		p.logger.Error("failed to bind", "error", err)
		return nil, err
	}
	resp := []models.VectorRow{}
	for stmt.Step() {
		res := models.VectorRow{}
		res.ID = uint32(stmt.ColumnInt64(0))
		res.Distance = float32(stmt.ColumnFloat(1))
		emb := stmt.ColumnRawText(2)
		res.Embeddings = decodeUnsafe(emb)
		res.Slug = stmt.ColumnText(3)
		res.RawText = stmt.ColumnText(4)
		resp = append(resp, res)
	}
	if err := stmt.Err(); err != nil {
		return nil, err
	}
	err = stmt.Close()
	if err != nil {
		return nil, err
	}
	return resp, nil
}