summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rag/rag.go103
1 files changed, 52 insertions, 51 deletions
diff --git a/rag/rag.go b/rag/rag.go
index d8b6978..71c4ce8 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -23,7 +23,6 @@ var (
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
)
-
type RAG struct {
logger *slog.Logger
store storage.FullRepo
@@ -122,10 +121,11 @@ func (r *RAG) LoadRAG(fpath string) error {
batchCh = make(chan map[int][]string, maxChSize)
vectorCh = make(chan []models.VectorRow, maxChSize)
errCh = make(chan error, 1)
+ doneCh = make(chan struct{})
wg = new(sync.WaitGroup)
- lock = new(sync.Mutex)
)
+ defer close(doneCh)
defer close(errCh)
defer close(batchCh)
@@ -156,18 +156,20 @@ func (r *RAG) LoadRAG(fpath string) error {
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go func(workerID int) {
defer wg.Done()
- r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
+ r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}(w)
}
- // Use a goroutine to close the batchCh when all batches are sent
+ // Close batchCh to signal workers no more data is coming
+ close(batchCh)
+
+ // Wait for all workers to finish, then close vectorCh
go func() {
wg.Wait()
- close(vectorCh) // Close vectorCh when all workers are done
+ close(vectorCh)
}()
- // Check for errors from workers
- // Use a non-blocking check for errors
+ // Check for errors from workers - this will block until an error occurs or all workers finish
select {
case err := <-errCh:
if err != nil {
@@ -179,12 +181,28 @@ func (r *RAG) LoadRAG(fpath string) error {
}
// Write vectors to storage - this will block until vectorCh is closed
- return r.writeVectors(vectorCh)
+ return r.writeVectors(vectorCh, errCh)
}
-func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
+func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
+ // Use a select to handle both vectorCh and errCh
for {
- for batch := range vectorCh {
+ select {
+ case err := <-errCh:
+ if err != nil {
+ r.logger.Error("error during RAG processing in writeVectors", "error", err)
+ return err
+ }
+ case batch, ok := <-vectorCh:
+ if !ok {
+ r.logger.Debug("vector channel closed, finished writing vectors")
+ select {
+ case LongJobStatusCh <- FinishedRAGStatus:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
+ }
+ return nil
+ }
for _, vector := range batch {
if err := r.storage.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
@@ -192,74 +210,57 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
case LongJobStatusCh <- ErrRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
- // Channel is full or closed, ignore the message to prevent panic
}
- return err // Stop the entire RAG operation on DB error
- }
- }
- r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
- if len(vectorCh) == 0 {
- r.logger.Debug("finished writing vectors")
- select {
- case LongJobStatusCh <- FinishedRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
- // Channel is full or closed, ignore the message to prevent panic
+ return err
}
- return nil
}
+ r.logger.Debug("wrote batch to db", "size", len(batch))
}
}
}
-func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
- vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
+func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
+ vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
var err error
defer func() {
- // For errCh, make sure we only send if there's actually an error and the channel can accept it
if err != nil {
select {
case errCh <- err:
default:
- // errCh might be full or closed, log but don't panic
r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
}
}
}()
for {
- lock.Lock()
- if len(inputCh) == 0 {
- lock.Unlock()
- return
- }
-
select {
- case linesMap := <-inputCh:
+ case <-doneCh:
+ r.logger.Debug("worker received done signal", "worker", id)
+ return
+ case linesMap, ok := <-inputCh:
+ if !ok {
+ r.logger.Debug("input channel closed, worker exiting", "worker", id)
+ return
+ }
for leftI, lines := range linesMap {
+ select {
+ case <-doneCh:
+ return
+ default:
+ }
if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
r.logger.Error("error fetching embeddings", "error", err, "worker", id)
- lock.Unlock()
return
}
}
- lock.Unlock()
- case err = <-errCh:
- r.logger.Error("got an error from error channel", "error", err)
- lock.Unlock()
- return
- default:
- lock.Unlock()
- }
-
- r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
- statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
- select {
- case LongJobStatusCh <- statusMsg:
- default:
- r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
- // Channel is full or closed, ignore the message to prevent panic
+ r.logger.Debug("processed batch", "worker#", id)
+ statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
+ select {
+ case LongJobStatusCh <- statusMsg:
+ default:
+ r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
+ }
}
}
}