diff options
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/document.go | 56 | ||||
| -rw-r--r-- | rag/query.go | 245 | ||||
| -rw-r--r-- | rag/vector.go | 62 | 
3 files changed, 363 insertions, 0 deletions
| diff --git a/rag/document.go b/rag/document.go new file mode 100644 index 0000000..48c907e --- /dev/null +++ b/rag/document.go @@ -0,0 +1,56 @@ +package rag + +import ( +	"context" +	"errors" +) + +type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) + +func NewEmbeddingFuncDefault() EmbeddingFunc { +	return nil +} + +// Document represents a single document. +type Document struct { +	ID        string +	Metadata  map[string]string +	Embedding []float32 +	Content   string +} + +// NewDocument creates a new document, including its embeddings. +// Metadata is optional. +// If the embeddings are not provided, they are created using the embedding function. +// You can leave the content empty if you only want to store embeddings. +// If embeddingFunc is nil, the default embedding function is used. +// +// If you want to create a document without embeddings, for example to let [Collection.AddDocuments] +// create them concurrently, you can create a document with `chromem.Document{...}` +// instead of using this constructor. +func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) { +	if id == "" { +		return Document{}, errors.New("id is empty") +	} +	if len(embedding) == 0 && content == "" { +		return Document{}, errors.New("either embedding or content must be filled") +	} +	if embeddingFunc == nil { +		embeddingFunc = NewEmbeddingFuncDefault() +	} + +	if len(embedding) == 0 { +		var err error +		embedding, err = embeddingFunc(ctx, content) +		if err != nil { +			return Document{}, err +		} +	} + +	return Document{ +		ID:        id, +		Metadata:  metadata, +		Embedding: embedding, +		Content:   content, +	}, nil +} diff --git a/rag/query.go b/rag/query.go new file mode 100644 index 0000000..e120a7a --- /dev/null +++ b/rag/query.go @@ -0,0 +1,245 @@ +package rag + +import ( +	"cmp" +	"container/heap" +	"context" +	"fmt" +	"runtime" +	"slices" +	"strings" +	"sync" +) + +var supportedFilters = []string{"$contains", "$not_contains"} + +type docSim struct { +	docID      string +	similarity float32 +} + +// docMaxHeap is a max-heap of docSims, based on similarity. +// See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap +type docMaxHeap []docSim + +func (h docMaxHeap) Len() int           { return len(h) } +func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } +func (h docMaxHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] } + +func (h *docMaxHeap) Push(x any) { +	// Push and Pop use pointer receivers because they modify the slice's length, +	// not just its contents. +	*h = append(*h, x.(docSim)) +} + +func (h *docMaxHeap) Pop() any { +	old := *h +	n := len(old) +	x := old[n-1] +	*h = old[0 : n-1] +	return x +} + +// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest +// similarities. It's safe for concurrent use, but not the result of values(). +// In our benchmarks this was faster than sorting a slice of docSims at the end. +type maxDocSims struct { +	h    docMaxHeap +	lock sync.RWMutex +	size int +} + +// newMaxDocSims creates a new nMaxDocs with a fixed size. +func newMaxDocSims(size int) *maxDocSims { +	return &maxDocSims{ +		h:    make(docMaxHeap, 0, size), +		size: size, +	} +} + +// add inserts a new docSim into the heap, keeping only the top n similarities. +func (d *maxDocSims) add(doc docSim) { +	d.lock.Lock() +	defer d.lock.Unlock() +	if d.h.Len() < d.size { +		heap.Push(&d.h, doc) +	} else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity { +		// Replace the smallest similarity if the new doc's similarity is higher +		heap.Pop(&d.h) +		heap.Push(&d.h, doc) +	} +} + +// values returns the docSims in the heap, sorted by similarity (descending). +// The call itself is safe for concurrent use with add(), but the result isn't. +// Only work with the result after all calls to add() have finished. +func (d *maxDocSims) values() []docSim { +	d.lock.RLock() +	defer d.lock.RUnlock() +	slices.SortFunc(d.h, func(i, j docSim) int { +		return cmp.Compare(j.similarity, i.similarity) +	}) +	return d.h +} + +// filterDocs filters a map of documents by metadata and content. +// It does this concurrently. +func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { +	filteredDocs := make([]*Document, 0, len(docs)) +	filteredDocsLock := sync.Mutex{} + +	// Determine concurrency. Use number of docs or CPUs, whichever is smaller. +	numCPUs := runtime.NumCPU() +	numDocs := len(docs) +	concurrency := numCPUs +	if numDocs < numCPUs { +		concurrency = numDocs +	} + +	docChan := make(chan *Document, concurrency*2) + +	wg := sync.WaitGroup{} +	for i := 0; i < concurrency; i++ { +		wg.Add(1) +		go func() { +			defer wg.Done() +			for doc := range docChan { +				if documentMatchesFilters(doc, where, whereDocument) { +					filteredDocsLock.Lock() +					filteredDocs = append(filteredDocs, doc) +					filteredDocsLock.Unlock() +				} +			} +		}() +	} + +	for _, doc := range docs { +		docChan <- doc +	} +	close(docChan) + +	wg.Wait() + +	// With filteredDocs being initialized as potentially large slice, let's return +	// nil instead of the empty slice. +	if len(filteredDocs) == 0 { +		filteredDocs = nil +	} +	return filteredDocs +} + +// documentMatchesFilters checks if a document matches the given filters. +// When calling this function, the whereDocument keys must already be validated! +func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { +	// A document's metadata must have *all* the fields in the where clause. +	for k, v := range where { +		// TODO: Do we want to check for existence of the key? I.e. should +		// a where clause with empty string as value match a document's +		// metadata that doesn't have the key at all? +		if document.Metadata[k] != v { +			return false +		} +	} + +	// A document must satisfy *all* filters, until we support the `$or` operator. +	for k, v := range whereDocument { +		switch k { +		case "$contains": +			if !strings.Contains(document.Content, v) { +				return false +			} +		case "$not_contains": +			if strings.Contains(document.Content, v) { +				return false +			} +		default: +			// No handling (error) required because we already validated the +			// operators. This simplifies the concurrency logic (no err var +			// and lock, no context to cancel). +		} +	} + +	return true +} + +func getMostSimilarDocs(ctx context.Context, queryVectors, negativeVector []float32, negativeFilterThreshold float32, docs []*Document, n int) ([]docSim, error) { +	nMaxDocs := newMaxDocSims(n) + +	// Determine concurrency. Use number of docs or CPUs, whichever is smaller. +	numCPUs := runtime.NumCPU() +	numDocs := len(docs) +	concurrency := numCPUs +	if numDocs < numCPUs { +		concurrency = numDocs +	} + +	var sharedErr error +	sharedErrLock := sync.Mutex{} +	ctx, cancel := context.WithCancelCause(ctx) +	defer cancel(nil) +	setSharedErr := func(err error) { +		sharedErrLock.Lock() +		defer sharedErrLock.Unlock() +		// Another goroutine might have already set the error. +		if sharedErr == nil { +			sharedErr = err +			// Cancel the operation for all other goroutines. +			cancel(sharedErr) +		} +	} + +	wg := sync.WaitGroup{} +	// Instead of using a channel to pass documents into the goroutines, we just +	// split the slice into sub-slices and pass those to the goroutines. +	// This turned out to be faster in the query benchmarks. +	subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1 +	rem := len(docs) % concurrency +	for i := 0; i < concurrency; i++ { +		start := i * subSliceSize +		end := start + subSliceSize +		// Add remainder to last goroutine +		if i == concurrency-1 { +			end += rem +		} + +		wg.Add(1) +		go func(subSlice []*Document) { +			defer wg.Done() +			for _, doc := range subSlice { +				// Stop work if another goroutine encountered an error. +				if ctx.Err() != nil { +					return +				} + +				// As the vectors are normalized, the dot product is the cosine similarity. +				sim, err := dotProduct(queryVectors, doc.Embedding) +				if err != nil { +					setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) +					return +				} + +				if negativeFilterThreshold > 0 { +					nsim, err := dotProduct(negativeVector, doc.Embedding) +					if err != nil { +						setSharedErr(fmt.Errorf("couldn't calculate negative similarity for document '%s': %w", doc.ID, err)) +						return +					} + +					if nsim > negativeFilterThreshold { +						continue +					} +				} + +				nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) +			} +		}(docs[start:end]) +	} + +	wg.Wait() + +	if sharedErr != nil { +		return nil, sharedErr +	} + +	return nMaxDocs.values(), nil +} diff --git a/rag/vector.go b/rag/vector.go new file mode 100644 index 0000000..7a69a1b --- /dev/null +++ b/rag/vector.go @@ -0,0 +1,62 @@ +package rag + +import ( +	"errors" +	"math" +) + +const isNormalizedPrecisionTolerance = 1e-6 + +// dotProduct calculates the dot product between two vectors. +// It's the same as cosine similarity for normalized vectors. +// The resulting value represents the similarity, so a higher value means the +// vectors are more similar. +func dotProduct(a, b []float32) (float32, error) { +	// The vectors must have the same length +	if len(a) != len(b) { +		return 0, errors.New("vectors must have the same length") +	} + +	var dotProduct float32 +	for i := range a { +		dotProduct += a[i] * b[i] +	} + +	return dotProduct, nil +} + +func normalizeVector(v []float32) []float32 { +	var norm float32 +	for _, val := range v { +		norm += val * val +	} +	norm = float32(math.Sqrt(float64(norm))) + +	res := make([]float32, len(v)) +	for i, val := range v { +		res[i] = val / norm +	} + +	return res +} + +// subtractVector subtracts vector b from vector a in place. +func subtractVector(a, b []float32) []float32 { +	res := make([]float32, len(a)) + +	for i := range a { +		res[i] = a[i] - b[i] +	} + +	return res +} + +// isNormalized checks if the vector is normalized. +func isNormalized(v []float32) bool { +	var sqSum float64 +	for _, val := range v { +		sqSum += float64(val) * float64(val) +	} +	magnitude := math.Sqrt(sqSum) +	return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance +} | 
