diff options
author | Grail Finder <wohilas@gmail.com> | 2025-01-24 09:25:08 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2025-01-24 09:25:08 +0300 |
commit | 976d6423ac0f0b80efb2c933d8c3e6a816507c54 (patch) | |
tree | 2610428b5d5b8230fa9a5c2e530a1fdd85ae2d21 | |
parent | 3374080ba01e601c2c05aecc114b6c0a53b60f76 (diff) |
Refactor: different endpoint types
-rw-r--r-- | llm.go | 116 |
1 files changed, 116 insertions, 0 deletions
@@ -0,0 +1,116 @@ +package main + +import ( + "bytes" + "elefant/models" + "encoding/json" + "io" + "strings" +) + +type ChunkParser interface { + ParseChunk([]byte) (string, bool, error) + FormMsg(msg, role string) (io.Reader, error) +} + +func initChunkParser() { + chunkParser = LlamaCPPeer{} + if strings.Contains(cfg.APIURL, "v1") { + logger.Info("chosen openai parser") + chunkParser = OpenAIer{} + return + } + logger.Info("chosen llamacpp parser") +} + +type LlamaCPPeer struct { +} +type OpenAIer struct { +} + +func (lcp LlamaCPPeer) FormMsg(msg, role string) (io.Reader, error) { + if msg != "" { // otherwise let the bot continue + // if role == cfg.UserRole { + // msg = msg + cfg.AssistantRole + ":" + // } + newMsg := models.RoleMsg{Role: role, Content: msg} + chatBody.Messages = append(chatBody.Messages, newMsg) + // if rag + if cfg.RAGEnabled { + ragResp, err := chatRagUse(newMsg.Content) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil, err + } + ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + } + } + messages := make([]string, len(chatBody.Messages)) + for i, m := range chatBody.Messages { + messages[i] = m.ToPrompt() + } + prompt := strings.Join(messages, "\n") + botMsgStart := "\n" + cfg.AssistantRole + ":\n" + payload := models.NewLCPReq(prompt+botMsgStart, role) + data, err := json.Marshal(payload) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} + +func (lcp LlamaCPPeer) ParseChunk(data []byte) (string, bool, error) { + llmchunk := models.LlamaCPPResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return "", false, err + } + if llmchunk.Stop { + if llmchunk.Content != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + return llmchunk.Content, true, nil + } + return llmchunk.Content, false, nil +} + +func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) { + llmchunk := models.LLMRespChunk{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return "", false, err + } + content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content + if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { + if content != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + return content, true, nil + } + return content, false, nil +} + +func (op OpenAIer) FormMsg(msg, role string) (io.Reader, error) { + if msg != "" { // otherwise let the bot continue + newMsg := models.RoleMsg{Role: role, Content: msg} + chatBody.Messages = append(chatBody.Messages, newMsg) + // if rag + if cfg.RAGEnabled { + ragResp, err := chatRagUse(newMsg.Content) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil, err + } + ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + } + } + data, err := json.Marshal(chatBody) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} |