summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--.golangci.yml50
-rw-r--r--Makefile4
-rw-r--r--README.md63
-rw-r--r--assets/ex01.pngbin0 -> 69006 bytes
-rw-r--r--bot.go213
-rw-r--r--config.example.toml11
-rw-r--r--config/config.go51
-rw-r--r--extra/cluedo.go73
-rw-r--r--extra/cluedo_test.go50
-rw-r--r--extra/stt.go166
-rw-r--r--extra/tts.go212
-rw-r--r--extra/twentyq.go11
-rw-r--r--extra/vad.go1
-rw-r--r--go.mod8
-rw-r--r--go.sum22
-rw-r--r--llm.go171
-rw-r--r--main.go2
-rw-r--r--main_test.go2
-rw-r--r--models/extra.go8
-rw-r--r--models/models.go170
-rw-r--r--pngmeta/altwriter.go133
-rw-r--r--pngmeta/metareader.go6
-rw-r--r--pngmeta/metareader_test.go164
-rw-r--r--pngmeta/partswriter.go214
-rw-r--r--rag/main.go28
-rw-r--r--server.go2
-rw-r--r--session.go21
-rw-r--r--storage/memory.go2
-rw-r--r--storage/storage.go2
-rw-r--r--storage/storage_test.go173
-rw-r--r--storage/vector.go2
-rw-r--r--sysprompts/cluedo.json7
-rw-r--r--tables.go242
-rw-r--r--tools.go23
-rw-r--r--tui.go205
36 files changed, 2096 insertions, 420 deletions
diff --git a/.gitignore b/.gitignore
index d99fe9d..6ec208c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,8 @@ history/
*.db
config.toml
sysprompts/*
+!sysprompts/cluedo.json
history_bak/
+.aider*
+tags
+gf-lt
diff --git a/.golangci.yml b/.golangci.yml
index 66732bf..d377c38 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -1,32 +1,44 @@
+version: "2"
run:
- timeout: 1m
concurrency: 2
tests: false
-
linters:
- enable-all: false
- disable-all: true
+ default: none
enable:
+ - bodyclose
- errcheck
- - gosimple
+ - fatcontext
- govet
- ineffassign
+ - noctx
+ - perfsprint
+ - prealloc
- staticcheck
- - typecheck
- unused
- - prealloc
- presets:
- - performance
-
-linters-settings:
- funlen:
- lines: 80
- statements: 50
- lll:
- line-length: 80
-
+ settings:
+ funlen:
+ lines: 80
+ statements: 50
+ lll:
+ line-length: 80
+ exclusions:
+ generated: lax
+ presets:
+ - comments
+ - common-false-positives
+ - legacy
+ - std-error-handling
+ paths:
+ - third_party$
+ - builtin$
+ - examples$
issues:
- exclude:
- # Display all issues
max-issues-per-linter: 0
max-same-issues: 0
+formatters:
+ exclusions:
+ generated: lax
+ paths:
+ - third_party$
+ - builtin$
+ - examples$
diff --git a/Makefile b/Makefile
index 4e96ed5..87304cc 100644
--- a/Makefile
+++ b/Makefile
@@ -1,10 +1,10 @@
.PHONY: setconfig run lint
run: setconfig
- go build -o elefant && ./elefant
+ go build -o gf-lt && ./gf-lt
server: setconfig
- go build -o elefant && ./elefant -port 3333
+ go build -o gf-lt && ./gf-lt -port 3333
setconfig:
find config.toml &>/dev/null || cp config.example.toml config.toml
diff --git a/README.md b/README.md
index 4925eda..6c29107 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,62 @@
-#### tui with middleware for LLM use
+### gf-lt (grail finder's llm tui)
+terminal user interface for large language models.
+made with use of [tview](https://github.com/rivo/tview)
+
+#### has/supports
+- character card spec;
+- llama.cpp api, deepseek (other ones were not tested);
+- showing images (not really, for now only if your char card is png it could show it);
+- tts/stt (if whisper.cpp server / fastapi tts server are provided);
+
+#### does not have/support
+- images; (ctrl+j will show an image of the card you use, but that is about it);
+- RAG; (RAG was implemented, but I found it unusable and then sql extention broke, so no RAG);
+- MCP; (agentic is implemented, but as a raw and predefined functions for llm to use. see [tools.go](https://github.com/GrailFinder/gf-lt/blob/master/tools.go));
+
+#### usage examples
+![usage example](assets/ex01.png)
+
+#### how to install
+(requires golang)
+clone the project
+```
+cd gf-lt
+make
+```
+
+#### keybindings
+while running you can press f12 for list of keys;
+```
+Esc: send msg
+PgUp/Down: switch focus between input and chat widgets
+F1: manage chats
+F2: regen last
+F3: delete last msg
+F4: edit msg
+F5: toggle system
+F6: interrupt bot resp
+F7: copy last msg to clipboard (linux xclip)
+F8: copy n msg to clipboard (linux xclip)
+F9: table to copy from; with all code blocks
+F11: import chat file
+F12: show this help page
+Ctrl+w: resume generation on the last msg
+Ctrl+s: load new char/agent
+Ctrl+e: export chat to json file
+Ctrl+n: start a new chat
+Ctrl+c: close programm
+Ctrl+p: props edit form (min-p, dry, etc.)
+Ctrl+v: switch between /completion and /chat api (if provided in config)
+Ctrl+r: start/stop recording from your microphone (needs stt server)
+Ctrl+t: remove thinking (<think>) and tool messages from context (delete from chat)
+Ctrl+l: update connected model name (llamacpp)
+Ctrl+k: switch tool use (recommend tool use to llm after user msg)
+Ctrl+j: if chat agent is char.png will show the image; then any key to return
+Ctrl+a: interrupt tts (needs tts server)
+```
+
+#### setting up config
+```
+cp config.example.toml config.toml
+```
+set values as you need them to be.
diff --git a/assets/ex01.png b/assets/ex01.png
new file mode 100644
index 0000000..b0f5ae3
--- /dev/null
+++ b/assets/ex01.png
Binary files differ
diff --git a/bot.go b/bot.go
index c3ce273..0503548 100644
--- a/bot.go
+++ b/bot.go
@@ -2,17 +2,22 @@ package main
import (
"bufio"
- "elefant/config"
- "elefant/models"
- "elefant/rag"
- "elefant/storage"
+ "bytes"
+ "context"
"encoding/json"
"fmt"
+ "gf-lt/config"
+ "gf-lt/extra"
+ "gf-lt/models"
+ "gf-lt/rag"
+ "gf-lt/storage"
"io"
"log/slog"
+ "net"
"net/http"
"os"
"path"
+ "strconv"
"strings"
"time"
@@ -20,9 +25,10 @@ import (
"github.com/rivo/tview"
)
-var httpClient = http.Client{}
-
var (
+ httpClient = &http.Client{}
+ cluedoState *extra.CluedoRoundInfo // Current game state
+ playerOrder []string // Turn order tracking
cfg *config.Config
logger *slog.Logger
logLevel = new(slog.LevelVar)
@@ -36,9 +42,11 @@ var (
defaultStarterBytes = []byte{}
interruptResp = false
ragger *rag.RAG
- currentModel = "none"
chunkParser ChunkParser
- defaultLCPProps = map[string]float32{
+ //nolint:unused // TTS_ENABLED conditionally uses this
+ orator extra.Orator
+ asr extra.STT
+ defaultLCPProps = map[string]float32{
"temperature": 0.8,
"dry_multiplier": 0.0,
"min_p": 0.05,
@@ -46,7 +54,30 @@ var (
}
)
+func createClient(connectTimeout time.Duration) *http.Client {
+ // Custom transport with connection timeout
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ // Create a dialer with connection timeout
+ dialer := &net.Dialer{
+ Timeout: connectTimeout,
+ KeepAlive: 30 * time.Second, // Optional
+ }
+ return dialer.DialContext(ctx, network, addr)
+ },
+ // Other transport settings (optional)
+ TLSHandshakeTimeout: connectTimeout,
+ ResponseHeaderTimeout: connectTimeout,
+ }
+ // Client with no overall timeout (or set to streaming-safe duration)
+ return &http.Client{
+ Transport: transport,
+ Timeout: 0, // No overall timeout (for streaming)
+ }
+}
+
func fetchModelName() *models.LLMModels {
+ // TODO: to config
api := "http://localhost:8080/v1/models"
//nolint
resp, err := httpClient.Get(api)
@@ -61,16 +92,63 @@ func fetchModelName() *models.LLMModels {
return nil
}
if resp.StatusCode != 200 {
- currentModel = "disconnected"
+ chatBody.Model = "disconnected"
return nil
}
- currentModel = path.Base(llmModel.Data[0].ID)
+ chatBody.Model = path.Base(llmModel.Data[0].ID)
return &llmModel
}
+// nolint
+func fetchDSBalance() *models.DSBalance {
+ url := "https://api.deepseek.com/user/balance"
+ method := "GET"
+ // nolint
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ logger.Warn("failed to create request", "error", err)
+ return nil
+ }
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
+ res, err := httpClient.Do(req)
+ if err != nil {
+ logger.Warn("failed to make request", "error", err)
+ return nil
+ }
+ defer res.Body.Close()
+ resp := models.DSBalance{}
+ if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
+ return nil
+ }
+ return &resp
+}
+
func sendMsgToLLM(body io.Reader) {
+ choseChunkParser()
+ bodyBytes, _ := io.ReadAll(body)
+ ok := json.Valid(bodyBytes)
+ if !ok {
+ panic("invalid json")
+ }
+ // nolint
+ req, err := http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes))
+ if err != nil {
+ logger.Error("newreq error", "error", err)
+ if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
+ logger.Error("failed to notify", "error", err)
+ }
+ streamDone <- true
+ return
+ }
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
+ req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
+ req.Header.Set("Accept-Encoding", "gzip")
// nolint
- resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
+ // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
+ resp, err := httpClient.Do(req)
if err != nil {
logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
@@ -97,12 +175,13 @@ func sendMsgToLLM(body io.Reader) {
}
line, err := reader.ReadBytes('\n')
if err != nil {
- logger.Error("error reading response body", "error", err, "line", string(line))
- if err.Error() != "EOF" {
- streamDone <- true
- break
- }
- continue
+ logger.Error("error reading response body", "error", err, "line", string(line),
+ "reqbody", string(bodyBytes), "user_role", cfg.UserRole, "parser", chunkParser, "link", cfg.CurrentAPI)
+ // if err.Error() != "EOF" {
+ streamDone <- true
+ break
+ // }
+ // continue
}
if len(line) <= 1 {
if interruptResp {
@@ -113,9 +192,20 @@ func sendMsgToLLM(body io.Reader) {
// starts with -> data:
line = line[6:]
logger.Debug("debugging resp", "line", string(line))
+ if bytes.Equal(line, []byte("[DONE]\n")) {
+ streamDone <- true
+ break
+ }
content, stop, err = chunkParser.ParseChunk(line)
if err != nil {
- logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
+ logger.Error("error parsing response body", "error", err,
+ "line", string(line), "url", cfg.CurrentAPI)
+ streamDone <- true
+ break
+ }
+ // Handle error messages in response content
+ if string(line) != "" && strings.Contains(strings.ToLower(string(line)), "error") {
+ logger.Error("API error response detected", "line", line, "url", cfg.CurrentAPI)
streamDone <- true
break
}
@@ -183,9 +273,49 @@ func roleToIcon(role string) string {
return "<" + role + ">: "
}
+// FIXME: it should not be here; move to extra
+func checkGame(role string, tv *tview.TextView) {
+ // Handle Cluedo game flow
+ // should go before form msg, since formmsg takes chatBody and makes ioreader out of it
+ // role is almost always user, unless it's regen or resume
+ // cannot get in this block, since cluedoState is nil;
+ if cfg.EnableCluedo {
+ // Initialize Cluedo game if needed
+ if cluedoState == nil {
+ playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2}
+ cluedoState = extra.CluedoPrepCards(playerOrder)
+ }
+ // notifyUser("got in cluedo", "yay")
+ currentPlayer := playerOrder[0]
+ playerOrder = append(playerOrder[1:], currentPlayer) // Rotate turns
+ if role == cfg.UserRole {
+ fmt.Fprintf(tv, "Your (%s) cards: %s\n", currentPlayer, cluedoState.GetPlayerCards(currentPlayer))
+ } else {
+ chatBody.Messages = append(chatBody.Messages, models.RoleMsg{
+ Role: cfg.ToolRole,
+ Content: cluedoState.GetPlayerCards(currentPlayer),
+ })
+ }
+ }
+}
+
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
botRespMode = true
- // reader := formMsg(chatBody, userMsg, role)
+ defer func() { botRespMode = false }()
+ // check that there is a model set to use if is not local
+ if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
+ if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
+ if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
+ logger.Warn("failed ot notify user", "error", err)
+ return
+ }
+ return
+ }
+ }
+ if !resume {
+ checkGame(role, tv)
+ }
+ choseChunkParser()
reader, err := chunkParser.FormMsg(userMsg, role, resume)
if reader == nil || err != nil {
logger.Error("empty reader from msgs", "role", role, "error", err)
@@ -193,7 +323,6 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
}
go sendMsgToLLM(reader)
logger.Debug("looking at vars in chatRound", "msg", userMsg, "regen", regen, "resume", resume)
- // TODO: consider case where user msg is regened (not assistant one)
if !resume {
fmt.Fprintf(tv, "[-:-:b](%d) ", len(chatBody.Messages))
fmt.Fprint(tv, roleToIcon(cfg.AssistantRole))
@@ -211,8 +340,18 @@ out:
fmt.Fprint(tv, chunk)
respText.WriteString(chunk)
tv.ScrollToEnd()
+ // Send chunk to audio stream handler
+ if cfg.TTS_ENABLED {
+ // audioStream.TextChan <- chunk
+ extra.TTSTextChan <- chunk
+ }
case <-streamDone:
botRespMode = false
+ if cfg.TTS_ENABLED {
+ // audioStream.TextChan <- chunk
+ extra.TTSFlushChan <- true
+ logger.Info("sending flushchan signal")
+ }
break out
}
}
@@ -282,13 +421,15 @@ func chatToText(showSys bool) string {
func removeThinking(chatBody *models.ChatBody) {
msgs := []models.RoleMsg{}
for _, msg := range chatBody.Messages {
- rm := models.RoleMsg{}
+ // Filter out tool messages and thinking markers
if msg.Role == cfg.ToolRole {
continue
}
// find thinking and remove it
- rm.Content = thinkRE.ReplaceAllString(msg.Content, "")
- rm.Role = msg.Role
+ rm := models.RoleMsg{
+ Role: msg.Role,
+ Content: thinkRE.ReplaceAllString(msg.Content, ""),
+ }
msgs = append(msgs, rm)
}
chatBody.Messages = msgs
@@ -296,8 +437,15 @@ func removeThinking(chatBody *models.ChatBody) {
func applyCharCard(cc *models.CharCard) {
cfg.AssistantRole = cc.Role
+ // FIXME: remove
+ // Initialize Cluedo if enabled and matching role
+ if cfg.EnableCluedo && cc.Role == "CluedoPlayer" {
+ playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2}
+ cluedoState = extra.CluedoPrepCards(playerOrder)
+ }
history, err := loadAgentsLastChat(cfg.AssistantRole)
if err != nil {
+ // TODO: too much action for err != nil; loadAgentsLastChat needs to be split up
logger.Warn("failed to load last agent chat;", "agent", cc.Role, "err", err)
history = []models.RoleMsg{
{Role: "system", Content: cc.SysPrompt},
@@ -353,6 +501,7 @@ func init() {
//
logLevel.Set(slog.LevelInfo)
logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel}))
+ // TODO: rename and/or put in cfg
store = storage.NewProviderSQL("test.db", logger)
if store == nil {
os.Exit(1)
@@ -366,11 +515,21 @@ func init() {
}
lastChat := loadOldChatOrGetNew()
chatBody = &models.ChatBody{
- Model: "modl_name",
+ Model: "modelname",
Stream: true,
Messages: lastChat,
}
- initChunkParser()
- // go runModelNameTicker(time.Second * 120)
- // tempLoad()
+ // Initialize Cluedo if enabled and matching role
+ if cfg.EnableCluedo && cfg.AssistantRole == "CluedoPlayer" {
+ playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2}
+ cluedoState = extra.CluedoPrepCards(playerOrder)
+ }
+ choseChunkParser()
+ httpClient = createClient(time.Second * 15)
+ if cfg.TTS_ENABLED {
+ orator = extra.NewOrator(logger, cfg)
+ }
+ if cfg.STT_ENABLED {
+ asr = extra.NewWhisperSTT(logger, cfg.STT_URL, 16000)
+ }
}
diff --git a/config.example.toml b/config.example.toml
index 80e3640..229f657 100644
--- a/config.example.toml
+++ b/config.example.toml
@@ -8,3 +8,14 @@ ToolRole = "tool"
AssistantRole = "assistant"
SysDir = "sysprompts"
ChunkLimit = 100000
+# rag settings
+RAGBatchSize = 100
+RAGWordLimit = 80
+RAGWorkers = 5
+# extra tts
+TTS_ENABLED = false
+TTS_URL = "http://localhost:8880/v1/audio/speech"
+TTS_SPEED = 1.0
+# extra stt
+STT_ENABLED = false
+STT_URL = "http://localhost:8081/inference"
diff --git a/config/config.go b/config/config.go
index f26a82e..e612aa7 100644
--- a/config/config.go
+++ b/config/config.go
@@ -7,10 +7,13 @@ import (
)
type Config struct {
- ChatAPI string `toml:"ChatAPI"`
- CompletionAPI string `toml:"CompletionAPI"`
- CurrentAPI string
- APIMap map[string]string
+ EnableCluedo bool `toml:"EnableCluedo"` // Cluedo game mode toggle
+ CluedoRole2 string `toml:"CluedoRole2"` // Secondary AI role name
+ ChatAPI string `toml:"ChatAPI"`
+ CompletionAPI string `toml:"CompletionAPI"`
+ CurrentAPI string
+ CurrentProvider string
+ APIMap map[string]string
//
ShowSys bool `toml:"ShowSys"`
LogFile string `toml:"LogFile"`
@@ -26,6 +29,23 @@ type Config struct {
EmbedURL string `toml:"EmbedURL"`
HFToken string `toml:"HFToken"`
RAGDir string `toml:"RAGDir"`
+ // rag settings
+ RAGWorkers uint32 `toml:"RAGWorkers"`
+ RAGBatchSize int `toml:"RAGBatchSize"`
+ RAGWordLimit uint32 `toml:"RAGWordLimit"`
+ // deepseek
+ DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
+ DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
+ DeepSeekToken string `toml:"DeepSeekToken"`
+ DeepSeekModel string `toml:"DeepSeekModel"`
+ ApiLinks []string
+ // TTS
+ TTS_URL string `toml:"TTS_URL"`
+ TTS_ENABLED bool `toml:"TTS_ENABLED"`
+ TTS_SPEED float32 `toml:"TTS_SPEED"`
+ // STT
+ STT_URL string `toml:"STT_URL"`
+ STT_ENABLED bool `toml:"STT_ENABLED"`
}
func LoadConfigOrDefault(fn string) *Config {
@@ -35,9 +55,11 @@ func LoadConfigOrDefault(fn string) *Config {
config := &Config{}
_, err := toml.DecodeFile(fn, &config)
if err != nil {
- fmt.Println("failed to read config from file, loading default")
+ fmt.Println("failed to read config from file, loading default", "error", err)
config.ChatAPI = "http://localhost:8080/v1/chat/completions"
config.CompletionAPI = "http://localhost:8080/completion"
+ config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions"
+ config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions"
config.RAGEnabled = false
config.EmbedURL = "http://localhost:8080/v1/embiddings"
config.ShowSys = true
@@ -47,15 +69,24 @@ func LoadConfigOrDefault(fn string) *Config {
config.AssistantRole = "assistant"
config.SysDir = "sysprompts"
config.ChunkLimit = 8192
+ //
+ config.RAGBatchSize = 100
+ config.RAGWordLimit = 80
+ config.RAGWorkers = 5
+ // tts
+ config.TTS_ENABLED = false
+ config.TTS_URL = "http://localhost:8880/v1/audio/speech"
}
config.CurrentAPI = config.ChatAPI
config.APIMap = map[string]string{
- config.ChatAPI: config.CompletionAPI,
+ config.ChatAPI: config.CompletionAPI,
+ config.CompletionAPI: config.DeepSeekChatAPI,
+ config.DeepSeekChatAPI: config.DeepSeekCompletionAPI,
+ config.DeepSeekCompletionAPI: config.ChatAPI,
}
- if config.CompletionAPI != "" {
- config.CurrentAPI = config.CompletionAPI
- config.APIMap = map[string]string{
- config.CompletionAPI: config.ChatAPI,
+ for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} {
+ if el != "" {
+ config.ApiLinks = append(config.ApiLinks, el)
}
}
// if any value is empty fill with default
diff --git a/extra/cluedo.go b/extra/cluedo.go
new file mode 100644
index 0000000..1ef11cc
--- /dev/null
+++ b/extra/cluedo.go
@@ -0,0 +1,73 @@
+package extra
+
+import (
+ "math/rand"
+ "strings"
+)
+
+var (
+ rooms = []string{"HALL", "LOUNGE", "DINING ROOM", "KITCHEN", "BALLROOM", "CONSERVATORY", "BILLIARD ROOM", "LIBRARY", "STUDY"}
+ weapons = []string{"CANDLESTICK", "DAGGER", "LEAD PIPE", "REVOLVER", "ROPE", "SPANNER"}
+ people = []string{"Miss Scarlett", "Colonel Mustard", "Mrs. White", "Reverend Green", "Mrs. Peacock", "Professor Plum"}
+)
+
+type MurderTrifecta struct {
+ Murderer string
+ Weapon string
+ Room string
+}
+
+type CluedoRoundInfo struct {
+ Answer MurderTrifecta
+ PlayersCards map[string][]string
+}
+
+func (c *CluedoRoundInfo) GetPlayerCards(player string) string {
+ // maybe format it a little
+ return "cards of " + player + "are " + strings.Join(c.PlayersCards[player], ",")
+}
+
+func CluedoPrepCards(playerOrder []string) *CluedoRoundInfo {
+ res := &CluedoRoundInfo{}
+ // Select murder components
+ trifecta := MurderTrifecta{
+ Murderer: people[rand.Intn(len(people))],
+ Weapon: weapons[rand.Intn(len(weapons))],
+ Room: rooms[rand.Intn(len(rooms))],
+ }
+ // Collect non-murder cards
+ var notInvolved []string
+ for _, room := range rooms {
+ if room != trifecta.Room {
+ notInvolved = append(notInvolved, room)
+ }
+ }
+ for _, weapon := range weapons {
+ if weapon != trifecta.Weapon {
+ notInvolved = append(notInvolved, weapon)
+ }
+ }
+ for _, person := range people {
+ if person != trifecta.Murderer {
+ notInvolved = append(notInvolved, person)
+ }
+ }
+ // Shuffle and distribute cards
+ rand.Shuffle(len(notInvolved), func(i, j int) {
+ notInvolved[i], notInvolved[j] = notInvolved[j], notInvolved[i]
+ })
+ players := map[string][]string{}
+ cardsPerPlayer := len(notInvolved) / len(playerOrder)
+ // playerOrder := []string{"{{user}}", "{{char}}", "{{char2}}"}
+ for i, player := range playerOrder {
+ start := i * cardsPerPlayer
+ end := (i + 1) * cardsPerPlayer
+ if end > len(notInvolved) {
+ end = len(notInvolved)
+ }
+ players[player] = notInvolved[start:end]
+ }
+ res.Answer = trifecta
+ res.PlayersCards = players
+ return res
+}
diff --git a/extra/cluedo_test.go b/extra/cluedo_test.go
new file mode 100644
index 0000000..e7a53b1
--- /dev/null
+++ b/extra/cluedo_test.go
@@ -0,0 +1,50 @@
+package extra
+
+import (
+ "testing"
+)
+
+func TestPrepCards(t *testing.T) {
+ // Run the function to get the murder combination and player cards
+ roundInfo := CluedoPrepCards([]string{"{{user}}", "{{char}}", "{{char2}}"})
+ // Create a map to track all distributed cards
+ distributedCards := make(map[string]bool)
+ // Check that the murder combination cards are not distributed to players
+ murderCards := []string{roundInfo.Answer.Murderer, roundInfo.Answer.Weapon, roundInfo.Answer.Room}
+ for _, card := range murderCards {
+ if distributedCards[card] {
+ t.Errorf("Murder card %s was distributed to a player", card)
+ }
+ }
+ // Check each player's cards
+ for player, cards := range roundInfo.PlayersCards {
+ for _, card := range cards {
+ // Ensure the card is not part of the murder combination
+ for _, murderCard := range murderCards {
+ if card == murderCard {
+ t.Errorf("Player %s has a murder card: %s", player, card)
+ }
+ }
+ // Ensure the card is unique and not already distributed
+ if distributedCards[card] {
+ t.Errorf("Card %s is duplicated in player %s's hand", card, player)
+ }
+ distributedCards[card] = true
+ }
+ }
+ // Verify that all non-murder cards are distributed
+ allCards := append(append([]string{}, rooms...), weapons...)
+ allCards = append(allCards, people...)
+ for _, card := range allCards {
+ isMurderCard := false
+ for _, murderCard := range murderCards {
+ if card == murderCard {
+ isMurderCard = true
+ break
+ }
+ }
+ if !isMurderCard && !distributedCards[card] {
+ t.Errorf("Card %s was not distributed to any player", card)
+ }
+ }
+}
diff --git a/extra/stt.go b/extra/stt.go
new file mode 100644
index 0000000..ce107b4
--- /dev/null
+++ b/extra/stt.go
@@ -0,0 +1,166 @@
+package extra
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "mime/multipart"
+ "net/http"
+ "regexp"
+ "strings"
+
+ "github.com/gordonklaus/portaudio"
+)
+
+var specialRE = regexp.MustCompile(`\[.*?\]`)
+
+type STT interface {
+ StartRecording() error
+ StopRecording() (string, error)
+ IsRecording() bool
+}
+
+type StreamCloser interface {
+ Close() error
+}
+
+type WhisperSTT struct {
+ logger *slog.Logger
+ ServerURL string
+ SampleRate int
+ AudioBuffer *bytes.Buffer
+ recording bool
+}
+
+func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT {
+ return &WhisperSTT{
+ logger: logger,
+ ServerURL: serverURL,
+ SampleRate: sampleRate,
+ AudioBuffer: new(bytes.Buffer),
+ }
+}
+
+func (stt *WhisperSTT) StartRecording() error {
+ if err := stt.microphoneStream(stt.SampleRate); err != nil {
+ return fmt.Errorf("failed to init microphone: %w", err)
+ }
+ stt.recording = true
+ return nil
+}
+
+func (stt *WhisperSTT) StopRecording() (string, error) {
+ stt.recording = false
+ // wait loop to finish?
+ if stt.AudioBuffer == nil {
+ err := errors.New("unexpected nil AudioBuffer")
+ stt.logger.Error(err.Error())
+ return "", err
+ }
+ // Create WAV header first
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ // Add audio file part
+ part, err := writer.CreateFormFile("file", "recording.wav")
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Stream directly to multipart writer: header + raw data
+ dataSize := stt.AudioBuffer.Len()
+ stt.writeWavHeader(part, dataSize)
+ if _, err := io.Copy(part, stt.AudioBuffer); err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Reset buffer for next recording
+ stt.AudioBuffer.Reset()
+ // Add response format field
+ err = writer.WriteField("response_format", "text")
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ if writer.Close() != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Send request
+ resp, err := http.Post(stt.ServerURL, writer.FormDataContentType(), body) //nolint:noctx
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ defer resp.Body.Close()
+ // Read and print response
+ responseTextBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ resptext := strings.TrimRight(string(responseTextBytes), "\n")
+ // in case there are special tokens like [_BEG_]
+ resptext = specialRE.ReplaceAllString(resptext, "")
+ return strings.TrimSpace(strings.ReplaceAll(resptext, "\n ", "\n")), nil
+}
+
+func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) {
+ header := make([]byte, 44)
+ copy(header[0:4], "RIFF")
+ binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize))
+ copy(header[8:12], "WAVE")
+ copy(header[12:16], "fmt ")
+ binary.LittleEndian.PutUint32(header[16:20], 16)
+ binary.LittleEndian.PutUint16(header[20:22], 1)
+ binary.LittleEndian.PutUint16(header[22:24], 1)
+ binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate))
+ binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8))
+ binary.LittleEndian.PutUint16(header[32:34], 1*(16/8))
+ binary.LittleEndian.PutUint16(header[34:36], 16)
+ copy(header[36:40], "data")
+ binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize))
+ if _, err := w.Write(header); err != nil {
+ stt.logger.Error("writeWavHeader", "error", err)
+ }
+}
+
+func (stt *WhisperSTT) IsRecording() bool {
+ return stt.recording
+}
+
+func (stt *WhisperSTT) microphoneStream(sampleRate int) error {
+ if err := portaudio.Initialize(); err != nil {
+ return fmt.Errorf("portaudio init failed: %w", err)
+ }
+ in := make([]int16, 64)
+ stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in)
+ if err != nil {
+ if paErr := portaudio.Terminate(); paErr != nil {
+ return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr)
+ }
+ return fmt.Errorf("failed to open microphone: %w", err)
+ }
+ go func(stream *portaudio.Stream) {
+ if err := stream.Start(); err != nil {
+ stt.logger.Error("microphoneStream", "error", err)
+ return
+ }
+ for {
+ if !stt.IsRecording() {
+ return
+ }
+ if err := stream.Read(); err != nil {
+ stt.logger.Error("reading stream", "error", err)
+ return
+ }
+ if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil {
+ stt.logger.Error("writing to buffer", "error", err)
+ return
+ }
+ }
+ }(stream)
+ return nil
+}
diff --git a/extra/tts.go b/extra/tts.go
new file mode 100644
index 0000000..31e6887
--- /dev/null
+++ b/extra/tts.go
@@ -0,0 +1,212 @@
+package extra
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "gf-lt/config"
+ "gf-lt/models"
+ "io"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gopxl/beep/v2"
+ "github.com/gopxl/beep/v2/mp3"
+ "github.com/gopxl/beep/v2/speaker"
+ "github.com/neurosnap/sentences/english"
+)
+
+var (
+ TTSTextChan = make(chan string, 10000)
+ TTSFlushChan = make(chan bool, 1)
+ TTSDoneChan = make(chan bool, 1)
+ // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`)
+)
+
+type Orator interface {
+ Speak(text string) error
+ Stop()
+ // pause and resume?
+ GetLogger() *slog.Logger
+}
+
+// impl https://github.com/remsky/Kokoro-FastAPI
+type KokoroOrator struct {
+ logger *slog.Logger
+ URL string
+ Format models.AudioFormat
+ Stream bool
+ Speed float32
+ Language string
+ Voice string
+ currentStream *beep.Ctrl // Added for playback control
+ textBuffer strings.Builder
+ // textBuffer bytes.Buffer
+}
+
+func (o *KokoroOrator) stoproutine() {
+ <-TTSDoneChan
+ o.logger.Info("orator got done signal")
+ o.Stop()
+ // drain the channel
+ for len(TTSTextChan) > 0 {
+ <-TTSTextChan
+ }
+}
+
+func (o *KokoroOrator) readroutine() {
+ tokenizer, _ := english.NewSentenceTokenizer(nil)
+ // var sentenceBuf bytes.Buffer
+ // var remainder strings.Builder
+ for {
+ select {
+ case chunk := <-TTSTextChan:
+ // sentenceBuf.WriteString(chunk)
+ // text := sentenceBuf.String()
+ _, err := o.textBuffer.WriteString(chunk)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ text := o.textBuffer.String()
+ sentences := tokenizer.Tokenize(text)
+ o.logger.Info("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences))
+ for i, sentence := range sentences {
+ if i == len(sentences)-1 { // last sentence
+ o.textBuffer.Reset()
+ _, err := o.textBuffer.WriteString(sentence.Text)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ continue // if only one (often incomplete) sentence; wait for next chunk
+ }
+ o.logger.Info("calling Speak with sentence", "sent", sentence.Text)
+ if err := o.Speak(sentence.Text); err != nil {
+ o.logger.Error("tts failed", "sentence", sentence.Text, "error", err)
+ }
+ }
+ case <-TTSFlushChan:
+ o.logger.Info("got flushchan signal start")
+ // lln is done get the whole message out
+ if len(TTSTextChan) > 0 { // otherwise might get stuck
+ for chunk := range TTSTextChan {
+ _, err := o.textBuffer.WriteString(chunk)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ if len(TTSTextChan) == 0 {
+ break
+ }
+ }
+ }
+ // INFO: if there is a lot of text it will take some time to make with tts at once
+ // to avoid this pause, it might be better to keep splitting on sentences
+ // but keepinig in mind that remainder could be ommited by tokenizer
+ // Flush remaining text
+ remaining := o.textBuffer.String()
+ o.textBuffer.Reset()
+ if remaining != "" {
+ o.logger.Info("calling Speak with remainder", "rem", remaining)
+ if err := o.Speak(remaining); err != nil {
+ o.logger.Error("tts failed", "sentence", remaining, "error", err)
+ }
+ }
+ }
+ }
+}
+
+func NewOrator(log *slog.Logger, cfg *config.Config) Orator {
+ orator := &KokoroOrator{
+ logger: log,
+ URL: cfg.TTS_URL,
+ Format: models.AFMP3,
+ Stream: false,
+ Speed: cfg.TTS_SPEED,
+ Language: "a",
+ Voice: "af_bella(1)+af_sky(1)",
+ }
+ go orator.readroutine()
+ go orator.stoproutine()
+ return orator
+}
+
+func (o *KokoroOrator) GetLogger() *slog.Logger {
+ return o.logger
+}
+
+func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) {
+ payload := map[string]interface{}{
+ "input": text,
+ "voice": o.Voice,
+ "response_format": o.Format,
+ "download_format": o.Format,
+ "stream": o.Stream,
+ "speed": o.Speed,
+ // "return_download_link": true,
+ "lang_code": o.Language,
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal payload: %w", err)
+ }
+ req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("accept", "application/json")
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %w", err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+ return resp.Body, nil
+}
+
+func (o *KokoroOrator) Speak(text string) error {
+ o.logger.Info("fn: Speak is called", "text-len", len(text))
+ body, err := o.requestSound(text)
+ if err != nil {
+ o.logger.Error("request failed", "error", err)
+ return fmt.Errorf("request failed: %w", err)
+ }
+ defer body.Close()
+ // Decode the mp3 audio from response body
+ streamer, format, err := mp3.Decode(body)
+ if err != nil {
+ o.logger.Error("mp3 decode failed", "error", err)
+ return fmt.Errorf("mp3 decode failed: %w", err)
+ }
+ defer streamer.Close()
+ // here it spams with errors that speaker cannot be initialized more than once, but how would we deal with many audio records then?
+ if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
+ o.logger.Debug("failed to init speaker", "error", err)
+ }
+ done := make(chan bool)
+ // Create controllable stream and store reference
+ o.currentStream = &beep.Ctrl{Streamer: beep.Seq(streamer, beep.Callback(func() {
+ close(done)
+ o.currentStream = nil
+ })), Paused: false}
+ speaker.Play(o.currentStream)
+ <-done // we hang in this routine;
+ return nil
+}
+
+func (o *KokoroOrator) Stop() {
+ // speaker.Clear()
+ o.logger.Info("attempted to stop orator", "orator", o)
+ speaker.Lock()
+ defer speaker.Unlock()
+ if o.currentStream != nil {
+ // o.currentStream.Paused = true
+ o.currentStream.Streamer = nil
+ }
+}
diff --git a/extra/twentyq.go b/extra/twentyq.go
new file mode 100644
index 0000000..30c08cc
--- /dev/null
+++ b/extra/twentyq.go
@@ -0,0 +1,11 @@
+package extra
+
+import "math/rand"
+
+var (
+ chars = []string{"Shrek", "Garfield", "Jack the Ripper"}
+)
+
+func GetRandomChar() string {
+ return chars[rand.Intn(len(chars))]
+}
diff --git a/extra/vad.go b/extra/vad.go
new file mode 100644
index 0000000..2a9e238
--- /dev/null
+++ b/extra/vad.go
@@ -0,0 +1 @@
+package extra
diff --git a/go.mod b/go.mod
index eeaa6f6..cc1e743 100644
--- a/go.mod
+++ b/go.mod
@@ -1,4 +1,4 @@
-module elefant
+module gf-lt
go 1.23.2
@@ -7,6 +7,8 @@ require (
github.com/asg017/sqlite-vec-go-bindings v0.1.6
github.com/gdamore/tcell/v2 v2.7.4
github.com/glebarez/go-sqlite v1.22.0
+ github.com/gopxl/beep/v2 v2.1.0
+ github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/jmoiron/sqlx v1.4.0
github.com/ncruces/go-sqlite3 v0.21.3
github.com/neurosnap/sentences v1.1.2
@@ -15,12 +17,16 @@ require (
require (
github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/ebitengine/oto/v3 v3.2.0 // indirect
+ github.com/ebitengine/purego v0.7.1 // indirect
github.com/gdamore/encoding v1.0.0 // indirect
github.com/google/uuid v1.6.0 // indirect
+ github.com/hajimehoshi/go-mp3 v0.3.4 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
+ github.com/pkg/errors v0.9.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
diff --git a/go.sum b/go.sum
index fe84d96..1fffadf 100644
--- a/go.sum
+++ b/go.sum
@@ -4,8 +4,14 @@ github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww=
github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/ebitengine/oto/v3 v3.2.0 h1:FuggTJTSI3/3hEYwZEIN0CZVXYT29ZOdCu+z/f4QjTw=
+github.com/ebitengine/oto/v3 v3.2.0/go.mod h1:dOKXShvy1EQbIXhXPFcKLargdnFqH0RjptecvyAxhyw=
+github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA=
+github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko=
github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg=
github.com/gdamore/tcell/v2 v2.7.4 h1:sg6/UnTM9jGpZU+oFYAsDahfchWAFW8Xx2yFinNSAYU=
@@ -18,6 +24,13 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gopxl/beep/v2 v2.1.0 h1:Jv95iHw3aNWoAa/J78YyXvOvMHH2ZGeAYD5ug8tVt8c=
+github.com/gopxl/beep/v2 v2.1.0/go.mod h1:sQvj2oSsu8fmmDWH3t0DzIe0OZzTW6/TJEHW4Ku+22o=
+github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
+github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
+github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68=
+github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo=
+github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -36,6 +49,10 @@ github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw=
github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 h1:YIJ+B1hePP6AgynC5TcqpO0H9k3SSoZa2BGyL6vDUzM=
@@ -44,6 +61,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
@@ -62,6 +81,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -85,6 +105,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/libc v1.37.6 h1:orZH3c5wmhIQFTXF+Nt+eeauyd+ZIt2BX6ARe+kD+aw=
modernc.org/libc v1.37.6/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
diff --git a/llm.go b/llm.go
index 85f9d51..3307467 100644
--- a/llm.go
+++ b/llm.go
@@ -2,7 +2,7 @@ package main
import (
"bytes"
- "elefant/models"
+ "gf-lt/models"
"encoding/json"
"io"
"strings"
@@ -13,22 +13,47 @@ type ChunkParser interface {
FormMsg(msg, role string, cont bool) (io.Reader, error)
}
-func initChunkParser() {
+func choseChunkParser() {
chunkParser = LlamaCPPeer{}
- if strings.Contains(cfg.CurrentAPI, "v1") {
- logger.Debug("chosen /v1/chat parser")
+ switch cfg.CurrentAPI {
+ case "http://localhost:8080/completion":
+ chunkParser = LlamaCPPeer{}
+ logger.Debug("chosen llamacppeer", "link", cfg.CurrentAPI)
+ return
+ case "http://localhost:8080/v1/chat/completions":
chunkParser = OpenAIer{}
+ logger.Debug("chosen openair", "link", cfg.CurrentAPI)
+ return
+ case "https://api.deepseek.com/beta/completions":
+ chunkParser = DeepSeekerCompletion{}
+ logger.Debug("chosen deepseekercompletio", "link", cfg.CurrentAPI)
+ return
+ case "https://api.deepseek.com/chat/completions":
+ chunkParser = DeepSeekerChat{}
+ logger.Debug("chosen deepseekerchat", "link", cfg.CurrentAPI)
return
+ default:
+ chunkParser = LlamaCPPeer{}
}
- logger.Debug("chosen llamacpp /completion parser")
+ // if strings.Contains(cfg.CurrentAPI, "chat") {
+ // logger.Debug("chosen chat parser")
+ // chunkParser = OpenAIer{}
+ // return
+ // }
+ // logger.Debug("chosen llamacpp /completion parser")
}
type LlamaCPPeer struct {
}
type OpenAIer struct {
}
+type DeepSeekerCompletion struct {
+}
+type DeepSeekerChat struct {
+}
func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
+ logger.Debug("formmsg llamacppeer", "link", cfg.CurrentAPI)
if msg != "" { // otherwise let the bot to continue
newMsg := models.RoleMsg{Role: role, Content: msg}
chatBody.Messages = append(chatBody.Messages, newMsg)
@@ -53,18 +78,21 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
}
prompt := strings.Join(messages, "\n")
// strings builder?
- // if cfg.ToolUse && msg != "" && !resume {
if !resume {
botMsgStart := "\n" + cfg.AssistantRole + ":\n"
prompt += botMsgStart
}
- // if cfg.ThinkUse && msg != "" && !cfg.ToolUse {
if cfg.ThinkUse && !cfg.ToolUse {
prompt += "<think>"
}
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
- payload := models.NewLCPReq(prompt, cfg, defaultLCPProps)
+ var payload any
+ payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
+ if strings.Contains(chatBody.Model, "deepseek") {
+ payload = models.NewDSCompletionReq(prompt, chatBody.Model,
+ defaultLCPProps["temp"], cfg)
+ }
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
@@ -105,6 +133,7 @@ func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) {
}
func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
+ logger.Debug("formmsg openaier", "link", cfg.CurrentAPI)
if cfg.ToolUse && !resume {
// prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg
// add to chat body
@@ -131,3 +160,129 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
}
return bytes.NewReader(data), nil
}
+
+// deepseek
+func (ds DeepSeekerCompletion) ParseChunk(data []byte) (string, bool, error) {
+ llmchunk := models.DSCompletionResp{}
+ if err := json.Unmarshal(data, &llmchunk); err != nil {
+ logger.Error("failed to decode", "error", err, "line", string(data))
+ return "", false, err
+ }
+ if llmchunk.Choices[0].FinishReason != "" {
+ if llmchunk.Choices[0].Text != "" {
+ logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
+ }
+ return llmchunk.Choices[0].Text, true, nil
+ }
+ return llmchunk.Choices[0].Text, false, nil
+}
+
+func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) {
+ logger.Debug("formmsg deepseekercompletion", "link", cfg.CurrentAPI)
+ if msg != "" { // otherwise let the bot to continue
+ newMsg := models.RoleMsg{Role: role, Content: msg}
+ chatBody.Messages = append(chatBody.Messages, newMsg)
+ // if rag
+ if cfg.RAGEnabled {
+ ragResp, err := chatRagUse(newMsg.Content)
+ if err != nil {
+ logger.Error("failed to form a rag msg", "error", err)
+ return nil, err
+ }
+ ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
+ chatBody.Messages = append(chatBody.Messages, ragMsg)
+ }
+ }
+ if cfg.ToolUse && !resume {
+ // add to chat body
+ chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
+ }
+ messages := make([]string, len(chatBody.Messages))
+ for i, m := range chatBody.Messages {
+ messages[i] = m.ToPrompt()
+ }
+ prompt := strings.Join(messages, "\n")
+ // strings builder?
+ if !resume {
+ botMsgStart := "\n" + cfg.AssistantRole + ":\n"
+ prompt += botMsgStart
+ }
+ if cfg.ThinkUse && !cfg.ToolUse {
+ prompt += "<think>"
+ }
+ logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
+ "msg", msg, "resume", resume, "prompt", prompt)
+ payload := models.NewDSCompletionReq(prompt, chatBody.Model,
+ defaultLCPProps["temp"], cfg)
+ data, err := json.Marshal(payload)
+ if err != nil {
+ logger.Error("failed to form a msg", "error", err)
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+func (ds DeepSeekerChat) ParseChunk(data []byte) (string, bool, error) {
+ llmchunk := models.DSChatStreamResp{}
+ if err := json.Unmarshal(data, &llmchunk); err != nil {
+ logger.Error("failed to decode", "error", err, "line", string(data))
+ return "", false, err
+ }
+ if llmchunk.Choices[0].FinishReason != "" {
+ if llmchunk.Choices[0].Delta.Content != "" {
+ logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
+ }
+ return llmchunk.Choices[0].Delta.Content, true, nil
+ }
+ if llmchunk.Choices[0].Delta.ReasoningContent != "" {
+ return llmchunk.Choices[0].Delta.ReasoningContent, false, nil
+ }
+ return llmchunk.Choices[0].Delta.Content, false, nil
+}
+
+func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
+ logger.Debug("formmsg deepseekerchat", "link", cfg.CurrentAPI)
+ if cfg.ToolUse && !resume {
+ // prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg
+ // add to chat body
+ chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
+ }
+ if msg != "" { // otherwise let the bot continue
+ newMsg := models.RoleMsg{Role: role, Content: msg}
+ chatBody.Messages = append(chatBody.Messages, newMsg)
+ // if rag
+ if cfg.RAGEnabled {
+ ragResp, err := chatRagUse(newMsg.Content)
+ if err != nil {
+ logger.Error("failed to form a rag msg", "error", err)
+ return nil, err
+ }
+ ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
+ chatBody.Messages = append(chatBody.Messages, ragMsg)
+ }
+ }
+ // Create copy of chat body with standardized user role
+ // modifiedBody := *chatBody
+ bodyCopy := &models.ChatBody{
+ Messages: make([]models.RoleMsg, len(chatBody.Messages)),
+ Model: chatBody.Model,
+ Stream: chatBody.Stream,
+ }
+ // modifiedBody.Messages = make([]models.RoleMsg, len(chatBody.Messages))
+ for i, msg := range chatBody.Messages {
+ logger.Debug("checking roles", "#", i, "role", msg.Role)
+ if msg.Role == cfg.UserRole || i == 1 {
+ bodyCopy.Messages[i].Role = "user"
+ logger.Debug("replaced role in body", "#", i)
+ } else {
+ bodyCopy.Messages[i] = msg
+ }
+ }
+ dsBody := models.NewDSCharReq(*bodyCopy)
+ data, err := json.Marshal(dsBody)
+ if err != nil {
+ logger.Error("failed to form a msg", "error", err)
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
diff --git a/main.go b/main.go
index 73275e8..c73cf3c 100644
--- a/main.go
+++ b/main.go
@@ -12,7 +12,7 @@ var (
botRespMode = false
editMode = false
selectedIndex = int(-1)
- indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | RAGEnabled: [orange:-:b]%v[-:-:-] (F11) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p)"
+ indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r)"
focusSwitcher = map[tview.Primitive]tview.Primitive{}
)
diff --git a/main_test.go b/main_test.go
index 0046ca2..fb0a774 100644
--- a/main_test.go
+++ b/main_test.go
@@ -1,7 +1,7 @@
package main
import (
- "elefant/models"
+ "gf-lt/models"
"fmt"
"strings"
"testing"
diff --git a/models/extra.go b/models/extra.go
new file mode 100644
index 0000000..e1ca80f
--- /dev/null
+++ b/models/extra.go
@@ -0,0 +1,8 @@
+package models
+
+type AudioFormat string
+
+const (
+ AFWav AudioFormat = "wav"
+ AFMP3 AudioFormat = "mp3"
+)
diff --git a/models/models.go b/models/models.go
index bb61abf..918e35e 100644
--- a/models/models.go
+++ b/models/models.go
@@ -1,8 +1,8 @@
package models
import (
- "elefant/config"
"fmt"
+ "gf-lt/config"
"strings"
)
@@ -76,6 +76,27 @@ type ChatBody struct {
Messages []RoleMsg `json:"messages"`
}
+func (cb *ChatBody) Rename(oldname, newname string) {
+ for i, m := range cb.Messages {
+ cb.Messages[i].Content = strings.ReplaceAll(m.Content, oldname, newname)
+ cb.Messages[i].Role = strings.ReplaceAll(m.Role, oldname, newname)
+ }
+}
+
+func (cb *ChatBody) ListRoles() []string {
+ namesMap := make(map[string]struct{})
+ for _, m := range cb.Messages {
+ namesMap[m.Role] = struct{}{}
+ }
+ resp := make([]string, len(namesMap))
+ i := 0
+ for k := range namesMap {
+ resp[i] = k
+ i++
+ }
+ return resp
+}
+
type ChatToolsBody struct {
Model string `json:"model"`
Messages []RoleMsg `json:"messages"`
@@ -103,6 +124,143 @@ type ChatToolsBody struct {
ToolChoice string `json:"tool_choice"`
}
+type DSChatReq struct {
+ Messages []RoleMsg `json:"messages"`
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ FrequencyPenalty int `json:"frequency_penalty"`
+ MaxTokens int `json:"max_tokens"`
+ PresencePenalty int `json:"presence_penalty"`
+ Temperature float32 `json:"temperature"`
+ TopP float32 `json:"top_p"`
+ // ResponseFormat struct {
+ // Type string `json:"type"`
+ // } `json:"response_format"`
+ // Stop any `json:"stop"`
+ // StreamOptions any `json:"stream_options"`
+ // Tools any `json:"tools"`
+ // ToolChoice string `json:"tool_choice"`
+ // Logprobs bool `json:"logprobs"`
+ // TopLogprobs any `json:"top_logprobs"`
+}
+
+func NewDSCharReq(cb ChatBody) DSChatReq {
+ return DSChatReq{
+ Messages: cb.Messages,
+ Model: cb.Model,
+ Stream: cb.Stream,
+ MaxTokens: 2048,
+ PresencePenalty: 0,
+ FrequencyPenalty: 0,
+ Temperature: 1.0,
+ TopP: 1.0,
+ }
+}
+
+type DSCompletionReq struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Echo bool `json:"echo"`
+ FrequencyPenalty int `json:"frequency_penalty"`
+ // Logprobs int `json:"logprobs"`
+ MaxTokens int `json:"max_tokens"`
+ PresencePenalty int `json:"presence_penalty"`
+ Stop any `json:"stop"`
+ Stream bool `json:"stream"`
+ StreamOptions any `json:"stream_options"`
+ Suffix any `json:"suffix"`
+ Temperature float32 `json:"temperature"`
+ TopP float32 `json:"top_p"`
+}
+
+func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
+ return DSCompletionReq{
+ Model: model,
+ Prompt: prompt,
+ Temperature: temp,
+ Stream: true,
+ Echo: false,
+ MaxTokens: 2048,
+ PresencePenalty: 0,
+ FrequencyPenalty: 0,
+ TopP: 1.0,
+ Stop: []string{
+ cfg.UserRole + ":\n", "<|im_end|>",
+ cfg.ToolRole + ":\n",
+ cfg.AssistantRole + ":\n",
+ },
+ }
+}
+
+type DSCompletionResp struct {
+ ID string `json:"id"`
+ Choices []struct {
+ FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ Logprobs struct {
+ TextOffset []int `json:"text_offset"`
+ TokenLogprobs []int `json:"token_logprobs"`
+ Tokens []string `json:"tokens"`
+ TopLogprobs []struct {
+ } `json:"top_logprobs"`
+ } `json:"logprobs"`
+ Text string `json:"text"`
+ } `json:"choices"`
+ Created int `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Object string `json:"object"`
+ Usage struct {
+ CompletionTokens int `json:"completion_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"`
+ PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ CompletionTokensDetails struct {
+ ReasoningTokens int `json:"reasoning_tokens"`
+ } `json:"completion_tokens_details"`
+ } `json:"usage"`
+}
+
+type DSChatResp struct {
+ Choices []struct {
+ Delta struct {
+ Content string `json:"content"`
+ Role any `json:"role"`
+ } `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ Logprobs any `json:"logprobs"`
+ } `json:"choices"`
+ Created int `json:"created"`
+ ID string `json:"id"`
+ Model string `json:"model"`
+ Object string `json:"object"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Usage struct {
+ CompletionTokens int `json:"completion_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ } `json:"usage"`
+}
+
+type DSChatStreamResp struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []struct {
+ Index int `json:"index"`
+ Delta struct {
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content"`
+ } `json:"delta"`
+ Logprobs any `json:"logprobs"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+}
+
type EmbeddingResp struct {
Embedding []float32 `json:"embedding"`
Index uint32 `json:"index"`
@@ -190,3 +348,13 @@ type LlamaCPPResp struct {
Content string `json:"content"`
Stop bool `json:"stop"`
}
+
+type DSBalance struct {
+ IsAvailable bool `json:"is_available"`
+ BalanceInfos []struct {
+ Currency string `json:"currency"`
+ TotalBalance string `json:"total_balance"`
+ GrantedBalance string `json:"granted_balance"`
+ ToppedUpBalance string `json:"topped_up_balance"`
+ } `json:"balance_infos"`
+}
diff --git a/pngmeta/altwriter.go b/pngmeta/altwriter.go
new file mode 100644
index 0000000..206b563
--- /dev/null
+++ b/pngmeta/altwriter.go
@@ -0,0 +1,133 @@
+package pngmeta
+
+import (
+ "bytes"
+ "gf-lt/models"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "io"
+ "os"
+)
+
+const (
+ pngHeader = "\x89PNG\r\n\x1a\n"
+ textChunkType = "tEXt"
+)
+
+// WriteToPng embeds the metadata into the specified PNG file and writes the result to outfile.
+func WriteToPng(metadata *models.CharCardSpec, sourcePath, outfile string) error {
+ pngData, err := os.ReadFile(sourcePath)
+ if err != nil {
+ return err
+ }
+ jsonData, err := json.Marshal(metadata)
+ if err != nil {
+ return err
+ }
+ base64Data := base64.StdEncoding.EncodeToString(jsonData)
+ embedData := PngEmbed{
+ Key: "gf-lt", // Replace with appropriate key constant
+ Value: base64Data,
+ }
+ var outputBuffer bytes.Buffer
+ if _, err := outputBuffer.Write([]byte(pngHeader)); err != nil {
+ return err
+ }
+ chunks, iend, err := processChunks(pngData[8:])
+ if err != nil {
+ return err
+ }
+ for _, chunk := range chunks {
+ outputBuffer.Write(chunk)
+ }
+ newChunk, err := createTextChunk(embedData)
+ if err != nil {
+ return err
+ }
+ outputBuffer.Write(newChunk)
+ outputBuffer.Write(iend)
+ return os.WriteFile(outfile, outputBuffer.Bytes(), 0666)
+}
+
+// processChunks extracts non-tEXt chunks and locates the IEND chunk
+func processChunks(data []byte) ([][]byte, []byte, error) {
+ var (
+ chunks [][]byte
+ iendChunk []byte
+ reader = bytes.NewReader(data)
+ )
+ for {
+ var chunkLength uint32
+ if err := binary.Read(reader, binary.BigEndian, &chunkLength); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ return nil, nil, fmt.Errorf("error reading chunk length: %w", err)
+ }
+ chunkType := make([]byte, 4)
+ if _, err := reader.Read(chunkType); err != nil {
+ return nil, nil, fmt.Errorf("error reading chunk type: %w", err)
+ }
+ chunkData := make([]byte, chunkLength)
+ if _, err := reader.Read(chunkData); err != nil {
+ return nil, nil, fmt.Errorf("error reading chunk data: %w", err)
+ }
+ crc := make([]byte, 4)
+ if _, err := reader.Read(crc); err != nil {
+ return nil, nil, fmt.Errorf("error reading CRC: %w", err)
+ }
+ fullChunk := bytes.NewBuffer(nil)
+ if err := binary.Write(fullChunk, binary.BigEndian, chunkLength); err != nil {
+ return nil, nil, fmt.Errorf("error writing chunk length: %w", err)
+ }
+ if _, err := fullChunk.Write(chunkType); err != nil {
+ return nil, nil, fmt.Errorf("error writing chunk type: %w", err)
+ }
+ if _, err := fullChunk.Write(chunkData); err != nil {
+ return nil, nil, fmt.Errorf("error writing chunk data: %w", err)
+ }
+ if _, err := fullChunk.Write(crc); err != nil {
+ return nil, nil, fmt.Errorf("error writing CRC: %w", err)
+ }
+ switch string(chunkType) {
+ case "IEND":
+ iendChunk = fullChunk.Bytes()
+ return chunks, iendChunk, nil
+ case textChunkType:
+ continue // Skip existing tEXt chunks
+ default:
+ chunks = append(chunks, fullChunk.Bytes())
+ }
+ }
+ return nil, nil, errors.New("IEND chunk not found")
+}
+
+// createTextChunk generates a valid tEXt chunk with proper CRC
+func createTextChunk(embed PngEmbed) ([]byte, error) {
+ content := bytes.NewBuffer(nil)
+ content.WriteString(embed.Key)
+ content.WriteByte(0) // Null separator
+ content.WriteString(embed.Value)
+ data := content.Bytes()
+ crc := crc32.NewIEEE()
+ crc.Write([]byte(textChunkType))
+ crc.Write(data)
+ chunk := bytes.NewBuffer(nil)
+ if err := binary.Write(chunk, binary.BigEndian, uint32(len(data))); err != nil {
+ return nil, fmt.Errorf("error writing chunk length: %w", err)
+ }
+ if _, err := chunk.Write([]byte(textChunkType)); err != nil {
+ return nil, fmt.Errorf("error writing chunk type: %w", err)
+ }
+ if _, err := chunk.Write(data); err != nil {
+ return nil, fmt.Errorf("error writing chunk data: %w", err)
+ }
+ if err := binary.Write(chunk, binary.BigEndian, crc.Sum32()); err != nil {
+ return nil, fmt.Errorf("error writing CRC: %w", err)
+ }
+ return chunk.Bytes(), nil
+}
diff --git a/pngmeta/metareader.go b/pngmeta/metareader.go
index df2a8d4..369345a 100644
--- a/pngmeta/metareader.go
+++ b/pngmeta/metareader.go
@@ -2,7 +2,7 @@ package pngmeta
import (
"bytes"
- "elefant/models"
+ "gf-lt/models"
"encoding/base64"
"encoding/json"
"errors"
@@ -22,8 +22,6 @@ const (
writeHeader = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"
)
-var tEXtChunkDataSpecification = "%s\x00%s"
-
type PngEmbed struct {
Key string
Value string
@@ -97,7 +95,7 @@ func ReadCard(fname, uname string) (*models.CharCard, error) {
return nil, err
}
if charSpec.Name == "" {
- return nil, fmt.Errorf("failed to find role; fname %s\n", fname)
+ return nil, fmt.Errorf("failed to find role; fname %s", fname)
}
return charSpec.Simplify(uname, fname), nil
}
diff --git a/pngmeta/metareader_test.go b/pngmeta/metareader_test.go
index 5d9a0e2..f88de06 100644
--- a/pngmeta/metareader_test.go
+++ b/pngmeta/metareader_test.go
@@ -1,7 +1,19 @@
package pngmeta
import (
+ "bytes"
+ "gf-lt/models"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
"fmt"
+ "image"
+ "image/color"
+ "image/png"
+ "io"
+ "os"
+ "path/filepath"
"testing"
)
@@ -28,3 +40,155 @@ func TestReadMeta(t *testing.T) {
})
}
}
+
+// Test helper: Create a simple PNG image with test shapes
+func createTestImage(t *testing.T) string {
+ img := image.NewRGBA(image.Rect(0, 0, 200, 200))
+ // Fill background with white
+ for y := 0; y < 200; y++ {
+ for x := 0; x < 200; x++ {
+ img.Set(x, y, color.White)
+ }
+ }
+ // Draw a red square
+ for y := 50; y < 150; y++ {
+ for x := 50; x < 150; x++ {
+ img.Set(x, y, color.RGBA{R: 255, A: 255})
+ }
+ }
+ // Draw a blue circle
+ center := image.Point{100, 100}
+ radius := 40
+ for y := center.Y - radius; y <= center.Y+radius; y++ {
+ for x := center.X - radius; x <= center.X+radius; x++ {
+ dx := x - center.X
+ dy := y - center.Y
+ if dx*dx+dy*dy <= radius*radius {
+ img.Set(x, y, color.RGBA{B: 255, A: 255})
+ }
+ }
+ }
+ // Create temp file
+ tmpDir := t.TempDir()
+ fpath := filepath.Join(tmpDir, "test-image.png")
+ f, err := os.Create(fpath)
+ if err != nil {
+ t.Fatalf("Error creating temp file: %v", err)
+ }
+ defer f.Close()
+ if err := png.Encode(f, img); err != nil {
+ t.Fatalf("Error encoding PNG: %v", err)
+ }
+ return fpath
+}
+
+func TestWriteToPng(t *testing.T) {
+ // Create test image
+ srcPath := createTestImage(t)
+ dstPath := filepath.Join(filepath.Dir(srcPath), "output.png")
+ // dstPath := "test.png"
+ // Create test metadata
+ metadata := &models.CharCardSpec{
+ Description: "Test image containing a red square and blue circle on white background",
+ }
+ // Embed metadata
+ if err := WriteToPng(metadata, srcPath, dstPath); err != nil {
+ t.Fatalf("WriteToPng failed: %v", err)
+ }
+ // Verify output file exists
+ if _, err := os.Stat(dstPath); os.IsNotExist(err) {
+ t.Fatalf("Output file not created: %v", err)
+ }
+ // Read and verify metadata
+ t.Run("VerifyMetadata", func(t *testing.T) {
+ data, err := os.ReadFile(dstPath)
+ if err != nil {
+ t.Fatalf("Error reading output file: %v", err)
+ }
+ // Verify PNG header
+ if string(data[:8]) != pngHeader {
+ t.Errorf("Invalid PNG header")
+ }
+ // Extract metadata
+ embedded := extractMetadata(t, data)
+ if embedded.Description != metadata.Description {
+ t.Errorf("Metadata mismatch\nWant: %q\nGot: %q",
+ metadata.Description, embedded.Description)
+ }
+ })
+ // Optional: Add cleanup if needed
+ // t.Cleanup(func() {
+ // os.Remove(dstPath)
+ // })
+}
+
+// Helper to extract embedded metadata from PNG bytes
+func extractMetadata(t *testing.T, data []byte) *models.CharCardSpec {
+ r := bytes.NewReader(data[8:]) // Skip PNG header
+ for {
+ var length uint32
+ if err := binary.Read(r, binary.BigEndian, &length); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ t.Fatalf("Error reading chunk length: %v", err)
+ }
+ chunkType := make([]byte, 4)
+ if _, err := r.Read(chunkType); err != nil {
+ t.Fatalf("Error reading chunk type: %v", err)
+ }
+ // Read chunk data
+ chunkData := make([]byte, length)
+ if _, err := r.Read(chunkData); err != nil {
+ t.Fatalf("Error reading chunk data: %v", err)
+ }
+ // Read and discard CRC
+ if _, err := r.Read(make([]byte, 4)); err != nil {
+ t.Fatalf("Error reading CRC: %v", err)
+ }
+ if string(chunkType) == embType {
+ parts := bytes.SplitN(chunkData, []byte{0}, 2)
+ if len(parts) != 2 {
+ t.Fatalf("Invalid tEXt chunk format")
+ }
+ decoded, err := base64.StdEncoding.DecodeString(string(parts[1]))
+ if err != nil {
+ t.Fatalf("Base64 decode error: %v", err)
+ }
+ var result models.CharCardSpec
+ if err := json.Unmarshal(decoded, &result); err != nil {
+ t.Fatalf("JSON unmarshal error: %v", err)
+ }
+ return &result
+ }
+ }
+ t.Fatal("Metadata not found in PNG")
+ return nil
+}
+
+func readTextChunk(t *testing.T, r io.ReadSeeker) *models.CharCardSpec {
+ var length uint32
+ binary.Read(r, binary.BigEndian, &length)
+ chunkType := make([]byte, 4)
+ r.Read(chunkType)
+ data := make([]byte, length)
+ r.Read(data)
+ // Read CRC (but skip validation for test purposes)
+ crc := make([]byte, 4)
+ r.Read(crc)
+ parts := bytes.SplitN(data, []byte{0}, 2) // Split key-value pair
+ if len(parts) != 2 {
+ t.Fatalf("Invalid tEXt chunk format")
+ }
+ // key := string(parts[0])
+ value := parts[1]
+ decoded, err := base64.StdEncoding.DecodeString(string(value))
+ if err != nil {
+ t.Fatalf("Base64 decode error: %v; value: %s", err, string(value))
+ }
+ var result models.CharCardSpec
+ if err := json.Unmarshal(decoded, &result); err != nil {
+ t.Fatalf("JSON unmarshal error: %v", err)
+ }
+ return &result
+}
diff --git a/pngmeta/partswriter.go b/pngmeta/partswriter.go
index 7c36daf..7282df6 100644
--- a/pngmeta/partswriter.go
+++ b/pngmeta/partswriter.go
@@ -1,116 +1,112 @@
package pngmeta
-import (
- "bytes"
- "elefant/models"
- "encoding/base64"
- "encoding/binary"
- "encoding/json"
- "errors"
- "fmt"
- "hash/crc32"
- "io"
- "os"
-)
+// import (
+// "bytes"
+// "encoding/binary"
+// "errors"
+// "fmt"
+// "hash/crc32"
+// "io"
+// )
-type Writer struct {
- w io.Writer
-}
+// type Writer struct {
+// w io.Writer
+// }
-func NewPNGWriter(w io.Writer) (*Writer, error) {
- if _, err := io.WriteString(w, writeHeader); err != nil {
- return nil, err
- }
- return &Writer{w}, nil
-}
+// func NewPNGWriter(w io.Writer) (*Writer, error) {
+// if _, err := io.WriteString(w, writeHeader); err != nil {
+// return nil, err
+// }
+// return &Writer{w}, nil
+// }
-func (w *Writer) WriteChunk(length int32, typ string, r io.Reader) error {
- if err := binary.Write(w.w, binary.BigEndian, length); err != nil {
- return err
- }
- if _, err := w.w.Write([]byte(typ)); err != nil {
- return err
- }
- checksummer := crc32.NewIEEE()
- checksummer.Write([]byte(typ))
- if _, err := io.CopyN(io.MultiWriter(w.w, checksummer), r, int64(length)); err != nil {
- return err
- }
- if err := binary.Write(w.w, binary.BigEndian, checksummer.Sum32()); err != nil {
- return err
- }
- return nil
-}
+// func (w *Writer) WriteChunk(length int32, typ string, r io.Reader) error {
+// if err := binary.Write(w.w, binary.BigEndian, length); err != nil {
+// return err
+// }
+// if _, err := w.w.Write([]byte(typ)); err != nil {
+// return err
+// }
+// checksummer := crc32.NewIEEE()
+// checksummer.Write([]byte(typ))
+// if _, err := io.CopyN(io.MultiWriter(w.w, checksummer), r, int64(length)); err != nil {
+// return err
+// }
+// if err := binary.Write(w.w, binary.BigEndian, checksummer.Sum32()); err != nil {
+// return err
+// }
+// return nil
+// }
-func WriteToPng(c *models.CharCardSpec, fpath, outfile string) error {
- data, err := os.ReadFile(fpath)
- if err != nil {
- return err
- }
- jsonData, err := json.Marshal(c)
- if err != nil {
- return err
- }
- // Base64 encode the JSON data
- base64Data := base64.StdEncoding.EncodeToString(jsonData)
- pe := PngEmbed{
- Key: cKey,
- Value: base64Data,
- }
- w, err := WritetEXtToPngBytes(data, pe)
- if err != nil {
- return err
- }
- return os.WriteFile(outfile, w.Bytes(), 0666)
-}
+// func WWriteToPngriteToPng(c *models.CharCardSpec, fpath, outfile string) error {
+// data, err := os.ReadFile(fpath)
+// if err != nil {
+// return err
+// }
+// jsonData, err := json.Marshal(c)
+// if err != nil {
+// return err
+// }
+// // Base64 encode the JSON data
+// base64Data := base64.StdEncoding.EncodeToString(jsonData)
+// pe := PngEmbed{
+// Key: cKey,
+// Value: base64Data,
+// }
+// w, err := WritetEXtToPngBytes(data, pe)
+// if err != nil {
+// return err
+// }
+// return os.WriteFile(outfile, w.Bytes(), 0666)
+// }
-func WritetEXtToPngBytes(inputBytes []byte, pe PngEmbed) (outputBytes bytes.Buffer, err error) {
- if !(string(inputBytes[:8]) == header) {
- return outputBytes, errors.New("wrong file format")
- }
- reader := bytes.NewReader(inputBytes)
- pngr, err := NewPNGStepReader(reader)
- if err != nil {
- return outputBytes, fmt.Errorf("NewReader(): %s", err)
- }
- pngw, err := NewPNGWriter(&outputBytes)
- if err != nil {
- return outputBytes, fmt.Errorf("NewWriter(): %s", err)
- }
- for {
- chunk, err := pngr.Next()
- if err != nil {
- if errors.Is(err, io.EOF) {
- break
- }
- return outputBytes, fmt.Errorf("NextChunk(): %s", err)
- }
- if chunk.Type() != embType {
- // IENDChunkType will only appear on the final iteration of a valid PNG
- if chunk.Type() == IEND {
- // This is where we inject tEXtChunkType as the penultimate chunk with the new value
- newtEXtChunk := []byte(fmt.Sprintf(tEXtChunkDataSpecification, pe.Key, pe.Value))
- if err := pngw.WriteChunk(int32(len(newtEXtChunk)), embType, bytes.NewBuffer(newtEXtChunk)); err != nil {
- return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
- }
- // Now we end the buffer with IENDChunkType chunk
- if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil {
- return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
- }
- } else {
- // writes back original chunk to buffer
- if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil {
- return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
- }
- }
- } else {
- if _, err := io.Copy(io.Discard, chunk); err != nil {
- return outputBytes, fmt.Errorf("io.Copy(io.Discard, chunk): %s", err)
- }
- }
- if err := chunk.Close(); err != nil {
- return outputBytes, fmt.Errorf("chunk.Close(): %s", err)
- }
- }
- return outputBytes, nil
-}
+// func WritetEXtToPngBytes(inputBytes []byte, pe PngEmbed) (outputBytes bytes.Buffer, err error) {
+// if !(string(inputBytes[:8]) == header) {
+// return outputBytes, errors.New("wrong file format")
+// }
+// reader := bytes.NewReader(inputBytes)
+// pngr, err := NewPNGStepReader(reader)
+// if err != nil {
+// return outputBytes, fmt.Errorf("NewReader(): %s", err)
+// }
+// pngw, err := NewPNGWriter(&outputBytes)
+// if err != nil {
+// return outputBytes, fmt.Errorf("NewWriter(): %s", err)
+// }
+// for {
+// chunk, err := pngr.Next()
+// if err != nil {
+// if errors.Is(err, io.EOF) {
+// break
+// }
+// return outputBytes, fmt.Errorf("NextChunk(): %s", err)
+// }
+// if chunk.Type() != embType {
+// // IENDChunkType will only appear on the final iteration of a valid PNG
+// if chunk.Type() == IEND {
+// // This is where we inject tEXtChunkType as the penultimate chunk with the new value
+// newtEXtChunk := []byte(fmt.Sprintf(tEXtChunkDataSpecification, pe.Key, pe.Value))
+// if err := pngw.WriteChunk(int32(len(newtEXtChunk)), embType, bytes.NewBuffer(newtEXtChunk)); err != nil {
+// return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
+// }
+// // Now we end the buffer with IENDChunkType chunk
+// if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil {
+// return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
+// }
+// } else {
+// // writes back original chunk to buffer
+// if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil {
+// return outputBytes, fmt.Errorf("WriteChunk(): %s", err)
+// }
+// }
+// } else {
+// if _, err := io.Copy(io.Discard, chunk); err != nil {
+// return outputBytes, fmt.Errorf("io.Copy(io.Discard, chunk): %s", err)
+// }
+// }
+// if err := chunk.Close(); err != nil {
+// return outputBytes, fmt.Errorf("chunk.Close(): %s", err)
+// }
+// }
+// return outputBytes, nil
+// }
diff --git a/rag/main.go b/rag/main.go
index 5f2aa00..b7e0c00 100644
--- a/rag/main.go
+++ b/rag/main.go
@@ -2,9 +2,9 @@ package rag
import (
"bytes"
- "elefant/config"
- "elefant/models"
- "elefant/storage"
+ "gf-lt/config"
+ "gf-lt/models"
+ "gf-lt/storage"
"encoding/json"
"errors"
"fmt"
@@ -61,16 +61,10 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences {
sents[i] = s.Text
}
- // TODO: maybe better to decide batch size based on sentences len
var (
- // TODO: to config
- workers = 5
- batchSize = 100
maxChSize = 1000
- //
- wordLimit = 80
left = 0
- right = batchSize
+ right = r.cfg.RAGBatchSize
batchCh = make(chan map[int][]string, maxChSize)
vectorCh = make(chan []models.VectorRow, maxChSize)
errCh = make(chan error, 1)
@@ -85,29 +79,29 @@ func (r *RAG) LoadRAG(fpath string) error {
par := strings.Builder{}
for i := 0; i < len(sents); i++ {
par.WriteString(sents[i])
- if wordCounter(par.String()) > wordLimit {
+ if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
paragraphs = append(paragraphs, par.String())
par.Reset()
}
}
- if len(paragraphs) < batchSize {
- batchSize = len(paragraphs)
+ if len(paragraphs) < int(r.cfg.RAGBatchSize) {
+ r.cfg.RAGBatchSize = len(paragraphs)
}
// fill input channel
ctn := 0
for {
- if right > len(paragraphs) {
+ if int(right) > len(paragraphs) {
batchCh <- map[int][]string{left: paragraphs[left:]}
break
}
batchCh <- map[int][]string{left: paragraphs[left:right]}
- left, right = right, right+batchSize
+ left, right = right, right+r.cfg.RAGBatchSize
ctn++
}
finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents))
r.logger.Debug(finishedBatchesMsg)
LongJobStatusCh <- finishedBatchesMsg
- for w := 0; w < workers; w++ {
+ for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}
// wait for emb to be done
@@ -241,7 +235,7 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
- err = fmt.Errorf("non 200 resp; code: %v\n", resp.StatusCode)
+ err = fmt.Errorf("non 200 resp; code: %v", resp.StatusCode)
r.logger.Error(err.Error())
return nil, err
}
diff --git a/server.go b/server.go
index 5a1a1c3..5654855 100644
--- a/server.go
+++ b/server.go
@@ -1,7 +1,7 @@
package main
import (
- "elefant/config"
+ "gf-lt/config"
"encoding/json"
"fmt"
"net/http"
diff --git a/session.go b/session.go
index 7d790f3..dbfa645 100644
--- a/session.go
+++ b/session.go
@@ -1,12 +1,13 @@
package main
import (
- "elefant/models"
+ "gf-lt/models"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
+ "path/filepath"
"strings"
"time"
)
@@ -34,6 +35,24 @@ func exportChat() error {
return os.WriteFile(activeChatName+".json", data, 0666)
}
+func importChat(filename string) error {
+ data, err := os.ReadFile(filename)
+ if err != nil {
+ return err
+ }
+ messages := []models.RoleMsg{}
+ if err := json.Unmarshal(data, &messages); err != nil {
+ return err
+ }
+ activeChatName = filepath.Base(filename)
+ chatBody.Messages = messages
+ cfg.AssistantRole = messages[1].Role
+ if cfg.AssistantRole == cfg.UserRole {
+ cfg.AssistantRole = messages[2].Role
+ }
+ return nil
+}
+
func updateStorageChat(name string, msgs []models.RoleMsg) error {
var err error
chat, ok := chatMap[name]
diff --git a/storage/memory.go b/storage/memory.go
index c9fc853..406182f 100644
--- a/storage/memory.go
+++ b/storage/memory.go
@@ -1,6 +1,6 @@
package storage
-import "elefant/models"
+import "gf-lt/models"
type Memories interface {
Memorise(m *models.Memory) (*models.Memory, error)
diff --git a/storage/storage.go b/storage/storage.go
index f759700..7911e13 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -1,7 +1,7 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"log/slog"
_ "github.com/glebarez/go-sqlite"
diff --git a/storage/storage_test.go b/storage/storage_test.go
index ff3b5e6..a1c4cf4 100644
--- a/storage/storage_test.go
+++ b/storage/storage_test.go
@@ -1,18 +1,15 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"fmt"
- "log"
"log/slog"
"os"
"testing"
"time"
- sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
- "github.com/ncruces/go-sqlite3"
)
func TestMemories(t *testing.T) {
@@ -177,87 +174,87 @@ func TestChatHistory(t *testing.T) {
}
}
-func TestVecTable(t *testing.T) {
- // healthcheck
- db, err := sqlite3.Open(":memory:")
- if err != nil {
- t.Fatal(err)
- }
- stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
- if err != nil {
- t.Fatal(err)
- }
- stmt.Step()
- log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
- stmt.Close()
- // migration
- err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
- if err != nil {
- t.Fatal(err)
- }
- // data prep and insert
- items := map[int][]float32{
- 1: {0.1, 0.1, 0.1, 0.1},
- 2: {0.2, 0.2, 0.2, 0.2},
- 3: {0.3, 0.3, 0.3, 0.3},
- 4: {0.4, 0.4, 0.4, 0.4},
- 5: {0.5, 0.5, 0.5, 0.5},
- }
- q := []float32{0.28, 0.3, 0.3, 0.3}
- stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
- if err != nil {
- t.Fatal(err)
- }
- for id, values := range items {
- v, err := sqlite_vec.SerializeFloat32(values)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindInt(1, id)
- stmt.BindBlob(2, v)
- stmt.BindText(3, "some_chat")
- err = stmt.Exec()
- if err != nil {
- t.Fatal(err)
- }
- stmt.Reset()
- }
- stmt.Close()
- // select | vec search
- stmt, _, err = db.Prepare(`
- SELECT
- rowid,
- distance,
- embedding
- FROM vec_items
- WHERE embedding MATCH ?
- ORDER BY distance
- LIMIT 3
- `)
- if err != nil {
- t.Fatal(err)
- }
- query, err := sqlite_vec.SerializeFloat32(q)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindBlob(1, query)
- for stmt.Step() {
- rowid := stmt.ColumnInt64(0)
- distance := stmt.ColumnFloat(1)
- emb := stmt.ColumnRawText(2)
- floats := decodeUnsafe(emb)
- log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
- }
- if err := stmt.Err(); err != nil {
- t.Fatal(err)
- }
- err = stmt.Close()
- if err != nil {
- t.Fatal(err)
- }
- err = db.Close()
- if err != nil {
- t.Fatal(err)
- }
-}
+// func TestVecTable(t *testing.T) {
+// // healthcheck
+// db, err := sqlite3.Open(":memory:")
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Step()
+// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
+// stmt.Close()
+// // migration
+// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// // data prep and insert
+// items := map[int][]float32{
+// 1: {0.1, 0.1, 0.1, 0.1},
+// 2: {0.2, 0.2, 0.2, 0.2},
+// 3: {0.3, 0.3, 0.3, 0.3},
+// 4: {0.4, 0.4, 0.4, 0.4},
+// 5: {0.5, 0.5, 0.5, 0.5},
+// }
+// q := []float32{0.4, 0.3, 0.3, 0.3}
+// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// for id, values := range items {
+// v, err := sqlite_vec.SerializeFloat32(values)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindInt(1, id)
+// stmt.BindBlob(2, v)
+// stmt.BindText(3, "some_chat")
+// err = stmt.Exec()
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Reset()
+// }
+// stmt.Close()
+// // select | vec search
+// stmt, _, err = db.Prepare(`
+// SELECT
+// rowid,
+// distance,
+// embedding
+// FROM vec_items
+// WHERE embedding MATCH ?
+// ORDER BY distance
+// LIMIT 3
+// `)
+// if err != nil {
+// t.Fatal(err)
+// }
+// query, err := sqlite_vec.SerializeFloat32(q)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindBlob(1, query)
+// for stmt.Step() {
+// rowid := stmt.ColumnInt64(0)
+// distance := stmt.ColumnFloat(1)
+// emb := stmt.ColumnRawText(2)
+// floats := decodeUnsafe(emb)
+// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
+// }
+// if err := stmt.Err(); err != nil {
+// t.Fatal(err)
+// }
+// err = stmt.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// err = db.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// }
diff --git a/storage/vector.go b/storage/vector.go
index 5e9069c..71005e4 100644
--- a/storage/vector.go
+++ b/storage/vector.go
@@ -1,7 +1,7 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"errors"
"fmt"
"unsafe"
diff --git a/sysprompts/cluedo.json b/sysprompts/cluedo.json
new file mode 100644
index 0000000..0c90cb5
--- /dev/null
+++ b/sysprompts/cluedo.json
@@ -0,0 +1,7 @@
+{
+ "sys_prompt": "A game of cluedo. Players are {{user}}, {{char}}, {{char2}};\n\nrooms: hall, lounge, dinning room kitchen, ballroom, conservatory, billiard room, library, study;\nweapons: candlestick, dagger, lead pipe, revolver, rope, spanner;\npeople: miss Scarlett, colonel Mustard, mrs. White, reverend Green, mrs. Peacock, professor Plum;\n\nA murder happened in a mansion with 9 rooms. Victim is dr. Black.\nPlayers goal is to find out who commited a murder, in what room and with what weapon.\nWeapons, people and rooms not involved in murder are distributed between players (as cards) by tool agent.\nThe objective of the game is to deduce the details of the murder. There are six characters, six murder weapons, and nine rooms, leaving the players with 324 possibilities. As soon as a player enters a room, they may make a suggestion as to the details, naming a suspect, the room they are in, and the weapon. For example: \"I suspect Professor Plum, in the Dining Room, with the candlestick\".\nOnce a player makes a suggestion, the others are called upon to disprove it.\nBefore the player's move, tool agent will remind that players their cards. There are two types of moves: making a suggestion (suggestion_move) and disproving other player suggestion (evidence_move);\nIn this version player wins when the correct details are named in the suggestion_move.\n\n<example_game>\n{{user}}:\nlet's start a game of cluedo!\ntool: cards of {{char}} are 'LEAD PIPE', 'BALLROOM', 'CONSERVATORY', 'STUDY', 'Mrs. White'; suggestion_move;\n{{char}}:\n(putting miss Scarlet into the Hall with the Revolver) \"I suspect miss Scarlett, in the Hall, with the revolver.\"\ntool: cards of {{char2}} are 'SPANNER', 'DAGGER', 'Professor Plum', 'LIBRARY', 'Mrs. Peacock'; evidence_move;\n{{char2}}:\n\"No objections.\" (no cards matching the suspicion of {{char}})\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; evidence_move;\n{{user}}:\n\"I object. Miss Scarlett is innocent.\" (shows card with 'Miss Scarlett')\ntool: cards of {{char2}} are 'SPANNER', 'DAGGER', 'Professor Plum', 'LIBRARY', 'Mrs. Peacock'; suggestion_move;\n{{char2}}:\n*So it was not Miss Scarlett, good to know.*\n(moves Mrs. White to the Billiard Room) \"It might have been Mrs. White, in the Billiard Room, with the Revolver.\"\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; evidence_move;\n{{user}}:\n(no matching cards for the assumption of {{char2}}) \"Sounds possible to me.\"\ntool: cards of {{char}} are 'LEAD PIPE', 'BALLROOM', 'CONSERVATORY', 'STUDY', 'Mrs. White'; evidence_move;\n{{char}}:\n(shows Mrs. White card) \"No. Was not Mrs. White\"\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; suggestion_move;\n{{user}}:\n*So not Mrs. White...* (moves Reverend Green into the Billiard Room) \"I suspect Reverend Green, in the Billiard Room, with the Revolver.\"\ntool: Correct. It was Reverend Green in the Billiard Room, with the revolver. {{user}} wins.\n</example_game>",
+ "role": "CluedoPlayer",
+ "role2": "CluedoEnjoyer",
+ "filepath": "sysprompts/cluedo.json",
+ "first_msg": "Hey guys! Want to play cluedo?"
+}
diff --git a/tables.go b/tables.go
index e281dd2..c4c97b9 100644
--- a/tables.go
+++ b/tables.go
@@ -7,16 +7,16 @@ import (
"strings"
"time"
- "elefant/models"
- "elefant/pngmeta"
- "elefant/rag"
+ "gf-lt/models"
+ "gf-lt/pngmeta"
+ "gf-lt/rag"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
)
func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
- actions := []string{"load", "rename", "delete", "update card"}
+ actions := []string{"load", "rename", "delete", "update card", "move sysprompt onto 1st msg", "new_chat_from_card"}
chatList := make([]string, len(chatMap))
i := 0
for name := range chatMap {
@@ -26,9 +26,7 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
rows, cols := len(chatMap), len(actions)+2
chatActTable := tview.NewTable().
SetBorders(true)
- // for chatName, chat := range chatMap {
for r := 0; r < rows; r++ {
- // r := 0
for c := 0; c < cols; c++ {
color := tcell.ColorWhite
switch c {
@@ -49,7 +47,6 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
SetAlign(tview.AlignCenter))
}
}
- // r++
}
chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEsc || key == tcell.KeyF1 {
@@ -65,7 +62,6 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
chatActTable.SetSelectable(false, false)
selectedChat := chatList[row]
defer pages.RemovePage(historyPage)
- // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
switch tc.Text {
case "load":
history, err := loadHistoryChat(selectedChat)
@@ -114,12 +110,12 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
}
return
}
- if chatBody.Messages[0].Role != "system" || chatBody.Messages[1].Role != agentName {
- if err := notifyUser("error", "unexpected chat structure; card: "+agentName); err != nil {
- logger.Warn("failed ot notify", "error", err)
- }
- return
- }
+ // if chatBody.Messages[0].Role != "system" || chatBody.Messages[1].Role != agentName {
+ // if err := notifyUser("error", "unexpected chat structure; card: "+agentName); err != nil {
+ // logger.Warn("failed ot notify", "error", err)
+ // }
+ // return
+ // }
// change sys_prompt + first msg
cc.SysPrompt = chatBody.Messages[0].Content
cc.FirstMsg = chatBody.Messages[1].Content
@@ -128,6 +124,40 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
"error", err)
}
return
+ case "move sysprompt onto 1st msg":
+ chatBody.Messages[1].Content = chatBody.Messages[0].Content + chatBody.Messages[1].Content
+ chatBody.Messages[0].Content = rpDefenitionSysMsg
+ textView.SetText(chatToText(cfg.ShowSys))
+ activeChatName = selectedChat
+ pages.RemovePage(historyPage)
+ return
+ case "new_chat_from_card":
+ // Reread card from file and start fresh chat
+ fi := strings.Index(selectedChat, "_")
+ agentName := selectedChat[fi+1:]
+ cc, ok := sysMap[agentName]
+ if !ok {
+ logger.Warn("no such card", "agent", agentName)
+ if err := notifyUser("error", "no such card: "+agentName); err != nil {
+ logger.Warn("failed to notify", "error", err)
+ }
+ return
+ }
+ // Reload card from disk
+ newCard, err := pngmeta.ReadCard(cc.FilePath, cfg.UserRole)
+ if err != nil {
+ logger.Error("failed to reload charcard", "path", cc.FilePath, "error", err)
+ if err := notifyUser("error", "failed to reload card: "+cc.FilePath); err != nil {
+ logger.Warn("failed to notify", "error", err)
+ }
+ return
+ }
+ // Update sysMap with fresh card data
+ sysMap[agentName] = newCard
+ applyCharCard(newCard)
+ startNewChat()
+ pages.RemovePage(historyPage)
+ return
default:
return
}
@@ -135,7 +165,7 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
return chatActTable
}
-// func makeRAGTable(fileList []string) *tview.Table {
+// nolint:unused
func makeRAGTable(fileList []string) *tview.Flex {
actions := []string{"load", "delete"}
rows, cols := len(fileList), len(actions)+1
@@ -237,59 +267,59 @@ func makeRAGTable(fileList []string) *tview.Flex {
return ragflex
}
-func makeLoadedRAGTable(fileList []string) *tview.Table {
- actions := []string{"delete"}
- rows, cols := len(fileList), len(actions)+1
- fileTable := tview.NewTable().
- SetBorders(true)
- for r := 0; r < rows; r++ {
- for c := 0; c < cols; c++ {
- color := tcell.ColorWhite
- if c < 1 {
- fileTable.SetCell(r, c,
- tview.NewTableCell(fileList[r]).
- SetTextColor(color).
- SetAlign(tview.AlignCenter))
- } else {
- fileTable.SetCell(r, c,
- tview.NewTableCell(actions[c-1]).
- SetTextColor(color).
- SetAlign(tview.AlignCenter))
- }
- }
- }
- fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
- if key == tcell.KeyEsc || key == tcell.KeyF1 {
- pages.RemovePage(RAGPage)
- return
- }
- if key == tcell.KeyEnter {
- fileTable.SetSelectable(true, true)
- }
- }).SetSelectedFunc(func(row int, column int) {
- defer pages.RemovePage(RAGPage)
- tc := fileTable.GetCell(row, column)
- tc.SetTextColor(tcell.ColorRed)
- fileTable.SetSelectable(false, false)
- fpath := fileList[row]
- // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text)
- switch tc.Text {
- case "delete":
- if err := ragger.RemoveFile(fpath); err != nil {
- logger.Error("failed to delete file", "filename", fpath, "error", err)
- return
- }
- if err := notifyUser("chat deleted", fpath+" was deleted"); err != nil {
- logger.Error("failed to send notification", "error", err)
- }
- return
- default:
- // pages.RemovePage(RAGPage)
- return
- }
- })
- return fileTable
-}
+// func makeLoadedRAGTable(fileList []string) *tview.Table {
+// actions := []string{"delete"}
+// rows, cols := len(fileList), len(actions)+1
+// fileTable := tview.NewTable().
+// SetBorders(true)
+// for r := 0; r < rows; r++ {
+// for c := 0; c < cols; c++ {
+// color := tcell.ColorWhite
+// if c < 1 {
+// fileTable.SetCell(r, c,
+// tview.NewTableCell(fileList[r]).
+// SetTextColor(color).
+// SetAlign(tview.AlignCenter))
+// } else {
+// fileTable.SetCell(r, c,
+// tview.NewTableCell(actions[c-1]).
+// SetTextColor(color).
+// SetAlign(tview.AlignCenter))
+// }
+// }
+// }
+// fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
+// if key == tcell.KeyEsc || key == tcell.KeyF1 {
+// pages.RemovePage(RAGPage)
+// return
+// }
+// if key == tcell.KeyEnter {
+// fileTable.SetSelectable(true, true)
+// }
+// }).SetSelectedFunc(func(row int, column int) {
+// defer pages.RemovePage(RAGPage)
+// tc := fileTable.GetCell(row, column)
+// tc.SetTextColor(tcell.ColorRed)
+// fileTable.SetSelectable(false, false)
+// fpath := fileList[row]
+// // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text)
+// switch tc.Text {
+// case "delete":
+// if err := ragger.RemoveFile(fpath); err != nil {
+// logger.Error("failed to delete file", "filename", fpath, "error", err)
+// return
+// }
+// if err := notifyUser("chat deleted", fpath+" was deleted"); err != nil {
+// logger.Error("failed to send notification", "error", err)
+// }
+// return
+// default:
+// // pages.RemovePage(RAGPage)
+// return
+// }
+// })
+// return fileTable
+// }
func makeAgentTable(agentList []string) *tview.Table {
actions := []string{"load"}
@@ -427,3 +457,79 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table {
})
return table
}
+
+func makeImportChatTable(filenames []string) *tview.Table {
+ actions := []string{"load"}
+ rows, cols := len(filenames), len(actions)+1
+ chatActTable := tview.NewTable().
+ SetBorders(true)
+ for r := 0; r < rows; r++ {
+ for c := 0; c < cols; c++ {
+ color := tcell.ColorWhite
+ if c < 1 {
+ chatActTable.SetCell(r, c,
+ tview.NewTableCell(filenames[r]).
+ SetTextColor(color).
+ SetAlign(tview.AlignCenter))
+ } else {
+ chatActTable.SetCell(r, c,
+ tview.NewTableCell(actions[c-1]).
+ SetTextColor(color).
+ SetAlign(tview.AlignCenter))
+ }
+ }
+ }
+ chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
+ if key == tcell.KeyEsc || key == tcell.KeyF1 {
+ pages.RemovePage(historyPage)
+ return
+ }
+ if key == tcell.KeyEnter {
+ chatActTable.SetSelectable(true, true)
+ }
+ }).SetSelectedFunc(func(row int, column int) {
+ tc := chatActTable.GetCell(row, column)
+ tc.SetTextColor(tcell.ColorRed)
+ chatActTable.SetSelectable(false, false)
+ selected := filenames[row]
+ // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
+ switch tc.Text {
+ case "load":
+ if err := importChat(selected); err != nil {
+ logger.Warn("failed to import chat", "filename", selected)
+ pages.RemovePage(historyPage)
+ return
+ }
+ colorText()
+ updateStatusLine()
+ // redraw the text in text area
+ textView.SetText(chatToText(cfg.ShowSys))
+ pages.RemovePage(historyPage)
+ app.SetFocus(textArea)
+ return
+ case "rename":
+ pages.RemovePage(historyPage)
+ pages.AddPage(renamePage, renameWindow, true, true)
+ return
+ case "delete":
+ sc, ok := chatMap[selected]
+ if !ok {
+ // no chat found
+ pages.RemovePage(historyPage)
+ return
+ }
+ if err := store.RemoveChat(sc.ID); err != nil {
+ logger.Error("failed to remove chat from db", "chat_id", sc.ID, "chat_name", sc.Name)
+ }
+ if err := notifyUser("chat deleted", selected+" was deleted"); err != nil {
+ logger.Error("failed to send notification", "error", err)
+ }
+ pages.RemovePage(historyPage)
+ return
+ default:
+ pages.RemovePage(historyPage)
+ return
+ }
+ })
+ return chatActTable
+}
diff --git a/tools.go b/tools.go
index a380bf5..fe95ce5 100644
--- a/tools.go
+++ b/tools.go
@@ -1,7 +1,7 @@
package main
import (
- "elefant/models"
+ "gf-lt/models"
"fmt"
"regexp"
"strings"
@@ -9,11 +9,16 @@ import (
)
var (
- toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`)
- quotesRE = regexp.MustCompile(`(".*?")`)
- starRE = regexp.MustCompile(`(\*.*?\*)`)
- thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`)
- codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`)
+ toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`)
+ quotesRE = regexp.MustCompile(`(".*?")`)
+ starRE = regexp.MustCompile(`(\*.*?\*)`)
+ thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`)
+ codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`)
+ rpDefenitionSysMsg = `
+For this roleplay immersion is at most importance.
+Every character thinks and acts based on their personality and setting of the roleplay.
+Meta discussions outside of roleplay is allowed if clearly labeled as out of character, for example: (ooc: {msg}) or <ooc>{msg}</ooc>.
+`
basicSysMsg = `Large Language Model that helps user with any of his requests.`
toolSysMsg = `You can do functions call if needed.
Your current tools:
@@ -21,7 +26,7 @@ Your current tools:
[
{
"name":"recall",
-"args": "topic",
+"args": ["topic"],
"when_to_use": "when asked about topic that user previously asked to memorise"
},
{
@@ -31,7 +36,7 @@ Your current tools:
},
{
"name":"recall_topics",
-"args": null,
+"args": [],
"when_to_use": "to see what topics are saved in memory"
}
]
@@ -41,7 +46,7 @@ To make a function call return a json object within __tool_call__ tags;
__tool_call__
{
"name":"recall",
-"args": "Adam's number"
+"args": ["Adam's number"]
}
__tool_call__
</example_request>
diff --git a/tui.go b/tui.go
index 16d63e5..2b5c599 100644
--- a/tui.go
+++ b/tui.go
@@ -1,13 +1,16 @@
package main
import (
- "elefant/models"
- "elefant/pngmeta"
"fmt"
+ "gf-lt/extra"
+ "gf-lt/models"
+ "gf-lt/pngmeta"
"image"
_ "image/jpeg"
_ "image/png"
"os"
+ "path"
+ "slices"
"strconv"
"strings"
@@ -35,11 +38,12 @@ var (
indexPage = "indexPage"
helpPage = "helpPage"
renamePage = "renamePage"
- RAGPage = "RAGPage "
+ RAGPage = "RAGPage"
propsPage = "propsPage"
codeBlockPage = "codeBlockPage"
imgPage = "imgPage"
// help text
+ // [yellow]F10[white]: manage loaded rag files (that already in vector db)
helpText = `
[yellow]Esc[white]: send msg
[yellow]PgUp/Down[white]: switch focus between input and chat widgets
@@ -52,8 +56,7 @@ var (
[yellow]F7[white]: copy last msg to clipboard (linux xclip)
[yellow]F8[white]: copy n msg to clipboard (linux xclip)
[yellow]F9[white]: table to copy from; with all code blocks
-[yellow]F10[white]: manage loaded rag files (that already in vector db)
-[yellow]F11[white]: switch RAGEnabled boolean
+[yellow]F11[white]: import chat file
[yellow]F12[white]: show this help page
[yellow]Ctrl+w[white]: resume generation on the last msg
[yellow]Ctrl+s[white]: load new char/agent
@@ -62,11 +65,12 @@ var (
[yellow]Ctrl+c[white]: close programm
[yellow]Ctrl+p[white]: props edit form (min-p, dry, etc.)
[yellow]Ctrl+v[white]: switch between /completion and /chat api (if provided in config)
-[yellow]Ctrl+r[white]: menu of files that can be loaded in vector db (RAG)
+[yellow]Ctrl+r[white]: start/stop recording from your microphone (needs stt server)
[yellow]Ctrl+t[white]: remove thinking (<think>) and tool messages from context (delete from chat)
[yellow]Ctrl+l[white]: update connected model name (llamacpp)
[yellow]Ctrl+k[white]: switch tool use (recommend tool use to llm after user msg)
[yellow]Ctrl+j[white]: if chat agent is char.png will show the image; then any key to return
+[yellow]Ctrl+a[white]: interrupt tts (needs tts server)
Press Enter to go back
`
@@ -115,10 +119,10 @@ func colorText() {
text = thinkRE.ReplaceAllStringFunc(text, func(match string) string {
// Style the code block and store it
styled := fmt.Sprintf("[red::i]%s[-:-:-]", match)
- thinkBlocks = append(codeBlocks, styled)
+ thinkBlocks = append(thinkBlocks, styled)
// Generate a unique placeholder (e.g., "__CODE_BLOCK_0__")
id := fmt.Sprintf(placeholderThink, counterThink)
- counter++
+ counterThink++
return id
})
// Step 2: Apply other regex styles to the non-code parts
@@ -129,6 +133,7 @@ func colorText() {
for i, cb := range codeBlocks {
text = strings.Replace(text, fmt.Sprintf(placeholder, i), cb, 1)
}
+ logger.Debug("thinking debug", "blocks", thinkBlocks)
for i, tb := range thinkBlocks {
text = strings.Replace(text, fmt.Sprintf(placeholderThink, i), tb, 1)
}
@@ -136,7 +141,11 @@ func colorText() {
}
func updateStatusLine() {
- position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
+ isRecording := false
+ if asr != nil {
+ isRecording = asr.IsRecording()
+ }
+ position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level(), isRecording))
}
func initSysCards() ([]string, error) {
@@ -158,6 +167,36 @@ func initSysCards() ([]string, error) {
return labels, nil
}
+func renameUser(oldname, newname string) {
+ if oldname == "" {
+ // not provided; deduce who user is
+ // INFO: if user not yet spoke, it is hard to replace mentions in sysprompt and first message about thme
+ roles := chatBody.ListRoles()
+ for _, role := range roles {
+ if role == cfg.AssistantRole {
+ continue
+ }
+ if role == cfg.ToolRole {
+ continue
+ }
+ if role == "system" {
+ continue
+ }
+ oldname = role
+ break
+ }
+ if oldname == "" {
+ // still
+ logger.Warn("fn: renameUser; failed to find old name", "newname", newname)
+ return
+ }
+ }
+ viewText := textView.GetText(false)
+ viewText = strings.ReplaceAll(viewText, oldname, newname)
+ chatBody.Rename(oldname, newname)
+ textView.SetText(viewText)
+}
+
func startNewChat() {
id, err := store.ChatGetMaxID()
if err != nil {
@@ -199,10 +238,23 @@ func makePropsForm(props map[string]float32) *tview.Form {
AddTextView("Notes", "Props for llamacpp completion call", 40, 2, true, false).
AddCheckbox("Insert <think> (/completion only)", cfg.ThinkUse, func(checked bool) {
cfg.ThinkUse = checked
- }).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
+ }).AddCheckbox("RAG use", cfg.RAGEnabled, func(checked bool) {
+ cfg.RAGEnabled = checked
+ }).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
func(option string, optionIndex int) {
setLogLevel(option)
- }).
+ }).AddDropDown("Select an api: ", slices.Insert(cfg.ApiLinks, 0, cfg.CurrentAPI), 0,
+ func(option string, optionIndex int) {
+ cfg.CurrentAPI = option
+ }).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0,
+ func(option string, optionIndex int) {
+ chatBody.Model = option
+ }).AddInputField("username: ", cfg.UserRole, 32, tview.InputFieldMaxLength(32), func(text string) {
+ if text != "" {
+ renameUser(cfg.UserRole, text)
+ cfg.UserRole = text
+ }
+ }).
AddButton("Quit", func() {
pages.RemovePage(propsPage)
})
@@ -283,7 +335,9 @@ func init() {
SetPlaceholder("Replace msg...")
editArea.SetBorder(true).SetTitle("input")
editArea.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
- if event.Key() == tcell.KeyEscape && editMode {
+ // if event.Key() == tcell.KeyEscape && editMode {
+ if event.Key() == tcell.KeyEscape {
+ logger.Warn("edit debug; esc is pressed")
defer colorText()
editedMsg := editArea.GetText()
if editedMsg == "" {
@@ -291,7 +345,6 @@ func init() {
logger.Error("failed to send notification", "error", err)
}
pages.RemovePage(editMsgPage)
- editMode = false
return nil
}
chatBody.Messages[selectedIndex].Content = editedMsg
@@ -310,8 +363,8 @@ func init() {
SetDoneFunc(func(key tcell.Key) {
defer indexPickWindow.SetText("")
pages.RemovePage(indexPage)
- colorText()
- updateStatusLine()
+ // colorText()
+ // updateStatusLine()
})
indexPickWindow.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
switch event.Key() {
@@ -340,6 +393,7 @@ func init() {
}
m := chatBody.Messages[selectedIndex]
if editMode && event.Key() == tcell.KeyEnter {
+ pages.RemovePage(indexPage)
pages.AddPage(editMsgPage, editArea, true, true)
editArea.SetText(m.Content, true)
}
@@ -450,6 +504,8 @@ func init() {
if event.Key() == tcell.KeyF2 {
// regen last msg
chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1]
+ // there is no case where user msg is regenerated
+ // lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role
textView.SetText(chatToText(cfg.ShowSys))
go chatRound("", cfg.UserRole, textView, true, false)
return nil
@@ -525,26 +581,43 @@ func init() {
// updateStatusLine()
return nil
}
- if event.Key() == tcell.KeyF10 {
- // list rag loaded in db
- loadedFiles, err := ragger.ListLoaded()
+ // if event.Key() == tcell.KeyF10 {
+ // // list rag loaded in db
+ // loadedFiles, err := ragger.ListLoaded()
+ // if err != nil {
+ // logger.Error("failed to list regfiles in db", "error", err)
+ // return nil
+ // }
+ // if len(loadedFiles) == 0 {
+ // if err := notifyUser("loaded RAG", "no files in db"); err != nil {
+ // logger.Error("failed to send notification", "error", err)
+ // }
+ // return nil
+ // }
+ // dbRAGTable := makeLoadedRAGTable(loadedFiles)
+ // pages.AddPage(RAGPage, dbRAGTable, true, true)
+ // return nil
+ // }
+ if event.Key() == tcell.KeyF11 {
+ // read files in chat_exports
+ dirname := "chat_exports"
+ filelist, err := os.ReadDir(dirname)
if err != nil {
- logger.Error("failed to list regfiles in db", "error", err)
- return nil
- }
- if len(loadedFiles) == 0 {
- if err := notifyUser("loaded RAG", "no files in db"); err != nil {
+ if err := notifyUser("failed to load exports", err.Error()); err != nil {
logger.Error("failed to send notification", "error", err)
}
- return nil
}
- dbRAGTable := makeLoadedRAGTable(loadedFiles)
- pages.AddPage(RAGPage, dbRAGTable, true, true)
- return nil
- }
- if event.Key() == tcell.KeyF11 {
- // xor
- cfg.RAGEnabled = !cfg.RAGEnabled
+ fli := []string{}
+ for _, f := range filelist {
+ if f.IsDir() || !strings.HasSuffix(f.Name(), ".json") {
+ continue
+ }
+ fpath := path.Join(dirname, f.Name())
+ fli = append(fli, fpath)
+ }
+ // check error
+ exportsTable := makeImportChatTable(fli)
+ pages.AddPage(historyPage, exportsTable, true, true)
updateStatusLine()
return nil
}
@@ -590,15 +663,18 @@ func init() {
}
if event.Key() == tcell.KeyCtrlV {
// switch between /chat and /completion api
- prevAPI := cfg.CurrentAPI
newAPI := cfg.APIMap[cfg.CurrentAPI]
if newAPI == "" {
// do not switch
return nil
}
- cfg.APIMap[newAPI] = prevAPI
cfg.CurrentAPI = newAPI
- initChunkParser()
+ if strings.Contains(cfg.CurrentAPI, "deepseek") {
+ chatBody.Model = "deepseek-chat"
+ } else {
+ chatBody.Model = "local"
+ }
+ choseChunkParser()
updateStatusLine()
return nil
}
@@ -631,24 +707,55 @@ func init() {
pages.AddPage(imgPage, imgView, true, true)
return nil
}
- if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" {
- // rag load
- // menu of the text files from defined rag directory
- files, err := os.ReadDir(cfg.RAGDir)
- if err != nil {
- logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err)
+ // TODO: move to menu or table
+ // if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" {
+ // // rag load
+ // // menu of the text files from defined rag directory
+ // files, err := os.ReadDir(cfg.RAGDir)
+ // if err != nil {
+ // logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err)
+ // return nil
+ // }
+ // fileList := []string{}
+ // for _, f := range files {
+ // if f.IsDir() {
+ // continue
+ // }
+ // fileList = append(fileList, f.Name())
+ // }
+ // chatRAGTable := makeRAGTable(fileList)
+ // pages.AddPage(RAGPage, chatRAGTable, true, true)
+ // return nil
+ // }
+ if event.Key() == tcell.KeyCtrlR && cfg.STT_ENABLED {
+ defer updateStatusLine()
+ if asr.IsRecording() {
+ userSpeech, err := asr.StopRecording()
+ if err != nil {
+ logger.Error("failed to inference user speech", "error", err)
+ return nil
+ }
+ if userSpeech != "" {
+ // append indtead of replacing
+ prevText := textArea.GetText()
+ textArea.SetText(prevText+userSpeech, true)
+ } else {
+ logger.Warn("empty user speech")
+ }
return nil
}
- fileList := []string{}
- for _, f := range files {
- if f.IsDir() {
- continue
- }
- fileList = append(fileList, f.Name())
+ if err := asr.StartRecording(); err != nil {
+ logger.Error("failed to start recording user speech", "error", err)
+ return nil
+ }
+ }
+ // I need keybind for tts to shut up
+ if event.Key() == tcell.KeyCtrlA {
+ textArea.SetText("pressed ctrl+A", true)
+ if cfg.TTS_ENABLED {
+ // audioStream.TextChan <- chunk
+ extra.TTSDoneChan <- true
}
- chatRAGTable := makeRAGTable(fileList)
- pages.AddPage(RAGPage, chatRAGTable, true, true)
- return nil
}
if event.Key() == tcell.KeyCtrlW {
// INFO: continue bot/text message