summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-09-05 20:09:35 +0300
committerGrail Finder <wohilas@gmail.com>2025-09-05 20:09:35 +0300
commite0bd66e9ad7176ce85937c72f6e312b06fe259ab (patch)
tree712308cc87bdd8c2b9b990a98fe97ee5585ac5cd /rag
parent0068cd17ff1985102dd5ab22df899ea6b1065ff0 (diff)
Feat: rag chromemfeat/rag
Diffstat (limited to 'rag')
-rw-r--r--rag/document.go56
-rw-r--r--rag/query.go245
-rw-r--r--rag/vector.go62
3 files changed, 363 insertions, 0 deletions
diff --git a/rag/document.go b/rag/document.go
new file mode 100644
index 0000000..48c907e
--- /dev/null
+++ b/rag/document.go
@@ -0,0 +1,56 @@
+package rag
+
+import (
+ "context"
+ "errors"
+)
+
+type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
+
+func NewEmbeddingFuncDefault() EmbeddingFunc {
+ return nil
+}
+
+// Document represents a single document.
+type Document struct {
+ ID string
+ Metadata map[string]string
+ Embedding []float32
+ Content string
+}
+
+// NewDocument creates a new document, including its embeddings.
+// Metadata is optional.
+// If the embeddings are not provided, they are created using the embedding function.
+// You can leave the content empty if you only want to store embeddings.
+// If embeddingFunc is nil, the default embedding function is used.
+//
+// If you want to create a document without embeddings, for example to let [Collection.AddDocuments]
+// create them concurrently, you can create a document with `chromem.Document{...}`
+// instead of using this constructor.
+func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) {
+ if id == "" {
+ return Document{}, errors.New("id is empty")
+ }
+ if len(embedding) == 0 && content == "" {
+ return Document{}, errors.New("either embedding or content must be filled")
+ }
+ if embeddingFunc == nil {
+ embeddingFunc = NewEmbeddingFuncDefault()
+ }
+
+ if len(embedding) == 0 {
+ var err error
+ embedding, err = embeddingFunc(ctx, content)
+ if err != nil {
+ return Document{}, err
+ }
+ }
+
+ return Document{
+ ID: id,
+ Metadata: metadata,
+ Embedding: embedding,
+ Content: content,
+ }, nil
+}
diff --git a/rag/query.go b/rag/query.go
new file mode 100644
index 0000000..e120a7a
--- /dev/null
+++ b/rag/query.go
@@ -0,0 +1,245 @@
+package rag
+
+import (
+ "cmp"
+ "container/heap"
+ "context"
+ "fmt"
+ "runtime"
+ "slices"
+ "strings"
+ "sync"
+)
+
+var supportedFilters = []string{"$contains", "$not_contains"}
+
+type docSim struct {
+ docID string
+ similarity float32
+}
+
+// docMaxHeap is a max-heap of docSims, based on similarity.
+// See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap
+type docMaxHeap []docSim
+
+func (h docMaxHeap) Len() int { return len(h) }
+func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity }
+func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+
+func (h *docMaxHeap) Push(x any) {
+ // Push and Pop use pointer receivers because they modify the slice's length,
+ // not just its contents.
+ *h = append(*h, x.(docSim))
+}
+
+func (h *docMaxHeap) Pop() any {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[0 : n-1]
+ return x
+}
+
+// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest
+// similarities. It's safe for concurrent use, but not the result of values().
+// In our benchmarks this was faster than sorting a slice of docSims at the end.
+type maxDocSims struct {
+ h docMaxHeap
+ lock sync.RWMutex
+ size int
+}
+
+// newMaxDocSims creates a new nMaxDocs with a fixed size.
+func newMaxDocSims(size int) *maxDocSims {
+ return &maxDocSims{
+ h: make(docMaxHeap, 0, size),
+ size: size,
+ }
+}
+
+// add inserts a new docSim into the heap, keeping only the top n similarities.
+func (d *maxDocSims) add(doc docSim) {
+ d.lock.Lock()
+ defer d.lock.Unlock()
+ if d.h.Len() < d.size {
+ heap.Push(&d.h, doc)
+ } else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity {
+ // Replace the smallest similarity if the new doc's similarity is higher
+ heap.Pop(&d.h)
+ heap.Push(&d.h, doc)
+ }
+}
+
+// values returns the docSims in the heap, sorted by similarity (descending).
+// The call itself is safe for concurrent use with add(), but the result isn't.
+// Only work with the result after all calls to add() have finished.
+func (d *maxDocSims) values() []docSim {
+ d.lock.RLock()
+ defer d.lock.RUnlock()
+ slices.SortFunc(d.h, func(i, j docSim) int {
+ return cmp.Compare(j.similarity, i.similarity)
+ })
+ return d.h
+}
+
+// filterDocs filters a map of documents by metadata and content.
+// It does this concurrently.
+func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
+ filteredDocs := make([]*Document, 0, len(docs))
+ filteredDocsLock := sync.Mutex{}
+
+ // Determine concurrency. Use number of docs or CPUs, whichever is smaller.
+ numCPUs := runtime.NumCPU()
+ numDocs := len(docs)
+ concurrency := numCPUs
+ if numDocs < numCPUs {
+ concurrency = numDocs
+ }
+
+ docChan := make(chan *Document, concurrency*2)
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for doc := range docChan {
+ if documentMatchesFilters(doc, where, whereDocument) {
+ filteredDocsLock.Lock()
+ filteredDocs = append(filteredDocs, doc)
+ filteredDocsLock.Unlock()
+ }
+ }
+ }()
+ }
+
+ for _, doc := range docs {
+ docChan <- doc
+ }
+ close(docChan)
+
+ wg.Wait()
+
+ // With filteredDocs being initialized as potentially large slice, let's return
+ // nil instead of the empty slice.
+ if len(filteredDocs) == 0 {
+ filteredDocs = nil
+ }
+ return filteredDocs
+}
+
+// documentMatchesFilters checks if a document matches the given filters.
+// When calling this function, the whereDocument keys must already be validated!
+func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
+ // A document's metadata must have *all* the fields in the where clause.
+ for k, v := range where {
+ // TODO: Do we want to check for existence of the key? I.e. should
+ // a where clause with empty string as value match a document's
+ // metadata that doesn't have the key at all?
+ if document.Metadata[k] != v {
+ return false
+ }
+ }
+
+ // A document must satisfy *all* filters, until we support the `$or` operator.
+ for k, v := range whereDocument {
+ switch k {
+ case "$contains":
+ if !strings.Contains(document.Content, v) {
+ return false
+ }
+ case "$not_contains":
+ if strings.Contains(document.Content, v) {
+ return false
+ }
+ default:
+ // No handling (error) required because we already validated the
+ // operators. This simplifies the concurrency logic (no err var
+ // and lock, no context to cancel).
+ }
+ }
+
+ return true
+}
+
+func getMostSimilarDocs(ctx context.Context, queryVectors, negativeVector []float32, negativeFilterThreshold float32, docs []*Document, n int) ([]docSim, error) {
+ nMaxDocs := newMaxDocSims(n)
+
+ // Determine concurrency. Use number of docs or CPUs, whichever is smaller.
+ numCPUs := runtime.NumCPU()
+ numDocs := len(docs)
+ concurrency := numCPUs
+ if numDocs < numCPUs {
+ concurrency = numDocs
+ }
+
+ var sharedErr error
+ sharedErrLock := sync.Mutex{}
+ ctx, cancel := context.WithCancelCause(ctx)
+ defer cancel(nil)
+ setSharedErr := func(err error) {
+ sharedErrLock.Lock()
+ defer sharedErrLock.Unlock()
+ // Another goroutine might have already set the error.
+ if sharedErr == nil {
+ sharedErr = err
+ // Cancel the operation for all other goroutines.
+ cancel(sharedErr)
+ }
+ }
+
+ wg := sync.WaitGroup{}
+ // Instead of using a channel to pass documents into the goroutines, we just
+ // split the slice into sub-slices and pass those to the goroutines.
+ // This turned out to be faster in the query benchmarks.
+ subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1
+ rem := len(docs) % concurrency
+ for i := 0; i < concurrency; i++ {
+ start := i * subSliceSize
+ end := start + subSliceSize
+ // Add remainder to last goroutine
+ if i == concurrency-1 {
+ end += rem
+ }
+
+ wg.Add(1)
+ go func(subSlice []*Document) {
+ defer wg.Done()
+ for _, doc := range subSlice {
+ // Stop work if another goroutine encountered an error.
+ if ctx.Err() != nil {
+ return
+ }
+
+ // As the vectors are normalized, the dot product is the cosine similarity.
+ sim, err := dotProduct(queryVectors, doc.Embedding)
+ if err != nil {
+ setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
+ return
+ }
+
+ if negativeFilterThreshold > 0 {
+ nsim, err := dotProduct(negativeVector, doc.Embedding)
+ if err != nil {
+ setSharedErr(fmt.Errorf("couldn't calculate negative similarity for document '%s': %w", doc.ID, err))
+ return
+ }
+
+ if nsim > negativeFilterThreshold {
+ continue
+ }
+ }
+
+ nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
+ }
+ }(docs[start:end])
+ }
+
+ wg.Wait()
+
+ if sharedErr != nil {
+ return nil, sharedErr
+ }
+
+ return nMaxDocs.values(), nil
+}
diff --git a/rag/vector.go b/rag/vector.go
new file mode 100644
index 0000000..7a69a1b
--- /dev/null
+++ b/rag/vector.go
@@ -0,0 +1,62 @@
+package rag
+
+import (
+ "errors"
+ "math"
+)
+
+const isNormalizedPrecisionTolerance = 1e-6
+
+// dotProduct calculates the dot product between two vectors.
+// It's the same as cosine similarity for normalized vectors.
+// The resulting value represents the similarity, so a higher value means the
+// vectors are more similar.
+func dotProduct(a, b []float32) (float32, error) {
+ // The vectors must have the same length
+ if len(a) != len(b) {
+ return 0, errors.New("vectors must have the same length")
+ }
+
+ var dotProduct float32
+ for i := range a {
+ dotProduct += a[i] * b[i]
+ }
+
+ return dotProduct, nil
+}
+
+func normalizeVector(v []float32) []float32 {
+ var norm float32
+ for _, val := range v {
+ norm += val * val
+ }
+ norm = float32(math.Sqrt(float64(norm)))
+
+ res := make([]float32, len(v))
+ for i, val := range v {
+ res[i] = val / norm
+ }
+
+ return res
+}
+
+// subtractVector subtracts vector b from vector a in place.
+func subtractVector(a, b []float32) []float32 {
+ res := make([]float32, len(a))
+
+ for i := range a {
+ res[i] = a[i] - b[i]
+ }
+
+ return res
+}
+
+// isNormalized checks if the vector is normalized.
+func isNormalized(v []float32) bool {
+ var sqSum float64
+ for _, val := range v {
+ sqSum += float64(val) * float64(val)
+ }
+ magnitude := math.Sqrt(sqSum)
+ return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance
+}