diff options
Diffstat (limited to 'storage/vector.go')
| -rw-r--r-- | storage/vector.go | 38 |
1 files changed, 15 insertions, 23 deletions
diff --git a/storage/vector.go b/storage/vector.go index 75f5c9a..e3bbb89 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "gf-lt/models" + "sort" "unsafe" "github.com/jmoiron/sqlx" @@ -11,7 +12,7 @@ import ( type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q []float32) ([]models.VectorRow, error) + SearchClosest(q []float32, limit int) ([]models.VectorRow, error) ListFiles() ([]string, error) RemoveEmbByFileName(filename string) error DB() *sqlx.DB @@ -79,7 +80,7 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { return err } -func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { +func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) { tableName, err := fetchTableName(q) if err != nil { return nil, err @@ -94,7 +95,7 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { vector models.VectorRow distance float32 } - var topResults []SearchResult + var allResults []SearchResult for rows.Next() { var ( embeddingsBlob []byte @@ -119,28 +120,19 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { }, distance: distance, } - - // Add to top results and maintain only top results - topResults = append(topResults, result) - - // Sort and keep only top results - // We'll keep the top 3 closest vectors - if len(topResults) > 3 { - // Simple sort and truncate to maintain only 3 best matches - for i := 0; i < len(topResults); i++ { - for j := i + 1; j < len(topResults); j++ { - if topResults[i].distance > topResults[j].distance { - topResults[i], topResults[j] = topResults[j], topResults[i] - } - } - } - topResults = topResults[:3] - } + allResults = append(allResults, result) + } + // Sort by distance + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].distance < allResults[j].distance + }) + // Truncate to limit + if len(allResults) > limit { + allResults = allResults[:limit] } - // Convert back to VectorRow slice - results := make([]models.VectorRow, len(topResults)) - for i, result := range topResults { + results := make([]models.VectorRow, len(allResults)) + for i, result := range allResults { result.vector.Distance = result.distance results[i] = result.vector } |
