diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-12-10 15:07:06 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-12-10 15:07:06 +0300 |
| commit | ad36d1c3e0b545c3e0517ec384087075ad77f63e (patch) | |
| tree | 9a5208f1e9c1b918949bb9322bfcf3abc899c15b | |
| parent | 8af2a59a9a12667bae2ce138259d432ba81f8e03 (diff) | |
Fix: rag panics
| -rw-r--r-- | bot.go | 2 | ||||
| -rw-r--r-- | config.example.toml | 4 | ||||
| -rw-r--r-- | rag/rag.go | 87 |
3 files changed, 70 insertions, 23 deletions
@@ -429,13 +429,11 @@ func chatRagUse(qText string) (string, error) { logger.Error("failed to get embs", "error", err, "index", i, "question", q) continue } - // Create EmbeddingResp struct for the search embeddingResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, // Not used in search but required for the struct } - vecs, err := ragger.SearchEmb(embeddingResp) if err != nil { logger.Error("failed to query embs", "error", err, "index", i, "question", q) diff --git a/config.example.toml b/config.example.toml index 47e4408..113b7ea 100644 --- a/config.example.toml +++ b/config.example.toml @@ -10,7 +10,7 @@ DeepSeekModel = "deepseek-reasoner" OpenRouterCompletionAPI = "https://openrouter.ai/api/v1/completions" OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions" # OpenRouterToken = "" -EmbedURL = "http://localhost:8080/v1/embeddings" +EmbedURL = "http://localhost:8082/v1/embeddings" ShowSys = true LogFile = "log.txt" UserRole = "user" @@ -19,7 +19,7 @@ AssistantRole = "assistant" SysDir = "sysprompts" ChunkLimit = 100000 # rag settings -RAGBatchSize = 10 +RAGBatchSize = 1 RAGWordLimit = 80 RAGWorkers = 2 RAGDir = "ragimport" @@ -23,6 +23,7 @@ var ( ErrRAGStatus = "some error occurred; failed to transfer data to vector db" ) + type RAG struct { logger *slog.Logger store storage.FullRepo @@ -58,7 +59,12 @@ func (r *RAG) LoadRAG(fpath string) error { return err } r.logger.Debug("rag: loaded file", "fp", fpath) - LongJobStatusCh <- LoadedFileRAGStatus + select { + case LongJobStatusCh <- LoadedFileRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } fileText := string(data) tokenizer, err := english.NewSentenceTokenizer(nil) @@ -116,11 +122,10 @@ func (r *RAG) LoadRAG(fpath string) error { batchCh = make(chan map[int][]string, maxChSize) vectorCh = make(chan []models.VectorRow, maxChSize) errCh = make(chan error, 1) - doneCh = make(chan bool, 1) + wg = new(sync.WaitGroup) lock = new(sync.Mutex) ) - defer close(doneCh) defer close(errCh) defer close(batchCh) @@ -139,19 +144,41 @@ func (r *RAG) LoadRAG(fpath string) error { finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) r.logger.Debug(finishedBatchesMsg) - LongJobStatusCh <- finishedBatchesMsg + select { + case LongJobStatusCh <- finishedBatchesMsg: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg) + // Channel is full or closed, ignore the message to prevent panic + } - // Start worker goroutines + // Start worker goroutines with WaitGroup + wg.Add(int(r.cfg.RAGWorkers)) for w := 0; w < int(r.cfg.RAGWorkers); w++ { - go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) + go func(workerID int) { + defer wg.Done() + r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath)) + }(w) } - // Wait for embedding to be done - <-doneCh - err = <-errCh - if err != nil { - return err + + // Use a goroutine to close the batchCh when all batches are sent + go func() { + wg.Wait() + close(vectorCh) // Close vectorCh when all workers are done + }() + + // Check for errors from workers + // Use a non-blocking check for errors + select { + case err := <-errCh: + if err != nil { + r.logger.Error("error during RAG processing", "error", err) + return err + } + default: + // No immediate error, continue } - // Write vectors to storage + + // Write vectors to storage - this will block until vectorCh is closed return r.writeVectors(vectorCh) } @@ -161,14 +188,24 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { for _, vector := range batch { if err := r.storage.WriteVector(&vector); err != nil { r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) - LongJobStatusCh <- ErrRAGStatus + select { + case LongJobStatusCh <- ErrRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } return err // Stop the entire RAG operation on DB error } } r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) if len(vectorCh) == 0 { r.logger.Debug("finished writing vectors") - LongJobStatusCh <- FinishedRAGStatus + select { + case LongJobStatusCh <- FinishedRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } return nil } } @@ -176,12 +213,18 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { } func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, - vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { + vectorCh chan<- []models.VectorRow, errCh chan error, filename string) { var err error + defer func() { - if len(doneCh) == 0 { - doneCh <- true - errCh <- err + // For errCh, make sure we only send if there's actually an error and the channel can accept it + if err != nil { + select { + case errCh <- err: + default: + // errCh might be full or closed, log but don't panic + r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err) + } } }() @@ -211,7 +254,13 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in } r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) - LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) + statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) + select { + case LongJobStatusCh <- statusMsg: + default: + r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) + // Channel is full or closed, ignore the message to prevent panic + } } } |
