summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-01-24 09:25:08 +0300
committerGrail Finder <wohilas@gmail.com>2025-01-24 09:25:08 +0300
commit976d6423ac0f0b80efb2c933d8c3e6a816507c54 (patch)
tree2610428b5d5b8230fa9a5c2e530a1fdd85ae2d21
parent3374080ba01e601c2c05aecc114b6c0a53b60f76 (diff)
Refactor: different endpoint types
-rw-r--r--llm.go116
1 files changed, 116 insertions, 0 deletions
diff --git a/llm.go b/llm.go
new file mode 100644
index 0000000..3fb248a
--- /dev/null
+++ b/llm.go
@@ -0,0 +1,116 @@
+package main
+
+import (
+ "bytes"
+ "elefant/models"
+ "encoding/json"
+ "io"
+ "strings"
+)
+
+type ChunkParser interface {
+ ParseChunk([]byte) (string, bool, error)
+ FormMsg(msg, role string) (io.Reader, error)
+}
+
+func initChunkParser() {
+ chunkParser = LlamaCPPeer{}
+ if strings.Contains(cfg.APIURL, "v1") {
+ logger.Info("chosen openai parser")
+ chunkParser = OpenAIer{}
+ return
+ }
+ logger.Info("chosen llamacpp parser")
+}
+
+type LlamaCPPeer struct {
+}
+type OpenAIer struct {
+}
+
+func (lcp LlamaCPPeer) FormMsg(msg, role string) (io.Reader, error) {
+ if msg != "" { // otherwise let the bot continue
+ // if role == cfg.UserRole {
+ // msg = msg + cfg.AssistantRole + ":"
+ // }
+ newMsg := models.RoleMsg{Role: role, Content: msg}
+ chatBody.Messages = append(chatBody.Messages, newMsg)
+ // if rag
+ if cfg.RAGEnabled {
+ ragResp, err := chatRagUse(newMsg.Content)
+ if err != nil {
+ logger.Error("failed to form a rag msg", "error", err)
+ return nil, err
+ }
+ ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
+ chatBody.Messages = append(chatBody.Messages, ragMsg)
+ }
+ }
+ messages := make([]string, len(chatBody.Messages))
+ for i, m := range chatBody.Messages {
+ messages[i] = m.ToPrompt()
+ }
+ prompt := strings.Join(messages, "\n")
+ botMsgStart := "\n" + cfg.AssistantRole + ":\n"
+ payload := models.NewLCPReq(prompt+botMsgStart, role)
+ data, err := json.Marshal(payload)
+ if err != nil {
+ logger.Error("failed to form a msg", "error", err)
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+func (lcp LlamaCPPeer) ParseChunk(data []byte) (string, bool, error) {
+ llmchunk := models.LlamaCPPResp{}
+ if err := json.Unmarshal(data, &llmchunk); err != nil {
+ logger.Error("failed to decode", "error", err, "line", string(data))
+ return "", false, err
+ }
+ if llmchunk.Stop {
+ if llmchunk.Content != "" {
+ logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
+ }
+ return llmchunk.Content, true, nil
+ }
+ return llmchunk.Content, false, nil
+}
+
+func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) {
+ llmchunk := models.LLMRespChunk{}
+ if err := json.Unmarshal(data, &llmchunk); err != nil {
+ logger.Error("failed to decode", "error", err, "line", string(data))
+ return "", false, err
+ }
+ content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content
+ if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
+ if content != "" {
+ logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
+ }
+ return content, true, nil
+ }
+ return content, false, nil
+}
+
+func (op OpenAIer) FormMsg(msg, role string) (io.Reader, error) {
+ if msg != "" { // otherwise let the bot continue
+ newMsg := models.RoleMsg{Role: role, Content: msg}
+ chatBody.Messages = append(chatBody.Messages, newMsg)
+ // if rag
+ if cfg.RAGEnabled {
+ ragResp, err := chatRagUse(newMsg.Content)
+ if err != nil {
+ logger.Error("failed to form a rag msg", "error", err)
+ return nil, err
+ }
+ ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
+ chatBody.Messages = append(chatBody.Messages, ragMsg)
+ }
+ }
+ data, err := json.Marshal(chatBody)
+ if err != nil {
+ logger.Error("failed to form a msg", "error", err)
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}