From f40d8afe08c524fc7f9df0dfa0802342af2d2c3d Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Thu, 9 Jan 2025 19:58:08 +0300 Subject: Fix: flow control --- rag/main.go | 59 ++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 23 deletions(-) (limited to 'rag') diff --git a/rag/main.go b/rag/main.go index 58ff448..f1be167 100644 --- a/rag/main.go +++ b/rag/main.go @@ -13,6 +13,7 @@ import ( "os" "path" "strings" + "sync" "github.com/neurosnap/sentences/english" ) @@ -56,7 +57,7 @@ func (r *RAG) LoadRAG(fpath string) error { var ( // TODO: to config workers = 5 - batchSize = 200 + batchSize = 100 maxChSize = 1000 // // psize = 3 @@ -68,7 +69,11 @@ func (r *RAG) LoadRAG(fpath string) error { vectorCh = make(chan []models.VectorRow, maxChSize) errCh = make(chan error, 1) doneCh = make(chan bool, 1) + lock = new(sync.Mutex) ) + defer close(doneCh) + defer close(errCh) + defer close(batchCh) // group sentences paragraphs := []string{} par := strings.Builder{} @@ -100,18 +105,19 @@ func (r *RAG) LoadRAG(fpath string) error { left, right = right, right+batchSize ctn++ } - r.logger.Info("finished batching", "batches#", len(batchCh)) + r.logger.Info("finished batching", "batches#", len(batchCh), "paragraphs", len(paragraphs), "sentences", len(sents)) for w := 0; w < workers; w++ { - go r.batchToVectorHFAsync(len(paragraphs), batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) + go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) } + // wait for emb to be done + <-doneCh // write to db - return r.writeVectors(vectorCh, doneCh) + return r.writeVectors(vectorCh) } -func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan bool) error { +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { for { - select { - case batch := <-vectorCh: + for batch := range vectorCh { for _, vector := range batch { if err := r.store.WriteVector(&vector); err != nil { r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) @@ -119,36 +125,43 @@ func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan boo // return err } } - r.logger.Info("wrote batch to db", "size", len(batch)) - case <-doneCh: - r.logger.Info("rag finished") - return nil + r.logger.Info("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) + if len(vectorCh) == 0 { + r.logger.Info("finished writing vectors") + defer close(vectorCh) + return nil + } } } } -func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, +func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { - r.logger.Info("to vector batches", "batches#", len(inputCh)) for { + lock.Lock() + if len(inputCh) == 0 { + if len(doneCh) == 0 { + doneCh <- true + } + lock.Unlock() + return + } select { case linesMap := <-inputCh: for leftI, v := range linesMap { r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) - if leftI+200 >= limit { // last batch - doneCh <- true - return - } + // if leftI+200 >= limit { // last batch + // // doneCh <- true + // return + // } } - case <-doneCh: - r.logger.Info("got done") - close(errCh) - close(doneCh) - return + lock.Unlock() case err := <-errCh: r.logger.Error("got an error", "error", err) + lock.Unlock() return } + r.logger.Info("to vector batches", "batches#", len(inputCh), "worker#", id) } } @@ -202,7 +215,7 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod vector := models.VectorRow{ Embeddings: e, RawText: lines[i], - Slug: slug, + Slug: fmt.Sprintf("%s_%d", slug, i), FileName: filename, } vectors[i] = vector -- cgit v1.2.3