summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-12-10 15:07:06 +0300
committerGrail Finder <wohilas@gmail.com>2025-12-10 15:07:06 +0300
commitad36d1c3e0b545c3e0517ec384087075ad77f63e (patch)
tree9a5208f1e9c1b918949bb9322bfcf3abc899c15b
parent8af2a59a9a12667bae2ce138259d432ba81f8e03 (diff)
Fix: rag panics
-rw-r--r--bot.go2
-rw-r--r--config.example.toml4
-rw-r--r--rag/rag.go87
3 files changed, 70 insertions, 23 deletions
diff --git a/bot.go b/bot.go
index 1988969..5c8ca1e 100644
--- a/bot.go
+++ b/bot.go
@@ -429,13 +429,11 @@ func chatRagUse(qText string) (string, error) {
logger.Error("failed to get embs", "error", err, "index", i, "question", q)
continue
}
-
// Create EmbeddingResp struct for the search
embeddingResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0, // Not used in search but required for the struct
}
-
vecs, err := ragger.SearchEmb(embeddingResp)
if err != nil {
logger.Error("failed to query embs", "error", err, "index", i, "question", q)
diff --git a/config.example.toml b/config.example.toml
index 47e4408..113b7ea 100644
--- a/config.example.toml
+++ b/config.example.toml
@@ -10,7 +10,7 @@ DeepSeekModel = "deepseek-reasoner"
OpenRouterCompletionAPI = "https://openrouter.ai/api/v1/completions"
OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
# OpenRouterToken = ""
-EmbedURL = "http://localhost:8080/v1/embeddings"
+EmbedURL = "http://localhost:8082/v1/embeddings"
ShowSys = true
LogFile = "log.txt"
UserRole = "user"
@@ -19,7 +19,7 @@ AssistantRole = "assistant"
SysDir = "sysprompts"
ChunkLimit = 100000
# rag settings
-RAGBatchSize = 10
+RAGBatchSize = 1
RAGWordLimit = 80
RAGWorkers = 2
RAGDir = "ragimport"
diff --git a/rag/rag.go b/rag/rag.go
index f50a913..b29b9eb 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -23,6 +23,7 @@ var (
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
)
+
type RAG struct {
logger *slog.Logger
store storage.FullRepo
@@ -58,7 +59,12 @@ func (r *RAG) LoadRAG(fpath string) error {
return err
}
r.logger.Debug("rag: loaded file", "fp", fpath)
- LongJobStatusCh <- LoadedFileRAGStatus
+ select {
+ case LongJobStatusCh <- LoadedFileRAGStatus:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
+ // Channel is full or closed, ignore the message to prevent panic
+ }
fileText := string(data)
tokenizer, err := english.NewSentenceTokenizer(nil)
@@ -116,11 +122,10 @@ func (r *RAG) LoadRAG(fpath string) error {
batchCh = make(chan map[int][]string, maxChSize)
vectorCh = make(chan []models.VectorRow, maxChSize)
errCh = make(chan error, 1)
- doneCh = make(chan bool, 1)
+ wg = new(sync.WaitGroup)
lock = new(sync.Mutex)
)
- defer close(doneCh)
defer close(errCh)
defer close(batchCh)
@@ -139,19 +144,41 @@ func (r *RAG) LoadRAG(fpath string) error {
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents))
r.logger.Debug(finishedBatchesMsg)
- LongJobStatusCh <- finishedBatchesMsg
+ select {
+ case LongJobStatusCh <- finishedBatchesMsg:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg)
+ // Channel is full or closed, ignore the message to prevent panic
+ }
- // Start worker goroutines
+ // Start worker goroutines with WaitGroup
+ wg.Add(int(r.cfg.RAGWorkers))
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
- go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
+ go func(workerID int) {
+ defer wg.Done()
+ r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath))
+ }(w)
}
- // Wait for embedding to be done
- <-doneCh
- err = <-errCh
- if err != nil {
- return err
+
+ // Use a goroutine to close the batchCh when all batches are sent
+ go func() {
+ wg.Wait()
+ close(vectorCh) // Close vectorCh when all workers are done
+ }()
+
+ // Check for errors from workers
+ // Use a non-blocking check for errors
+ select {
+ case err := <-errCh:
+ if err != nil {
+ r.logger.Error("error during RAG processing", "error", err)
+ return err
+ }
+ default:
+ // No immediate error, continue
}
- // Write vectors to storage
+
+ // Write vectors to storage - this will block until vectorCh is closed
return r.writeVectors(vectorCh)
}
@@ -161,14 +188,24 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
for _, vector := range batch {
if err := r.storage.WriteVector(&vector); err != nil {
r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
- LongJobStatusCh <- ErrRAGStatus
+ select {
+ case LongJobStatusCh <- ErrRAGStatus:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
+ // Channel is full or closed, ignore the message to prevent panic
+ }
return err // Stop the entire RAG operation on DB error
}
}
r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh))
if len(vectorCh) == 0 {
r.logger.Debug("finished writing vectors")
- LongJobStatusCh <- FinishedRAGStatus
+ select {
+ case LongJobStatusCh <- FinishedRAGStatus:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
+ // Channel is full or closed, ignore the message to prevent panic
+ }
return nil
}
}
@@ -176,12 +213,18 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
}
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
- vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
+ vectorCh chan<- []models.VectorRow, errCh chan error, filename string) {
var err error
+
defer func() {
- if len(doneCh) == 0 {
- doneCh <- true
- errCh <- err
+ // For errCh, make sure we only send if there's actually an error and the channel can accept it
+ if err != nil {
+ select {
+ case errCh <- err:
+ default:
+ // errCh might be full or closed, log but don't panic
+ r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
+ }
}
}()
@@ -211,7 +254,13 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
}
r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id)
- LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
+ statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id)
+ select {
+ case LongJobStatusCh <- statusMsg:
+ default:
+ r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
+ // Channel is full or closed, ignore the message to prevent panic
+ }
}
}