diff options
Diffstat (limited to 'rag_new/rag.go')
| -rw-r--r-- | rag_new/rag.go | 260 | 
1 files changed, 260 insertions, 0 deletions
| diff --git a/rag_new/rag.go b/rag_new/rag.go new file mode 100644 index 0000000..d012087 --- /dev/null +++ b/rag_new/rag.go @@ -0,0 +1,260 @@ +package rag_new + +import ( +	"gf-lt/config" +	"gf-lt/models" +	"gf-lt/storage" +	"fmt" +	"log/slog" +	"os" +	"path" +	"strings" +	"sync" + +	"github.com/neurosnap/sentences/english" +) + +var ( +	// Status messages for TUI integration +	LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking +	FinishedRAGStatus   = "finished loading RAG file; press Enter" +	LoadedFileRAGStatus = "loaded file" +	ErrRAGStatus        = "some error occurred; failed to transfer data to vector db" +) + +type RAG struct { +	logger *slog.Logger +	store  storage.FullRepo +	cfg    *config.Config +	embedder Embedder +	storage *VectorStorage +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { +	// Initialize with API embedder by default, could be configurable later +	embedder := NewAPIEmbedder(l, cfg) +	 +	rag := &RAG{ +		logger: l, +		store:  s, +		cfg:    cfg, +		embedder: embedder, +		storage: NewVectorStorage(l, s), +	} +	 +	// Create the necessary tables +	if err := rag.storage.CreateTables(); err != nil { +		l.Error("failed to create vector tables", "error", err) +	} +	 +	return rag +} + +func wordCounter(sentence string) int { +	return len(strings.Split(strings.TrimSpace(sentence), " ")) +} + +func (r *RAG) LoadRAG(fpath string) error { +	data, err := os.ReadFile(fpath) +	if err != nil { +		return err +	} +	r.logger.Debug("rag: loaded file", "fp", fpath) +	LongJobStatusCh <- LoadedFileRAGStatus +	 +	fileText := string(data) +	tokenizer, err := english.NewSentenceTokenizer(nil) +	if err != nil { +		return err +	} +	sentences := tokenizer.Tokenize(fileText) +	sents := make([]string, len(sentences)) +	for i, s := range sentences { +		sents[i] = s.Text +	} +	 +	// Group sentences into paragraphs based on word limit +	paragraphs := []string{} +	par := strings.Builder{} +	for i := 0; i < len(sents); i++ { +		// Only add sentences that aren't empty +		if strings.TrimSpace(sents[i]) != "" { +			if par.Len() > 0 { +				par.WriteString(" ") // Add space between sentences +			} +			par.WriteString(sents[i]) +		} +		 +		if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { +			paragraph := strings.TrimSpace(par.String()) +			if paragraph != "" { +				paragraphs = append(paragraphs, paragraph) +			} +			par.Reset() +		} +	} +	 +	// Handle any remaining content in the paragraph buffer +	if par.Len() > 0 { +		paragraph := strings.TrimSpace(par.String()) +		if paragraph != "" { +			paragraphs = append(paragraphs, paragraph) +		} +	} +	 +	// Adjust batch size if needed +	if len(paragraphs) < int(r.cfg.RAGBatchSize) && len(paragraphs) > 0 { +		r.cfg.RAGBatchSize = len(paragraphs) +	} +	 +	if len(paragraphs) == 0 { +		return fmt.Errorf("no valid paragraphs found in file") +	} +	 +	var ( +		maxChSize = 100 +		left      = 0 +		right     = r.cfg.RAGBatchSize +		batchCh   = make(chan map[int][]string, maxChSize) +		vectorCh  = make(chan []models.VectorRow, maxChSize) +		errCh     = make(chan error, 1) +		doneCh    = make(chan bool, 1) +		lock      = new(sync.Mutex) +	) +	 +	defer close(doneCh) +	defer close(errCh) +	defer close(batchCh) +	 +	// Fill input channel with batches +	ctn := 0 +	totalParagraphs := len(paragraphs) +	for { +		if int(right) > totalParagraphs { +			batchCh <- map[int][]string{left: paragraphs[left:]} +			break +		} +		batchCh <- map[int][]string{left: paragraphs[left:right]} +		left, right = right, right+r.cfg.RAGBatchSize +		ctn++ +	} +	 +	finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) +	r.logger.Debug(finishedBatchesMsg) +	LongJobStatusCh <- finishedBatchesMsg +	 +	// Start worker goroutines +	for w := 0; w < int(r.cfg.RAGWorkers); w++ { +		go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) +	} +	 +	// Wait for embedding to be done +	<-doneCh +	 +	// Write vectors to storage +	return r.writeVectors(vectorCh) +} + +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { +	for { +		for batch := range vectorCh { +			for _, vector := range batch { +				if err := r.storage.WriteVector(&vector); err != nil { +					r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) +					LongJobStatusCh <- ErrRAGStatus +					continue // a duplicate is not critical +				} +			} +			r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) +			if len(vectorCh) == 0 { +				r.logger.Debug("finished writing vectors") +				LongJobStatusCh <- FinishedRAGStatus +				return nil +			} +		} +	} +} + +func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, +	vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { +	defer func() { +		if len(doneCh) == 0 { +			doneCh <- true +		} +	}() +	 +	for { +		lock.Lock() +		if len(inputCh) == 0 { +			lock.Unlock() +			return +		} +		 +		select { +		case linesMap := <-inputCh: +			for leftI, lines := range linesMap { +				if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { +					r.logger.Error("error fetching embeddings", "error", err, "worker", id) +					lock.Unlock() +					return +				} +			} +			lock.Unlock() +		case err := <-errCh: +			r.logger.Error("got an error from error channel", "error", err) +			lock.Unlock() +			return +		default: +			lock.Unlock() +		} +		 +		r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) +		LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) +	} +} + +func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { +	embeddings, err := r.embedder.Embed(lines) +	if err != nil { +		r.logger.Error("failed to embed lines", "err", err.Error()) +		errCh <- err +		return err +	} +	 +	if len(embeddings) == 0 { +		err := fmt.Errorf("no embeddings returned") +		r.logger.Error("empty embeddings") +		errCh <- err +		return err +	} +	 +	vectors := make([]models.VectorRow, len(embeddings)) +	for i, emb := range embeddings { +		vector := models.VectorRow{ +			Embeddings: emb, +			RawText:    lines[i], +			Slug:       fmt.Sprintf("%s_%d", slug, i), +			FileName:   filename, +		} +		vectors[i] = vector +	} +	 +	vectorCh <- vectors +	return nil +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { +	return r.embedder.EmbedSingle(line) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { +	return r.storage.SearchClosest(emb.Embedding) +} + +func (r *RAG) ListLoaded() ([]string, error) { +	return r.storage.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { +	return r.storage.RemoveEmbByFileName(filename) +}
\ No newline at end of file | 
