summaryrefslogtreecommitdiff
path: root/llm.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-02-08 18:28:47 +0300
committerGrail Finder <wohilas@gmail.com>2025-02-08 18:28:47 +0300
commitc85766139371bb4324826fa8716b3478eea898c1 (patch)
tree2b58fcff3c79751a4d7e5034e035f6f270cb8bc8 /llm.go
parent884004a855980444319769d9b10f9cf6e3ba33cd (diff)
Feat: add tool reminder bind
Diffstat (limited to 'llm.go')
-rw-r--r--llm.go10
1 files changed, 9 insertions, 1 deletions
diff --git a/llm.go b/llm.go
index be7f418..a5f70bf 100644
--- a/llm.go
+++ b/llm.go
@@ -51,8 +51,11 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string) (io.Reader, error) {
messages[i] = m.ToPrompt()
}
prompt := strings.Join(messages, "\n")
+ if cfg.ToolUse && msg != "" {
+ prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg
+ }
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
- payload := models.NewLCPReq(prompt+botMsgStart, role, defaultLCPProps)
+ payload := models.NewLCPReq(prompt+botMsgStart, cfg, defaultLCPProps)
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
@@ -106,6 +109,11 @@ func (op OpenAIer) FormMsg(msg, role string) (io.Reader, error) {
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
}
+ if cfg.ToolUse {
+ toolMsg := models.RoleMsg{Role: cfg.ToolRole,
+ Content: toolSysMsg}
+ chatBody.Messages = append(chatBody.Messages, toolMsg)
+ }
}
data, err := json.Marshal(chatBody)
if err != nil {