summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llm.go98
-rw-r--r--rag/rag.go308
-rw-r--r--tools.go77
3 files changed, 383 insertions, 100 deletions
diff --git a/llm.go b/llm.go
index 6697dfa..ebda29b 100644
--- a/llm.go
+++ b/llm.go
@@ -11,7 +11,6 @@ import (
var imageAttachmentPath string // Global variable to track image attachment for next message
var lastImg string // for ctrl+j
-var RAGMsg = "Retrieved context for user's query:\n"
// containsToolSysMsg checks if the toolSysMsg already exists in the chat body
func containsToolSysMsg() bool {
@@ -142,22 +141,6 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg)
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("RAG response received", "response_len", len(ragResp),
- "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
- }
// sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -301,23 +284,6 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role,
"content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages))
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("LCPChat: RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("LCPChat: failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("LCPChat: RAG response received",
- "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("LCPChat: RAG message added to chat body", "role", ragMsg.Role,
- "rag_content_len", len(ragMsg.Content), "message_count_after_rag", len(chatBody.Messages))
- }
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// openai /v1/chat does not support custom roles; needs to be user, assistant, system
// Add persona suffix to the last user message to indicate who the assistant should reply as
@@ -389,22 +355,6 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg)
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("DeepSeekerCompletion: RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("DeepSeekerCompletion: failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("DeepSeekerCompletion: RAG response received",
- "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages))
- }
// sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -474,22 +424,6 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg)
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("RAG response received", "response_len", len(ragResp),
- "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
- }
// Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// Add persona suffix to the last user message to indicate who the assistant should reply as
@@ -552,22 +486,6 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg)
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("RAG response received", "response_len",
- len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
- }
// sending description of the tools and how to use them
if cfg.ToolUse && !resume && role == cfg.UserRole && !containsToolSysMsg() {
chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
@@ -670,22 +588,6 @@ func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
newMsg = *processMessageTag(&newMsg)
chatBody.Messages = append(chatBody.Messages, newMsg)
}
- // if rag - add as system message to avoid conflicts with tool usage
- if !resume && cfg.RAGEnabled {
- um := chatBody.Messages[len(chatBody.Messages)-1].Content
- logger.Debug("RAG is enabled, preparing RAG context", "user_message", um)
- ragResp, err := chatRagUse(um)
- if err != nil {
- logger.Error("failed to form a rag msg", "error", err)
- return nil, err
- }
- logger.Debug("RAG response received", "response_len", len(ragResp),
- "response_preview", ragResp[:min(len(ragResp), 100)])
- // Use system role for RAG context to avoid conflicts with tool usage
- ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp}
- chatBody.Messages = append(chatBody.Messages, ragMsg)
- logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages))
- }
// Create copy of chat body with standardized user role
filteredMessages, _ := filterMessagesForCurrentCharacter(chatBody.Messages)
// Add persona suffix to the last user message to indicate who the assistant should reply as
diff --git a/rag/rag.go b/rag/rag.go
index f554924..b49bd97 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -9,6 +9,8 @@ import (
"log/slog"
"os"
"path"
+ "regexp"
+ "sort"
"strings"
"sync"
@@ -195,3 +197,309 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename)
}
+
+var (
+ queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
+ importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
+ stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
+)
+
+func (r *RAG) RefineQuery(query string) string {
+ original := query
+ query = strings.TrimSpace(query)
+
+ if len(query) == 0 {
+ return original
+ }
+
+ if len(query) <= 3 {
+ return original
+ }
+
+ query = strings.ToLower(query)
+
+ for _, stopWord := range stopWords {
+ wordPattern := `\b` + stopWord + `\b`
+ re := regexp.MustCompile(wordPattern)
+ query = re.ReplaceAllString(query, "")
+ }
+
+ query = strings.TrimSpace(query)
+
+ if len(query) < 5 {
+ return original
+ }
+
+ if queryRefinementPattern.MatchString(original) {
+ cleaned := queryRefinementPattern.ReplaceAllString(original, "")
+ cleaned = strings.TrimSpace(cleaned)
+ if len(cleaned) >= 5 {
+ return cleaned
+ }
+ }
+
+ query = r.extractImportantPhrases(query)
+
+ if len(query) < 5 {
+ return original
+ }
+
+ return query
+}
+
+func (r *RAG) extractImportantPhrases(query string) string {
+ words := strings.Fields(query)
+
+ var important []string
+ for _, word := range words {
+ word = strings.Trim(word, ".,!?;:'\"()[]{}")
+
+ isImportant := false
+ for _, kw := range importantKeywords {
+ if strings.Contains(strings.ToLower(word), kw) {
+ isImportant = true
+ break
+ }
+ }
+
+ if isImportant || len(word) > 3 {
+ important = append(important, word)
+ }
+ }
+
+ if len(important) == 0 {
+ return query
+ }
+
+ return strings.Join(important, " ")
+}
+
+func (r *RAG) GenerateQueryVariations(query string) []string {
+ variations := []string{query}
+
+ if len(query) < 5 {
+ return variations
+ }
+
+ parts := strings.Fields(query)
+ if len(parts) == 0 {
+ return variations
+ }
+
+ if len(parts) >= 2 {
+ trimmed := strings.Join(parts[:len(parts)-1], " ")
+ if len(trimmed) >= 5 {
+ variations = append(variations, trimmed)
+ }
+ }
+
+ if len(parts) >= 2 {
+ trimmed := strings.Join(parts[1:], " ")
+ if len(trimmed) >= 5 {
+ variations = append(variations, trimmed)
+ }
+ }
+
+ if !strings.HasSuffix(query, " explanation") {
+ variations = append(variations, query+" explanation")
+ }
+ if !strings.HasPrefix(query, "what is ") {
+ variations = append(variations, "what is "+query)
+ }
+ if !strings.HasSuffix(query, " details") {
+ variations = append(variations, query+" details")
+ }
+ if !strings.HasSuffix(query, " summary") {
+ variations = append(variations, query+" summary")
+ }
+
+ return variations
+}
+
+func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
+ type scoredResult struct {
+ row models.VectorRow
+ distance float32
+ }
+
+ scored := make([]scoredResult, 0, len(results))
+
+ for i := range results {
+ row := results[i]
+
+ score := float32(0)
+
+ rawTextLower := strings.ToLower(row.RawText)
+ queryLower := strings.ToLower(query)
+
+ if strings.Contains(rawTextLower, queryLower) {
+ score += 10
+ }
+
+ queryWords := strings.Fields(queryLower)
+ matchCount := 0
+ for _, word := range queryWords {
+ if len(word) > 2 && strings.Contains(rawTextLower, word) {
+ matchCount++
+ }
+ }
+ if len(queryWords) > 0 {
+ score += float32(matchCount) / float32(len(queryWords)) * 5
+ }
+
+ if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
+ score += 3
+ }
+
+ distance := row.Distance - score/100
+
+ scored = append(scored, scoredResult{row: row, distance: distance})
+ }
+
+ sort.Slice(scored, func(i, j int) bool {
+ return scored[i].distance < scored[j].distance
+ })
+
+ unique := make([]models.VectorRow, 0)
+ seen := make(map[string]bool)
+
+ for i := range scored {
+ if !seen[scored[i].row.Slug] {
+ seen[scored[i].row.Slug] = true
+ unique = append(unique, scored[i].row)
+ }
+ }
+
+ if len(unique) > 10 {
+ unique = unique[:10]
+ }
+
+ return unique
+}
+
+func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
+ if len(results) == 0 {
+ return "No relevant information found in the vector database.", nil
+ }
+
+ var contextBuilder strings.Builder
+ contextBuilder.WriteString("User Query: ")
+ contextBuilder.WriteString(query)
+ contextBuilder.WriteString("\n\nRetrieved Context:\n")
+
+ for i, row := range results {
+ contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
+ contextBuilder.WriteString(row.RawText)
+ contextBuilder.WriteString("\n\n")
+ }
+
+ contextBuilder.WriteString("Instructions: ")
+ contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
+ contextBuilder.WriteString("Extract only the most relevant information. ")
+ contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
+ contextBuilder.WriteString("Cite sources by filename when relevant. ")
+ contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
+
+ synthesisPrompt := contextBuilder.String()
+
+ emb, err := r.LineToVector(synthesisPrompt)
+ if err != nil {
+ r.logger.Error("failed to embed synthesis prompt", "error", err)
+ return "", err
+ }
+
+ embResp := &models.EmbeddingResp{
+ Embedding: emb,
+ Index: 0,
+ }
+
+ topResults, err := r.SearchEmb(embResp)
+ if err != nil {
+ r.logger.Error("failed to search for synthesis context", "error", err)
+ return "", err
+ }
+
+ if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
+ return topResults[0].RawText, nil
+ }
+
+ var finalAnswer strings.Builder
+ finalAnswer.WriteString("Based on the retrieved context:\n\n")
+
+ for i, row := range results {
+ if i >= 5 {
+ break
+ }
+ finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
+ }
+
+ return finalAnswer.String(), nil
+}
+
+func truncateString(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
+
+func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
+ refined := r.RefineQuery(query)
+ variations := r.GenerateQueryVariations(refined)
+
+ allResults := make([]models.VectorRow, 0)
+ seen := make(map[string]bool)
+
+ for _, q := range variations {
+ emb, err := r.LineToVector(q)
+ if err != nil {
+ r.logger.Error("failed to embed query variation", "error", err, "query", q)
+ continue
+ }
+
+ embResp := &models.EmbeddingResp{
+ Embedding: emb,
+ Index: 0,
+ }
+
+ results, err := r.SearchEmb(embResp)
+ if err != nil {
+ r.logger.Error("failed to search embeddings", "error", err, "query", q)
+ continue
+ }
+
+ for _, row := range results {
+ if !seen[row.Slug] {
+ seen[row.Slug] = true
+ allResults = append(allResults, row)
+ }
+ }
+ }
+
+ reranked := r.RerankResults(allResults, query)
+
+ if len(reranked) > limit {
+ reranked = reranked[:limit]
+ }
+
+ return reranked, nil
+}
+
+var (
+ ragInstance *RAG
+ ragOnce sync.Once
+)
+
+func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
+ ragOnce.Do(func() {
+ if c == nil || l == nil || s == nil {
+ return
+ }
+ ragInstance = New(l, s, c)
+ })
+ return nil
+}
+
+func GetInstance() *RAG {
+ return ragInstance
+}
diff --git a/tools.go b/tools.go
index c397137..5cd2770 100644
--- a/tools.go
+++ b/tools.go
@@ -16,6 +16,7 @@ import (
"sync"
"time"
+ "gf-lt/rag"
"github.com/GrailFinder/searchagent/searcher"
)
@@ -58,9 +59,9 @@ Your current tools:
"when_to_use": "when asked to search the web for information; returns clean summary without html,css and other web elements; limit is optional (default 3)"
},
{
-"name":"websearch_raw",
+"name":"rag_search",
"args": ["query", "limit"],
-"when_to_use": "when asked to search the web for information; returns raw data as is without processing; limit is optional (default 3)"
+"when_to_use": "when asked to search the local document database for information; performs query refinement, semantic search, reranking, and synthesis; returns clean summary with sources; limit is optional (default 3)"
},
{
"name":"read_url",
@@ -146,6 +147,7 @@ under the topic: Adam's number is stored:
After that you are free to respond to the user.
`
webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
+ ragSearchSysPrompt = `Synthesize the document search results, extracting key information and presenting a concise answer. Provide sources and document IDs where relevant.`
readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
summarySysPrompt = `Please provide a concise summary of the following conversation. Focus on key points, decisions, and actions. Provide only the summary, no additional commentary.`
basicCard = &models.CharCard{
@@ -170,6 +172,10 @@ func init() {
panic("failed to init seachagent; error: " + err.Error())
}
WebSearcher = sa
+
+ if err := rag.Init(cfg, logger, store); err != nil {
+ logger.Warn("failed to init rag; rag_search tool will not be available", "error", err)
+ }
}
// getWebAgentClient returns a singleton AgentClient for web agents.
@@ -196,6 +202,8 @@ func getWebAgentClient() *agent.AgentClient {
func registerWebAgents() {
webAgentsOnce.Do(func() {
client := getWebAgentClient()
+ // Register rag_search agent
+ agent.Register("rag_search", agent.NewWebAgentB(client, ragSearchSysPrompt))
// Register websearch agent
agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
// Register read_url agent
@@ -239,6 +247,48 @@ func websearch(args map[string]string) []byte {
return data
}
+// rag search (searches local document database)
+func ragsearch(args map[string]string) []byte {
+ query, ok := args["query"]
+ if !ok || query == "" {
+ msg := "query not provided to rag_search tool"
+ logger.Error(msg)
+ return []byte(msg)
+ }
+ limitS, ok := args["limit"]
+ if !ok || limitS == "" {
+ limitS = "3"
+ }
+ limit, err := strconv.Atoi(limitS)
+ if err != nil || limit == 0 {
+ logger.Warn("ragsearch limit; passed bad value; setting to default (3)",
+ "limit_arg", limitS, "error", err)
+ limit = 3
+ }
+
+ ragInstance := rag.GetInstance()
+ if ragInstance == nil {
+ msg := "rag not initialized; rag_search tool is not available"
+ logger.Error(msg)
+ return []byte(msg)
+ }
+
+ results, err := ragInstance.Search(query, limit)
+ if err != nil {
+ msg := "rag search failed; error: " + err.Error()
+ logger.Error(msg)
+ return []byte(msg)
+ }
+
+ data, err := json.Marshal(results)
+ if err != nil {
+ msg := "failed to marshal rag search result; error: " + err.Error()
+ logger.Error(msg)
+ return []byte(msg)
+ }
+ return data
+}
+
// web search raw (returns raw data without processing)
func websearchRaw(args map[string]string) []byte {
// make http request return bytes
@@ -997,6 +1047,7 @@ var fnMap = map[string]fnSig{
"recall": recall,
"recall_topics": recallTopics,
"memorise": memorise,
+ "rag_search": ragsearch,
"websearch": websearch,
"websearch_raw": websearchRaw,
"read_url": readURL,
@@ -1033,6 +1084,28 @@ func callToolWithAgent(name string, args map[string]string) []byte {
// openai style def
var baseTools = []models.Tool{
+ // rag_search
+ models.Tool{
+ Type: "function",
+ Function: models.ToolFunc{
+ Name: "rag_search",
+ Description: "Search local document database given query, limit of sources (default 3). Performs query refinement, semantic search, reranking, and synthesis.",
+ Parameters: models.ToolFuncParams{
+ Type: "object",
+ Required: []string{"query", "limit"},
+ Properties: map[string]models.ToolArgProps{
+ "query": models.ToolArgProps{
+ Type: "string",
+ Description: "search query",
+ },
+ "limit": models.ToolArgProps{
+ Type: "string",
+ Description: "limit of the document results",
+ },
+ },
+ },
+ },
+ },
// websearch
models.Tool{
Type: "function",