diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-01-09 15:49:59 +0300 | 
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-01-09 15:49:59 +0300 | 
| commit | 363bbae2c756f448d8cdac50305902d68d45c26c (patch) | |
| tree | 4206136d7342006c32e7dbaf2891bce608969427 /rag | |
| parent | 7bbedd93cf078fc7496a6779cf9eda6e588e64c0 (diff) | |
Fix: RAG updates
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/main.go | 70 | 
1 files changed, 47 insertions, 23 deletions
| diff --git a/rag/main.go b/rag/main.go index da919b4..58ff448 100644 --- a/rag/main.go +++ b/rag/main.go @@ -11,6 +11,8 @@ import (  	"log/slog"  	"net/http"  	"os" +	"path" +	"strings"  	"github.com/neurosnap/sentences/english"  ) @@ -29,6 +31,10 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {  	}  } +func wordCounter(sentence string) int { +	return len(strings.Split(sentence, " ")) +} +  func (r *RAG) LoadRAG(fpath string) error {  	data, err := os.ReadFile(fpath)  	if err != nil { @@ -53,6 +59,9 @@ func (r *RAG) LoadRAG(fpath string) error {  		batchSize = 200  		maxChSize = 1000  		// +		// psize     = 3 +		wordLimit = 80 +		//  		left     = 0  		right    = batchSize  		batchCh  = make(chan map[int][]string, maxChSize) @@ -60,23 +69,40 @@ func (r *RAG) LoadRAG(fpath string) error {  		errCh    = make(chan error, 1)  		doneCh   = make(chan bool, 1)  	) -	if len(sents) < batchSize { -		batchSize = len(sents) +	// group sentences +	paragraphs := []string{} +	par := strings.Builder{} +	for i := 0; i < len(sents); i++ { +		par.WriteString(sents[i]) +		if wordCounter(par.String()) > wordLimit { +			paragraphs = append(paragraphs, par.String()) +			par.Reset() +		} +	} +	// for i := 0; i < len(sents); i += psize { +	// 	if len(sents) < i+psize { +	// 		paragraphs = append(paragraphs, strings.Join(sents[i:], " ")) +	// 		break +	// 	} +	// 	paragraphs = append(paragraphs, strings.Join(sents[i:i+psize], " ")) +	// } +	if len(paragraphs) < batchSize { +		batchSize = len(paragraphs)  	}  	// fill input channel  	ctn := 0  	for { -		if right > len(sents) { -			batchCh <- map[int][]string{left: sents[left:]} +		if right > len(paragraphs) { +			batchCh <- map[int][]string{left: paragraphs[left:]}  			break  		} -		batchCh <- map[int][]string{left: sents[left:right]} +		batchCh <- map[int][]string{left: paragraphs[left:right]}  		left, right = right, right+batchSize  		ctn++  	}  	r.logger.Info("finished batching", "batches#", len(batchCh))  	for w := 0; w < workers; w++ { -		go r.batchToVectorHFAsync(len(sents), batchCh, vectorCh, errCh, doneCh) +		go r.batchToVectorHFAsync(len(paragraphs), batchCh, vectorCh, errCh, doneCh, path.Base(fpath))  	}  	// write to db  	return r.writeVectors(vectorCh, doneCh) @@ -102,20 +128,17 @@ func (r *RAG) writeVectors(vectorCh <-chan []models.VectorRow, doneCh <-chan boo  }  func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string, -	vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool) { +	vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {  	r.logger.Info("to vector batches", "batches#", len(inputCh))  	for {  		select {  		case linesMap := <-inputCh: -			// r.logger.Info("batch from ch")  			for leftI, v := range linesMap { -				// r.logger.Info("fetching", "index", leftI) -				r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("test_%d", leftI)) +				r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename)  				if leftI+200 >= limit { // last batch  					doneCh <- true  					return  				} -				// r.logger.Info("done feitching", "index", leftI)  			}  		case <-doneCh:  			r.logger.Info("got done") @@ -129,7 +152,7 @@ func (r *RAG) batchToVectorHFAsync(limit int, inputCh <-chan map[int][]string,  	}  } -func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug string) { +func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) {  	payload, err := json.Marshal(  		map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}},  	) @@ -138,13 +161,14 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod  		errCh <- err  		return  	} +	// nolint  	req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload))  	if err != nil {  		r.logger.Error("failed to create new req", "err:", err.Error())  		errCh <- err  		return  	} -	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) +	req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)  	resp, err := http.DefaultClient.Do(req)  	// nolint  	// resp, err := httpClient.Post(cfg.EmbedURL, "application/json", bytes.NewReader(payload)) @@ -179,6 +203,7 @@ func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []mod  			Embeddings: e,  			RawText:    lines[i],  			Slug:       slug, +			FileName:   filename,  		}  		vectors[i] = vector  	} @@ -201,7 +226,7 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {  		r.logger.Error("failed to create new req", "err:", err.Error())  		return nil, err  	} -	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.cfg.HFToken)) +	req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken)  	resp, err := http.DefaultClient.Do(req)  	// resp, err := req.Post(r.cfg.EmbedURL, "application/json", bytes.NewReader(payload))  	if err != nil { @@ -228,15 +253,14 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {  	return emb[0], nil  } -// func (r *RAG) saveLine(topic, line string, emb *models.EmbeddingResp) error { -// 	row := &models.VectorRow{ -// 		Embeddings: emb.Embedding, -// 		Slug:       topic, -// 		RawText:    line, -// 	} -// 	return r.store.WriteVector(row) -// } -  func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {  	return r.store.SearchClosest(emb.Embedding)  } + +func (r *RAG) ListLoaded() ([]string, error) { +	return r.store.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { +	return r.store.RemoveEmbByFileName(filename) +} | 
