summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go24
1 files changed, 16 insertions, 8 deletions
diff --git a/bot.go b/bot.go
index 423a7cb..72a0c44 100644
--- a/bot.go
+++ b/bot.go
@@ -34,6 +34,7 @@ var (
logLevel = new(slog.LevelVar)
activeChatName string
chunkChan = make(chan string, 10)
+ openAIToolChan = make(chan string, 10)
streamDone = make(chan bool, 1)
chatBody *models.ChatBody
store storage.FullRepo
@@ -189,8 +190,8 @@ func sendMsgToLLM(body io.Reader) {
for {
var (
answerText string
- content string
stop bool
+ chunk *models.TextChunk
)
counter++
// to stop from spiriling in infinity read of bad bytes that happens with poor connection
@@ -225,7 +226,7 @@ func sendMsgToLLM(body io.Reader) {
if bytes.Equal(line, []byte("ROUTER PROCESSING\n")) {
continue
}
- content, stop, err = chunkParser.ParseChunk(line)
+ chunk, err = chunkParser.ParseChunk(line)
if err != nil {
logger.Error("error parsing response body", "error", err,
"line", string(line), "url", cfg.CurrentAPI)
@@ -239,18 +240,19 @@ func sendMsgToLLM(body io.Reader) {
break
}
if stop {
- if content != "" {
- logger.Warn("text inside of finish llmchunk", "chunk", content, "counter", counter)
+ if chunk.Chunk != "" {
+ logger.Warn("text inside of finish llmchunk", "chunk", chunk, "counter", counter)
}
streamDone <- true
break
}
if counter == 0 {
- content = strings.TrimPrefix(content, " ")
+ chunk.Chunk = strings.TrimPrefix(chunk.Chunk, " ")
}
// bot sends way too many \n
- answerText = strings.ReplaceAll(content, "\n\n", "\n")
+ answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
chunkChan <- answerText
+ openAIToolChan <- chunk.ToolChunk
interrupt:
if interruptResp { // read bytes, so it would not get into beginning of the next req
interruptResp = false
@@ -362,6 +364,7 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
}
}
respText := strings.Builder{}
+ toolResp := strings.Builder{}
out:
for {
select {
@@ -374,6 +377,10 @@ out:
// audioStream.TextChan <- chunk
extra.TTSTextChan <- chunk
}
+ case toolChunk := <-openAIToolChan:
+ fmt.Fprint(tv, toolChunk)
+ toolResp.WriteString(toolChunk)
+ tv.ScrollToEnd()
case <-streamDone:
botRespMode = false
if cfg.TTS_ENABLED {
@@ -402,10 +409,11 @@ out:
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
logger.Warn("failed to update storage", "error", err, "name", activeChatName)
}
- findCall(respText.String(), tv)
+ // INFO: for completion only; openai has it's own tool struct
+ findCall(respText.String(), toolResp.String(), tv)
}
-func findCall(msg string, tv *tview.TextView) {
+func findCall(msg, toolCall string, tv *tview.TextView) {
fc := models.FuncCall{}
jsStr := toolCallRE.FindString(msg)
if jsStr == "" {