summaryrefslogtreecommitdiff
path: root/llm.go
diff options
context:
space:
mode:
Diffstat (limited to 'llm.go')
-rw-r--r--llm.go10
1 files changed, 9 insertions, 1 deletions
diff --git a/llm.go b/llm.go
index be7f418..a5f70bf 100644
--- a/llm.go
+++ b/llm.go
@@ -51,8 +51,11 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string) (io.Reader, error) {
messages[i] = m.ToPrompt()
}
prompt := strings.Join(messages, "\n")
+ if cfg.ToolUse && msg != "" {
+ prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg
+ }
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
- payload := models.NewLCPReq(prompt+botMsgStart, role, defaultLCPProps)
+ payload := models.NewLCPReq(prompt+botMsgStart, cfg, defaultLCPProps)
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
@@ -106,6 +109,11 @@ func (op OpenAIer) FormMsg(msg, role string) (io.Reader, error) {
ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
chatBody.Messages = append(chatBody.Messages, ragMsg)
}
+ if cfg.ToolUse {
+ toolMsg := models.RoleMsg{Role: cfg.ToolRole,
+ Content: toolSysMsg}
+ chatBody.Messages = append(chatBody.Messages, toolMsg)
+ }
}
data, err := json.Marshal(chatBody)
if err != nil {