diff options
-rw-r--r-- | llm.go | 43 | ||||
-rw-r--r-- | models/models.go | 17 |
2 files changed, 43 insertions, 17 deletions
@@ -9,7 +9,7 @@ import ( ) type ChunkParser interface { - ParseChunk([]byte) (string, bool, error) + ParseChunk([]byte) (*models.TextChunk, error) FormMsg(msg, role string, cont bool) (io.Reader, error) GetToken() string } @@ -114,39 +114,47 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) return bytes.NewReader(data), nil } -func (lcp LlamaCPPeer) ParseChunk(data []byte) (string, bool, error) { +func (lcp LlamaCPPeer) ParseChunk(data []byte) (*models.TextChunk, error) { llmchunk := models.LlamaCPPResp{} + resp := &models.TextChunk{} if err := json.Unmarshal(data, &llmchunk); err != nil { logger.Error("failed to decode", "error", err, "line", string(data)) - return "", false, err + return nil, err } + resp.Chunk = llmchunk.Content if llmchunk.Stop { if llmchunk.Content != "" { logger.Error("text inside of finish llmchunk", "chunk", llmchunk) } - return llmchunk.Content, true, nil + resp.Finished = true } - return llmchunk.Content, false, nil + return resp, nil } func (op OpenAIer) GetToken() string { return "" } -func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) { +func (op OpenAIer) ParseChunk(data []byte) (*models.TextChunk, error) { llmchunk := models.LLMRespChunk{} if err := json.Unmarshal(data, &llmchunk); err != nil { logger.Error("failed to decode", "error", err, "line", string(data)) - return "", false, err + return nil, err + } + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content, + ToolChunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments, } - content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { - if content != "" { + if resp.Chunk != "" { logger.Error("text inside of finish llmchunk", "chunk", llmchunk) } - return content, true, nil + resp.Finished = true } - return content, false, nil + if resp.ToolChunk != "" { + resp.ToolResp = true + } + return resp, nil } func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { @@ -171,19 +179,22 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { } // deepseek -func (ds DeepSeekerCompletion) ParseChunk(data []byte) (string, bool, error) { +func (ds DeepSeekerCompletion) ParseChunk(data []byte) (*models.TextChunk, error) { llmchunk := models.DSCompletionResp{} if err := json.Unmarshal(data, &llmchunk); err != nil { logger.Error("failed to decode", "error", err, "line", string(data)) - return "", false, err + return nil, err + } + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[0].Text, } if llmchunk.Choices[0].FinishReason != "" { - if llmchunk.Choices[0].Text != "" { + if resp.Chunk != "" { logger.Error("text inside of finish llmchunk", "chunk", llmchunk) } - return llmchunk.Choices[0].Text, true, nil + resp.Finished = true } - return llmchunk.Choices[0].Text, false, nil + return resp, nil } func (ds DeepSeekerCompletion) GetToken() string { diff --git a/models/models.go b/models/models.go index c88417f..9ca12d9 100644 --- a/models/models.go +++ b/models/models.go @@ -30,13 +30,21 @@ type LLMResp struct { ID string `json:"id"` } +type ToolDeltaResp struct { + Index int `json:"index"` + Function struct { + Arguments string `json:"arguments"` + } `json:"function"` +} + // for streaming type LLMRespChunk struct { Choices []struct { FinishReason string `json:"finish_reason"` Index int `json:"index"` Delta struct { - Content string `json:"content"` + Content string `json:"content"` + ToolCalls []ToolDeltaResp `json:"tool_calls"` } `json:"delta"` } `json:"choices"` Created int `json:"created"` @@ -50,6 +58,13 @@ type LLMRespChunk struct { } `json:"usage"` } +type TextChunk struct { + Chunk string + ToolChunk string + Finished bool + ToolResp bool +} + type RoleMsg struct { Role string `json:"role"` Content string `json:"content"` |