diff options
53 files changed, 8090 insertions, 1309 deletions
@@ -1,9 +1,15 @@ *.txt *.json testlog -elefant history/ *.db config.toml sysprompts/* +!sysprompts/cluedo.json history_bak/ +.aider* +tags +gf-lt +gflt +chat_exports/*.json +ragimport diff --git a/.golangci.yml b/.golangci.yml index 66732bf..2c7e552 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,32 +1,43 @@ +version: "2" run: - timeout: 1m concurrency: 2 tests: false - linters: - enable-all: false - disable-all: true + default: none enable: + - bodyclose - errcheck - - gosimple + - fatcontext - govet - ineffassign + - perfsprint + - prealloc - staticcheck - - typecheck - unused - - prealloc - presets: - - performance - -linters-settings: - funlen: - lines: 80 - statements: 50 - lll: - line-length: 80 - + settings: + funlen: + lines: 80 + statements: 50 + lll: + line-length: 80 + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ issues: - exclude: - # Display all issues max-issues-per-linter: 0 max-same-issues: 0 +formatters: + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ @@ -1,13 +1,79 @@ -.PHONY: setconfig run lint +.PHONY: setconfig run lint setup-whisper build-whisper download-whisper-model docker-up docker-down docker-logs run: setconfig - go build -o elefant && ./elefant + go build -o gf-lt && ./gf-lt server: setconfig - go build -o elefant && ./elefant -port 3333 + go build -o gf-lt && ./gf-lt -port 3333 setconfig: find config.toml &>/dev/null || cp config.example.toml config.toml lint: ## Run linters. Use make install-linters first. golangci-lint run -c .golangci.yml ./... + +# Whisper STT Setup (in batteries directory) +setup-whisper: build-whisper download-whisper-model + +build-whisper: ## Build whisper.cpp from source in batteries directory + @echo "Building whisper.cpp from source in batteries directory..." + @if [ ! -d "batteries/whisper.cpp" ]; then \ + echo "Cloning whisper.cpp repository to batteries directory..."; \ + git clone https://github.com/ggml-org/whisper.cpp.git batteries/whisper.cpp; \ + fi + cd batteries/whisper.cpp && cmake -B build -DGGML_CUDA=ON -DWHISPER_SDL2=ON; cmake --build build --config Release -j 8 + @echo "Whisper binary built successfully!" + +download-whisper-model: ## Download Whisper model for STT in batteries directory + @echo "Downloading Whisper model for STT..." + @if [ ! -d "batteries/whisper.cpp" ]; then \ + echo "Please run 'make setup-whisper' first to clone the repository."; \ + exit 1; \ + fi + @cd batteries/whisper.cpp && bash ./models/download-ggml-model.sh large-v3-turbo-q5_0 + @echo "Whisper model downloaded successfully!" + +# Docker targets for STT/TTS services (in batteries directory) +docker-up: ## Start all Docker Compose services for STT and TTS from batteries directory + @echo "Starting Docker services for STT (whisper) and TTS (kokoro)..." + @echo "Note: The Whisper model will be downloaded automatically inside the container on first run" + docker-compose -f batteries/docker-compose.yml up -d + @echo "Docker services started. STT available at http://localhost:8081, TTS available at http://localhost:8880" + +docker-up-whisper: ## Start only the Whisper STT service + @echo "Starting Whisper STT service only..." + @echo "Note: The Whisper model will be downloaded automatically inside the container on first run" + docker-compose -f batteries/docker-compose.yml up -d whisper + @echo "Whisper STT service started. Available at http://localhost:8081" + +docker-up-kokoro: ## Start only the Kokoro TTS service + @echo "Starting Kokoro TTS service only..." + docker-compose -f batteries/docker-compose.yml up -d kokoro-tts + @echo "Kokoro TTS service started. Available at http://localhost:8880" + +docker-down: ## Stop all Docker Compose services from batteries directory + @echo "Stopping Docker services..." + docker-compose -f batteries/docker-compose.yml down + @echo "Docker services stopped" + +docker-down-whisper: ## Stop only the Whisper STT service + @echo "Stopping Whisper STT service..." + docker-compose -f batteries/docker-compose.yml down whisper + @echo "Whisper STT service stopped" + +docker-down-kokoro: ## Stop only the Kokoro TTS service + @echo "Stopping Kokoro TTS service..." + docker-compose -f batteries/docker-compose.yml down kokoro-tts + @echo "Kokoro TTS service stopped" + +docker-logs: ## View logs from all Docker services in batteries directory + @echo "Displaying logs from Docker services..." + docker-compose -f batteries/docker-compose.yml logs -f + +docker-logs-whisper: ## View logs from Whisper STT service only + @echo "Displaying logs from Whisper STT service..." + docker-compose -f batteries/docker-compose.yml logs -f whisper + +docker-logs-kokoro: ## View logs from Kokoro TTS service only + @echo "Displaying logs from Kokoro TTS service..." + docker-compose -f batteries/docker-compose.yml logs -f kokoro-tts @@ -1 +1,82 @@ -#### tui with middleware for LLM use +### gf-lt (grail finder's llm tui) +terminal user interface for large language models. +made with use of [tview](https://github.com/rivo/tview) + +#### has/supports +- character card spec; +- llama.cpp api, deepseek, openrouter (other ones were not tested); +- showing images (not really, for now only if your char card is png it could show it); +- tts/stt (if whisper.cpp server / fastapi tts server are provided); +- image input; + +#### does not have/support +- RAG; (RAG was implemented, but I found it unusable and then sql extention broke, so no RAG); +- MCP; (agentic is implemented, but as a raw and predefined functions for llm to use. see [tools.go](https://github.com/GrailFinder/gf-lt/blob/master/tools.go)); + +#### usage examples + + +#### how to install +(requires golang) +clone the project +``` +cd gf-lt +make +``` + +#### keybindings +while running you can press f12 for list of keys; +``` +Esc: send msg +PgUp/Down: switch focus between input and chat widgets +F1: manage chats +F2: regen last +F3: delete last msg +F4: edit msg +F5: toggle system +F6: interrupt bot resp +F7: copy last msg to clipboard (linux xclip) +F8: copy n msg to clipboard (linux xclip) +F9: table to copy from; with all code blocks +F10: switch if LLM will respond on this message (for user to write multiple messages in a row) +F11: import chat file +F12: show this help page +Ctrl+w: resume generation on the last msg +Ctrl+s: load new char/agent +Ctrl+e: export chat to json file +Ctrl+c: close programm +Ctrl+n: start a new chat +Ctrl+o: open file picker for img input +Ctrl+p: props edit form (min-p, dry, etc.) +Ctrl+v: switch between /completion and /chat api (if provided in config) +Ctrl+r: start/stop recording from your microphone (needs stt server) +Ctrl+t: remove thinking (<think>) and tool messages from context (delete from chat) +Ctrl+l: update connected model name (llamacpp) +Ctrl+k: switch tool use (recommend tool use to llm after user msg) +Ctrl+j: if chat agent is char.png will show the image; then any key to return +Ctrl+a: interrupt tts (needs tts server) +Ctrl+q: cycle through mentioned chars in chat, to pick persona to send next msg as +Ctrl+x: cycle through mentioned chars in chat, to pick persona to send next msg as (for llm) +``` + +#### setting up config +``` +cp config.example.toml config.toml +``` +set values as you need them to be. + +#### setting up STT/TTS services +For speech-to-text (STT) and text-to-speech (TTS) functionality: +1. The project uses Whisper.cpp for STT and Kokoro for TTS +2. Docker Compose automatically downloads the required Whisper model on first run +3. To start all services: `make docker-up` +4. To start only STT service: `make docker-up-whisper` +5. To start only TTS service: `make docker-up-kokoro` +6. To stop all services: `make docker-down` +7. To stop only STT service: `make docker-down-whisper` +8. To stop only TTS service: `make docker-down-kokoro` +9. To view all service logs: `make docker-logs` +10. To view only STT service logs: `make docker-logs-whisper` +11. To view only TTS service logs: `make docker-logs-kokoro` +12. The STT service runs on http://localhost:8081 +13. The TTS service runs on http://localhost:8880 diff --git a/assets/ex01.png b/assets/ex01.png Binary files differnew file mode 100644 index 0000000..b0f5ae3 --- /dev/null +++ b/assets/ex01.png diff --git a/batteries/docker-compose.yml b/batteries/docker-compose.yml new file mode 100644 index 0000000..7cf401b --- /dev/null +++ b/batteries/docker-compose.yml @@ -0,0 +1,51 @@ +services: + # Whisper.cpp STT service + whisper: + image: ghcr.io/ggml-org/whisper.cpp:main-cuda + container_name: whisper-stt + ports: + - "8081:8081" + volumes: + - whisper_models:/app/models + working_dir: /app + entrypoint: "" + command: > + sh -c " + if [ ! -f /app/models/ggml-large-v3-turbo.bin ]; then + echo 'Downloading ggml-large-v3-turbo model...' + ./download-ggml-model.sh large-v3-turbo /app/models + fi && + ./build/bin/whisper-server -m /app/models/ggml-large-v3-turbo.bin -t 4 -p 1 --port 8081 --host 0.0.0.0 + " + environment: + - WHISPER_LOG_LEVEL=3 + # Restart policy in case the service fails + restart: unless-stopped + + + # Kokoro-FastAPI TTS service + kokoro-tts: + # image: ghcr.io/remsky/kokoro-fastapi-cpu:latest + image: ghcr.io/remsky/kokoro-fastapi-gpu:latest + container_name: kokoro-tts + ports: + - "8880:8880" + environment: + - API_LOG_LEVEL=INFO + # For GPU support, uncomment the following lines: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + restart: unless-stopped + +volumes: + models: + driver: local + audio: + driver: local + whisper_models: + driver: local diff --git a/batteries/whisper.cpp b/batteries/whisper.cpp new file mode 160000 +Subproject a88b93f85f08fc6045e5d8a8c3f94b7be0ac8bc @@ -2,14 +2,19 @@ package main import ( "bufio" - "elefant/config" - "elefant/models" - "elefant/rag" - "elefant/storage" + "bytes" + "context" "encoding/json" "fmt" + "gf-lt/config" + "gf-lt/extra" + "gf-lt/models" + "gf-lt/rag" + "gf-lt/storage" + "html" "io" "log/slog" + "net" "net/http" "os" "path" @@ -20,14 +25,18 @@ import ( "github.com/rivo/tview" ) -var httpClient = http.Client{} - var ( - cfg *config.Config - logger *slog.Logger - logLevel = new(slog.LevelVar) + httpClient = &http.Client{} + cluedoState *extra.CluedoRoundInfo // Current game state + playerOrder []string // Turn order tracking + cfg *config.Config + logger *slog.Logger + logLevel = new(slog.LevelVar) +) +var ( activeChatName string chunkChan = make(chan string, 10) + openAIToolChan = make(chan string, 10) streamDone = make(chan bool, 1) chatBody *models.ChatBody store storage.FullRepo @@ -36,41 +45,295 @@ var ( defaultStarterBytes = []byte{} interruptResp = false ragger *rag.RAG - currentModel = "none" chunkParser ChunkParser - defaultLCPProps = map[string]float32{ + lastToolCall *models.FuncCall + lastToolCallID string // Store the ID of the most recent tool call + //nolint:unused // TTS_ENABLED conditionally uses this + orator extra.Orator + asr extra.STT + defaultLCPProps = map[string]float32{ "temperature": 0.8, "dry_multiplier": 0.0, "min_p": 0.05, "n_predict": -1.0, } + ORFreeModels = []string{ + "google/gemini-2.0-flash-exp:free", + "deepseek/deepseek-chat-v3-0324:free", + "mistralai/mistral-small-3.2-24b-instruct:free", + "qwen/qwen3-14b:free", + "google/gemma-3-27b-it:free", + "meta-llama/llama-3.3-70b-instruct:free", + } + LocalModels = []string{} ) -func fetchModelName() *models.LLMModels { - api := "http://localhost:8080/v1/models" +// cleanNullMessages removes messages with null or empty content to prevent API issues +func cleanNullMessages(messages []models.RoleMsg) []models.RoleMsg { + // // deletes tool calls which we don't want for now + // cleaned := make([]models.RoleMsg, 0, len(messages)) + // for _, msg := range messages { + // // is there a sense for this check at all? + // if msg.HasContent() || msg.ToolCallID != "" || msg.Role == cfg.AssistantRole || msg.Role == cfg.WriteNextMsgAsCompletionAgent { + // cleaned = append(cleaned, msg) + // } else { + // // Log filtered messages for debugging + // logger.Warn("filtering out message during cleaning", "role", msg.Role, "content", msg.Content, "tool_call_id", msg.ToolCallID, "has_content", msg.HasContent()) + // } + // } + return consolidateConsecutiveAssistantMessages(messages) +} + +// consolidateConsecutiveAssistantMessages merges consecutive assistant messages into a single message +func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models.RoleMsg { + if len(messages) == 0 { + return messages + } + + consolidated := make([]models.RoleMsg, 0, len(messages)) + currentAssistantMsg := models.RoleMsg{} + isBuildingAssistantMsg := false + + for i := 0; i < len(messages); i++ { + msg := messages[i] + + if msg.Role == cfg.AssistantRole || msg.Role == cfg.WriteNextMsgAsCompletionAgent { + // If this is an assistant message, start or continue building + if !isBuildingAssistantMsg { + // Start accumulating assistant message + currentAssistantMsg = msg.Copy() + isBuildingAssistantMsg = true + } else { + // Continue accumulating - append content to the current assistant message + if currentAssistantMsg.IsContentParts() || msg.IsContentParts() { + // Handle structured content + if !currentAssistantMsg.IsContentParts() { + // Preserve the original ToolCallID before conversion + originalToolCallID := currentAssistantMsg.ToolCallID + // Convert existing content to content parts + currentAssistantMsg = models.NewMultimodalMsg(currentAssistantMsg.Role, []interface{}{models.TextContentPart{Type: "text", Text: currentAssistantMsg.Content}}) + // Restore the original ToolCallID to preserve tool call linking + currentAssistantMsg.ToolCallID = originalToolCallID + } + if msg.IsContentParts() { + currentAssistantMsg.ContentParts = append(currentAssistantMsg.ContentParts, msg.GetContentParts()...) + } else if msg.Content != "" { + currentAssistantMsg.AddTextPart(msg.Content) + } + } else { + // Simple string content + if currentAssistantMsg.Content != "" { + currentAssistantMsg.Content += "\n" + msg.Content + } else { + currentAssistantMsg.Content = msg.Content + } + // ToolCallID is already preserved since we're not creating a new message object when just concatenating content + } + } + } else { + // This is not an assistant message + // If we were building an assistant message, add it to the result + if isBuildingAssistantMsg { + consolidated = append(consolidated, currentAssistantMsg) + isBuildingAssistantMsg = false + } + // Add the non-assistant message + consolidated = append(consolidated, msg) + } + } + + // Don't forget the last assistant message if we were building one + if isBuildingAssistantMsg { + consolidated = append(consolidated, currentAssistantMsg) + } + + return consolidated +} + +// GetLogLevel returns the current log level as a string +func GetLogLevel() string { + level := logLevel.Level() + switch level { + case slog.LevelDebug: + return "Debug" + case slog.LevelInfo: + return "Info" + case slog.LevelWarn: + return "Warn" + default: + // For any other values, return "Info" as default + return "Info" + } +} + +func createClient(connectTimeout time.Duration) *http.Client { + // Custom transport with connection timeout + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Create a dialer with connection timeout + dialer := &net.Dialer{ + Timeout: connectTimeout, + KeepAlive: 30 * time.Second, // Optional + } + return dialer.DialContext(ctx, network, addr) + }, + // Other transport settings (optional) + TLSHandshakeTimeout: connectTimeout, + ResponseHeaderTimeout: connectTimeout, + } + // Client with no overall timeout (or set to streaming-safe duration) + return &http.Client{ + Transport: transport, + Timeout: 0, // No overall timeout (for streaming) + } +} + +func fetchLCPModelName() *models.LCPModels { //nolint - resp, err := httpClient.Get(api) + resp, err := httpClient.Get(cfg.FetchModelNameAPI) if err != nil { - logger.Warn("failed to get model", "link", api, "error", err) + chatBody.Model = "disconnected" + logger.Warn("failed to get model", "link", cfg.FetchModelNameAPI, "error", err) + if err := notifyUser("error", "request failed "+cfg.FetchModelNameAPI); err != nil { + logger.Debug("failed to notify user", "error", err, "fn", "fetchLCPModelName") + } return nil } defer resp.Body.Close() - llmModel := models.LLMModels{} + llmModel := models.LCPModels{} if err := json.NewDecoder(resp.Body).Decode(&llmModel); err != nil { - logger.Warn("failed to decode resp", "link", api, "error", err) + logger.Warn("failed to decode resp", "link", cfg.FetchModelNameAPI, "error", err) return nil } if resp.StatusCode != 200 { - currentModel = "disconnected" + chatBody.Model = "disconnected" return nil } - currentModel = path.Base(llmModel.Data[0].ID) + chatBody.Model = path.Base(llmModel.Data[0].ID) return &llmModel } +// nolint +func fetchDSBalance() *models.DSBalance { + url := "https://api.deepseek.com/user/balance" + method := "GET" + // nolint + req, err := http.NewRequest(method, url, nil) + if err != nil { + logger.Warn("failed to create request", "error", err) + return nil + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken) + res, err := httpClient.Do(req) + if err != nil { + logger.Warn("failed to make request", "error", err) + return nil + } + defer res.Body.Close() + resp := models.DSBalance{} + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return nil + } + return &resp +} + +func fetchORModels(free bool) ([]string, error) { + resp, err := http.Get("https://openrouter.ai/api/v1/models") + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err := fmt.Errorf("failed to fetch or models; status: %s", resp.Status) + return nil, err + } + data := &models.ORModels{} + if err := json.NewDecoder(resp.Body).Decode(data); err != nil { + return nil, err + } + freeModels := data.ListModels(free) + return freeModels, nil +} + +func fetchLCPModels() ([]string, error) { + resp, err := http.Get(cfg.FetchModelNameAPI) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err := fmt.Errorf("failed to fetch or models; status: %s", resp.Status) + return nil, err + } + data := &models.LCPModels{} + if err := json.NewDecoder(resp.Body).Decode(data); err != nil { + return nil, err + } + localModels := data.ListModels() + return localModels, nil +} + func sendMsgToLLM(body io.Reader) { + choseChunkParser() + + var req *http.Request + var err error + + // Capture and log the request body for debugging + if _, ok := body.(*io.LimitedReader); ok { + // If it's a LimitedReader, we need to handle it differently + logger.Debug("request body type is LimitedReader", "parser", chunkParser, "link", cfg.CurrentAPI) + req, err = http.NewRequest("POST", cfg.CurrentAPI, body) + if err != nil { + logger.Error("newreq error", "error", err) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) + req.Header.Set("Accept-Encoding", "gzip") + } else { + // For other reader types, capture and log the body content + bodyBytes, err := io.ReadAll(body) + if err != nil { + logger.Error("failed to read request body for logging", "error", err) + // Create request with original body if reading fails + req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) + if err != nil { + logger.Error("newreq error", "error", err) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return + } + } else { + // Log the request body for debugging + logger.Debug("sending request to API", "api", cfg.CurrentAPI, "body", string(bodyBytes)) + // Create request with the captured body + req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) + if err != nil { + logger.Error("newreq error", "error", err) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return + } + } + + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) + req.Header.Set("Accept-Encoding", "gzip") + } // nolint - resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) + resp, err := httpClient.Do(req) if err != nil { logger.Error("llamacpp api", "error", err) if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { @@ -85,8 +348,7 @@ func sendMsgToLLM(body io.Reader) { for { var ( answerText string - content string - stop bool + chunk *models.TextChunk ) counter++ // to stop from spiriling in infinity read of bad bytes that happens with poor connection @@ -97,12 +359,16 @@ func sendMsgToLLM(body io.Reader) { } line, err := reader.ReadBytes('\n') if err != nil { - logger.Error("error reading response body", "error", err, "line", string(line)) - if err.Error() != "EOF" { - streamDone <- true - break + logger.Error("error reading response body", "error", err, "line", string(line), + "user_role", cfg.UserRole, "parser", chunkParser, "link", cfg.CurrentAPI) + // if err.Error() != "EOF" { + if err := notifyUser("API error", err.Error()); err != nil { + logger.Error("failed to notify", "error", err) } - continue + streamDone <- true + break + // } + // continue } if len(line) <= 1 { if interruptResp { @@ -113,25 +379,51 @@ func sendMsgToLLM(body io.Reader) { // starts with -> data: line = line[6:] logger.Debug("debugging resp", "line", string(line)) - content, stop, err = chunkParser.ParseChunk(line) + if bytes.Equal(line, []byte("[DONE]\n")) { + streamDone <- true + break + } + if bytes.Equal(line, []byte("ROUTER PROCESSING\n")) { + continue + } + chunk, err = chunkParser.ParseChunk(line) if err != nil { - logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI) + logger.Error("error parsing response body", "error", err, + "line", string(line), "url", cfg.CurrentAPI) + if err := notifyUser("LLM Response Error", "Failed to parse LLM response: "+err.Error()); err != nil { + logger.Error("failed to notify user", "error", err) + } streamDone <- true break } - if stop { - if content != "" { - logger.Warn("text inside of finish llmchunk", "chunk", content, "counter", counter) + // Handle error messages in response content + // example needed, since llm could use the word error in the normal msg + // if string(line) != "" && strings.Contains(strings.ToLower(string(line)), "error") { + // logger.Error("API error response detected", "line", line, "url", cfg.CurrentAPI) + // streamDone <- true + // break + // } + if chunk.Finished { + if chunk.Chunk != "" { + logger.Warn("text inside of finish llmchunk", "chunk", chunk, "counter", counter) + answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") + chunkChan <- answerText } streamDone <- true break } if counter == 0 { - content = strings.TrimPrefix(content, " ") + chunk.Chunk = strings.TrimPrefix(chunk.Chunk, " ") } // bot sends way too many \n - answerText = strings.ReplaceAll(content, "\n\n", "\n") + answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") chunkChan <- answerText + openAIToolChan <- chunk.ToolChunk + if chunk.FuncName != "" { + lastToolCall.Name = chunk.FuncName + // Store the tool call ID for the response + lastToolCallID = chunk.ToolID + } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req interruptResp = false @@ -143,60 +435,131 @@ func sendMsgToLLM(body io.Reader) { } func chatRagUse(qText string) (string, error) { + logger.Debug("Starting RAG query", "original_query", qText) tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { + logger.Error("failed to create sentence tokenizer", "error", err) return "", err } - // TODO: this where llm should find the questions in text and ask them + // this where llm should find the questions in text and ask them questionsS := tokenizer.Tokenize(qText) questions := make([]string, len(questionsS)) for i, q := range questionsS { questions[i] = q.Text + logger.Debug("RAG question extracted", "index", i, "question", q.Text) + } + + if len(questions) == 0 { + logger.Warn("No questions extracted from query text", "query", qText) + return "No related results from RAG vector storage.", nil } + respVecs := []models.VectorRow{} for i, q := range questions { + logger.Debug("Processing RAG question", "index", i, "question", q) emb, err := ragger.LineToVector(q) if err != nil { - logger.Error("failed to get embs", "error", err, "index", i, "question", q) + logger.Error("failed to get embeddings for RAG", "error", err, "index", i, "question", q) continue } - vecs, err := store.SearchClosest(emb) + logger.Debug("Got embeddings for question", "index", i, "question_len", len(q), "embedding_len", len(emb)) + + // Create EmbeddingResp struct for the search + embeddingResp := &models.EmbeddingResp{ + Embedding: emb, + Index: 0, // Not used in search but required for the struct + } + vecs, err := ragger.SearchEmb(embeddingResp) if err != nil { - logger.Error("failed to query embs", "error", err, "index", i, "question", q) + logger.Error("failed to query embeddings in RAG", "error", err, "index", i, "question", q) continue } + logger.Debug("RAG search returned vectors", "index", i, "question", q, "vector_count", len(vecs)) respVecs = append(respVecs, vecs...) } + // get raw text resps := []string{} - logger.Debug("sqlvec resp", "vecs len", len(respVecs)) + logger.Debug("RAG query final results", "total_vecs_found", len(respVecs)) for _, rv := range respVecs { resps = append(resps, rv.RawText) + logger.Debug("RAG result", "slug", rv.Slug, "filename", rv.FileName, "raw_text_len", len(rv.RawText)) } + if len(resps) == 0 { - return "No related results from vector storage.", nil + logger.Info("No RAG results found for query", "original_query", qText, "question_count", len(questions)) + return "No related results from RAG vector storage.", nil } - return strings.Join(resps, "\n"), nil + + result := strings.Join(resps, "\n") + logger.Debug("RAG query completed", "result_len", len(result), "response_count", len(resps)) + return result, nil } func roleToIcon(role string) string { return "<" + role + ">: " } +// FIXME: it should not be here; move to extra +func checkGame(role string, tv *tview.TextView) { + // Handle Cluedo game flow + // should go before form msg, since formmsg takes chatBody and makes ioreader out of it + // role is almost always user, unless it's regen or resume + // cannot get in this block, since cluedoState is nil; + if cfg.EnableCluedo { + // Initialize Cluedo game if needed + if cluedoState == nil { + playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2} + cluedoState = extra.CluedoPrepCards(playerOrder) + } + // notifyUser("got in cluedo", "yay") + currentPlayer := playerOrder[0] + playerOrder = append(playerOrder[1:], currentPlayer) // Rotate turns + if role == cfg.UserRole { + fmt.Fprintf(tv, "Your (%s) cards: %s\n", currentPlayer, cluedoState.GetPlayerCards(currentPlayer)) + } else { + chatBody.Messages = append(chatBody.Messages, models.RoleMsg{ + Role: cfg.ToolRole, + Content: cluedoState.GetPlayerCards(currentPlayer), + }) + } + } +} + func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { botRespMode = true - // reader := formMsg(chatBody, userMsg, role) + botPersona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + botPersona = cfg.WriteNextMsgAsCompletionAgent + } + defer func() { botRespMode = false }() + // check that there is a model set to use if is not local + if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { + if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { + if err := notifyUser("bad request", "wrong deepseek model name"); err != nil { + logger.Warn("failed ot notify user", "error", err) + return + } + return + } + } + if !resume { + checkGame(role, tv) + } + choseChunkParser() reader, err := chunkParser.FormMsg(userMsg, role, resume) if reader == nil || err != nil { logger.Error("empty reader from msgs", "role", role, "error", err) return } + if cfg.SkipLLMResp { + return + } go sendMsgToLLM(reader) logger.Debug("looking at vars in chatRound", "msg", userMsg, "regen", regen, "resume", resume) - // TODO: consider case where user msg is regened (not assistant one) if !resume { fmt.Fprintf(tv, "[-:-:b](%d) ", len(chatBody.Messages)) - fmt.Fprint(tv, roleToIcon(cfg.AssistantRole)) + fmt.Fprint(tv, roleToIcon(botPersona)) fmt.Fprint(tv, "[-:-:-]\n") if cfg.ThinkUse && !strings.Contains(cfg.CurrentAPI, "v1") { // fmt.Fprint(tv, "<think>") @@ -204,6 +567,7 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { } } respText := strings.Builder{} + toolResp := strings.Builder{} out: for { select { @@ -211,8 +575,22 @@ out: fmt.Fprint(tv, chunk) respText.WriteString(chunk) tv.ScrollToEnd() + // Send chunk to audio stream handler + if cfg.TTS_ENABLED { + // audioStream.TextChan <- chunk + extra.TTSTextChan <- chunk + } + case toolChunk := <-openAIToolChan: + fmt.Fprint(tv, toolChunk) + toolResp.WriteString(toolChunk) + tv.ScrollToEnd() case <-streamDone: botRespMode = false + if cfg.TTS_ENABLED { + // audioStream.TextChan <- chunk + extra.TTSFlushChan <- true + logger.Debug("sending flushchan signal") + } break out } } @@ -223,9 +601,23 @@ out: // lastM.Content = lastM.Content + respText.String() } else { chatBody.Messages = append(chatBody.Messages, models.RoleMsg{ - Role: cfg.AssistantRole, Content: respText.String(), + Role: botPersona, Content: respText.String(), }) } + + logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages)) + for i, msg := range chatBody.Messages { + logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } + + // // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template + cleanChatBody() + + logger.Debug("chatRound: after cleanChatBody", "messages_after_clean", len(chatBody.Messages)) + for i, msg := range chatBody.Messages { + logger.Debug("chatRound: after cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } + colorText() updateStatusLine() // bot msg is done; @@ -234,42 +626,130 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - findCall(respText.String(), tv) + findCall(respText.String(), toolResp.String(), tv) } -func findCall(msg string, tv *tview.TextView) { - fc := models.FuncCall{} - jsStr := toolCallRE.FindString(msg) - if jsStr == "" { - return +// cleanChatBody removes messages with null or empty content to prevent API issues +func cleanChatBody() { + if chatBody != nil && chatBody.Messages != nil { + originalLen := len(chatBody.Messages) + logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } + + chatBody.Messages = cleanNullMessages(chatBody.Messages) + + logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages)) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } } - prefix := "__tool_call__\n" - suffix := "\n__tool_call__" - jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) - if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { - logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) - return +} + +func findCall(msg, toolCall string, tv *tview.TextView) { + fc := &models.FuncCall{} + if toolCall != "" { + // HTML-decode the tool call string to handle encoded characters like < -> <= + decodedToolCall := html.UnescapeString(toolCall) + openAIToolMap := make(map[string]string) + // respect tool call + if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil { + logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err) + // Send error response to LLM so it can retry or handle the error + toolResponseMsg := models.RoleMsg{ + Role: cfg.ToolRole, + Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), + ToolCallID: lastToolCallID, // Use the stored tool call ID + } + chatBody.Messages = append(chatBody.Messages, toolResponseMsg) + // Clear the stored tool call ID after using it + lastToolCallID = "" + // Trigger the assistant to continue processing with the error message + chatRound("", cfg.AssistantRole, tv, false, false) + return + } + lastToolCall.Args = openAIToolMap + fc = lastToolCall + // Ensure lastToolCallID is set if it's available in the tool call + if lastToolCallID == "" && len(openAIToolMap) > 0 { + // Attempt to extract ID from the parsed tool call if not already set + if id, exists := openAIToolMap["id"]; exists { + lastToolCallID = id + } + } + } else { + jsStr := toolCallRE.FindString(msg) + if jsStr == "" { + return + } + prefix := "__tool_call__\n" + suffix := "\n__tool_call__" + jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) + // HTML-decode the JSON string to handle encoded characters like < -> <= + decodedJsStr := html.UnescapeString(jsStr) + if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil { + logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr) + // Send error response to LLM so it can retry or handle the error + toolResponseMsg := models.RoleMsg{ + Role: cfg.ToolRole, + Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), + } + chatBody.Messages = append(chatBody.Messages, toolResponseMsg) + logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", len(chatBody.Messages)) + // Trigger the assistant to continue processing with the error message + chatRound("", cfg.AssistantRole, tv, false, false) + return + } } // call a func f, ok := fnMap[fc.Name] if !ok { - m := fc.Name + "%s is not implemented" - chatRound(m, cfg.ToolRole, tv, false, false) + m := fc.Name + " is not implemented" + // Create tool response message with the proper tool_call_id + toolResponseMsg := models.RoleMsg{ + Role: cfg.ToolRole, + Content: m, + ToolCallID: lastToolCallID, // Use the stored tool call ID + } + chatBody.Messages = append(chatBody.Messages, toolResponseMsg) + logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) + // Clear the stored tool call ID after using it + lastToolCallID = "" + + // Trigger the assistant to continue processing with the new tool response + // by calling chatRound with empty content to continue the assistant's response + chatRound("", cfg.AssistantRole, tv, false, false) return } - resp := f(fc.Args...) - toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) - chatRound(toolMsg, cfg.ToolRole, tv, false, false) + resp := f(fc.Args) + toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting + logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc) + fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", + "\n", len(chatBody.Messages), cfg.ToolRole, toolMsg) + // Create tool response message with the proper tool_call_id + toolResponseMsg := models.RoleMsg{ + Role: cfg.ToolRole, + Content: toolMsg, + ToolCallID: lastToolCallID, // Use the stored tool call ID + } + chatBody.Messages = append(chatBody.Messages, toolResponseMsg) + logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) + // Clear the stored tool call ID after using it + lastToolCallID = "" + // Trigger the assistant to continue processing with the new tool response + // by calling chatRound with empty content to continue the assistant's response + chatRound("", cfg.AssistantRole, tv, false, false) } func chatToTextSlice(showSys bool) []string { resp := make([]string, len(chatBody.Messages)) for i, msg := range chatBody.Messages { - // INFO: skips system msg - if !showSys && (msg.Role != cfg.AssistantRole && msg.Role != cfg.UserRole) { + // INFO: skips system msg and tool msg + if !showSys && (msg.Role == cfg.ToolRole || msg.Role == "system") { continue } - resp[i] = msg.ToText(i, cfg) + resp[i] = msg.ToText(i) } return resp } @@ -282,41 +762,57 @@ func chatToText(showSys bool) string { func removeThinking(chatBody *models.ChatBody) { msgs := []models.RoleMsg{} for _, msg := range chatBody.Messages { - rm := models.RoleMsg{} + // Filter out tool messages and thinking markers if msg.Role == cfg.ToolRole { continue } // find thinking and remove it - rm.Content = thinkRE.ReplaceAllString(msg.Content, "") - rm.Role = msg.Role + rm := models.RoleMsg{ + Role: msg.Role, + Content: thinkRE.ReplaceAllString(msg.Content, ""), + } msgs = append(msgs, rm) } chatBody.Messages = msgs } +func addNewChat(chatName string) { + id, err := store.ChatGetMaxID() + if err != nil { + logger.Error("failed to get max chat id from db;", "id:", id) + // INFO: will rewrite first chat + } + chat := &models.Chat{ + ID: id + 1, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Agent: cfg.AssistantRole, + } + if chatName == "" { + chatName = fmt.Sprintf("%d_%s", chat.ID, cfg.AssistantRole) + } + chat.Name = chatName + chatMap[chat.Name] = chat + activeChatName = chat.Name +} + func applyCharCard(cc *models.CharCard) { cfg.AssistantRole = cc.Role + // FIXME: remove + // Initialize Cluedo if enabled and matching role + if cfg.EnableCluedo && cc.Role == "CluedoPlayer" { + playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2} + cluedoState = extra.CluedoPrepCards(playerOrder) + } history, err := loadAgentsLastChat(cfg.AssistantRole) if err != nil { + // too much action for err != nil; loadAgentsLastChat needs to be split up logger.Warn("failed to load last agent chat;", "agent", cc.Role, "err", err) history = []models.RoleMsg{ {Role: "system", Content: cc.SysPrompt}, {Role: cfg.AssistantRole, Content: cc.FirstMsg}, } - id, err := store.ChatGetMaxID() - if err != nil { - logger.Error("failed to get max chat id from db;", "id:", id) - // INFO: will rewrite first chat - } - chat := &models.Chat{ - ID: id + 1, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - Agent: cfg.AssistantRole, - } - chat.Name = fmt.Sprintf("%d_%s", chat.ID, cfg.AssistantRole) - chatMap[chat.Name] = chat - activeChatName = chat.Name + addNewChat("") } chatBody.Messages = history } @@ -331,7 +827,13 @@ func charToStart(agentName string) bool { } func init() { - cfg = config.LoadConfigOrDefault("config.toml") + var err error + cfg, err = config.LoadConfig("config.toml") + if err != nil { + fmt.Println("failed to load config.toml") + os.Exit(1) + return + } defaultStarter = []models.RoleMsg{ {Role: "system", Content: basicSysMsg}, {Role: cfg.AssistantRole, Content: defaultFirstMsg}, @@ -339,21 +841,21 @@ func init() { logfile, err := os.OpenFile(cfg.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - logger.Error("failed to open log file", "error", err, "filename", cfg.LogFile) + slog.Error("failed to open log file", "error", err, "filename", cfg.LogFile) return } defaultStarterBytes, err = json.Marshal(defaultStarter) if err != nil { - logger.Error("failed to marshal defaultStarter", "error", err) + slog.Error("failed to marshal defaultStarter", "error", err) return } // load cards basicCard.Role = cfg.AssistantRole - toolCard.Role = cfg.AssistantRole + // toolCard.Role = cfg.AssistantRole // logLevel.Set(slog.LevelInfo) logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel})) - store = storage.NewProviderSQL("test.db", logger) + store = storage.NewProviderSQL(cfg.DBPATH, logger) if store == nil { os.Exit(1) } @@ -364,13 +866,40 @@ func init() { logger.Error("failed to load chat", "error", err) return } + lastToolCall = &models.FuncCall{} lastChat := loadOldChatOrGetNew() chatBody = &models.ChatBody{ - Model: "modl_name", + Model: "modelname", Stream: true, Messages: lastChat, } - initChunkParser() - // go runModelNameTicker(time.Second * 120) - // tempLoad() + // Initialize Cluedo if enabled and matching role + if cfg.EnableCluedo && cfg.AssistantRole == "CluedoPlayer" { + playerOrder = []string{cfg.UserRole, cfg.AssistantRole, cfg.CluedoRole2} + cluedoState = extra.CluedoPrepCards(playerOrder) + } + if cfg.OpenRouterToken != "" { + go func() { + ORModels, err := fetchORModels(true) + if err != nil { + logger.Error("failed to fetch or models", "error", err) + } else { + ORFreeModels = ORModels + } + }() + } + go func() { + LocalModels, err = fetchLCPModels() + if err != nil { + logger.Error("failed to fetch llama.cpp models", "error", err) + } + }() + choseChunkParser() + httpClient = createClient(time.Second * 15) + if cfg.TTS_ENABLED { + orator = extra.NewOrator(logger, cfg) + } + if cfg.STT_ENABLED { + asr = extra.NewSTT(logger, cfg) + } } diff --git a/bot_test.go b/bot_test.go new file mode 100644 index 0000000..2d59c3c --- /dev/null +++ b/bot_test.go @@ -0,0 +1,155 @@ +package main + +import ( + "gf-lt/config" + "gf-lt/models" + "reflect" + "testing" +) + +func TestConsolidateConsecutiveAssistantMessages(t *testing.T) { + // Mock config for testing + testCfg := &config.Config{ + AssistantRole: "assistant", + WriteNextMsgAsCompletionAgent: "", + } + cfg = testCfg + + tests := []struct { + name string + input []models.RoleMsg + expected []models.RoleMsg + }{ + { + name: "no consecutive assistant messages", + input: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + {Role: "user", Content: "How are you?"}, + }, + expected: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + {Role: "user", Content: "How are you?"}, + }, + }, + { + name: "consecutive assistant messages should be consolidated", + input: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "First part"}, + {Role: "assistant", Content: "Second part"}, + {Role: "user", Content: "Thanks"}, + }, + expected: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "First part\nSecond part"}, + {Role: "user", Content: "Thanks"}, + }, + }, + { + name: "multiple sets of consecutive assistant messages", + input: []models.RoleMsg{ + {Role: "user", Content: "First question"}, + {Role: "assistant", Content: "First answer part 1"}, + {Role: "assistant", Content: "First answer part 2"}, + {Role: "user", Content: "Second question"}, + {Role: "assistant", Content: "Second answer part 1"}, + {Role: "assistant", Content: "Second answer part 2"}, + {Role: "assistant", Content: "Second answer part 3"}, + }, + expected: []models.RoleMsg{ + {Role: "user", Content: "First question"}, + {Role: "assistant", Content: "First answer part 1\nFirst answer part 2"}, + {Role: "user", Content: "Second question"}, + {Role: "assistant", Content: "Second answer part 1\nSecond answer part 2\nSecond answer part 3"}, + }, + }, + { + name: "single assistant message (no consolidation needed)", + input: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + }, + expected: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + }, + }, + { + name: "only assistant messages", + input: []models.RoleMsg{ + {Role: "assistant", Content: "First"}, + {Role: "assistant", Content: "Second"}, + {Role: "assistant", Content: "Third"}, + }, + expected: []models.RoleMsg{ + {Role: "assistant", Content: "First\nSecond\nThird"}, + }, + }, + { + name: "user messages at the end are preserved", + input: []models.RoleMsg{ + {Role: "assistant", Content: "First"}, + {Role: "assistant", Content: "Second"}, + {Role: "user", Content: "Final user message"}, + }, + expected: []models.RoleMsg{ + {Role: "assistant", Content: "First\nSecond"}, + {Role: "user", Content: "Final user message"}, + }, + }, + { + name: "tool call ids preserved in consolidation", + input: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "First part", ToolCallID: "call_123"}, + {Role: "assistant", Content: "Second part", ToolCallID: "call_123"}, // Same ID + {Role: "user", Content: "Thanks"}, + }, + expected: []models.RoleMsg{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "First part\nSecond part", ToolCallID: "call_123"}, + {Role: "user", Content: "Thanks"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := consolidateConsecutiveAssistantMessages(tt.input) + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result)) + t.Logf("Result: %+v", result) + t.Logf("Expected: %+v", tt.expected) + return + } + + for i, expectedMsg := range tt.expected { + if i >= len(result) { + t.Errorf("Result has fewer messages than expected at index %d", i) + continue + } + + actualMsg := result[i] + if actualMsg.Role != expectedMsg.Role { + t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role) + } + + if actualMsg.Content != expectedMsg.Content { + t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content) + } + + if actualMsg.ToolCallID != expectedMsg.ToolCallID { + t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID) + } + } + + // Additional check: ensure no messages were lost + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected) + } + }) + } +}
\ No newline at end of file diff --git a/config.example.toml b/config.example.toml index 80e3640..113b7ea 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,6 +1,16 @@ ChatAPI = "http://localhost:8080/v1/chat/completions" CompletionAPI = "http://localhost:8080/completion" -EmbedURL = "http://localhost:8080/v1/embeddings" +FetchModelNameAPI = "http://localhost:8080/v1/models" +# in case you have deepseek token +DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions" +DeepSeekChatAPI = "https://api.deepseek.com/chat/completions" +DeepSeekModel = "deepseek-reasoner" +# DeepSeekToken = "" +# in case you have opentouter token +OpenRouterCompletionAPI = "https://openrouter.ai/api/v1/completions" +OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions" +# OpenRouterToken = "" +EmbedURL = "http://localhost:8082/v1/embeddings" ShowSys = true LogFile = "log.txt" UserRole = "user" @@ -8,3 +18,23 @@ ToolRole = "tool" AssistantRole = "assistant" SysDir = "sysprompts" ChunkLimit = 100000 +# rag settings +RAGBatchSize = 1 +RAGWordLimit = 80 +RAGWorkers = 2 +RAGDir = "ragimport" +# extra tts +TTS_ENABLED = false +TTS_URL = "http://localhost:8880/v1/audio/speech" +TTS_SPEED = 1.2 +# extra stt +STT_ENABLED = false +STT_TYPE = "WHISPER_SERVER" # WHISPER_SERVER or WHISPER_BINARY +STT_URL = "http://localhost:8081/inference" +WhisperBinaryPath = "./batteries/whisper.cpp/build/bin/whisper-cli" # Path to whisper binary (for WHISPER_BINARY mode) +WhisperModelPath = "./batteries/whisper.cpp/ggml-large-v3-turbo-q5_0.bin" # Path to whisper model file (for WHISPER_BINARY mode) +STT_LANG = "en" # Language for speech recognition (for WHISPER_BINARY mode) +STT_SR = 16000 # Sample rate for audio recording +DBPATH = "gflt.db" +FilePickerDir = "." # Directory where file picker should start +FilePickerExts = "png,jpg,jpeg,gif,webp" # Comma-separated list of allowed file extensions for file picker diff --git a/config/config.go b/config/config.go index f26a82e..eef8035 100644 --- a/config/config.go +++ b/config/config.go @@ -1,63 +1,124 @@ package config import ( - "fmt" + "os" "github.com/BurntSushi/toml" ) type Config struct { - ChatAPI string `toml:"ChatAPI"` - CompletionAPI string `toml:"CompletionAPI"` - CurrentAPI string - APIMap map[string]string + EnableCluedo bool `toml:"EnableCluedo"` // Cluedo game mode toggle + CluedoRole2 string `toml:"CluedoRole2"` // Secondary AI role name + ChatAPI string `toml:"ChatAPI"` + CompletionAPI string `toml:"CompletionAPI"` + CurrentAPI string + CurrentProvider string + APIMap map[string]string + FetchModelNameAPI string `toml:"FetchModelNameAPI"` + // ToolsAPI list? + SearchAPI string `toml:"SearchAPI"` + SearchDescribe string `toml:"SearchDescribe"` // - ShowSys bool `toml:"ShowSys"` - LogFile string `toml:"LogFile"` - UserRole string `toml:"UserRole"` - ToolRole string `toml:"ToolRole"` - ToolUse bool `toml:"ToolUse"` - ThinkUse bool `toml:"ThinkUse"` - AssistantRole string `toml:"AssistantRole"` - SysDir string `toml:"SysDir"` - ChunkLimit uint32 `toml:"ChunkLimit"` + ShowSys bool `toml:"ShowSys"` + LogFile string `toml:"LogFile"` + UserRole string `toml:"UserRole"` + ToolRole string `toml:"ToolRole"` + ToolUse bool `toml:"ToolUse"` + ThinkUse bool `toml:"ThinkUse"` + AssistantRole string `toml:"AssistantRole"` + SysDir string `toml:"SysDir"` + ChunkLimit uint32 `toml:"ChunkLimit"` + WriteNextMsgAs string + WriteNextMsgAsCompletionAgent string + SkipLLMResp bool // embeddings RAGEnabled bool `toml:"RAGEnabled"` EmbedURL string `toml:"EmbedURL"` HFToken string `toml:"HFToken"` RAGDir string `toml:"RAGDir"` + // rag settings + RAGWorkers uint32 `toml:"RAGWorkers"` + RAGBatchSize int `toml:"RAGBatchSize"` + RAGWordLimit uint32 `toml:"RAGWordLimit"` + // deepseek + DeepSeekChatAPI string `toml:"DeepSeekChatAPI"` + DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"` + DeepSeekToken string `toml:"DeepSeekToken"` + DeepSeekModel string `toml:"DeepSeekModel"` + ApiLinks []string + // openrouter + OpenRouterChatAPI string `toml:"OpenRouterChatAPI"` + OpenRouterCompletionAPI string `toml:"OpenRouterCompletionAPI"` + OpenRouterToken string `toml:"OpenRouterToken"` + OpenRouterModel string `toml:"OpenRouterModel"` + // TTS + TTS_URL string `toml:"TTS_URL"` + TTS_ENABLED bool `toml:"TTS_ENABLED"` + TTS_SPEED float32 `toml:"TTS_SPEED"` + // STT + STT_TYPE string `toml:"STT_TYPE"` // WHISPER_SERVER, WHISPER_BINARY + STT_URL string `toml:"STT_URL"` + STT_SR int `toml:"STT_SR"` + STT_ENABLED bool `toml:"STT_ENABLED"` + WhisperBinaryPath string `toml:"WhisperBinaryPath"` + WhisperModelPath string `toml:"WhisperModelPath"` + STT_LANG string `toml:"STT_LANG"` + DBPATH string `toml:"DBPATH"` + FilePickerDir string `toml:"FilePickerDir"` + FilePickerExts string `toml:"FilePickerExts"` } -func LoadConfigOrDefault(fn string) *Config { +func LoadConfig(fn string) (*Config, error) { if fn == "" { fn = "config.toml" } config := &Config{} _, err := toml.DecodeFile(fn, &config) if err != nil { - fmt.Println("failed to read config from file, loading default") - config.ChatAPI = "http://localhost:8080/v1/chat/completions" - config.CompletionAPI = "http://localhost:8080/completion" - config.RAGEnabled = false - config.EmbedURL = "http://localhost:8080/v1/embiddings" - config.ShowSys = true - config.LogFile = "log.txt" - config.UserRole = "user" - config.ToolRole = "tool" - config.AssistantRole = "assistant" - config.SysDir = "sysprompts" - config.ChunkLimit = 8192 + return nil, err } config.CurrentAPI = config.ChatAPI config.APIMap = map[string]string{ - config.ChatAPI: config.CompletionAPI, + config.ChatAPI: config.CompletionAPI, + config.CompletionAPI: config.DeepSeekChatAPI, + config.DeepSeekChatAPI: config.DeepSeekCompletionAPI, + config.DeepSeekCompletionAPI: config.OpenRouterCompletionAPI, + config.OpenRouterCompletionAPI: config.OpenRouterChatAPI, + config.OpenRouterChatAPI: config.ChatAPI, } - if config.CompletionAPI != "" { - config.CurrentAPI = config.CompletionAPI - config.APIMap = map[string]string{ - config.CompletionAPI: config.ChatAPI, + // check env if keys not in config + if config.OpenRouterToken == "" { + config.OpenRouterToken = os.Getenv("OPENROUTER_API_KEY") + } + if config.DeepSeekToken == "" { + config.DeepSeekToken = os.Getenv("DEEPSEEK_API_KEY") + } + // Build ApiLinks slice with only non-empty API links + // Only include DeepSeek APIs if DeepSeekToken is provided + if config.DeepSeekToken != "" { + if config.DeepSeekChatAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.DeepSeekChatAPI) + } + if config.DeepSeekCompletionAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.DeepSeekCompletionAPI) + } + } + // Only include OpenRouter APIs if OpenRouterToken is provided + if config.OpenRouterToken != "" { + if config.OpenRouterChatAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.OpenRouterChatAPI) + } + if config.OpenRouterCompletionAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.OpenRouterCompletionAPI) } } + // Always include basic APIs + if config.ChatAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.ChatAPI) + } + if config.CompletionAPI != "" { + config.ApiLinks = append(config.ApiLinks, config.CompletionAPI) + } // if any value is empty fill with default - return config + return config, nil } diff --git a/extra/cluedo.go b/extra/cluedo.go new file mode 100644 index 0000000..1ef11cc --- /dev/null +++ b/extra/cluedo.go @@ -0,0 +1,73 @@ +package extra + +import ( + "math/rand" + "strings" +) + +var ( + rooms = []string{"HALL", "LOUNGE", "DINING ROOM", "KITCHEN", "BALLROOM", "CONSERVATORY", "BILLIARD ROOM", "LIBRARY", "STUDY"} + weapons = []string{"CANDLESTICK", "DAGGER", "LEAD PIPE", "REVOLVER", "ROPE", "SPANNER"} + people = []string{"Miss Scarlett", "Colonel Mustard", "Mrs. White", "Reverend Green", "Mrs. Peacock", "Professor Plum"} +) + +type MurderTrifecta struct { + Murderer string + Weapon string + Room string +} + +type CluedoRoundInfo struct { + Answer MurderTrifecta + PlayersCards map[string][]string +} + +func (c *CluedoRoundInfo) GetPlayerCards(player string) string { + // maybe format it a little + return "cards of " + player + "are " + strings.Join(c.PlayersCards[player], ",") +} + +func CluedoPrepCards(playerOrder []string) *CluedoRoundInfo { + res := &CluedoRoundInfo{} + // Select murder components + trifecta := MurderTrifecta{ + Murderer: people[rand.Intn(len(people))], + Weapon: weapons[rand.Intn(len(weapons))], + Room: rooms[rand.Intn(len(rooms))], + } + // Collect non-murder cards + var notInvolved []string + for _, room := range rooms { + if room != trifecta.Room { + notInvolved = append(notInvolved, room) + } + } + for _, weapon := range weapons { + if weapon != trifecta.Weapon { + notInvolved = append(notInvolved, weapon) + } + } + for _, person := range people { + if person != trifecta.Murderer { + notInvolved = append(notInvolved, person) + } + } + // Shuffle and distribute cards + rand.Shuffle(len(notInvolved), func(i, j int) { + notInvolved[i], notInvolved[j] = notInvolved[j], notInvolved[i] + }) + players := map[string][]string{} + cardsPerPlayer := len(notInvolved) / len(playerOrder) + // playerOrder := []string{"{{user}}", "{{char}}", "{{char2}}"} + for i, player := range playerOrder { + start := i * cardsPerPlayer + end := (i + 1) * cardsPerPlayer + if end > len(notInvolved) { + end = len(notInvolved) + } + players[player] = notInvolved[start:end] + } + res.Answer = trifecta + res.PlayersCards = players + return res +} diff --git a/extra/cluedo_test.go b/extra/cluedo_test.go new file mode 100644 index 0000000..e7a53b1 --- /dev/null +++ b/extra/cluedo_test.go @@ -0,0 +1,50 @@ +package extra + +import ( + "testing" +) + +func TestPrepCards(t *testing.T) { + // Run the function to get the murder combination and player cards + roundInfo := CluedoPrepCards([]string{"{{user}}", "{{char}}", "{{char2}}"}) + // Create a map to track all distributed cards + distributedCards := make(map[string]bool) + // Check that the murder combination cards are not distributed to players + murderCards := []string{roundInfo.Answer.Murderer, roundInfo.Answer.Weapon, roundInfo.Answer.Room} + for _, card := range murderCards { + if distributedCards[card] { + t.Errorf("Murder card %s was distributed to a player", card) + } + } + // Check each player's cards + for player, cards := range roundInfo.PlayersCards { + for _, card := range cards { + // Ensure the card is not part of the murder combination + for _, murderCard := range murderCards { + if card == murderCard { + t.Errorf("Player %s has a murder card: %s", player, card) + } + } + // Ensure the card is unique and not already distributed + if distributedCards[card] { + t.Errorf("Card %s is duplicated in player %s's hand", card, player) + } + distributedCards[card] = true + } + } + // Verify that all non-murder cards are distributed + allCards := append(append([]string{}, rooms...), weapons...) + allCards = append(allCards, people...) + for _, card := range allCards { + isMurderCard := false + for _, murderCard := range murderCards { + if card == murderCard { + isMurderCard = true + break + } + } + if !isMurderCard && !distributedCards[card] { + t.Errorf("Card %s was not distributed to any player", card) + } + } +} diff --git a/extra/stt.go b/extra/stt.go new file mode 100644 index 0000000..e33a94d --- /dev/null +++ b/extra/stt.go @@ -0,0 +1,199 @@ +package extra + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "gf-lt/config" + "io" + "log/slog" + "mime/multipart" + "net/http" + "regexp" + "strings" + "syscall" + + "github.com/gordonklaus/portaudio" +) + +var specialRE = regexp.MustCompile(`\[.*?\]`) + +type STT interface { + StartRecording() error + StopRecording() (string, error) + IsRecording() bool +} + +type StreamCloser interface { + Close() error +} + +func NewSTT(logger *slog.Logger, cfg *config.Config) STT { + switch cfg.STT_TYPE { + case "WHISPER_BINARY": + logger.Debug("stt init, chosen whisper binary") + return NewWhisperBinary(logger, cfg) + case "WHISPER_SERVER": + logger.Debug("stt init, chosen whisper server") + return NewWhisperServer(logger, cfg) + } + return NewWhisperServer(logger, cfg) +} + +type WhisperServer struct { + logger *slog.Logger + ServerURL string + SampleRate int + AudioBuffer *bytes.Buffer + recording bool +} + +func NewWhisperServer(logger *slog.Logger, cfg *config.Config) *WhisperServer { + return &WhisperServer{ + logger: logger, + ServerURL: cfg.STT_URL, + SampleRate: cfg.STT_SR, + AudioBuffer: new(bytes.Buffer), + } +} + +func (stt *WhisperServer) StartRecording() error { + if err := stt.microphoneStream(stt.SampleRate); err != nil { + return fmt.Errorf("failed to init microphone: %w", err) + } + stt.recording = true + return nil +} + +func (stt *WhisperServer) StopRecording() (string, error) { + stt.recording = false + // wait loop to finish? + if stt.AudioBuffer == nil { + err := errors.New("unexpected nil AudioBuffer") + stt.logger.Error(err.Error()) + return "", err + } + // Create WAV header first + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + // Add audio file part + part, err := writer.CreateFormFile("file", "recording.wav") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Stream directly to multipart writer: header + raw data + dataSize := stt.AudioBuffer.Len() + stt.writeWavHeader(part, dataSize) + if _, err := io.Copy(part, stt.AudioBuffer); err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Reset buffer for next recording + stt.AudioBuffer.Reset() + // Add response format field + err = writer.WriteField("response_format", "text") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + if writer.Close() != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Send request + resp, err := http.Post(stt.ServerURL, writer.FormDataContentType(), body) //nolint:noctx + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + defer resp.Body.Close() + // Read and print response + responseTextBytes, err := io.ReadAll(resp.Body) + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + resptext := strings.TrimRight(string(responseTextBytes), "\n") + // in case there are special tokens like [_BEG_] + resptext = specialRE.ReplaceAllString(resptext, "") + return strings.TrimSpace(strings.ReplaceAll(resptext, "\n ", "\n")), nil +} + +func (stt *WhisperServer) writeWavHeader(w io.Writer, dataSize int) { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize)) + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + binary.LittleEndian.PutUint32(header[16:20], 16) + binary.LittleEndian.PutUint16(header[20:22], 1) + binary.LittleEndian.PutUint16(header[22:24], 1) + binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate)) + binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8)) + binary.LittleEndian.PutUint16(header[32:34], 1*(16/8)) + binary.LittleEndian.PutUint16(header[34:36], 16) + copy(header[36:40], "data") + binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize)) + if _, err := w.Write(header); err != nil { + stt.logger.Error("writeWavHeader", "error", err) + } +} + +func (stt *WhisperServer) IsRecording() bool { + return stt.recording +} + +func (stt *WhisperServer) microphoneStream(sampleRate int) error { + // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init + origStderr, errDup := syscall.Dup(syscall.Stderr) + if errDup != nil { + return fmt.Errorf("failed to dup stderr: %w", errDup) + } + nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + _ = syscall.Close(origStderr) // Close the dup'd fd if open fails + return fmt.Errorf("failed to open /dev/null: %w", err) + } + // redirect stderr + _ = syscall.Dup2(nullFD, syscall.Stderr) + // Initialize PortAudio (this is where ALSA warnings occur) + defer func() { + // Restore stderr + _ = syscall.Dup2(origStderr, syscall.Stderr) + _ = syscall.Close(origStderr) + _ = syscall.Close(nullFD) + }() + if err := portaudio.Initialize(); err != nil { + return fmt.Errorf("portaudio init failed: %w", err) + } + in := make([]int16, 64) + stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in) + if err != nil { + if paErr := portaudio.Terminate(); paErr != nil { + return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr) + } + return fmt.Errorf("failed to open microphone: %w", err) + } + go func(stream *portaudio.Stream) { + if err := stream.Start(); err != nil { + stt.logger.Error("microphoneStream", "error", err) + return + } + for { + if !stt.IsRecording() { + return + } + if err := stream.Read(); err != nil { + stt.logger.Error("reading stream", "error", err) + return + } + if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil { + stt.logger.Error("writing to buffer", "error", err) + return + } + } + }(stream) + return nil +} diff --git a/extra/tts.go b/extra/tts.go new file mode 100644 index 0000000..c6f373a --- /dev/null +++ b/extra/tts.go @@ -0,0 +1,212 @@ +package extra + +import ( + "bytes" + "encoding/json" + "fmt" + "gf-lt/config" + "gf-lt/models" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gopxl/beep/v2" + "github.com/gopxl/beep/v2/mp3" + "github.com/gopxl/beep/v2/speaker" + "github.com/neurosnap/sentences/english" +) + +var ( + TTSTextChan = make(chan string, 10000) + TTSFlushChan = make(chan bool, 1) + TTSDoneChan = make(chan bool, 1) + // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`) +) + +type Orator interface { + Speak(text string) error + Stop() + // pause and resume? + GetLogger() *slog.Logger +} + +// impl https://github.com/remsky/Kokoro-FastAPI +type KokoroOrator struct { + logger *slog.Logger + URL string + Format models.AudioFormat + Stream bool + Speed float32 + Language string + Voice string + currentStream *beep.Ctrl // Added for playback control + textBuffer strings.Builder + // textBuffer bytes.Buffer +} + +func (o *KokoroOrator) stoproutine() { + <-TTSDoneChan + o.logger.Debug("orator got done signal") + o.Stop() + // drain the channel + for len(TTSTextChan) > 0 { + <-TTSTextChan + } +} + +func (o *KokoroOrator) readroutine() { + tokenizer, _ := english.NewSentenceTokenizer(nil) + // var sentenceBuf bytes.Buffer + // var remainder strings.Builder + for { + select { + case chunk := <-TTSTextChan: + // sentenceBuf.WriteString(chunk) + // text := sentenceBuf.String() + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences)) + for i, sentence := range sentences { + if i == len(sentences)-1 { // last sentence + o.textBuffer.Reset() + _, err := o.textBuffer.WriteString(sentence.Text) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + continue // if only one (often incomplete) sentence; wait for next chunk + } + o.logger.Debug("calling Speak with sentence", "sent", sentence.Text) + if err := o.Speak(sentence.Text); err != nil { + o.logger.Error("tts failed", "sentence", sentence.Text, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("got flushchan signal start") + // lln is done get the whole message out + if len(TTSTextChan) > 0 { // otherwise might get stuck + for chunk := range TTSTextChan { + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + // INFO: if there is a lot of text it will take some time to make with tts at once + // to avoid this pause, it might be better to keep splitting on sentences + // but keepinig in mind that remainder could be ommited by tokenizer + // Flush remaining text + remaining := o.textBuffer.String() + o.textBuffer.Reset() + if remaining != "" { + o.logger.Debug("calling Speak with remainder", "rem", remaining) + if err := o.Speak(remaining); err != nil { + o.logger.Error("tts failed", "sentence", remaining, "error", err) + } + } + } + } +} + +func NewOrator(log *slog.Logger, cfg *config.Config) Orator { + orator := &KokoroOrator{ + logger: log, + URL: cfg.TTS_URL, + Format: models.AFMP3, + Stream: false, + Speed: cfg.TTS_SPEED, + Language: "a", + Voice: "af_bella(1)+af_sky(1)", + } + go orator.readroutine() + go orator.stoproutine() + return orator +} + +func (o *KokoroOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) { + payload := map[string]interface{}{ + "input": text, + "voice": o.Voice, + "response_format": o.Format, + "download_format": o.Format, + "stream": o.Stream, + "speed": o.Speed, + // "return_download_link": true, + "lang_code": o.Language, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return resp.Body, nil +} + +func (o *KokoroOrator) Speak(text string) error { + o.logger.Debug("fn: Speak is called", "text-len", len(text)) + body, err := o.requestSound(text) + if err != nil { + o.logger.Error("request failed", "error", err) + return fmt.Errorf("request failed: %w", err) + } + defer body.Close() + // Decode the mp3 audio from response body + streamer, format, err := mp3.Decode(body) + if err != nil { + o.logger.Error("mp3 decode failed", "error", err) + return fmt.Errorf("mp3 decode failed: %w", err) + } + defer streamer.Close() + // here it spams with errors that speaker cannot be initialized more than once, but how would we deal with many audio records then? + if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil { + o.logger.Debug("failed to init speaker", "error", err) + } + done := make(chan bool) + // Create controllable stream and store reference + o.currentStream = &beep.Ctrl{Streamer: beep.Seq(streamer, beep.Callback(func() { + close(done) + o.currentStream = nil + })), Paused: false} + speaker.Play(o.currentStream) + <-done // we hang in this routine; + return nil +} + +func (o *KokoroOrator) Stop() { + // speaker.Clear() + o.logger.Debug("attempted to stop orator", "orator", o) + speaker.Lock() + defer speaker.Unlock() + if o.currentStream != nil { + // o.currentStream.Paused = true + o.currentStream.Streamer = nil + } +} diff --git a/extra/twentyq.go b/extra/twentyq.go new file mode 100644 index 0000000..30c08cc --- /dev/null +++ b/extra/twentyq.go @@ -0,0 +1,11 @@ +package extra + +import "math/rand" + +var ( + chars = []string{"Shrek", "Garfield", "Jack the Ripper"} +) + +func GetRandomChar() string { + return chars[rand.Intn(len(chars))] +} diff --git a/extra/vad.go b/extra/vad.go new file mode 100644 index 0000000..2a9e238 --- /dev/null +++ b/extra/vad.go @@ -0,0 +1 @@ +package extra diff --git a/extra/websearch.go b/extra/websearch.go new file mode 100644 index 0000000..99bc1b6 --- /dev/null +++ b/extra/websearch.go @@ -0,0 +1,13 @@ +package extra + +import "github.com/GrailFinder/searchagent/searcher" + +var WebSearcher searcher.WebSurfer + +func init() { + sa, err := searcher.NewWebSurfer(searcher.SearcherTypeScraper, "") + if err != nil { + panic("failed to init seachagent; error: " + err.Error()) + } + WebSearcher = sa +} diff --git a/extra/whisper_binary.go b/extra/whisper_binary.go new file mode 100644 index 0000000..a016a30 --- /dev/null +++ b/extra/whisper_binary.go @@ -0,0 +1,318 @@ +package extra + +import ( + "bytes" + "context" + "errors" + "fmt" + "gf-lt/config" + "io" + "log/slog" + "os" + "os/exec" + "strings" + "sync" + "syscall" + + "github.com/gordonklaus/portaudio" +) + +type WhisperBinary struct { + logger *slog.Logger + whisperPath string + modelPath string + lang string + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + recording bool + audioBuffer []int16 +} + +func NewWhisperBinary(logger *slog.Logger, cfg *config.Config) *WhisperBinary { + ctx, cancel := context.WithCancel(context.Background()) + // Set ALSA error handler first + return &WhisperBinary{ + logger: logger, + whisperPath: cfg.WhisperBinaryPath, + modelPath: cfg.WhisperModelPath, + lang: cfg.STT_LANG, + ctx: ctx, + cancel: cancel, + } +} + +func (w *WhisperBinary) StartRecording() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.recording { + return errors.New("recording is already in progress") + } + // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init + origStderr, errDup := syscall.Dup(syscall.Stderr) + if errDup != nil { + return fmt.Errorf("failed to dup stderr: %w", errDup) + } + nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + _ = syscall.Close(origStderr) // Close the dup'd fd if open fails + return fmt.Errorf("failed to open /dev/null: %w", err) + } + // redirect stderr + _ = syscall.Dup2(nullFD, syscall.Stderr) + // Initialize PortAudio (this is where ALSA warnings occur) + portaudioErr := portaudio.Initialize() + defer func() { + // Restore stderr + _ = syscall.Dup2(origStderr, syscall.Stderr) + _ = syscall.Close(origStderr) + _ = syscall.Close(nullFD) + }() + if portaudioErr != nil { + return fmt.Errorf("portaudio init failed: %w", portaudioErr) + } + // Initialize audio buffer + w.audioBuffer = make([]int16, 0) + in := make([]int16, 1024) // buffer size + stream, err := portaudio.OpenDefaultStream(1, 0, 16000.0, len(in), in) + if err != nil { + if paErr := portaudio.Terminate(); paErr != nil { + return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr) + } + return fmt.Errorf("failed to open microphone: %w", err) + } + go w.recordAudio(stream, in) + w.recording = true + w.logger.Debug("Recording started") + return nil +} + +func (w *WhisperBinary) recordAudio(stream *portaudio.Stream, in []int16) { + defer func() { + w.logger.Debug("recordAudio defer function called") + _ = stream.Stop() // Stop the stream + _ = portaudio.Terminate() // ignoring error as we're shutting down + w.logger.Debug("recordAudio terminated") + }() + w.logger.Debug("Starting audio stream") + if err := stream.Start(); err != nil { + w.logger.Error("Failed to start audio stream", "error", err) + return + } + w.logger.Debug("Audio stream started, entering recording loop") + for { + select { + case <-w.ctx.Done(): + w.logger.Debug("Context done, exiting recording loop") + return + default: + // Check recording status with minimal lock time + w.mu.Lock() + recording := w.recording + w.mu.Unlock() + + if !recording { + w.logger.Debug("Recording flag is false, exiting recording loop") + return + } + if err := stream.Read(); err != nil { + w.logger.Error("Error reading from stream", "error", err) + return + } + // Append samples to buffer - only acquire lock when necessary + w.mu.Lock() + if w.audioBuffer == nil { + w.audioBuffer = make([]int16, 0) + } + // Make a copy of the input buffer to avoid overwriting + tempBuffer := make([]int16, len(in)) + copy(tempBuffer, in) + w.audioBuffer = append(w.audioBuffer, tempBuffer...) + w.mu.Unlock() + } + } +} + +func (w *WhisperBinary) StopRecording() (string, error) { + w.logger.Debug("StopRecording called") + w.mu.Lock() + if !w.recording { + w.mu.Unlock() + return "", errors.New("not currently recording") + } + w.logger.Debug("Setting recording to false and cancelling context") + w.recording = false + w.cancel() // This will stop the recording goroutine + w.mu.Unlock() + // // Small delay to allow the recording goroutine to react to context cancellation + // time.Sleep(20 * time.Millisecond) + // Save the recorded audio to a temporary file + tempFile, err := w.saveAudioToTempFile() + if err != nil { + w.logger.Error("Error saving audio to temp file", "error", err) + return "", fmt.Errorf("failed to save audio to temp file: %w", err) + } + w.logger.Debug("Saved audio to temp file", "file", tempFile) + // Run the whisper binary with a separate context to avoid cancellation during transcription + cmd := exec.Command(w.whisperPath, "-m", w.modelPath, "-l", w.lang, tempFile, "2>/dev/null") + var outBuf bytes.Buffer + cmd.Stdout = &outBuf + // Redirect stderr to suppress ALSA warnings and other stderr output + cmd.Stderr = io.Discard // Suppress stderr output from whisper binary + w.logger.Debug("Running whisper binary command") + if err := cmd.Run(); err != nil { + // Clean up audio buffer + w.mu.Lock() + w.audioBuffer = nil + w.mu.Unlock() + // Since we're suppressing stderr, we'll just log that the command failed + w.logger.Error("Error running whisper binary", "error", err) + return "", fmt.Errorf("whisper binary failed: %w", err) + } + result := outBuf.String() + w.logger.Debug("Whisper binary completed", "result", result) + // Clean up audio buffer + w.mu.Lock() + w.audioBuffer = nil + w.mu.Unlock() + // Clean up the temporary file after transcription + w.logger.Debug("StopRecording completed") + os.Remove(tempFile) + result = strings.TrimRight(result, "\n") + // in case there are special tokens like [_BEG_] + result = specialRE.ReplaceAllString(result, "") + return strings.TrimSpace(strings.ReplaceAll(result, "\n ", "\n")), nil +} + +// saveAudioToTempFile saves the recorded audio data to a temporary WAV file +func (w *WhisperBinary) saveAudioToTempFile() (string, error) { + w.logger.Debug("saveAudioToTempFile called") + // Create temporary WAV file + tempFile, err := os.CreateTemp("", "recording_*.wav") + if err != nil { + w.logger.Error("Failed to create temp file", "error", err) + return "", fmt.Errorf("failed to create temp file: %w", err) + } + w.logger.Debug("Created temp file", "file", tempFile.Name()) + defer tempFile.Close() + + // Write WAV header and data + w.logger.Debug("About to write WAV file", "file", tempFile.Name()) + err = w.writeWAVFile(tempFile.Name()) + if err != nil { + w.logger.Error("Error writing WAV file", "error", err) + return "", fmt.Errorf("failed to write WAV file: %w", err) + } + w.logger.Debug("WAV file written successfully", "file", tempFile.Name()) + + return tempFile.Name(), nil +} + +// writeWAVFile creates a WAV file from the recorded audio data +func (w *WhisperBinary) writeWAVFile(filename string) error { + w.logger.Debug("writeWAVFile called", "filename", filename) + // Open file for writing + file, err := os.Create(filename) + if err != nil { + w.logger.Error("Error creating file", "error", err) + return err + } + defer file.Close() + + w.logger.Debug("About to acquire mutex in writeWAVFile") + w.mu.Lock() + w.logger.Debug("Locked mutex, copying audio buffer") + audioData := make([]int16, len(w.audioBuffer)) + copy(audioData, w.audioBuffer) + w.mu.Unlock() + w.logger.Debug("Unlocked mutex", "audio_data_length", len(audioData)) + + if len(audioData) == 0 { + w.logger.Warn("No audio data to write") + return errors.New("no audio data to write") + } + + // Calculate data size (number of samples * size of int16) + dataSize := len(audioData) * 2 // 2 bytes per int16 sample + w.logger.Debug("Calculated data size", "size", dataSize) + + // Write WAV header with the correct data size + header := w.createWAVHeader(16000, 1, 16, dataSize) + _, err = file.Write(header) + if err != nil { + w.logger.Error("Error writing WAV header", "error", err) + return err + } + w.logger.Debug("WAV header written successfully") + + // Write audio data + w.logger.Debug("About to write audio data samples") + for i, sample := range audioData { + // Write little-endian 16-bit sample + _, err := file.Write([]byte{byte(sample), byte(sample >> 8)}) + if err != nil { + w.logger.Error("Error writing sample", "index", i, "error", err) + return err + } + // Log progress every 10000 samples to avoid too much output + if i%10000 == 0 { + w.logger.Debug("Written samples", "count", i) + } + } + w.logger.Debug("All audio data written successfully") + + return nil +} + +// createWAVHeader creates a WAV file header +func (w *WhisperBinary) createWAVHeader(sampleRate, channels, bitsPerSample int, dataSize int) []byte { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + // Total file size will be updated later + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + // fmt chunk size (16 for PCM) + header[16] = 16 + header[17] = 0 + header[18] = 0 + header[19] = 0 + // Audio format (1 = PCM) + header[20] = 1 + header[21] = 0 + // Number of channels + header[22] = byte(channels) + header[23] = 0 + // Sample rate + header[24] = byte(sampleRate) + header[25] = byte(sampleRate >> 8) + header[26] = byte(sampleRate >> 16) + header[27] = byte(sampleRate >> 24) + // Byte rate + byteRate := sampleRate * channels * bitsPerSample / 8 + header[28] = byte(byteRate) + header[29] = byte(byteRate >> 8) + header[30] = byte(byteRate >> 16) + header[31] = byte(byteRate >> 24) + // Block align + blockAlign := channels * bitsPerSample / 8 + header[32] = byte(blockAlign) + header[33] = 0 + // Bits per sample + header[34] = byte(bitsPerSample) + header[35] = 0 + // "data" subchunk + copy(header[36:40], "data") + // Data size + header[40] = byte(dataSize) + header[41] = byte(dataSize >> 8) + header[42] = byte(dataSize >> 16) + header[43] = byte(dataSize >> 24) + + return header +} + +func (w *WhisperBinary) IsRecording() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.recording +} @@ -1,34 +1,41 @@ -module elefant +module gf-lt -go 1.23.2 +go 1.25.1 require ( - github.com/BurntSushi/toml v1.4.0 - github.com/asg017/sqlite-vec-go-bindings v0.1.6 - github.com/gdamore/tcell/v2 v2.7.4 + github.com/BurntSushi/toml v1.5.0 + github.com/GrailFinder/searchagent v0.2.0 + github.com/gdamore/tcell/v2 v2.13.2 github.com/glebarez/go-sqlite v1.22.0 + github.com/gopxl/beep/v2 v2.1.1 + github.com/gordonklaus/portaudio v0.0.0-20250206071425-98a94950218b github.com/jmoiron/sqlx v1.4.0 - github.com/ncruces/go-sqlite3 v0.21.3 github.com/neurosnap/sentences v1.1.2 - github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 + github.com/rivo/tview v0.42.0 ) require ( + github.com/PuerkitoBio/goquery v1.11.0 // indirect + github.com/andybalholm/cascadia v1.3.3 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/gdamore/encoding v1.0.0 // indirect + github.com/ebitengine/oto/v3 v3.4.0 // indirect + github.com/ebitengine/purego v0.9.1 // indirect + github.com/gdamore/encoding v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/hajimehoshi/go-mp3 v0.3.4 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/ncruces/julianday v1.0.0 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/tetratelabs/wazero v1.8.2 // indirect - golang.org/x/sys v0.28.0 // indirect - golang.org/x/term v0.17.0 // indirect - golang.org/x/text v0.21.0 // indirect - modernc.org/libc v1.37.6 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.7.2 // indirect - modernc.org/sqlite v1.28.0 // indirect + golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/term v0.38.0 // indirect + golang.org/x/text v0.32.0 // indirect + modernc.org/libc v1.67.1 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.40.1 // indirect ) @@ -1,95 +1,178 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= -github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= -github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/GrailFinder/searchagent v0.2.0 h1:U2GVjLh/9xZt0xX9OcYk9Q2fMkyzyTiADPUmUisRdtQ= +github.com/GrailFinder/searchagent v0.2.0/go.mod h1:d66tn5+22LI8IGJREUsRBT60P0sFdgQgvQRqyvgItrs= +github.com/PuerkitoBio/goquery v1.11.0 h1:jZ7pwMQXIITcUXNH83LLk+txlaEy6NVOfTuP43xxfqw= +github.com/PuerkitoBio/goquery v1.11.0/go.mod h1:wQHgxUOU3JGuj3oD/QFfxUdlzW6xPHfqyHre6VMY4DQ= +github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= +github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= -github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= -github.com/gdamore/tcell/v2 v2.7.4 h1:sg6/UnTM9jGpZU+oFYAsDahfchWAFW8Xx2yFinNSAYU= -github.com/gdamore/tcell/v2 v2.7.4/go.mod h1:dSXtXTSK0VsW1biw65DZLZ2NKr7j0qP/0J7ONmsraWg= +github.com/ebitengine/oto/v3 v3.4.0 h1:br0PgASsEWaoWn38b2Goe7m1GKFYfNgnsjSd5Gg+/bQ= +github.com/ebitengine/oto/v3 v3.4.0/go.mod h1:IOleLVD0m+CMak3mRVwsYY8vTctQgOM0iiL6S7Ar7eI= +github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= +github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw= +github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo= +github.com/gdamore/tcell/v2 v2.13.2 h1:5j4srfF8ow3HICOv/61/sOhQtA25qxEB2XR3Q/Bhx2g= +github.com/gdamore/tcell/v2 v2.13.2/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopxl/beep/v2 v2.1.1 h1:6FYIYMm2qPAdWkjX+7xwKrViS1x0Po5kDMdRkq8NVbU= +github.com/gopxl/beep/v2 v2.1.1/go.mod h1:ZAm9TGQ9lvpoiFLd4zf5B1IuyxZhgRACMId1XJbaW0E= +github.com/gordonklaus/portaudio v0.0.0-20250206071425-98a94950218b h1:WEuQWBxelOGHA6z9lABqaMLMrfwVyMdN3UgRLT+YUPo= +github.com/gordonklaus/portaudio v0.0.0-20250206071425-98a94950218b/go.mod h1:esZFQEUwqC+l76f2R8bIWSwXMaPbp79PppwZ1eJhFco= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1CdMQ= -github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= -github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= -github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw= github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 h1:YIJ+B1hePP6AgynC5TcqpO0H9k3SSoZa2BGyL6vDUzM= -github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592/go.mod h1:02iFIz7K/A9jGCvrizLPvoqr4cEIx7q54RH5Qudkrss= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rivo/tview v0.42.0 h1:b/ftp+RxtDsHSaynXTbJb+/n/BxDEi+W3UfF5jILK6c= +github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4= -github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 h1:MDfG8Cvcqlt9XXrmEiD4epKn7VJHZO84hejP9Jmp0MM= +golang.org/x/exp v0.0.0-20251209150349-8475f28825e9/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -modernc.org/libc v1.37.6 h1:orZH3c5wmhIQFTXF+Nt+eeauyd+ZIt2BX6ARe+kD+aw= -modernc.org/libc v1.37.6/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= -modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= -modernc.org/sqlite v1.28.0 h1:Zx+LyDDmXczNnEQdvPuEfcFVA2ZPyaD7UCZDjef3BHQ= -modernc.org/sqlite v1.28.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.1 h1:bFaqOaa5/zbWYJo8aW0tXPX21hXsngG2M7mckCnFSVk= +modernc.org/libc v1.67.1/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY= +modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/helpfuncs.go b/helpfuncs.go new file mode 100644 index 0000000..df49ae5 --- /dev/null +++ b/helpfuncs.go @@ -0,0 +1,231 @@ +package main + +import ( + "fmt" + "gf-lt/models" + "gf-lt/pngmeta" + "image" + "os" + "path" + "strings" +) + +func colorText() { + text := textView.GetText(false) + quoteReplacer := strings.NewReplacer( + `”`, `"`, + `“`, `"`, + `“`, `"`, + `”`, `"`, + `**`, `*`, + ) + text = quoteReplacer.Replace(text) + // Step 1: Extract code blocks and replace them with unique placeholders + var codeBlocks []string + placeholder := "__CODE_BLOCK_%d__" + counter := 0 + // thinking + var thinkBlocks []string + placeholderThink := "__THINK_BLOCK_%d__" + counterThink := 0 + // Replace code blocks with placeholders and store their styled versions + text = codeBlockRE.ReplaceAllStringFunc(text, func(match string) string { + // Style the code block and store it + styled := fmt.Sprintf("[red::i]%s[-:-:-]", match) + codeBlocks = append(codeBlocks, styled) + // Generate a unique placeholder (e.g., "__CODE_BLOCK_0__") + id := fmt.Sprintf(placeholder, counter) + counter++ + return id + }) + text = thinkRE.ReplaceAllStringFunc(text, func(match string) string { + // Style the code block and store it + styled := fmt.Sprintf("[red::i]%s[-:-:-]", match) + thinkBlocks = append(thinkBlocks, styled) + // Generate a unique placeholder (e.g., "__CODE_BLOCK_0__") + id := fmt.Sprintf(placeholderThink, counterThink) + counterThink++ + return id + }) + // Step 2: Apply other regex styles to the non-code parts + text = quotesRE.ReplaceAllString(text, `[orange::-]$1[-:-:-]`) + text = starRE.ReplaceAllString(text, `[turquoise::i]$1[-:-:-]`) + // text = thinkRE.ReplaceAllString(text, `[yellow::i]$1[-:-:-]`) + // Step 3: Restore the styled code blocks from placeholders + for i, cb := range codeBlocks { + text = strings.Replace(text, fmt.Sprintf(placeholder, i), cb, 1) + } + logger.Debug("thinking debug", "blocks", thinkBlocks) + for i, tb := range thinkBlocks { + text = strings.Replace(text, fmt.Sprintf(placeholderThink, i), tb, 1) + } + textView.SetText(text) +} + +func updateStatusLine() { + position.SetText(makeStatusLine()) + helpView.SetText(fmt.Sprintf(helpText, makeStatusLine())) +} + +func initSysCards() ([]string, error) { + labels := []string{} + labels = append(labels, sysLabels...) + cards, err := pngmeta.ReadDirCards(cfg.SysDir, cfg.UserRole, logger) + if err != nil { + logger.Error("failed to read sys dir", "error", err) + return nil, err + } + for _, cc := range cards { + if cc.Role == "" { + logger.Warn("empty role", "file", cc.FilePath) + continue + } + sysMap[cc.Role] = cc + labels = append(labels, cc.Role) + } + return labels, nil +} + +func startNewChat() { + id, err := store.ChatGetMaxID() + if err != nil { + logger.Error("failed to get chat id", "error", err) + } + if ok := charToStart(cfg.AssistantRole); !ok { + logger.Warn("no such sys msg", "name", cfg.AssistantRole) + } + // set chat body + chatBody.Messages = chatBody.Messages[:2] + textView.SetText(chatToText(cfg.ShowSys)) + newChat := &models.Chat{ + ID: id + 1, + Name: fmt.Sprintf("%d_%s", id+1, cfg.AssistantRole), + Msgs: string(defaultStarterBytes), + Agent: cfg.AssistantRole, + } + activeChatName = newChat.Name + chatMap[newChat.Name] = newChat + updateStatusLine() + colorText() +} + +func renameUser(oldname, newname string) { + if oldname == "" { + // not provided; deduce who user is + // INFO: if user not yet spoke, it is hard to replace mentions in sysprompt and first message about thme + roles := chatBody.ListRoles() + for _, role := range roles { + if role == cfg.AssistantRole { + continue + } + if role == "tool" { + continue + } + if role == "system" { + continue + } + oldname = role + break + } + if oldname == "" { + // still + logger.Warn("fn: renameUser; failed to find old name", "newname", newname) + return + } + } + viewText := textView.GetText(false) + viewText = strings.ReplaceAll(viewText, oldname, newname) + chatBody.Rename(oldname, newname) + textView.SetText(viewText) +} + +func setLogLevel(sl string) { + switch sl { + case "Debug": + logLevel.Set(-4) + case "Info": + logLevel.Set(0) + case "Warn": + logLevel.Set(4) + } +} + +func listRolesWithUser() []string { + roles := chatBody.ListRoles() + // Remove user role if it exists in the list (to avoid duplicates and ensure it's at position 0) + filteredRoles := make([]string, 0, len(roles)) + for _, role := range roles { + if role != cfg.UserRole { + filteredRoles = append(filteredRoles, role) + } + } + // Prepend user role to the beginning of the list + result := append([]string{cfg.UserRole}, filteredRoles...) + return result +} + +func loadImage() { + filepath := defaultImage + cc, ok := sysMap[cfg.AssistantRole] + if ok { + if strings.HasSuffix(cc.FilePath, ".png") { + filepath = cc.FilePath + } + } + file, err := os.Open(filepath) + if err != nil { + panic(err) + } + defer file.Close() + img, _, err := image.Decode(file) + if err != nil { + panic(err) + } + imgView.SetImage(img) +} + +func strInSlice(s string, sl []string) bool { + for _, el := range sl { + if strings.EqualFold(s, el) { + return true + } + } + return false +} + +func makeStatusLine() string { + isRecording := false + if asr != nil { + isRecording = asr.IsRecording() + } + persona := cfg.UserRole + if cfg.WriteNextMsgAs != "" { + persona = cfg.WriteNextMsgAs + } + botPersona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + botPersona = cfg.WriteNextMsgAsCompletionAgent + } + // Add image attachment info to status line + var imageInfo string + if imageAttachmentPath != "" { + // Get just the filename from the path + imageName := path.Base(imageAttachmentPath) + imageInfo = fmt.Sprintf(" | attached img: [orange:-:b]%s[-:-:-]", imageName) + } else { + imageInfo = "" + } + + // Add shell mode status to status line + var shellModeInfo string + if shellMode { + shellModeInfo = " | [green:-:b]SHELL MODE[-:-:-]" + } else { + shellModeInfo = "" + } + + statusLine := fmt.Sprintf(indexLineCompletion, botRespMode, cfg.AssistantRole, activeChatName, + cfg.ToolUse, chatBody.Model, cfg.SkipLLMResp, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level(), + isRecording, persona, botPersona, injectRole) + return statusLine + imageInfo + shellModeInfo +} @@ -2,45 +2,124 @@ package main import ( "bytes" - "elefant/models" "encoding/json" + "errors" + "gf-lt/models" "io" "strings" ) +var imageAttachmentPath string // Global variable to track image attachment for next message +var lastImg string // for ctrl+j +var RAGMsg = "Retrieved context for user's query:\n" + +// SetImageAttachment sets an image to be attached to the next message sent to the LLM +func SetImageAttachment(imagePath string) { + imageAttachmentPath = imagePath + lastImg = imagePath +} + +// ClearImageAttachment clears any pending image attachment and updates UI +func ClearImageAttachment() { + imageAttachmentPath = "" +} + type ChunkParser interface { - ParseChunk([]byte) (string, bool, error) + ParseChunk([]byte) (*models.TextChunk, error) FormMsg(msg, role string, cont bool) (io.Reader, error) + GetToken() string } -func initChunkParser() { - chunkParser = LlamaCPPeer{} - if strings.Contains(cfg.CurrentAPI, "v1") { - logger.Debug("chosen /v1/chat parser") - chunkParser = OpenAIer{} +func choseChunkParser() { + chunkParser = LCPCompletion{} + switch cfg.CurrentAPI { + case "http://localhost:8080/completion": + chunkParser = LCPCompletion{} + logger.Debug("chosen lcpcompletion", "link", cfg.CurrentAPI) + return + case "http://localhost:8080/v1/chat/completions": + chunkParser = LCPChat{} + logger.Debug("chosen lcpchat", "link", cfg.CurrentAPI) + return + case "https://api.deepseek.com/beta/completions": + chunkParser = DeepSeekerCompletion{} + logger.Debug("chosen deepseekercompletio", "link", cfg.CurrentAPI) + return + case "https://api.deepseek.com/chat/completions": + chunkParser = DeepSeekerChat{} + logger.Debug("chosen deepseekerchat", "link", cfg.CurrentAPI) return + case "https://openrouter.ai/api/v1/completions": + chunkParser = OpenRouterCompletion{} + logger.Debug("chosen openroutercompletion", "link", cfg.CurrentAPI) + return + case "https://openrouter.ai/api/v1/chat/completions": + chunkParser = OpenRouterChat{} + logger.Debug("chosen openrouterchat", "link", cfg.CurrentAPI) + return + default: + chunkParser = LCPCompletion{} } - logger.Debug("chosen llamacpp /completion parser") } -type LlamaCPPeer struct { +type LCPCompletion struct { +} +type LCPChat struct { +} +type DeepSeekerCompletion struct { } -type OpenAIer struct { +type DeepSeekerChat struct { +} +type OpenRouterCompletion struct { + Model string +} +type OpenRouterChat struct { + Model string } -func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) { +func (lcp LCPCompletion) GetToken() string { + return "" +} + +func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg lcpcompletion", "link", cfg.CurrentAPI) + localImageAttachmentPath := imageAttachmentPath + var multimodalData []string + if localImageAttachmentPath != "" { + imageURL, err := models.CreateImageURLFromPath(localImageAttachmentPath) + if err != nil { + logger.Error("failed to create image URL from path for completion", "error", err, "path", localImageAttachmentPath) + return nil, err + } + // Extract base64 part from data URL (e.g., "data:image/jpeg;base64,...") + parts := strings.SplitN(imageURL, ",", 2) + if len(parts) == 2 { + multimodalData = append(multimodalData, parts[1]) + } else { + logger.Error("invalid image data URL format", "url", imageURL) + return nil, errors.New("invalid image data URL format") + } + imageAttachmentPath = "" // Clear the attachment after use + } if msg != "" { // otherwise let the bot to continue newMsg := models.RoleMsg{Role: role, Content: msg} chatBody.Messages = append(chatBody.Messages, newMsg) - // if rag + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage if cfg.RAGEnabled { - ragResp, err := chatRagUse(newMsg.Content) + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) if err != nil { logger.Error("failed to form a rag msg", "error", err) return nil, err } - ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + logger.Debug("RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) } } if cfg.ToolUse && !resume { @@ -53,18 +132,32 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) } prompt := strings.Join(messages, "\n") // strings builder? - // if cfg.ToolUse && msg != "" && !resume { if !resume { - botMsgStart := "\n" + cfg.AssistantRole + ":\n" + botPersona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + botPersona = cfg.WriteNextMsgAsCompletionAgent + } + botMsgStart := "\n" + botPersona + ":\n" prompt += botMsgStart } - // if cfg.ThinkUse && msg != "" && !cfg.ToolUse { if cfg.ThinkUse && !cfg.ToolUse { prompt += "<think>" } + // Add multimodal media markers to the prompt text when multimodal data is present + // This is required by llama.cpp multimodal models so they know where to insert media + if len(multimodalData) > 0 { + // Add a media marker for each item in the multimodal data + var sb strings.Builder + sb.WriteString(prompt) + for range multimodalData { + sb.WriteString(" <__media__>") // llama.cpp default multimodal marker + } + prompt = sb.String() + } + logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, - "msg", msg, "resume", resume, "prompt", prompt) - payload := models.NewLCPReq(prompt, cfg, defaultLCPProps) + "msg", msg, "resume", resume, "prompt", prompt, "multimodal_data_count", len(multimodalData)) + payload := models.NewLCPReq(prompt, chatBody.Model, multimodalData, defaultLCPProps, chatBody.MakeStopSlice()) data, err := json.Marshal(payload) if err != nil { logger.Error("failed to form a msg", "error", err) @@ -73,38 +166,260 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) return bytes.NewReader(data), nil } -func (lcp LlamaCPPeer) ParseChunk(data []byte) (string, bool, error) { +func (lcp LCPCompletion) ParseChunk(data []byte) (*models.TextChunk, error) { llmchunk := models.LlamaCPPResp{} + resp := &models.TextChunk{} if err := json.Unmarshal(data, &llmchunk); err != nil { logger.Error("failed to decode", "error", err, "line", string(data)) - return "", false, err + return nil, err } + resp.Chunk = llmchunk.Content if llmchunk.Stop { if llmchunk.Content != "" { logger.Error("text inside of finish llmchunk", "chunk", llmchunk) } - return llmchunk.Content, true, nil + resp.Finished = true } - return llmchunk.Content, false, nil + return resp, nil +} + +func (op LCPChat) GetToken() string { + return "" } -func (op OpenAIer) ParseChunk(data []byte) (string, bool, error) { +func (op LCPChat) ParseChunk(data []byte) (*models.TextChunk, error) { llmchunk := models.LLMRespChunk{} if err := json.Unmarshal(data, &llmchunk); err != nil { logger.Error("failed to decode", "error", err, "line", string(data)) - return "", false, err + return nil, err + } + + // Handle multiple choices safely + if len(llmchunk.Choices) == 0 { + logger.Warn("LCPChat ParseChunk: no choices in response", "data", string(data)) + return &models.TextChunk{Finished: true}, nil + } + + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content, + } + + // Check for tool calls in all choices, not just the last one + for _, choice := range llmchunk.Choices { + if len(choice.Delta.ToolCalls) > 0 { + toolCall := choice.Delta.ToolCalls[0] + resp.ToolChunk = toolCall.Function.Arguments + fname := toolCall.Function.Name + if fname != "" { + resp.FuncName = fname + } + // Capture the tool call ID if available + resp.ToolID = toolCall.ID + break // Process only the first tool call + } } - content := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content + if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { - if content != "" { + if resp.Chunk != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + resp.Finished = true + } + if resp.ToolChunk != "" { + resp.ToolResp = true + } + return resp, nil +} + +func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg lcpchat", "link", cfg.CurrentAPI) + // Capture the image attachment path at the beginning to avoid race conditions + // with API rotation that might clear the global variable + localImageAttachmentPath := imageAttachmentPath + if msg != "" { // otherwise let the bot continue + // Create the message with support for multimodal content + var newMsg models.RoleMsg + // Check if we have an image to add to this message + if localImageAttachmentPath != "" { + // Create a multimodal message with both text and image + newMsg = models.NewMultimodalMsg(role, []interface{}{}) + // Add the text content + newMsg.AddTextPart(msg) + // Add the image content + imageURL, err := models.CreateImageURLFromPath(localImageAttachmentPath) + if err != nil { + logger.Error("failed to create image URL from path", "error", err, "path", localImageAttachmentPath) + // If image processing fails, fall back to simple text message + newMsg = models.NewRoleMsg(role, msg) + } else { + newMsg.AddImagePart(imageURL) + } + // Only clear the global image attachment after successfully processing it in this API call + imageAttachmentPath = "" // Clear the attachment after use + } else { + // Create a simple text message + newMsg = models.NewRoleMsg(role, msg) + } + chatBody.Messages = append(chatBody.Messages, newMsg) + logger.Debug("LCPChat FormMsg: added message to chatBody", "role", newMsg.Role, "content_len", len(newMsg.Content), "message_count_after_add", len(chatBody.Messages)) + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage + if cfg.RAGEnabled { + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("LCPChat: RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) + if err != nil { + logger.Error("LCPChat: failed to form a rag msg", "error", err) + return nil, err + } + logger.Debug("LCPChat: RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("LCPChat: RAG message added to chat body", "role", ragMsg.Role, "rag_content_len", len(ragMsg.Content), "message_count_after_rag", len(chatBody.Messages)) + } + } + // openai /v1/chat does not support custom roles; needs to be user, assistant, system + bodyCopy := &models.ChatBody{ + Messages: make([]models.RoleMsg, len(chatBody.Messages)), + Model: chatBody.Model, + Stream: chatBody.Stream, + } + for i, msg := range chatBody.Messages { + if msg.Role == cfg.UserRole { + bodyCopy.Messages[i] = msg + bodyCopy.Messages[i].Role = "user" + } else { + bodyCopy.Messages[i] = msg + } + } + // Clean null/empty messages to prevent API issues + bodyCopy.Messages = cleanNullMessages(bodyCopy.Messages) + req := models.OpenAIReq{ + ChatBody: bodyCopy, + Tools: nil, + } + if cfg.ToolUse && !resume && role != cfg.ToolRole { + req.Tools = baseTools // set tools to use + } + data, err := json.Marshal(req) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} + +// deepseek +func (ds DeepSeekerCompletion) ParseChunk(data []byte) (*models.TextChunk, error) { + llmchunk := models.DSCompletionResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return nil, err + } + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[0].Text, + } + if llmchunk.Choices[0].FinishReason != "" { + if resp.Chunk != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + resp.Finished = true + } + return resp, nil +} + +func (ds DeepSeekerCompletion) GetToken() string { + return cfg.DeepSeekToken +} + +func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg deepseekercompletion", "link", cfg.CurrentAPI) + if msg != "" { // otherwise let the bot to continue + newMsg := models.RoleMsg{Role: role, Content: msg} + chatBody.Messages = append(chatBody.Messages, newMsg) + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage + // TODO: perhaps RAG should be a func/tool call instead? + if cfg.RAGEnabled { + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("DeepSeekerCompletion: RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) + if err != nil { + logger.Error("DeepSeekerCompletion: failed to form a rag msg", "error", err) + return nil, err + } + logger.Debug("DeepSeekerCompletion: RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages)) + } + } + if cfg.ToolUse && !resume { + // add to chat body + chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) + } + messages := make([]string, len(chatBody.Messages)) + for i, m := range chatBody.Messages { + messages[i] = m.ToPrompt() + } + prompt := strings.Join(messages, "\n") + // strings builder? + if !resume { + botPersona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + botPersona = cfg.WriteNextMsgAsCompletionAgent + } + botMsgStart := "\n" + botPersona + ":\n" + prompt += botMsgStart + } + if cfg.ThinkUse && !cfg.ToolUse { + prompt += "<think>" + } + logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, + "msg", msg, "resume", resume, "prompt", prompt) + payload := models.NewDSCompletionReq(prompt, chatBody.Model, + defaultLCPProps["temp"], chatBody.MakeStopSlice()) + data, err := json.Marshal(payload) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} + +func (ds DeepSeekerChat) ParseChunk(data []byte) (*models.TextChunk, error) { + llmchunk := models.DSChatStreamResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return nil, err + } + resp := &models.TextChunk{} + if llmchunk.Choices[0].FinishReason != "" { + if llmchunk.Choices[0].Delta.Content != "" { logger.Error("text inside of finish llmchunk", "chunk", llmchunk) } - return content, true, nil + resp.Chunk = llmchunk.Choices[0].Delta.Content + resp.Finished = true + } else { + if llmchunk.Choices[0].Delta.ReasoningContent != "" { + resp.Chunk = llmchunk.Choices[0].Delta.ReasoningContent + } else { + resp.Chunk = llmchunk.Choices[0].Delta.Content + } } - return content, false, nil + return resp, nil +} + +func (ds DeepSeekerChat) GetToken() string { + return cfg.DeepSeekToken } -func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { +func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg deepseekerchat", "link", cfg.CurrentAPI) if cfg.ToolUse && !resume { // prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg // add to chat body @@ -113,18 +428,232 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { if msg != "" { // otherwise let the bot continue newMsg := models.RoleMsg{Role: role, Content: msg} chatBody.Messages = append(chatBody.Messages, newMsg) - // if rag + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage + if cfg.RAGEnabled { + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil, err + } + logger.Debug("RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) + } + } + bodyCopy := &models.ChatBody{ + Messages: make([]models.RoleMsg, len(chatBody.Messages)), + Model: chatBody.Model, + Stream: chatBody.Stream, + } + for i, msg := range chatBody.Messages { + if msg.Role == cfg.UserRole || i == 1 { + bodyCopy.Messages[i] = msg + bodyCopy.Messages[i].Role = "user" + } else { + bodyCopy.Messages[i] = msg + } + } + // Clean null/empty messages to prevent API issues + bodyCopy.Messages = cleanNullMessages(bodyCopy.Messages) + dsBody := models.NewDSChatReq(*bodyCopy) + data, err := json.Marshal(dsBody) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} + +// openrouter +func (or OpenRouterCompletion) ParseChunk(data []byte) (*models.TextChunk, error) { + llmchunk := models.OpenRouterCompletionResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return nil, err + } + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Text, + } + if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { + if resp.Chunk != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + resp.Finished = true + } + return resp, nil +} + +func (or OpenRouterCompletion) GetToken() string { + return cfg.OpenRouterToken +} + +func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg openroutercompletion", "link", cfg.CurrentAPI) + if msg != "" { // otherwise let the bot to continue + newMsg := models.RoleMsg{Role: role, Content: msg} + chatBody.Messages = append(chatBody.Messages, newMsg) + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage + if cfg.RAGEnabled { + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil, err + } + logger.Debug("RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) + } + } + if cfg.ToolUse && !resume { + // add to chat body + chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) + } + messages := make([]string, len(chatBody.Messages)) + for i, m := range chatBody.Messages { + messages[i] = m.ToPrompt() + } + prompt := strings.Join(messages, "\n") + // strings builder? + if !resume { + botPersona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + botPersona = cfg.WriteNextMsgAsCompletionAgent + } + botMsgStart := "\n" + botPersona + ":\n" + prompt += botMsgStart + } + if cfg.ThinkUse && !cfg.ToolUse { + prompt += "<think>" + } + ss := chatBody.MakeStopSlice() + logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, + "msg", msg, "resume", resume, "prompt", prompt, "stop_strings", ss) + payload := models.NewOpenRouterCompletionReq(chatBody.Model, prompt, defaultLCPProps, ss) + data, err := json.Marshal(payload) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} + +// chat +func (or OpenRouterChat) ParseChunk(data []byte) (*models.TextChunk, error) { + llmchunk := models.OpenRouterChatResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return nil, err + } + resp := &models.TextChunk{ + Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content, + } + // Handle tool calls similar to LCPChat + if len(llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls) > 0 { + toolCall := llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0] + resp.ToolChunk = toolCall.Function.Arguments + fname := toolCall.Function.Name + if fname != "" { + resp.FuncName = fname + } + // Capture the tool call ID if available + resp.ToolID = toolCall.ID + } + if resp.ToolChunk != "" { + resp.ToolResp = true + } + if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { + if resp.Chunk != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + resp.Finished = true + } + return resp, nil +} + +func (or OpenRouterChat) GetToken() string { + return cfg.OpenRouterToken +} + +func (or OpenRouterChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { + logger.Debug("formmsg open router completion", "link", cfg.CurrentAPI) + // Capture the image attachment path at the beginning to avoid race conditions + // with API rotation that might clear the global variable + localImageAttachmentPath := imageAttachmentPath + if msg != "" { // otherwise let the bot continue + var newMsg models.RoleMsg + // Check if we have an image to add to this message + if localImageAttachmentPath != "" { + // Create a multimodal message with both text and image + newMsg = models.NewMultimodalMsg(role, []interface{}{}) + // Add the text content + newMsg.AddTextPart(msg) + // Add the image content + imageURL, err := models.CreateImageURLFromPath(localImageAttachmentPath) + if err != nil { + logger.Error("failed to create image URL from path", "error", err, "path", localImageAttachmentPath) + // If image processing fails, fall back to simple text message + newMsg = models.NewRoleMsg(role, msg) + } else { + newMsg.AddImagePart(imageURL) + } + // Only clear the global image attachment after successfully processing it in this API call + imageAttachmentPath = "" // Clear the attachment after use + } else { + // Create a simple text message + newMsg = models.NewRoleMsg(role, msg) + } + chatBody.Messages = append(chatBody.Messages, newMsg) + } + if !resume { + // if rag - add as system message to avoid conflicts with tool usage if cfg.RAGEnabled { - ragResp, err := chatRagUse(newMsg.Content) + um := chatBody.Messages[len(chatBody.Messages)-1].Content + logger.Debug("RAG is enabled, preparing RAG context", "user_message", um) + ragResp, err := chatRagUse(um) if err != nil { logger.Error("failed to form a rag msg", "error", err) return nil, err } - ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + logger.Debug("RAG response received", "response_len", len(ragResp), "response_preview", ragResp[:min(len(ragResp), 100)]) + // Use system role for RAG context to avoid conflicts with tool usage + ragMsg := models.RoleMsg{Role: "system", Content: RAGMsg + ragResp} chatBody.Messages = append(chatBody.Messages, ragMsg) + logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) } } - data, err := json.Marshal(chatBody) + // Create copy of chat body with standardized user role + bodyCopy := &models.ChatBody{ + Messages: make([]models.RoleMsg, len(chatBody.Messages)), + Model: chatBody.Model, + Stream: chatBody.Stream, + } + for i, msg := range chatBody.Messages { + bodyCopy.Messages[i] = msg + // Standardize role if it's a user role + if bodyCopy.Messages[i].Role == cfg.UserRole { + bodyCopy.Messages[i] = msg + bodyCopy.Messages[i].Role = "user" + } + } + // Clean null/empty messages to prevent API issues + bodyCopy.Messages = cleanNullMessages(bodyCopy.Messages) + orBody := models.NewOpenRouterChatReq(*bodyCopy, defaultLCPProps) + if cfg.ToolUse && !resume && role != cfg.ToolRole { + orBody.Tools = baseTools // set tools to use + } + data, err := json.Marshal(orBody) if err != nil { logger.Error("failed to form a msg", "error", err) return nil, err @@ -9,11 +9,18 @@ import ( ) var ( - botRespMode = false - editMode = false - selectedIndex = int(-1) - indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | RAGEnabled: [orange:-:b]%v[-:-:-] (F11) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p)" - focusSwitcher = map[tview.Primitive]tview.Primitive{} + botRespMode = false + editMode = false + roleEditMode = false + injectRole = true + selectedIndex = int(-1) + currentAPIIndex = 0 // Index to track current API in ApiLinks slice + currentORModelIndex = 0 // Index to track current OpenRouter model in ORFreeModels slice + currentLocalModelIndex = 0 // Index to track current llama.cpp model + shellMode = false + // indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q)" + indexLineCompletion = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | card's char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l) | skip LLM resp: [orange:-:b]%v[-:-:-] (F10)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | Insert <think>: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r) | Writing as: [orange:-:b]%s[-:-:-] (ctrl+q) | Bot will write as [orange:-:b]%s[-:-:-] (ctrl+x) | role_inject [orange:-:b]%v[-:-:-]" + focusSwitcher = map[tview.Primitive]tview.Primitive{} ) func isASCII(s string) bool { diff --git a/main_test.go b/main_test.go index 0046ca2..84d23ba 100644 --- a/main_test.go +++ b/main_test.go @@ -1,8 +1,9 @@ package main import ( - "elefant/models" + "gf-lt/models" "fmt" + "gf-lt/config" "strings" "testing" ) @@ -25,17 +26,17 @@ func TestRemoveThinking(t *testing.T) { }, } for i, tc := range cases { - t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { - mNum := len(tc.cb.Messages) - removeThinking(tc.cb) - if len(tc.cb.Messages) != mNum-int(tc.toolMsgs) { - t.Error("failed to delete tools msg", tc.cb.Messages, cfg.ToolRole) - } - for _, msg := range tc.cb.Messages { - if strings.Contains(msg.Content, "<think>") { - t.Errorf("msg contains think tag; msg: %s\n", msg.Content) - } - } - }) - } + t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { + cfg = &config.Config{ToolRole: "tool"} // Initialize cfg.ToolRole for test + mNum := len(tc.cb.Messages) + removeThinking(tc.cb) + if len(tc.cb.Messages) != mNum-int(tc.toolMsgs) { + t.Errorf("failed to delete tools msg %v; expected %d, got %d", tc.cb.Messages, mNum-int(tc.toolMsgs), len(tc.cb.Messages)) + } + for _, msg := range tc.cb.Messages { + if strings.Contains(msg.Content, "<think>") { + t.Errorf("msg contains think tag; msg: %s\n", msg.Content) + } + } + }) } } diff --git a/models/deepseek.go b/models/deepseek.go new file mode 100644 index 0000000..8f9868d --- /dev/null +++ b/models/deepseek.go @@ -0,0 +1,144 @@ +package models + +type DSChatReq struct { + Messages []RoleMsg `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream"` + FrequencyPenalty int `json:"frequency_penalty"` + MaxTokens int `json:"max_tokens"` + PresencePenalty int `json:"presence_penalty"` + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` + // ResponseFormat struct { + // Type string `json:"type"` + // } `json:"response_format"` + // Stop any `json:"stop"` + // StreamOptions any `json:"stream_options"` + // Tools any `json:"tools"` + // ToolChoice string `json:"tool_choice"` + // Logprobs bool `json:"logprobs"` + // TopLogprobs any `json:"top_logprobs"` +} + +func NewDSChatReq(cb ChatBody) DSChatReq { + return DSChatReq{ + Messages: cb.Messages, + Model: cb.Model, + Stream: cb.Stream, + MaxTokens: 2048, + PresencePenalty: 0, + FrequencyPenalty: 0, + Temperature: 1.0, + TopP: 1.0, + } +} + +type DSCompletionReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Echo bool `json:"echo"` + FrequencyPenalty int `json:"frequency_penalty"` + // Logprobs int `json:"logprobs"` + MaxTokens int `json:"max_tokens"` + PresencePenalty int `json:"presence_penalty"` + Stop any `json:"stop"` + Stream bool `json:"stream"` + StreamOptions any `json:"stream_options"` + Suffix any `json:"suffix"` + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` +} + +func NewDSCompletionReq(prompt, model string, temp float32, stopSlice []string) DSCompletionReq { + return DSCompletionReq{ + Model: model, + Prompt: prompt, + Temperature: temp, + Stream: true, + Echo: false, + MaxTokens: 2048, + PresencePenalty: 0, + FrequencyPenalty: 0, + TopP: 1.0, + Stop: stopSlice, + } +} + +type DSCompletionResp struct { + ID string `json:"id"` + Choices []struct { + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Logprobs struct { + TextOffset []int `json:"text_offset"` + TokenLogprobs []int `json:"token_logprobs"` + Tokens []string `json:"tokens"` + TopLogprobs []struct { + } `json:"top_logprobs"` + } `json:"logprobs"` + Text string `json:"text"` + } `json:"choices"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Object string `json:"object"` + Usage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details"` + } `json:"usage"` +} + +type DSChatResp struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + Role any `json:"role"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Logprobs any `json:"logprobs"` + } `json:"choices"` + Created int `json:"created"` + ID string `json:"id"` + Model string `json:"model"` + Object string `json:"object"` + SystemFingerprint string `json:"system_fingerprint"` + Usage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type DSChatStreamResp struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []struct { + Index int `json:"index"` + Delta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + } `json:"delta"` + Logprobs any `json:"logprobs"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +type DSBalance struct { + IsAvailable bool `json:"is_available"` + BalanceInfos []struct { + Currency string `json:"currency"` + TotalBalance string `json:"total_balance"` + GrantedBalance string `json:"granted_balance"` + ToppedUpBalance string `json:"topped_up_balance"` + } `json:"balance_infos"` +} diff --git a/models/embed.go b/models/embed.go new file mode 100644 index 0000000..078312c --- /dev/null +++ b/models/embed.go @@ -0,0 +1,15 @@ +package models + +type LCPEmbedResp struct { + Model string `json:"model"` + Object string `json:"object"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Data []struct { + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object"` + } `json:"data"` +} diff --git a/models/extra.go b/models/extra.go new file mode 100644 index 0000000..e1ca80f --- /dev/null +++ b/models/extra.go @@ -0,0 +1,8 @@ +package models + +type AudioFormat string + +const ( + AFWav AudioFormat = "wav" + AFMP3 AudioFormat = "mp3" +) diff --git a/models/models.go b/models/models.go index bb61abf..912f72b 100644 --- a/models/models.go +++ b/models/models.go @@ -1,14 +1,17 @@ package models import ( - "elefant/config" + "encoding/base64" + "encoding/json" "fmt" + "os" "strings" ) type FuncCall struct { - Name string `json:"name"` - Args []string `json:"args"` + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]string `json:"args"` } type LLMResp struct { @@ -31,13 +34,25 @@ type LLMResp struct { ID string `json:"id"` } +type ToolDeltaFunc struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type ToolDeltaResp struct { + ID string `json:"id,omitempty"` + Index int `json:"index"` + Function ToolDeltaFunc `json:"function"` +} + // for streaming type LLMRespChunk struct { Choices []struct { FinishReason string `json:"finish_reason"` Index int `json:"index"` Delta struct { - Content string `json:"content"` + Content string `json:"content"` + ToolCalls []ToolDeltaResp `json:"tool_calls"` } `json:"delta"` } `json:"choices"` Created int `json:"created"` @@ -51,23 +66,269 @@ type LLMRespChunk struct { } `json:"usage"` } +type TextChunk struct { + Chunk string + ToolChunk string + Finished bool + ToolResp bool + FuncName string + ToolID string +} + +type TextContentPart struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type ImageContentPart struct { + Type string `json:"type"` + ImageURL struct { + URL string `json:"url"` + } `json:"image_url"` +} + +// RoleMsg represents a message with content that can be either a simple string or structured content parts type RoleMsg struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"-"` + ContentParts []interface{} `json:"-"` + ToolCallID string `json:"tool_call_id,omitempty"` // For tool response messages + hasContentParts bool // Flag to indicate which content type to marshal +} + +// MarshalJSON implements custom JSON marshaling for RoleMsg +func (m RoleMsg) MarshalJSON() ([]byte, error) { + if m.hasContentParts { + // Use structured content format + aux := struct { + Role string `json:"role"` + Content []interface{} `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{ + Role: m.Role, + Content: m.ContentParts, + ToolCallID: m.ToolCallID, + } + return json.Marshal(aux) + } else { + // Use simple content format + aux := struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{ + Role: m.Role, + Content: m.Content, + ToolCallID: m.ToolCallID, + } + return json.Marshal(aux) + } } -func (m RoleMsg) ToText(i int, cfg *config.Config) string { +// UnmarshalJSON implements custom JSON unmarshaling for RoleMsg +func (m *RoleMsg) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as structured content format + var structured struct { + Role string `json:"role"` + Content []interface{} `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` + } + if err := json.Unmarshal(data, &structured); err == nil && len(structured.Content) > 0 { + m.Role = structured.Role + m.ContentParts = structured.Content + m.ToolCallID = structured.ToolCallID + m.hasContentParts = true + return nil + } + + // Otherwise, unmarshal as simple content format + var simple struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` + } + if err := json.Unmarshal(data, &simple); err != nil { + return err + } + m.Role = simple.Role + m.Content = simple.Content + m.ToolCallID = simple.ToolCallID + m.hasContentParts = false + return nil +} + +func (m RoleMsg) ToText(i int) string { icon := fmt.Sprintf("(%d)", i) + + // Convert content to string representation + contentStr := "" + if !m.hasContentParts { + contentStr = m.Content + } else { + // For structured content, just take the text parts + var textParts []string + for _, part := range m.ContentParts { + if partMap, ok := part.(map[string]interface{}); ok { + if partType, exists := partMap["type"]; exists && partType == "text" { + if textVal, textExists := partMap["text"]; textExists { + if textStr, isStr := textVal.(string); isStr { + textParts = append(textParts, textStr) + } + } + } + } + } + contentStr = strings.Join(textParts, " ") + " " + } + // check if already has role annotation (/completion makes them) - if !strings.HasPrefix(m.Content, m.Role+":") { + if !strings.HasPrefix(contentStr, m.Role+":") { icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role) } - textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, m.Content) + textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, contentStr) return strings.ReplaceAll(textMsg, "\n\n", "\n") } func (m RoleMsg) ToPrompt() string { - return strings.ReplaceAll(fmt.Sprintf("%s:\n%s", m.Role, m.Content), "\n\n", "\n") + contentStr := "" + if !m.hasContentParts { + contentStr = m.Content + } else { + // For structured content, just take the text parts + var textParts []string + for _, part := range m.ContentParts { + if partMap, ok := part.(map[string]interface{}); ok { + if partType, exists := partMap["type"]; exists && partType == "text" { + if textVal, textExists := partMap["text"]; textExists { + if textStr, isStr := textVal.(string); isStr { + textParts = append(textParts, textStr) + } + } + } + } + } + contentStr = strings.Join(textParts, " ") + " " + } + return strings.ReplaceAll(fmt.Sprintf("%s:\n%s", m.Role, contentStr), "\n\n", "\n") +} + +// NewRoleMsg creates a simple RoleMsg with string content +func NewRoleMsg(role, content string) RoleMsg { + return RoleMsg{ + Role: role, + Content: content, + hasContentParts: false, + } +} + +// NewMultimodalMsg creates a RoleMsg with structured content parts (text and images) +func NewMultimodalMsg(role string, contentParts []interface{}) RoleMsg { + return RoleMsg{ + Role: role, + ContentParts: contentParts, + hasContentParts: true, + } +} + +// HasContent returns true if the message has either string content or structured content parts +func (m RoleMsg) HasContent() bool { + if m.Content != "" { + return true + } + if m.hasContentParts && len(m.ContentParts) > 0 { + return true + } + return false +} + +// IsContentParts returns true if the message uses structured content parts +func (m RoleMsg) IsContentParts() bool { + return m.hasContentParts +} + +// GetContentParts returns the content parts of the message +func (m RoleMsg) GetContentParts() []interface{} { + return m.ContentParts +} + +// Copy creates a copy of the RoleMsg with all fields +func (m RoleMsg) Copy() RoleMsg { + return RoleMsg{ + Role: m.Role, + Content: m.Content, + ContentParts: m.ContentParts, + ToolCallID: m.ToolCallID, + hasContentParts: m.hasContentParts, + } +} + +// AddTextPart adds a text content part to the message +func (m *RoleMsg) AddTextPart(text string) { + if !m.hasContentParts { + // Convert to content parts format + if m.Content != "" { + m.ContentParts = []interface{}{TextContentPart{Type: "text", Text: m.Content}} + } else { + m.ContentParts = []interface{}{} + } + m.hasContentParts = true + } + + textPart := TextContentPart{Type: "text", Text: text} + m.ContentParts = append(m.ContentParts, textPart) +} + +// AddImagePart adds an image content part to the message +func (m *RoleMsg) AddImagePart(imageURL string) { + if !m.hasContentParts { + // Convert to content parts format + if m.Content != "" { + m.ContentParts = []interface{}{TextContentPart{Type: "text", Text: m.Content}} + } else { + m.ContentParts = []interface{}{} + } + m.hasContentParts = true + } + + imagePart := ImageContentPart{ + Type: "image_url", + ImageURL: struct { + URL string `json:"url"` + }{URL: imageURL}, + } + m.ContentParts = append(m.ContentParts, imagePart) +} + +// CreateImageURLFromPath creates a data URL from an image file path +func CreateImageURLFromPath(imagePath string) (string, error) { + // Read the image file + data, err := os.ReadFile(imagePath) + if err != nil { + return "", err + } + + // Determine the image format based on file extension + var mimeType string + switch { + case strings.HasSuffix(strings.ToLower(imagePath), ".png"): + mimeType = "image/png" + case strings.HasSuffix(strings.ToLower(imagePath), ".jpg"): + fallthrough + case strings.HasSuffix(strings.ToLower(imagePath), ".jpeg"): + mimeType = "image/jpeg" + case strings.HasSuffix(strings.ToLower(imagePath), ".gif"): + mimeType = "image/gif" + case strings.HasSuffix(strings.ToLower(imagePath), ".webp"): + mimeType = "image/webp" + default: + mimeType = "image/jpeg" // default + } + + // Encode to base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // Create data URL + return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), nil } type ChatBody struct { @@ -76,31 +337,37 @@ type ChatBody struct { Messages []RoleMsg `json:"messages"` } -type ChatToolsBody struct { - Model string `json:"model"` - Messages []RoleMsg `json:"messages"` - Tools []struct { - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters struct { - Type string `json:"type"` - Properties struct { - Location struct { - Type string `json:"type"` - Description string `json:"description"` - } `json:"location"` - Unit struct { - Type string `json:"type"` - Enum []string `json:"enum"` - } `json:"unit"` - } `json:"properties"` - Required []string `json:"required"` - } `json:"parameters"` - } `json:"function"` - } `json:"tools"` - ToolChoice string `json:"tool_choice"` +func (cb *ChatBody) Rename(oldname, newname string) { + for i, m := range cb.Messages { + cb.Messages[i].Content = strings.ReplaceAll(m.Content, oldname, newname) + cb.Messages[i].Role = strings.ReplaceAll(m.Role, oldname, newname) + } +} + +func (cb *ChatBody) ListRoles() []string { + namesMap := make(map[string]struct{}) + for _, m := range cb.Messages { + namesMap[m.Role] = struct{}{} + } + resp := make([]string, len(namesMap)) + i := 0 + for k := range namesMap { + resp[i] = k + i++ + } + return resp +} + +func (cb *ChatBody) MakeStopSlice() []string { + namesMap := make(map[string]struct{}) + for _, m := range cb.Messages { + namesMap[m.Role] = struct{}{} + } + ss := []string{"<|im_end|>"} + for k := range namesMap { + ss = append(ss, k+":\n") + } + return ss } type EmbeddingResp struct { @@ -122,33 +389,66 @@ type EmbeddingResp struct { // } `json:"data"` // } -type LLMModels struct { - Object string `json:"object"` - Data []struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - OwnedBy string `json:"owned_by"` - Meta struct { - VocabType int `json:"vocab_type"` - NVocab int `json:"n_vocab"` - NCtxTrain int `json:"n_ctx_train"` - NEmbd int `json:"n_embd"` - NParams int64 `json:"n_params"` - Size int64 `json:"size"` - } `json:"meta"` - } `json:"data"` +// === tools models + +type ToolArgProps struct { + Type string `json:"type"` + Description string `json:"description"` +} + +type ToolFuncParams struct { + Type string `json:"type"` + Properties map[string]ToolArgProps `json:"properties"` + Required []string `json:"required"` +} + +type ToolFunc struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolFuncParams `json:"parameters"` +} + +type Tool struct { + Type string `json:"type"` + Function ToolFunc `json:"function"` +} + +type OpenAIReq struct { + *ChatBody + Tools []Tool `json:"tools"` } +// === + +// type LLMModels struct { +// Object string `json:"object"` +// Data []struct { +// ID string `json:"id"` +// Object string `json:"object"` +// Created int `json:"created"` +// OwnedBy string `json:"owned_by"` +// Meta struct { +// VocabType int `json:"vocab_type"` +// NVocab int `json:"n_vocab"` +// NCtxTrain int `json:"n_ctx_train"` +// NEmbd int `json:"n_embd"` +// NParams int64 `json:"n_params"` +// Size int64 `json:"size"` +// } `json:"meta"` +// } `json:"data"` +// } + type LlamaCPPReq struct { - Stream bool `json:"stream"` - // Messages []RoleMsg `json:"messages"` - Prompt string `json:"prompt"` - Temperature float32 `json:"temperature"` - DryMultiplier float32 `json:"dry_multiplier"` - Stop []string `json:"stop"` - MinP float32 `json:"min_p"` - NPredict int32 `json:"n_predict"` + Model string `json:"model"` + Stream bool `json:"stream"` + // For multimodal requests, prompt should be an object with prompt_string and multimodal_data + // For regular requests, prompt is a string + Prompt interface{} `json:"prompt"` // Can be string or object with prompt_string and multimodal_data + Temperature float32 `json:"temperature"` + DryMultiplier float32 `json:"dry_multiplier"` + Stop []string `json:"stop"` + MinP float32 `json:"min_p"` + NPredict int32 `json:"n_predict"` // MaxTokens int `json:"max_tokens"` // DryBase float64 `json:"dry_base"` // DryAllowedLength int `json:"dry_allowed_length"` @@ -168,21 +468,36 @@ type LlamaCPPReq struct { // Samplers string `json:"samplers"` } -func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) LlamaCPPReq { +type PromptObject struct { + PromptString string `json:"prompt_string"` + MultimodalData []string `json:"multimodal_data,omitempty"` + // Alternative field name used by some llama.cpp implementations + ImageData []string `json:"image_data,omitempty"` // For compatibility +} + +func NewLCPReq(prompt, model string, multimodalData []string, props map[string]float32, stopStrings []string) LlamaCPPReq { + var finalPrompt interface{} + if len(multimodalData) > 0 { + // When multimodal data is present, use the object format as per Python example: + // { "prompt": { "prompt_string": "...", "multimodal_data": [...] } } + finalPrompt = PromptObject{ + PromptString: prompt, + MultimodalData: multimodalData, + ImageData: multimodalData, // Also populate for compatibility with different llama.cpp versions + } + } else { + // When no multimodal data, use plain string + finalPrompt = prompt + } return LlamaCPPReq{ - Stream: true, - Prompt: prompt, - // Temperature: 0.8, - // DryMultiplier: 0.5, + Model: model, + Stream: true, + Prompt: finalPrompt, Temperature: props["temperature"], DryMultiplier: props["dry_multiplier"], + Stop: stopStrings, MinP: props["min_p"], NPredict: int32(props["n_predict"]), - Stop: []string{ - cfg.UserRole + ":\n", "<|im_end|>", - cfg.ToolRole + ":\n", - cfg.AssistantRole + ":\n", - }, } } @@ -190,3 +505,27 @@ type LlamaCPPResp struct { Content string `json:"content"` Stop bool `json:"stop"` } + +type LCPModels struct { + Data []struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Created int `json:"created"` + InCache bool `json:"in_cache"` + Path string `json:"path"` + Status struct { + Value string `json:"value"` + Args []string `json:"args"` + } `json:"status"` + } `json:"data"` + Object string `json:"object"` +} + +func (lcp *LCPModels) ListModels() []string { + resp := []string{} + for _, model := range lcp.Data { + resp = append(resp, model.ID) + } + return resp +} diff --git a/models/openrouter.go b/models/openrouter.go new file mode 100644 index 0000000..50f26b6 --- /dev/null +++ b/models/openrouter.go @@ -0,0 +1,157 @@ +package models + +// openrouter +// https://openrouter.ai/docs/api-reference/completion +type OpenRouterCompletionReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` + Temperature float32 `json:"temperature"` + Stop []string `json:"stop"` // not present in docs + MinP float32 `json:"min_p"` + NPredict int32 `json:"max_tokens"` +} + +func NewOpenRouterCompletionReq(model, prompt string, props map[string]float32, stopStrings []string) OpenRouterCompletionReq { + return OpenRouterCompletionReq{ + Stream: true, + Prompt: prompt, + Temperature: props["temperature"], + MinP: props["min_p"], + NPredict: int32(props["n_predict"]), + Stop: stopStrings, + Model: model, + } +} + +type OpenRouterChatReq struct { + Messages []RoleMsg `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream"` + Temperature float32 `json:"temperature"` + MinP float32 `json:"min_p"` + NPredict int32 `json:"max_tokens"` + Tools []Tool `json:"tools"` +} + +func NewOpenRouterChatReq(cb ChatBody, props map[string]float32) OpenRouterChatReq { + return OpenRouterChatReq{ + Messages: cb.Messages, + Model: cb.Model, + Stream: cb.Stream, + Temperature: props["temperature"], + MinP: props["min_p"], + NPredict: int32(props["n_predict"]), + } +} + +type OpenRouterChatRespNonStream struct { + ID string `json:"id"` + Provider string `json:"provider"` + Model string `json:"model"` + Object string `json:"object"` + Created int `json:"created"` + Choices []struct { + Logprobs any `json:"logprobs"` + FinishReason string `json:"finish_reason"` + NativeFinishReason string `json:"native_finish_reason"` + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal any `json:"refusal"` + Reasoning any `json:"reasoning"` + ToolCalls []ToolDeltaResp `json:"tool_calls"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type OpenRouterChatResp struct { + ID string `json:"id"` + Provider string `json:"provider"` + Model string `json:"model"` + Object string `json:"object"` + Created int `json:"created"` + Choices []struct { + Index int `json:"index"` + Delta struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolDeltaResp `json:"tool_calls"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + NativeFinishReason string `json:"native_finish_reason"` + Logprobs any `json:"logprobs"` + } `json:"choices"` +} + +type OpenRouterCompletionResp struct { + ID string `json:"id"` + Provider string `json:"provider"` + Model string `json:"model"` + Object string `json:"object"` + Created int `json:"created"` + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + NativeFinishReason string `json:"native_finish_reason"` + Logprobs any `json:"logprobs"` + } `json:"choices"` +} + +type ORModel struct { + ID string `json:"id"` + CanonicalSlug string `json:"canonical_slug"` + HuggingFaceID string `json:"hugging_face_id"` + Name string `json:"name"` + Created int `json:"created"` + Description string `json:"description"` + ContextLength int `json:"context_length"` + Architecture struct { + Modality string `json:"modality"` + InputModalities []string `json:"input_modalities"` + OutputModalities []string `json:"output_modalities"` + Tokenizer string `json:"tokenizer"` + InstructType any `json:"instruct_type"` + } `json:"architecture"` + Pricing struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` + Request string `json:"request"` + Image string `json:"image"` + Audio string `json:"audio"` + WebSearch string `json:"web_search"` + InternalReasoning string `json:"internal_reasoning"` + } `json:"pricing,omitempty"` + TopProvider struct { + ContextLength int `json:"context_length"` + MaxCompletionTokens int `json:"max_completion_tokens"` + IsModerated bool `json:"is_moderated"` + } `json:"top_provider"` + PerRequestLimits any `json:"per_request_limits"` + SupportedParameters []string `json:"supported_parameters"` +} + +type ORModels struct { + Data []ORModel `json:"data"` +} + +func (orm *ORModels) ListModels(free bool) []string { + resp := []string{} + for _, model := range orm.Data { + if free { + if model.Pricing.Prompt == "0" && model.Pricing.Request == "0" && + model.Pricing.Completion == "0" { + resp = append(resp, model.ID) + } + } else { + resp = append(resp, model.ID) + } + } + return resp +} diff --git a/pngmeta/altwriter.go b/pngmeta/altwriter.go new file mode 100644 index 0000000..206b563 --- /dev/null +++ b/pngmeta/altwriter.go @@ -0,0 +1,133 @@ +package pngmeta + +import ( + "bytes" + "gf-lt/models" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "hash/crc32" + "io" + "os" +) + +const ( + pngHeader = "\x89PNG\r\n\x1a\n" + textChunkType = "tEXt" +) + +// WriteToPng embeds the metadata into the specified PNG file and writes the result to outfile. +func WriteToPng(metadata *models.CharCardSpec, sourcePath, outfile string) error { + pngData, err := os.ReadFile(sourcePath) + if err != nil { + return err + } + jsonData, err := json.Marshal(metadata) + if err != nil { + return err + } + base64Data := base64.StdEncoding.EncodeToString(jsonData) + embedData := PngEmbed{ + Key: "gf-lt", // Replace with appropriate key constant + Value: base64Data, + } + var outputBuffer bytes.Buffer + if _, err := outputBuffer.Write([]byte(pngHeader)); err != nil { + return err + } + chunks, iend, err := processChunks(pngData[8:]) + if err != nil { + return err + } + for _, chunk := range chunks { + outputBuffer.Write(chunk) + } + newChunk, err := createTextChunk(embedData) + if err != nil { + return err + } + outputBuffer.Write(newChunk) + outputBuffer.Write(iend) + return os.WriteFile(outfile, outputBuffer.Bytes(), 0666) +} + +// processChunks extracts non-tEXt chunks and locates the IEND chunk +func processChunks(data []byte) ([][]byte, []byte, error) { + var ( + chunks [][]byte + iendChunk []byte + reader = bytes.NewReader(data) + ) + for { + var chunkLength uint32 + if err := binary.Read(reader, binary.BigEndian, &chunkLength); err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, nil, fmt.Errorf("error reading chunk length: %w", err) + } + chunkType := make([]byte, 4) + if _, err := reader.Read(chunkType); err != nil { + return nil, nil, fmt.Errorf("error reading chunk type: %w", err) + } + chunkData := make([]byte, chunkLength) + if _, err := reader.Read(chunkData); err != nil { + return nil, nil, fmt.Errorf("error reading chunk data: %w", err) + } + crc := make([]byte, 4) + if _, err := reader.Read(crc); err != nil { + return nil, nil, fmt.Errorf("error reading CRC: %w", err) + } + fullChunk := bytes.NewBuffer(nil) + if err := binary.Write(fullChunk, binary.BigEndian, chunkLength); err != nil { + return nil, nil, fmt.Errorf("error writing chunk length: %w", err) + } + if _, err := fullChunk.Write(chunkType); err != nil { + return nil, nil, fmt.Errorf("error writing chunk type: %w", err) + } + if _, err := fullChunk.Write(chunkData); err != nil { + return nil, nil, fmt.Errorf("error writing chunk data: %w", err) + } + if _, err := fullChunk.Write(crc); err != nil { + return nil, nil, fmt.Errorf("error writing CRC: %w", err) + } + switch string(chunkType) { + case "IEND": + iendChunk = fullChunk.Bytes() + return chunks, iendChunk, nil + case textChunkType: + continue // Skip existing tEXt chunks + default: + chunks = append(chunks, fullChunk.Bytes()) + } + } + return nil, nil, errors.New("IEND chunk not found") +} + +// createTextChunk generates a valid tEXt chunk with proper CRC +func createTextChunk(embed PngEmbed) ([]byte, error) { + content := bytes.NewBuffer(nil) + content.WriteString(embed.Key) + content.WriteByte(0) // Null separator + content.WriteString(embed.Value) + data := content.Bytes() + crc := crc32.NewIEEE() + crc.Write([]byte(textChunkType)) + crc.Write(data) + chunk := bytes.NewBuffer(nil) + if err := binary.Write(chunk, binary.BigEndian, uint32(len(data))); err != nil { + return nil, fmt.Errorf("error writing chunk length: %w", err) + } + if _, err := chunk.Write([]byte(textChunkType)); err != nil { + return nil, fmt.Errorf("error writing chunk type: %w", err) + } + if _, err := chunk.Write(data); err != nil { + return nil, fmt.Errorf("error writing chunk data: %w", err) + } + if err := binary.Write(chunk, binary.BigEndian, crc.Sum32()); err != nil { + return nil, fmt.Errorf("error writing CRC: %w", err) + } + return chunk.Bytes(), nil +} diff --git a/pngmeta/metareader.go b/pngmeta/metareader.go index df2a8d4..7053546 100644 --- a/pngmeta/metareader.go +++ b/pngmeta/metareader.go @@ -2,11 +2,11 @@ package pngmeta import ( "bytes" - "elefant/models" "encoding/base64" "encoding/json" "errors" "fmt" + "gf-lt/models" "io" "log/slog" "os" @@ -22,8 +22,6 @@ const ( writeHeader = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A" ) -var tEXtChunkDataSpecification = "%s\x00%s" - type PngEmbed struct { Key string Value string @@ -97,12 +95,12 @@ func ReadCard(fname, uname string) (*models.CharCard, error) { return nil, err } if charSpec.Name == "" { - return nil, fmt.Errorf("failed to find role; fname %s\n", fname) + return nil, fmt.Errorf("failed to find role; fname %s", fname) } return charSpec.Simplify(uname, fname), nil } -func readCardJson(fname string) (*models.CharCard, error) { +func ReadCardJson(fname string) (*models.CharCard, error) { data, err := os.ReadFile(fname) if err != nil { return nil, err @@ -128,17 +126,17 @@ func ReadDirCards(dirname, uname string, log *slog.Logger) ([]*models.CharCard, fpath := path.Join(dirname, f.Name()) cc, err := ReadCard(fpath, uname) if err != nil { - log.Warn("failed to load card", "error", err) + log.Warn("failed to load card", "error", err, "card", fpath) continue - // return nil, err // better to log and continue } resp = append(resp, cc) } if strings.HasSuffix(f.Name(), ".json") { fpath := path.Join(dirname, f.Name()) - cc, err := readCardJson(fpath) + cc, err := ReadCardJson(fpath) if err != nil { - return nil, err // better to log and continue + log.Warn("failed to load card", "error", err, "card", fpath) + continue } cc.FirstMsg = strings.ReplaceAll(strings.ReplaceAll(cc.FirstMsg, "{{char}}", cc.Role), "{{user}}", uname) cc.SysPrompt = strings.ReplaceAll(strings.ReplaceAll(cc.SysPrompt, "{{char}}", cc.Role), "{{user}}", uname) diff --git a/pngmeta/metareader_test.go b/pngmeta/metareader_test.go index 5d9a0e2..f88de06 100644 --- a/pngmeta/metareader_test.go +++ b/pngmeta/metareader_test.go @@ -1,7 +1,19 @@ package pngmeta import ( + "bytes" + "gf-lt/models" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" "fmt" + "image" + "image/color" + "image/png" + "io" + "os" + "path/filepath" "testing" ) @@ -28,3 +40,155 @@ func TestReadMeta(t *testing.T) { }) } } + +// Test helper: Create a simple PNG image with test shapes +func createTestImage(t *testing.T) string { + img := image.NewRGBA(image.Rect(0, 0, 200, 200)) + // Fill background with white + for y := 0; y < 200; y++ { + for x := 0; x < 200; x++ { + img.Set(x, y, color.White) + } + } + // Draw a red square + for y := 50; y < 150; y++ { + for x := 50; x < 150; x++ { + img.Set(x, y, color.RGBA{R: 255, A: 255}) + } + } + // Draw a blue circle + center := image.Point{100, 100} + radius := 40 + for y := center.Y - radius; y <= center.Y+radius; y++ { + for x := center.X - radius; x <= center.X+radius; x++ { + dx := x - center.X + dy := y - center.Y + if dx*dx+dy*dy <= radius*radius { + img.Set(x, y, color.RGBA{B: 255, A: 255}) + } + } + } + // Create temp file + tmpDir := t.TempDir() + fpath := filepath.Join(tmpDir, "test-image.png") + f, err := os.Create(fpath) + if err != nil { + t.Fatalf("Error creating temp file: %v", err) + } + defer f.Close() + if err := png.Encode(f, img); err != nil { + t.Fatalf("Error encoding PNG: %v", err) + } + return fpath +} + +func TestWriteToPng(t *testing.T) { + // Create test image + srcPath := createTestImage(t) + dstPath := filepath.Join(filepath.Dir(srcPath), "output.png") + // dstPath := "test.png" + // Create test metadata + metadata := &models.CharCardSpec{ + Description: "Test image containing a red square and blue circle on white background", + } + // Embed metadata + if err := WriteToPng(metadata, srcPath, dstPath); err != nil { + t.Fatalf("WriteToPng failed: %v", err) + } + // Verify output file exists + if _, err := os.Stat(dstPath); os.IsNotExist(err) { + t.Fatalf("Output file not created: %v", err) + } + // Read and verify metadata + t.Run("VerifyMetadata", func(t *testing.T) { + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("Error reading output file: %v", err) + } + // Verify PNG header + if string(data[:8]) != pngHeader { + t.Errorf("Invalid PNG header") + } + // Extract metadata + embedded := extractMetadata(t, data) + if embedded.Description != metadata.Description { + t.Errorf("Metadata mismatch\nWant: %q\nGot: %q", + metadata.Description, embedded.Description) + } + }) + // Optional: Add cleanup if needed + // t.Cleanup(func() { + // os.Remove(dstPath) + // }) +} + +// Helper to extract embedded metadata from PNG bytes +func extractMetadata(t *testing.T, data []byte) *models.CharCardSpec { + r := bytes.NewReader(data[8:]) // Skip PNG header + for { + var length uint32 + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("Error reading chunk length: %v", err) + } + chunkType := make([]byte, 4) + if _, err := r.Read(chunkType); err != nil { + t.Fatalf("Error reading chunk type: %v", err) + } + // Read chunk data + chunkData := make([]byte, length) + if _, err := r.Read(chunkData); err != nil { + t.Fatalf("Error reading chunk data: %v", err) + } + // Read and discard CRC + if _, err := r.Read(make([]byte, 4)); err != nil { + t.Fatalf("Error reading CRC: %v", err) + } + if string(chunkType) == embType { + parts := bytes.SplitN(chunkData, []byte{0}, 2) + if len(parts) != 2 { + t.Fatalf("Invalid tEXt chunk format") + } + decoded, err := base64.StdEncoding.DecodeString(string(parts[1])) + if err != nil { + t.Fatalf("Base64 decode error: %v", err) + } + var result models.CharCardSpec + if err := json.Unmarshal(decoded, &result); err != nil { + t.Fatalf("JSON unmarshal error: %v", err) + } + return &result + } + } + t.Fatal("Metadata not found in PNG") + return nil +} + +func readTextChunk(t *testing.T, r io.ReadSeeker) *models.CharCardSpec { + var length uint32 + binary.Read(r, binary.BigEndian, &length) + chunkType := make([]byte, 4) + r.Read(chunkType) + data := make([]byte, length) + r.Read(data) + // Read CRC (but skip validation for test purposes) + crc := make([]byte, 4) + r.Read(crc) + parts := bytes.SplitN(data, []byte{0}, 2) // Split key-value pair + if len(parts) != 2 { + t.Fatalf("Invalid tEXt chunk format") + } + // key := string(parts[0]) + value := parts[1] + decoded, err := base64.StdEncoding.DecodeString(string(value)) + if err != nil { + t.Fatalf("Base64 decode error: %v; value: %s", err, string(value)) + } + var result models.CharCardSpec + if err := json.Unmarshal(decoded, &result); err != nil { + t.Fatalf("JSON unmarshal error: %v", err) + } + return &result +} diff --git a/pngmeta/partswriter.go b/pngmeta/partswriter.go index 7c36daf..7282df6 100644 --- a/pngmeta/partswriter.go +++ b/pngmeta/partswriter.go @@ -1,116 +1,112 @@ package pngmeta -import ( - "bytes" - "elefant/models" - "encoding/base64" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "hash/crc32" - "io" - "os" -) +// import ( +// "bytes" +// "encoding/binary" +// "errors" +// "fmt" +// "hash/crc32" +// "io" +// ) -type Writer struct { - w io.Writer -} +// type Writer struct { +// w io.Writer +// } -func NewPNGWriter(w io.Writer) (*Writer, error) { - if _, err := io.WriteString(w, writeHeader); err != nil { - return nil, err - } - return &Writer{w}, nil -} +// func NewPNGWriter(w io.Writer) (*Writer, error) { +// if _, err := io.WriteString(w, writeHeader); err != nil { +// return nil, err +// } +// return &Writer{w}, nil +// } -func (w *Writer) WriteChunk(length int32, typ string, r io.Reader) error { - if err := binary.Write(w.w, binary.BigEndian, length); err != nil { - return err - } - if _, err := w.w.Write([]byte(typ)); err != nil { - return err - } - checksummer := crc32.NewIEEE() - checksummer.Write([]byte(typ)) - if _, err := io.CopyN(io.MultiWriter(w.w, checksummer), r, int64(length)); err != nil { - return err - } - if err := binary.Write(w.w, binary.BigEndian, checksummer.Sum32()); err != nil { - return err - } - return nil -} +// func (w *Writer) WriteChunk(length int32, typ string, r io.Reader) error { +// if err := binary.Write(w.w, binary.BigEndian, length); err != nil { +// return err +// } +// if _, err := w.w.Write([]byte(typ)); err != nil { +// return err +// } +// checksummer := crc32.NewIEEE() +// checksummer.Write([]byte(typ)) +// if _, err := io.CopyN(io.MultiWriter(w.w, checksummer), r, int64(length)); err != nil { +// return err +// } +// if err := binary.Write(w.w, binary.BigEndian, checksummer.Sum32()); err != nil { +// return err +// } +// return nil +// } -func WriteToPng(c *models.CharCardSpec, fpath, outfile string) error { - data, err := os.ReadFile(fpath) - if err != nil { - return err - } - jsonData, err := json.Marshal(c) - if err != nil { - return err - } - // Base64 encode the JSON data - base64Data := base64.StdEncoding.EncodeToString(jsonData) - pe := PngEmbed{ - Key: cKey, - Value: base64Data, - } - w, err := WritetEXtToPngBytes(data, pe) - if err != nil { - return err - } - return os.WriteFile(outfile, w.Bytes(), 0666) -} +// func WWriteToPngriteToPng(c *models.CharCardSpec, fpath, outfile string) error { +// data, err := os.ReadFile(fpath) +// if err != nil { +// return err +// } +// jsonData, err := json.Marshal(c) +// if err != nil { +// return err +// } +// // Base64 encode the JSON data +// base64Data := base64.StdEncoding.EncodeToString(jsonData) +// pe := PngEmbed{ +// Key: cKey, +// Value: base64Data, +// } +// w, err := WritetEXtToPngBytes(data, pe) +// if err != nil { +// return err +// } +// return os.WriteFile(outfile, w.Bytes(), 0666) +// } -func WritetEXtToPngBytes(inputBytes []byte, pe PngEmbed) (outputBytes bytes.Buffer, err error) { - if !(string(inputBytes[:8]) == header) { - return outputBytes, errors.New("wrong file format") - } - reader := bytes.NewReader(inputBytes) - pngr, err := NewPNGStepReader(reader) - if err != nil { - return outputBytes, fmt.Errorf("NewReader(): %s", err) - } - pngw, err := NewPNGWriter(&outputBytes) - if err != nil { - return outputBytes, fmt.Errorf("NewWriter(): %s", err) - } - for { - chunk, err := pngr.Next() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return outputBytes, fmt.Errorf("NextChunk(): %s", err) - } - if chunk.Type() != embType { - // IENDChunkType will only appear on the final iteration of a valid PNG - if chunk.Type() == IEND { - // This is where we inject tEXtChunkType as the penultimate chunk with the new value - newtEXtChunk := []byte(fmt.Sprintf(tEXtChunkDataSpecification, pe.Key, pe.Value)) - if err := pngw.WriteChunk(int32(len(newtEXtChunk)), embType, bytes.NewBuffer(newtEXtChunk)); err != nil { - return outputBytes, fmt.Errorf("WriteChunk(): %s", err) - } - // Now we end the buffer with IENDChunkType chunk - if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil { - return outputBytes, fmt.Errorf("WriteChunk(): %s", err) - } - } else { - // writes back original chunk to buffer - if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil { - return outputBytes, fmt.Errorf("WriteChunk(): %s", err) - } - } - } else { - if _, err := io.Copy(io.Discard, chunk); err != nil { - return outputBytes, fmt.Errorf("io.Copy(io.Discard, chunk): %s", err) - } - } - if err := chunk.Close(); err != nil { - return outputBytes, fmt.Errorf("chunk.Close(): %s", err) - } - } - return outputBytes, nil -} +// func WritetEXtToPngBytes(inputBytes []byte, pe PngEmbed) (outputBytes bytes.Buffer, err error) { +// if !(string(inputBytes[:8]) == header) { +// return outputBytes, errors.New("wrong file format") +// } +// reader := bytes.NewReader(inputBytes) +// pngr, err := NewPNGStepReader(reader) +// if err != nil { +// return outputBytes, fmt.Errorf("NewReader(): %s", err) +// } +// pngw, err := NewPNGWriter(&outputBytes) +// if err != nil { +// return outputBytes, fmt.Errorf("NewWriter(): %s", err) +// } +// for { +// chunk, err := pngr.Next() +// if err != nil { +// if errors.Is(err, io.EOF) { +// break +// } +// return outputBytes, fmt.Errorf("NextChunk(): %s", err) +// } +// if chunk.Type() != embType { +// // IENDChunkType will only appear on the final iteration of a valid PNG +// if chunk.Type() == IEND { +// // This is where we inject tEXtChunkType as the penultimate chunk with the new value +// newtEXtChunk := []byte(fmt.Sprintf(tEXtChunkDataSpecification, pe.Key, pe.Value)) +// if err := pngw.WriteChunk(int32(len(newtEXtChunk)), embType, bytes.NewBuffer(newtEXtChunk)); err != nil { +// return outputBytes, fmt.Errorf("WriteChunk(): %s", err) +// } +// // Now we end the buffer with IENDChunkType chunk +// if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil { +// return outputBytes, fmt.Errorf("WriteChunk(): %s", err) +// } +// } else { +// // writes back original chunk to buffer +// if err := pngw.WriteChunk(chunk.length, chunk.Type(), chunk); err != nil { +// return outputBytes, fmt.Errorf("WriteChunk(): %s", err) +// } +// } +// } else { +// if _, err := io.Copy(io.Discard, chunk); err != nil { +// return outputBytes, fmt.Errorf("io.Copy(io.Discard, chunk): %s", err) +// } +// } +// if err := chunk.Close(); err != nil { +// return outputBytes, fmt.Errorf("chunk.Close(): %s", err) +// } +// } +// return outputBytes, nil +// } diff --git a/props_table.go b/props_table.go new file mode 100644 index 0000000..7807522 --- /dev/null +++ b/props_table.go @@ -0,0 +1,305 @@ +package main + +import ( + "fmt" + "slices" + "strconv" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" +) + +// Define constants for cell types +const ( + CellTypeCheckbox = "checkbox" + CellTypeDropdown = "dropdown" + CellTypeInput = "input" + CellTypeHeader = "header" + CellTypeListPopup = "listpopup" +) + +// CellData holds additional data for each cell +type CellData struct { + Type string + Options []string + OnChange interface{} +} + +// makePropsTable creates a table-based alternative to the props form +// This allows for better key bindings and immediate effect of changes +func makePropsTable(props map[string]float32) *tview.Table { + // Create a new table + table := tview.NewTable(). + SetBorders(true). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)) // Allow row selection but not column selection + table.SetTitle("Properties Configuration (Press 'x' to exit)"). + SetTitleAlign(tview.AlignLeft) + row := 0 + // Add a header or note row + headerCell := tview.NewTableCell("Props for llamacpp completion call"). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignLeft). + SetSelectable(false) + table.SetCell(row, 0, headerCell) + table.SetCell(row, 1, + tview.NewTableCell("press 'x' to exit"). + SetTextColor(tcell.ColorYellow). + SetSelectable(false)) + row++ + // Store cell data for later use in selection functions + cellData := make(map[string]*CellData) + // Helper function to add a checkbox-like row + addCheckboxRow := func(label string, initialValue bool, onChange func(bool)) { + table.SetCell(row, 0, + tview.NewTableCell(label). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignLeft). + SetSelectable(false)) + valueText := "No" + if initialValue { + valueText = "Yes" + } + valueCell := tview.NewTableCell(valueText). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignCenter) + table.SetCell(row, 1, valueCell) + // Store cell data + cellID := fmt.Sprintf("checkbox_%d", row) + cellData[cellID] = &CellData{ + Type: CellTypeCheckbox, + OnChange: onChange, + } + row++ + } + // Helper function to add a dropdown-like row, that opens a list popup + addListPopupRow := func(label string, options []string, initialValue string, onChange func(string)) { + table.SetCell(row, 0, + tview.NewTableCell(label). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignLeft). + SetSelectable(false)) + valueCell := tview.NewTableCell(initialValue). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignCenter) + table.SetCell(row, 1, valueCell) + // Store cell data + cellID := fmt.Sprintf("listpopup_%d", row) + cellData[cellID] = &CellData{ + Type: CellTypeListPopup, + Options: options, + OnChange: onChange, + } + row++ + } + // Helper function to add an input field row + addInputRow := func(label string, initialValue string, onChange func(string)) { + table.SetCell(row, 0, + tview.NewTableCell(label). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignLeft). + SetSelectable(false)) + valueCell := tview.NewTableCell(initialValue). + SetTextColor(tcell.ColorYellow). + SetAlign(tview.AlignCenter) + table.SetCell(row, 1, valueCell) + // Store cell data + cellID := fmt.Sprintf("input_%d", row) + cellData[cellID] = &CellData{ + Type: CellTypeInput, + OnChange: onChange, + } + row++ + } + // Add checkboxes + addCheckboxRow("Insert <think> tag (/completion only)", cfg.ThinkUse, func(checked bool) { + cfg.ThinkUse = checked + }) + addCheckboxRow("RAG use", cfg.RAGEnabled, func(checked bool) { + cfg.RAGEnabled = checked + }) + addCheckboxRow("Inject role", injectRole, func(checked bool) { + injectRole = checked + }) + addCheckboxRow("TTS Enabled", cfg.TTS_ENABLED, func(checked bool) { + cfg.TTS_ENABLED = checked + }) + // Add dropdowns + logLevels := []string{"Debug", "Info", "Warn"} + addListPopupRow("Set log level", logLevels, GetLogLevel(), func(option string) { + setLogLevel(option) + }) + // Prepare API links dropdown - insert current API at the beginning + apiLinks := slices.Insert(cfg.ApiLinks, 0, cfg.CurrentAPI) + addListPopupRow("Select an api", apiLinks, cfg.CurrentAPI, func(option string) { + cfg.CurrentAPI = option + }) + // Prepare model list dropdown + modelList := []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"} + modelList = append(modelList, ORFreeModels...) + addListPopupRow("Select a model", modelList, chatBody.Model, func(option string) { + chatBody.Model = option + }) + // Role selection dropdown + addListPopupRow("Write next message as", listRolesWithUser(), cfg.WriteNextMsgAs, func(option string) { + cfg.WriteNextMsgAs = option + }) + // Add input fields + addInputRow("New char to write msg as", "", func(text string) { + if text != "" { + cfg.WriteNextMsgAs = text + } + }) + addInputRow("Username", cfg.UserRole, func(text string) { + if text != "" { + renameUser(cfg.UserRole, text) + cfg.UserRole = text + } + }) + // Add property fields (the float32 values) + for propName, value := range props { + propName := propName // capture loop variable for closure + propValue := fmt.Sprintf("%v", value) + addInputRow(propName, propValue, func(text string) { + if val, err := strconv.ParseFloat(text, 32); err == nil { + props[propName] = float32(val) + } + }) + } + // Set selection function to handle dropdown-like behavior + table.SetSelectedFunc(func(selectedRow, selectedCol int) { + // Only handle selection on the value column (column 1) + if selectedCol != 1 { + // If user selects the label column, move to the value column + if table.GetRowCount() > selectedRow && table.GetColumnCount() > 1 { + table.Select(selectedRow, 1) + } + return + } + // Get the cell and its corresponding data + cell := table.GetCell(selectedRow, selectedCol) + cellID := fmt.Sprintf("checkbox_%d", selectedRow) + // Check if it's a checkbox + if cellData[cellID] != nil && cellData[cellID].Type == CellTypeCheckbox { + data := cellData[cellID] + if onChange, ok := data.OnChange.(func(bool)); ok { + // Toggle the checkbox value + newValue := cell.Text == "No" + onChange(newValue) + if newValue { + cell.SetText("Yes") + } else { + cell.SetText("No") + } + } + return + } + // Check for dropdown + dropdownCellID := fmt.Sprintf("dropdown_%d", selectedRow) + if cellData[dropdownCellID] != nil && cellData[dropdownCellID].Type == CellTypeDropdown { + data := cellData[dropdownCellID] + if onChange, ok := data.OnChange.(func(string)); ok && data.Options != nil { + // Find current option and cycle to next + currentValue := cell.Text + currentIndex := -1 + for i, opt := range data.Options { + if opt == currentValue { + currentIndex = i + break + } + } + // Move to next option (cycle back to 0 if at end) + nextIndex := (currentIndex + 1) % len(data.Options) + newValue := data.Options[nextIndex] + onChange(newValue) + cell.SetText(newValue) + } + return + } + // Check for listpopup + listPopupCellID := fmt.Sprintf("listpopup_%d", selectedRow) + if cellData[listPopupCellID] != nil && cellData[listPopupCellID].Type == CellTypeListPopup { + data := cellData[listPopupCellID] + if onChange, ok := data.OnChange.(func(string)); ok && data.Options != nil { + // Create a list primitive + apiList := tview.NewList().ShowSecondaryText(false). + SetSelectedBackgroundColor(tcell.ColorGray) + apiList.SetTitle("Select an API").SetBorder(true) + for i, api := range data.Options { + if api == cell.Text { + apiList.SetCurrentItem(i) + } + apiList.AddItem(api, "", 0, nil) + } + apiList.SetSelectedFunc(func(index int, mainText string, secondaryText string, shortcut rune) { + onChange(mainText) + cell.SetText(mainText) + pages.RemovePage("apiListPopup") + }) + apiList.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyEscape { + pages.RemovePage("apiListPopup") + return nil + } + return event + }) + modal := func(p tview.Primitive, width, height int) tview.Primitive { + return tview.NewFlex(). + AddItem(nil, 0, 1, false). + AddItem(tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(nil, 0, 1, false). + AddItem(p, height, 1, true). + AddItem(nil, 0, 1, false), width, 1, true). + AddItem(nil, 0, 1, false) + } + // Add modal page and make it visible + pages.AddPage("apiListPopup", modal(apiList, 80, 20), true, true) + app.SetFocus(apiList) + } + return + } + // Handle input fields by creating an input modal on selection + inputCellID := fmt.Sprintf("input_%d", selectedRow) + if cellData[inputCellID] != nil && cellData[inputCellID].Type == CellTypeInput { + data := cellData[inputCellID] + if onChange, ok := data.OnChange.(func(string)); ok { + // Create an input modal + currentValue := cell.Text + inputFld := tview.NewInputField() + inputFld.SetLabel("Edit value: ") + inputFld.SetText(currentValue) + inputFld.SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEnter { + newText := inputFld.GetText() + onChange(newText) + cell.SetText(newText) // Update the table cell + } + pages.RemovePage("editModal") + }) + // Create a simple modal with the input field + modalFlex := tview.NewFlex(). + SetDirection(tview.FlexRow). + AddItem(tview.NewBox(), 0, 1, false). // Spacer + AddItem(tview.NewFlex(). + AddItem(tview.NewBox(), 0, 1, false). // Spacer + AddItem(inputFld, 30, 1, true). // Input field + AddItem(tview.NewBox(), 0, 1, false), // Spacer + 0, 1, true). + AddItem(tview.NewBox(), 0, 1, false) // Spacer + // Add modal page and make it visible + pages.AddPage("editModal", modalFlex, true, true) + } + return + } + }) + // Set input capture to handle 'x' key for exiting + table.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(propsPage) + updateStatusLine() + return nil + } + return event + }) + return table +} diff --git a/rag/embedder.go b/rag/embedder.go new file mode 100644 index 0000000..bed1b41 --- /dev/null +++ b/rag/embedder.go @@ -0,0 +1,145 @@ +package rag + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "gf-lt/config" + "gf-lt/models" + "log/slog" + "net/http" +) + +// Embedder defines the interface for embedding text +type Embedder interface { + Embed(text string) ([]float32, error) + EmbedSlice(lines []string) ([][]float32, error) +} + +// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.) +type APIEmbedder struct { + logger *slog.Logger + client *http.Client + cfg *config.Config +} + +func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { + return &APIEmbedder{ + logger: l, + client: &http.Client{}, + cfg: cfg, + } +} + +func (a *APIEmbedder) Embed(text string) ([]float32, error) { + payload, err := json.Marshal( + map[string]any{"input": text, "encoding_format": "float"}, + ) + if err != nil { + a.logger.Error("failed to marshal payload", "err", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + a.logger.Error("failed to create new req", "err", err.Error()) + return nil, err + } + if a.cfg.HFToken != "" { + req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) + } + resp, err := a.client.Do(req) + if err != nil { + a.logger.Error("failed to embed text", "err", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) + a.logger.Error(err.Error()) + return nil, err + } + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + a.logger.Error("failed to decode embedding response", "err", err.Error()) + return nil, err + } + if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 { + err = errors.New("empty embedding response") + a.logger.Error("empty embedding response") + return nil, err + } + return embResp.Data[0].Embedding, nil +} + +func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { + payload, err := json.Marshal( + map[string]any{"input": lines, "encoding_format": "float"}, + ) + if err != nil { + a.logger.Error("failed to marshal payload", "err", err.Error()) + return nil, err + } + req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + a.logger.Error("failed to create new req", "err", err.Error()) + return nil, err + } + if a.cfg.HFToken != "" { + req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) + } + resp, err := a.client.Do(req) + if err != nil { + a.logger.Error("failed to embed text", "err", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) + a.logger.Error(err.Error()) + return nil, err + } + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + a.logger.Error("failed to decode embedding response", "err", err.Error()) + return nil, err + } + if len(embResp.Data) == 0 { + err = errors.New("empty embedding response") + a.logger.Error("empty embedding response") + return nil, err + } + + // Collect all embeddings from the response + embeddings := make([][]float32, len(embResp.Data)) + for i := range embResp.Data { + if len(embResp.Data[i].Embedding) == 0 { + err = fmt.Errorf("empty embedding at index %d", i) + a.logger.Error("empty embedding", "index", i) + return nil, err + } + embeddings[i] = embResp.Data[i].Embedding + } + + // Sort embeddings by index to match the order of input lines + // API responses may not be in order + for _, data := range embResp.Data { + if data.Index >= len(embeddings) || data.Index < 0 { + err = fmt.Errorf("invalid embedding index %d", data.Index) + a.logger.Error("invalid embedding index", "index", data.Index) + return nil, err + } + embeddings[data.Index] = data.Embedding + } + + return embeddings, nil +} + +// TODO: ONNXEmbedder implementation would go here +// This would require: +// 1. Loading ONNX models locally +// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) +// 3. Converting text to embeddings without external API calls +// +// For now, we'll focus on the API implementation which is already working in the current system, +// and can be extended later when we have ONNX runtime integration diff --git a/rag/main.go b/rag/main.go deleted file mode 100644 index 5f2aa00..0000000 --- a/rag/main.go +++ /dev/null @@ -1,271 +0,0 @@ -package rag - -import ( - "bytes" - "elefant/config" - "elefant/models" - "elefant/storage" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "os" - "path" - "strings" - "sync" - - "github.com/neurosnap/sentences/english" -) - -var ( - LongJobStatusCh = make(chan string, 1) - // messages - FinishedRAGStatus = "finished loading RAG file; press Enter" - LoadedFileRAGStatus = "loaded file" - ErrRAGStatus = "some error occured; failed to transfer data to vector db" -) - -type RAG struct { - logger *slog.Logger - store storage.FullRepo - cfg *config.Config -} - -func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { - return &RAG{ - logger: l, - store: s, - cfg: cfg, - } -} - -func wordCounter(sentence string) int { - return len(strings.Split(sentence, " ")) -} - -func (r *RAG) LoadRAG(fpath string) error { - data, err := os.ReadFile(fpath) - if err != nil { - return err - } - r.logger.Debug("rag: loaded file", "fp", fpath) - LongJobStatusCh <- LoadedFileRAGStatus - fileText := string(data) - tokenizer, err := english.NewSentenceTokenizer(nil) - if err != nil { - return err - } - sentences := tokenizer.Tokenize(fileText) - sents := make([]string, len(sentences)) - for i, s := range sentences { - sents[i] = s.Text - } - // TODO: maybe better to decide batch size based on sentences len - var ( - // TODO: to config - workers = 5 - batchSize = 100 - maxChSize = 1000 - // - wordLimit = 80 - left = 0 - right = batchSize - batchCh = make(chan map[int][]string, maxChSize) - vectorCh = make(chan []models.VectorRow, maxChSize) - errCh = make(chan error, 1) - doneCh = make(chan bool, 1) - lock = new(sync.Mutex) - ) - defer close(doneCh) - defer close(errCh) - defer close(batchCh) - // group sentences - paragraphs := []string{} - par := strings.Builder{} - for i := 0; i < len(sents); i++ { - par.WriteString(sents[i]) - if wordCounter(par.String()) > wordLimit { - paragraphs = append(paragraphs, par.String()) - par.Reset() - } - } - if len(paragraphs) < batchSize { - batchSize = len(paragraphs) - } - // fill input channel - ctn := 0 - for { - if right > len(paragraphs) { - batchCh <- map[int][]string{left: paragraphs[left:]} - break - } - batchCh <- map[int][]string{left: paragraphs[left:right]} - left, right = right, right+batchSize - ctn++ - } - finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", len(batchCh), len(paragraphs), len(sents)) - r.logger.Debug(finishedBatchesMsg) - LongJobStatusCh <- finishedBatchesMsg - for w := 0; w < workers; w++ { - go r.batchToVectorHFAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) - } - // wait for emb to be done - <-doneCh - // write to db - return r.writeVectors(vectorCh) -} - -func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { - for { - for batch := range vectorCh { - for _, vector := range batch { - if err := r.store.WriteVector(&vector); err != nil { - r.logger.Error("failed to write vector", "error", err, "slug", vector.Slug) - LongJobStatusCh <- ErrRAGStatus - continue // a duplicate is not critical - // return err - } - } - r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) - if len(vectorCh) == 0 { - r.logger.Debug("finished writing vectors") - LongJobStatusCh <- FinishedRAGStatus - defer close(vectorCh) - return nil - } - } - } -} - -func (r *RAG) batchToVectorHFAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, - vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { - for { - lock.Lock() - if len(inputCh) == 0 { - if len(doneCh) == 0 { - doneCh <- true - } - lock.Unlock() - return - } - select { - case linesMap := <-inputCh: - for leftI, v := range linesMap { - r.fecthEmbHF(v, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename) - } - lock.Unlock() - case err := <-errCh: - r.logger.Error("got an error", "error", err) - lock.Unlock() - return - } - r.logger.Debug("to vector batches", "batches#", len(inputCh), "worker#", id) - LongJobStatusCh <- fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) - } -} - -func (r *RAG) fecthEmbHF(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) { - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - r.logger.Error("failed to marshal payload", "err:", err.Error()) - errCh <- err - return - } - // nolint - req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) - if err != nil { - r.logger.Error("failed to create new req", "err:", err.Error()) - errCh <- err - return - } - req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - r.logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - r.logger.Error("non 200 resp", "code", resp.StatusCode) - return - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - r.logger.Error("failed to embedd line", "err:", err.Error()) - errCh <- err - return - } - if len(emb) == 0 { - r.logger.Error("empty emb") - err = errors.New("empty emb") - errCh <- err - return - } - vectors := make([]models.VectorRow, len(emb)) - for i, e := range emb { - vector := models.VectorRow{ - Embeddings: e, - RawText: lines[i], - Slug: fmt.Sprintf("%s_%d", slug, i), - FileName: filename, - } - vectors[i] = vector - } - vectorCh <- vectors -} - -func (r *RAG) LineToVector(line string) ([]float32, error) { - lines := []string{line} - payload, err := json.Marshal( - map[string]any{"inputs": lines, "options": map[string]bool{"wait_for_model": true}}, - ) - if err != nil { - r.logger.Error("failed to marshal payload", "err:", err.Error()) - return nil, err - } - // nolint - req, err := http.NewRequest("POST", r.cfg.EmbedURL, bytes.NewReader(payload)) - if err != nil { - r.logger.Error("failed to create new req", "err:", err.Error()) - return nil, err - } - req.Header.Add("Authorization", "Bearer "+r.cfg.HFToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - r.logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - err = fmt.Errorf("non 200 resp; code: %v\n", resp.StatusCode) - r.logger.Error(err.Error()) - return nil, err - } - emb := [][]float32{} - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { - r.logger.Error("failed to embedd line", "err:", err.Error()) - return nil, err - } - if len(emb) == 0 || len(emb[0]) == 0 { - r.logger.Error("empty emb") - err = errors.New("empty emb") - return nil, err - } - return emb[0], nil -} - -func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { - return r.store.SearchClosest(emb.Embedding) -} - -func (r *RAG) ListLoaded() ([]string, error) { - return r.store.ListFiles() -} - -func (r *RAG) RemoveFile(filename string) error { - return r.store.RemoveEmbByFileName(filename) -} diff --git a/rag/rag.go b/rag/rag.go new file mode 100644 index 0000000..b29b9eb --- /dev/null +++ b/rag/rag.go @@ -0,0 +1,334 @@ +package rag + +import ( + "errors" + "fmt" + "gf-lt/config" + "gf-lt/models" + "gf-lt/storage" + "log/slog" + "os" + "path" + "strings" + "sync" + + "github.com/neurosnap/sentences/english" +) + +var ( + // Status messages for TUI integration + LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking + FinishedRAGStatus = "finished loading RAG file; press Enter" + LoadedFileRAGStatus = "loaded file" + ErrRAGStatus = "some error occurred; failed to transfer data to vector db" +) + + +type RAG struct { + logger *slog.Logger + store storage.FullRepo + cfg *config.Config + embedder Embedder + storage *VectorStorage +} + +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { + // Initialize with API embedder by default, could be configurable later + embedder := NewAPIEmbedder(l, cfg) + + rag := &RAG{ + logger: l, + store: s, + cfg: cfg, + embedder: embedder, + storage: NewVectorStorage(l, s), + } + + // Note: Vector tables are created via database migrations, not at runtime + + return rag +} + +func wordCounter(sentence string) int { + return len(strings.Split(strings.TrimSpace(sentence), " ")) +} + +func (r *RAG) LoadRAG(fpath string) error { + data, err := os.ReadFile(fpath) + if err != nil { + return err + } + r.logger.Debug("rag: loaded file", "fp", fpath) + select { + case LongJobStatusCh <- LoadedFileRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + + fileText := string(data) + tokenizer, err := english.NewSentenceTokenizer(nil) + if err != nil { + return err + } + sentences := tokenizer.Tokenize(fileText) + sents := make([]string, len(sentences)) + for i, s := range sentences { + sents[i] = s.Text + } + + // Group sentences into paragraphs based on word limit + paragraphs := []string{} + par := strings.Builder{} + for i := 0; i < len(sents); i++ { + // Only add sentences that aren't empty + if strings.TrimSpace(sents[i]) != "" { + if par.Len() > 0 { + par.WriteString(" ") // Add space between sentences + } + par.WriteString(sents[i]) + } + + if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + par.Reset() + } + } + + // Handle any remaining content in the paragraph buffer + if par.Len() > 0 { + paragraph := strings.TrimSpace(par.String()) + if paragraph != "" { + paragraphs = append(paragraphs, paragraph) + } + } + + // Adjust batch size if needed + if len(paragraphs) < int(r.cfg.RAGBatchSize) && len(paragraphs) > 0 { + r.cfg.RAGBatchSize = len(paragraphs) + } + + if len(paragraphs) == 0 { + return errors.New("no valid paragraphs found in file") + } + + var ( + maxChSize = 100 + left = 0 + right = r.cfg.RAGBatchSize + batchCh = make(chan map[int][]string, maxChSize) + vectorCh = make(chan []models.VectorRow, maxChSize) + errCh = make(chan error, 1) + wg = new(sync.WaitGroup) + lock = new(sync.Mutex) + ) + + defer close(errCh) + defer close(batchCh) + + // Fill input channel with batches + ctn := 0 + totalParagraphs := len(paragraphs) + for { + if int(right) > totalParagraphs { + batchCh <- map[int][]string{left: paragraphs[left:]} + break + } + batchCh <- map[int][]string{left: paragraphs[left:right]} + left, right = right, right+r.cfg.RAGBatchSize + ctn++ + } + + finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents)) + r.logger.Debug(finishedBatchesMsg) + select { + case LongJobStatusCh <- finishedBatchesMsg: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg) + // Channel is full or closed, ignore the message to prevent panic + } + + // Start worker goroutines with WaitGroup + wg.Add(int(r.cfg.RAGWorkers)) + for w := 0; w < int(r.cfg.RAGWorkers); w++ { + go func(workerID int) { + defer wg.Done() + r.batchToVectorAsync(lock, workerID, batchCh, vectorCh, errCh, path.Base(fpath)) + }(w) + } + + // Use a goroutine to close the batchCh when all batches are sent + go func() { + wg.Wait() + close(vectorCh) // Close vectorCh when all workers are done + }() + + // Check for errors from workers + // Use a non-blocking check for errors + select { + case err := <-errCh: + if err != nil { + r.logger.Error("error during RAG processing", "error", err) + return err + } + default: + // No immediate error, continue + } + + // Write vectors to storage - this will block until vectorCh is closed + return r.writeVectors(vectorCh) +} + +func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { + for { + for batch := range vectorCh { + for _, vector := range batch { + if err := r.storage.WriteVector(&vector); err != nil { + r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) + select { + case LongJobStatusCh <- ErrRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + return err // Stop the entire RAG operation on DB error + } + } + r.logger.Debug("wrote batch to db", "size", len(batch), "vector_chan_len", len(vectorCh)) + if len(vectorCh) == 0 { + r.logger.Debug("finished writing vectors") + select { + case LongJobStatusCh <- FinishedRAGStatus: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) + // Channel is full or closed, ignore the message to prevent panic + } + return nil + } + } + } +} + +func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, + vectorCh chan<- []models.VectorRow, errCh chan error, filename string) { + var err error + + defer func() { + // For errCh, make sure we only send if there's actually an error and the channel can accept it + if err != nil { + select { + case errCh <- err: + default: + // errCh might be full or closed, log but don't panic + r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err) + } + } + }() + + for { + lock.Lock() + if len(inputCh) == 0 { + lock.Unlock() + return + } + + select { + case linesMap := <-inputCh: + for leftI, lines := range linesMap { + if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil { + r.logger.Error("error fetching embeddings", "error", err, "worker", id) + lock.Unlock() + return + } + } + lock.Unlock() + case err = <-errCh: + r.logger.Error("got an error from error channel", "error", err) + lock.Unlock() + return + default: + lock.Unlock() + } + + r.logger.Debug("processed batch", "batches#", len(inputCh), "worker#", id) + statusMsg := fmt.Sprintf("converted to vector; batches: %d, worker#: %d", len(inputCh), id) + select { + case LongJobStatusCh <- statusMsg: + default: + r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg) + // Channel is full or closed, ignore the message to prevent panic + } + } +} + +func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { + // Filter out empty lines before sending to embedder + nonEmptyLines := make([]string, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" { + nonEmptyLines = append(nonEmptyLines, trimmed) + } + } + + // Skip if no non-empty lines + if len(nonEmptyLines) == 0 { + // Send empty result but don't error + vectorCh <- []models.VectorRow{} + return nil + } + + embeddings, err := r.embedder.EmbedSlice(nonEmptyLines) + if err != nil { + r.logger.Error("failed to embed lines", "err", err.Error()) + errCh <- err + return err + } + + if len(embeddings) == 0 { + err := errors.New("no embeddings returned") + r.logger.Error("empty embeddings") + errCh <- err + return err + } + + if len(embeddings) != len(nonEmptyLines) { + err := errors.New("mismatch between number of lines and embeddings returned") + r.logger.Error("embedding mismatch", "err", err.Error()) + errCh <- err + return err + } + + // Create a VectorRow for each line in the batch + vectors := make([]models.VectorRow, len(nonEmptyLines)) + for i, line := range nonEmptyLines { + vectors[i] = models.VectorRow{ + Embeddings: embeddings[i], + RawText: line, + Slug: fmt.Sprintf("%s_%d", slug, i), + FileName: filename, + } + } + + vectorCh <- vectors + return nil +} + +func (r *RAG) LineToVector(line string) ([]float32, error) { + return r.embedder.Embed(line) +} + +func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { + return r.storage.SearchClosest(emb.Embedding) +} + +func (r *RAG) ListLoaded() ([]string, error) { + return r.storage.ListFiles() +} + +func (r *RAG) RemoveFile(filename string) error { + return r.storage.RemoveEmbByFileName(filename) +} diff --git a/rag/storage.go b/rag/storage.go new file mode 100644 index 0000000..782c504 --- /dev/null +++ b/rag/storage.go @@ -0,0 +1,278 @@ +package rag + +import ( + "encoding/binary" + "fmt" + "gf-lt/models" + "gf-lt/storage" + "log/slog" + "sort" + "strings" + "unsafe" + + "github.com/jmoiron/sqlx" +) + +// VectorStorage handles storing and retrieving vectors from SQLite +type VectorStorage struct { + logger *slog.Logger + sqlxDB *sqlx.DB + store storage.FullRepo +} + +func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorage { + return &VectorStorage{ + logger: logger, + sqlxDB: store.DB(), // Use the new DB() method + store: store, + } +} + + +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} + +// WriteVector stores an embedding vector in the database +func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { + tableName, err := vs.getTableName(row.Embeddings) + if err != nil { + return err + } + + // Serialize the embeddings to binary + serializedEmbeddings := SerializeVector(row.Embeddings) + + query := fmt.Sprintf( + "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", + tableName, + ) + + if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { + vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) + return err + } + + return nil +} + +// getTableName determines which table to use based on embedding size +func (vs *VectorStorage) getTableName(emb []float32) (string, error) { + size := len(emb) + + // Check if we support this embedding size + supportedSizes := map[int]bool{ + 384: true, + 768: true, + 1024: true, + 1536: true, + 2048: true, + 3072: true, + 4096: true, + 5120: true, + } + + if supportedSizes[size] { + return fmt.Sprintf("embeddings_%d", size), nil + } + + return "", fmt.Errorf("no table for embedding size of %d", size) +} + +// SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation +func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) { + tableName, err := vs.getTableName(query) + if err != nil { + return nil, err + } + + // For better performance, instead of loading all vectors at once, + // we'll implement batching and potentially add L2 distance-based pre-filtering + // since cosine similarity is related to L2 distance for normalized vectors + + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := vs.sqlxDB.Query(querySQL) + if err != nil { + return nil, err + } + defer rows.Close() + + // Use a min-heap or simple slice to keep track of top 3 closest vectors + type SearchResult struct { + vector models.VectorRow + distance float32 + } + + var topResults []SearchResult + + // Process vectors one by one to avoid loading everything into memory + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + vs.logger.Error("failed to scan row", "error", err) + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(query, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top 3 + topResults = append(topResults, result) + + // Sort and keep only top 3 + sort.Slice(topResults, func(i, j int) bool { + return topResults[i].distance < topResults[j].distance + }) + + if len(topResults) > 3 { + topResults = topResults[:3] // Keep only closest 3 + } + } + + // Convert back to VectorRow slice + results := make([]models.VectorRow, 0, len(topResults)) + for _, result := range topResults { + result.vector.Distance = result.distance + results = append(results, result.vector) + } + + return results, nil +} + +// ListFiles returns a list of all loaded files +func (vs *VectorStorage) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query all supported tables and combine results + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := "SELECT DISTINCT filename FROM " + table + rows, err := vs.sqlxDB.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + + return allFiles, nil +} + +// RemoveEmbByFileName removes all embeddings associated with a specific filename +func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { + var errors []string + + embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} + for _, size := range embeddingSizes { + table := fmt.Sprintf("embeddings_%d", size) + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := vs.sqlxDB.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; ")) + } + + return nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + @@ -1,9 +1,9 @@ package main import ( - "elefant/config" "encoding/json" "fmt" + "gf-lt/config" "net/http" "time" ) @@ -61,7 +61,7 @@ out: } func modelHandler(w http.ResponseWriter, req *http.Request) { - llmModel := fetchModelName() + llmModel := fetchLCPModelName() payload, err := json.Marshal(llmModel) if err != nil { logger.Error("model handler", "error", err) @@ -1,12 +1,14 @@ package main import ( - "elefant/models" "encoding/json" "errors" "fmt" + "gf-lt/models" "os" "os/exec" + "path" + "path/filepath" "strings" "time" ) @@ -31,7 +33,33 @@ func exportChat() error { if err != nil { return err } - return os.WriteFile(activeChatName+".json", data, 0666) + // Ensure the export directory exists + if err := os.MkdirAll(exportDir, 0755); err != nil { + return fmt.Errorf("failed to create export directory %s: %w", exportDir, err) + } + fp := path.Join(exportDir, activeChatName+".json") + return os.WriteFile(fp, data, 0666) +} + +func importChat(filename string) error { + data, err := os.ReadFile(filename) + if err != nil { + return err + } + messages := []models.RoleMsg{} + if err := json.Unmarshal(data, &messages); err != nil { + return err + } + activeChatName = filepath.Base(filename) + if _, ok := chatMap[activeChatName]; !ok { + addNewChat(activeChatName) + } + chatBody.Messages = messages + cfg.AssistantRole = messages[1].Role + if cfg.AssistantRole == cfg.UserRole { + cfg.AssistantRole = messages[2].Role + } + return nil } func updateStorageChat(name string, msgs []models.RoleMsg) error { diff --git a/storage/memory.go b/storage/memory.go index c9fc853..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) diff --git a/storage/migrate.go b/storage/migrate.go index b05dddc..decfe9c 100644 --- a/storage/migrate.go +++ b/storage/migrate.go @@ -5,8 +5,6 @@ import ( "fmt" "io/fs" "strings" - - _ "github.com/asg017/sqlite-vec-go-bindings/ncruces" ) //go:embed migrations/* @@ -53,8 +51,8 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err } func (p *ProviderSQL) executeSQL(sqlContent []byte) error { - // Connect to the database (example using a simple connection) - err := p.s3Conn.Exec(string(sqlContent)) + // Execute the migration content using standard database connection + _, err := p.db.Exec(string(sqlContent)) if err != nil { return fmt.Errorf("failed to execute SQL: %w", err) } diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql new file mode 100644 index 0000000..a257b11 --- /dev/null +++ b/storage/migrations/002_add_vector.down.sql @@ -0,0 +1,34 @@ +-- Drop vector storage tables +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_768_filename; +DROP INDEX IF EXISTS idx_embeddings_1024_filename; +DROP INDEX IF EXISTS idx_embeddings_1536_filename; +DROP INDEX IF EXISTS idx_embeddings_2048_filename; +DROP INDEX IF EXISTS idx_embeddings_3072_filename; +DROP INDEX IF EXISTS idx_embeddings_4096_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_768_slug; +DROP INDEX IF EXISTS idx_embeddings_1024_slug; +DROP INDEX IF EXISTS idx_embeddings_1536_slug; +DROP INDEX IF EXISTS idx_embeddings_2048_slug; +DROP INDEX IF EXISTS idx_embeddings_3072_slug; +DROP INDEX IF EXISTS idx_embeddings_4096_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_768_created_at; +DROP INDEX IF EXISTS idx_embeddings_1024_created_at; +DROP INDEX IF EXISTS idx_embeddings_1536_created_at; +DROP INDEX IF EXISTS idx_embeddings_2048_created_at; +DROP INDEX IF EXISTS idx_embeddings_3072_created_at; +DROP INDEX IF EXISTS idx_embeddings_4096_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_768; +DROP TABLE IF EXISTS embeddings_1024; +DROP TABLE IF EXISTS embeddings_1536; +DROP TABLE IF EXISTS embeddings_2048; +DROP TABLE IF EXISTS embeddings_3072; +DROP TABLE IF EXISTS embeddings_4096; +DROP TABLE IF EXISTS embeddings_5120;
\ No newline at end of file diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql index 2ac4621..baf703d 100644 --- a/storage/migrations/002_add_vector.up.sql +++ b/storage/migrations/002_add_vector.up.sql @@ -1,12 +1,98 @@ ---CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_5120 USING vec0( --- embedding FLOAT[5120], --- slug TEXT NOT NULL, --- raw_text TEXT PRIMARY KEY, ---); +-- Create tables for vector storage (replacing vec0 plugin usage) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_768 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1024 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1536 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); -CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( - embedding FLOAT[384], +CREATE TABLE IF NOT EXISTS embeddings_2048 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, slug TEXT NOT NULL, - raw_text TEXT PRIMARY KEY, - filename TEXT NOT NULL DEFAULT '' + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); + +CREATE TABLE IF NOT EXISTS embeddings_3072 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_4096 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Indexes for better performance +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_filename ON embeddings_768(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_slug ON embeddings_768(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_created_at ON embeddings_768(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/storage.go b/storage/storage.go index f759700..a092f8d 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,12 +1,11 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" - "github.com/ncruces/go-sqlite3" ) type FullRepo interface { @@ -28,7 +27,6 @@ type ChatHistory interface { type ProviderSQL struct { db *sqlx.DB - s3Conn *sqlite3.Conn logger *slog.Logger } @@ -97,7 +95,7 @@ func (p ProviderSQL) ChatGetMaxID() (uint32, error) { return id, err } -// opens two connections +// opens database connection func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { @@ -105,11 +103,12 @@ func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { return nil } p := ProviderSQL{db: db, logger: logger} - p.s3Conn, err = sqlite3.Open(dbPath) - if err != nil { - logger.Error("failed to open vecdb connection", "error", err) - return nil - } + p.Migrate() return p } + +// DB returns the underlying database connection +func (p ProviderSQL) DB() *sqlx.DB { + return p.db +} diff --git a/storage/storage_test.go b/storage/storage_test.go index ff3b5e6..a4f2bdd 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,18 +1,15 @@ package storage import ( - "elefant/models" "fmt" - "log" + "gf-lt/models" "log/slog" "os" "testing" "time" - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" - "github.com/ncruces/go-sqlite3" ) func TestMemories(t *testing.T) { @@ -176,88 +173,3 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } - -func TestVecTable(t *testing.T) { - // healthcheck - db, err := sqlite3.Open(":memory:") - if err != nil { - t.Fatal(err) - } - stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) - if err != nil { - t.Fatal(err) - } - stmt.Step() - log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) - stmt.Close() - // migration - err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") - if err != nil { - t.Fatal(err) - } - // data prep and insert - items := map[int][]float32{ - 1: {0.1, 0.1, 0.1, 0.1}, - 2: {0.2, 0.2, 0.2, 0.2}, - 3: {0.3, 0.3, 0.3, 0.3}, - 4: {0.4, 0.4, 0.4, 0.4}, - 5: {0.5, 0.5, 0.5, 0.5}, - } - q := []float32{0.28, 0.3, 0.3, 0.3} - stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") - if err != nil { - t.Fatal(err) - } - for id, values := range items { - v, err := sqlite_vec.SerializeFloat32(values) - if err != nil { - t.Fatal(err) - } - stmt.BindInt(1, id) - stmt.BindBlob(2, v) - stmt.BindText(3, "some_chat") - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - stmt.Reset() - } - stmt.Close() - // select | vec search - stmt, _, err = db.Prepare(` - SELECT - rowid, - distance, - embedding - FROM vec_items - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `) - if err != nil { - t.Fatal(err) - } - query, err := sqlite_vec.SerializeFloat32(q) - if err != nil { - t.Fatal(err) - } - stmt.BindBlob(1, query) - for stmt.Step() { - rowid := stmt.ColumnInt64(0) - distance := stmt.ColumnFloat(1) - emb := stmt.ColumnRawText(2) - floats := decodeUnsafe(emb) - log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) - } - if err := stmt.Err(); err != nil { - t.Fatal(err) - } - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - err = db.Close() - if err != nil { - t.Fatal(err) - } -} diff --git a/storage/vector.go b/storage/vector.go index 5e9069c..32b4731 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -1,12 +1,12 @@ package storage import ( - "elefant/models" - "errors" + "encoding/binary" "fmt" + "gf-lt/models" "unsafe" - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" + "github.com/jmoiron/sqlx" ) type VectorRepo interface { @@ -14,19 +14,55 @@ type VectorRepo interface { SearchClosest(q []float32) ([]models.VectorRow, error) ListFiles() ([]string, error) RemoveEmbByFileName(filename string) error + DB() *sqlx.DB } -var ( - vecTableName5120 = "embeddings_5120" - vecTableName384 = "embeddings_384" -) +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} func fetchTableName(emb []float32) (string, error) { switch len(emb) { - case 5120: - return vecTableName5120, nil case 384: - return vecTableName384, nil + return "embeddings_384", nil + case 768: + return "embeddings_768", nil + case 1024: + return "embeddings_1024", nil + case 1536: + return "embeddings_1536", nil + case 2048: + return "embeddings_2048", nil + case 3072: + return "embeddings_3072", nil + case 4096: + return "embeddings_4096", nil + case 5120: + return "embeddings_5120", nil default: return "", fmt.Errorf("no table for the size of %d", len(emb)) } @@ -37,50 +73,13 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { if err != nil { return err } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)) - if err != nil { - p.logger.Error("failed to prep a stmt", "error", err) - return err - } - defer stmt.Close() - v, err := sqlite_vec.SerializeFloat32(row.Embeddings) - if err != nil { - p.logger.Error("failed to serialize vector", - "emb-len", len(row.Embeddings), "error", err) - return err - } - if v == nil { - err = errors.New("empty vector after serialization") - p.logger.Error("empty vector after serialization", - "emb-len", len(row.Embeddings), "text", row.RawText, "error", err) - return err - } - if err := stmt.BindBlob(1, v); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(2, row.Slug); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(3, row.RawText); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(4, row.FileName); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - err = stmt.Exec() - if err != nil { - return err - } - return nil -} -func decodeUnsafe(bs []byte) []float32 { - return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) + serializedEmbeddings := SerializeVector(row.Embeddings) + + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) + + return err } func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { @@ -88,76 +87,169 @@ func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { if err != nil { return nil, err } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf(`SELECT - distance, - embedding, - slug, - raw_text, - filename - FROM %s - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `, tableName)) + + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := p.db.Query(querySQL) if err != nil { return nil, err } - query, err := sqlite_vec.SerializeFloat32(q[:]) - if err != nil { - return nil, err + defer rows.Close() + + type SearchResult struct { + vector models.VectorRow + distance float32 } - if err := stmt.BindBlob(1, query); err != nil { - p.logger.Error("failed to bind", "error", err) - return nil, err + + var topResults []SearchResult + + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top results + topResults = append(topResults, result) + + // Sort and keep only top results + // We'll keep the top 3 closest vectors + if len(topResults) > 3 { + // Simple sort and truncate to maintain only 3 best matches + for i := 0; i < len(topResults); i++ { + for j := i + 1; j < len(topResults); j++ { + if topResults[i].distance > topResults[j].distance { + topResults[i], topResults[j] = topResults[j], topResults[i] + } + } + } + topResults = topResults[:3] + } } - resp := []models.VectorRow{} - for stmt.Step() { - res := models.VectorRow{} - res.Distance = float32(stmt.ColumnFloat(0)) - emb := stmt.ColumnRawText(1) - res.Embeddings = decodeUnsafe(emb) - res.Slug = stmt.ColumnText(2) - res.RawText = stmt.ColumnText(3) - res.FileName = stmt.ColumnText(4) - resp = append(resp, res) - } - if err := stmt.Err(); err != nil { - return nil, err + + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(topResults)) + for i, result := range topResults { + result.vector.Distance = result.distance + results[i] = result.vector } - err = stmt.Close() - if err != nil { - return nil, err + + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 } - return resp, nil + + return dotProduct / (sqrt(normA) * sqrt(normB)) } -func (p ProviderSQL) ListFiles() ([]string, error) { - q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) - if err != nil { - return nil, err +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 } - defer stmt.Close() - resp := []string{} - for stmt.Step() { - resp = append(resp, stmt.ColumnText(0)) + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 } - if err := stmt.Err(); err != nil { - return nil, err + return guess +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query all supported tables and combine results + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", + } + for _, table := range tableNames { + query := "SELECT DISTINCT filename FROM " + table + rows, err := p.db.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } } - return resp, nil + + return allFiles, nil } func (p ProviderSQL) RemoveEmbByFileName(filename string) error { - q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) - if err != nil { - return err + var errors []string + + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", } - defer stmt.Close() - if err := stmt.BindText(1, filename); err != nil { - return err + for _, table := range tableNames { + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := p.db.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %v", errors) } - return stmt.Exec() + + return nil } diff --git a/sysprompts/cluedo.json b/sysprompts/cluedo.json new file mode 100644 index 0000000..0c90cb5 --- /dev/null +++ b/sysprompts/cluedo.json @@ -0,0 +1,7 @@ +{ + "sys_prompt": "A game of cluedo. Players are {{user}}, {{char}}, {{char2}};\n\nrooms: hall, lounge, dinning room kitchen, ballroom, conservatory, billiard room, library, study;\nweapons: candlestick, dagger, lead pipe, revolver, rope, spanner;\npeople: miss Scarlett, colonel Mustard, mrs. White, reverend Green, mrs. Peacock, professor Plum;\n\nA murder happened in a mansion with 9 rooms. Victim is dr. Black.\nPlayers goal is to find out who commited a murder, in what room and with what weapon.\nWeapons, people and rooms not involved in murder are distributed between players (as cards) by tool agent.\nThe objective of the game is to deduce the details of the murder. There are six characters, six murder weapons, and nine rooms, leaving the players with 324 possibilities. As soon as a player enters a room, they may make a suggestion as to the details, naming a suspect, the room they are in, and the weapon. For example: \"I suspect Professor Plum, in the Dining Room, with the candlestick\".\nOnce a player makes a suggestion, the others are called upon to disprove it.\nBefore the player's move, tool agent will remind that players their cards. There are two types of moves: making a suggestion (suggestion_move) and disproving other player suggestion (evidence_move);\nIn this version player wins when the correct details are named in the suggestion_move.\n\n<example_game>\n{{user}}:\nlet's start a game of cluedo!\ntool: cards of {{char}} are 'LEAD PIPE', 'BALLROOM', 'CONSERVATORY', 'STUDY', 'Mrs. White'; suggestion_move;\n{{char}}:\n(putting miss Scarlet into the Hall with the Revolver) \"I suspect miss Scarlett, in the Hall, with the revolver.\"\ntool: cards of {{char2}} are 'SPANNER', 'DAGGER', 'Professor Plum', 'LIBRARY', 'Mrs. Peacock'; evidence_move;\n{{char2}}:\n\"No objections.\" (no cards matching the suspicion of {{char}})\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; evidence_move;\n{{user}}:\n\"I object. Miss Scarlett is innocent.\" (shows card with 'Miss Scarlett')\ntool: cards of {{char2}} are 'SPANNER', 'DAGGER', 'Professor Plum', 'LIBRARY', 'Mrs. Peacock'; suggestion_move;\n{{char2}}:\n*So it was not Miss Scarlett, good to know.*\n(moves Mrs. White to the Billiard Room) \"It might have been Mrs. White, in the Billiard Room, with the Revolver.\"\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; evidence_move;\n{{user}}:\n(no matching cards for the assumption of {{char2}}) \"Sounds possible to me.\"\ntool: cards of {{char}} are 'LEAD PIPE', 'BALLROOM', 'CONSERVATORY', 'STUDY', 'Mrs. White'; evidence_move;\n{{char}}:\n(shows Mrs. White card) \"No. Was not Mrs. White\"\ntool: cards of {{user}} are 'Colonel Mustard', 'Miss Scarlett', 'DINNING ROOM', 'CANDLESTICK', 'HALL'; suggestion_move;\n{{user}}:\n*So not Mrs. White...* (moves Reverend Green into the Billiard Room) \"I suspect Reverend Green, in the Billiard Room, with the Revolver.\"\ntool: Correct. It was Reverend Green in the Billiard Room, with the revolver. {{user}} wins.\n</example_game>", + "role": "CluedoPlayer", + "role2": "CluedoEnjoyer", + "filepath": "sysprompts/cluedo.json", + "first_msg": "Hey guys! Want to play cluedo?" +} @@ -7,16 +7,16 @@ import ( "strings" "time" - "elefant/models" - "elefant/pngmeta" - "elefant/rag" + "gf-lt/models" + "gf-lt/pngmeta" + "gf-lt/rag" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) func makeChatTable(chatMap map[string]models.Chat) *tview.Table { - actions := []string{"load", "rename", "delete", "update card"} + actions := []string{"load", "rename", "delete", "update card", "move sysprompt onto 1st msg", "new_chat_from_card"} chatList := make([]string, len(chatMap)) i := 0 for name := range chatMap { @@ -26,20 +26,20 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table { rows, cols := len(chatMap), len(actions)+2 chatActTable := tview.NewTable(). SetBorders(true) - // for chatName, chat := range chatMap { for r := 0; r < rows; r++ { - // r := 0 for c := 0; c < cols; c++ { color := tcell.ColorWhite switch c { case 0: chatActTable.SetCell(r, c, tview.NewTableCell(chatList[r]). + SetSelectable(false). SetTextColor(color). SetAlign(tview.AlignCenter)) case 1: chatActTable.SetCell(r, c, tview.NewTableCell(chatMap[chatList[r]].Msgs[len(chatMap[chatList[r]].Msgs)-30:]). + SetSelectable(false). SetTextColor(color). SetAlign(tview.AlignCenter)) default: @@ -49,23 +49,18 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table { SetAlign(tview.AlignCenter)) } } - // r++ } - chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEsc || key == tcell.KeyF1 { + chatActTable.Select(0, 0).SetSelectable(true, true).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') { pages.RemovePage(historyPage) return } - if key == tcell.KeyEnter { - chatActTable.SetSelectable(true, true) - } }).SetSelectedFunc(func(row int, column int) { tc := chatActTable.GetCell(row, column) tc.SetTextColor(tcell.ColorRed) chatActTable.SetSelectable(false, false) selectedChat := chatList[row] defer pages.RemovePage(historyPage) - // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text) switch tc.Text { case "load": history, err := loadHistoryChat(selectedChat) @@ -114,12 +109,12 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table { } return } - if chatBody.Messages[0].Role != "system" || chatBody.Messages[1].Role != agentName { - if err := notifyUser("error", "unexpected chat structure; card: "+agentName); err != nil { - logger.Warn("failed ot notify", "error", err) - } - return - } + // if chatBody.Messages[0].Role != "system" || chatBody.Messages[1].Role != agentName { + // if err := notifyUser("error", "unexpected chat structure; card: "+agentName); err != nil { + // logger.Warn("failed ot notify", "error", err) + // } + // return + // } // change sys_prompt + first msg cc.SysPrompt = chatBody.Messages[0].Content cc.FirstMsg = chatBody.Messages[1].Content @@ -128,14 +123,60 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table { "error", err) } return + case "move sysprompt onto 1st msg": + chatBody.Messages[1].Content = chatBody.Messages[0].Content + chatBody.Messages[1].Content + chatBody.Messages[0].Content = rpDefenitionSysMsg + textView.SetText(chatToText(cfg.ShowSys)) + activeChatName = selectedChat + pages.RemovePage(historyPage) + return + case "new_chat_from_card": + // Reread card from file and start fresh chat + fi := strings.Index(selectedChat, "_") + agentName := selectedChat[fi+1:] + cc, ok := sysMap[agentName] + if !ok { + logger.Warn("no such card", "agent", agentName) + if err := notifyUser("error", "no such card: "+agentName); err != nil { + logger.Warn("failed to notify", "error", err) + } + return + } + // Reload card from disk + newCard, err := pngmeta.ReadCard(cc.FilePath, cfg.UserRole) + if err != nil { + logger.Error("failed to reload charcard", "path", cc.FilePath, "error", err) + newCard, err = pngmeta.ReadCardJson(cc.FilePath) + if err != nil { + logger.Error("failed to reload charcard", "path", cc.FilePath, "error", err) + if err := notifyUser("error", "failed to reload card: "+cc.FilePath); err != nil { + logger.Warn("failed to notify", "error", err) + } + return + } + } + // Update sysMap with fresh card data + sysMap[agentName] = newCard + applyCharCard(newCard) + startNewChat() + pages.RemovePage(historyPage) + return default: return } }) + // Add input capture to handle 'x' key for closing the table + chatActTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(historyPage) + return nil + } + return event + }) return chatActTable } -// func makeRAGTable(fileList []string) *tview.Table { +// nolint:unused func makeRAGTable(fileList []string) *tview.Flex { actions := []string{"load", "delete"} rows, cols := len(fileList), len(actions)+1 @@ -150,23 +191,46 @@ func makeRAGTable(fileList []string) *tview.Flex { ragflex := tview.NewFlex().SetDirection(tview.FlexRow). AddItem(longStatusView, 0, 10, false). AddItem(fileTable, 0, 60, true) + // Add the exit option as the first row (row 0) + fileTable.SetCell(0, 0, + tview.NewTableCell("Exit RAG manager"). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + fileTable.SetCell(0, 1, + tview.NewTableCell("(Close without action)"). + SetTextColor(tcell.ColorGray). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + fileTable.SetCell(0, 2, + tview.NewTableCell("exit"). + SetTextColor(tcell.ColorGray). + SetAlign(tview.AlignCenter)) + // Add the file rows starting from row 1 for r := 0; r < rows; r++ { for c := 0; c < cols; c++ { color := tcell.ColorWhite if c < 1 { - fileTable.SetCell(r, c, + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 tview.NewTableCell(fileList[r]). SetTextColor(color). - SetAlign(tview.AlignCenter)) - } else { - fileTable.SetCell(r, c, + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else if c == 1 { // Action description column - not selectable + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 + tview.NewTableCell("(Action)"). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else { // Action button column - selectable + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 tview.NewTableCell(actions[c-1]). SetTextColor(color). SetAlign(tview.AlignCenter)) } } } - errCh := make(chan error, 1) + errCh := make(chan error, 1) // why? go func() { defer pages.RemovePage(RAGPage) for { @@ -192,20 +256,32 @@ func makeRAGTable(fileList []string) *tview.Flex { } } }() - fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEsc || key == tcell.KeyF1 { - pages.RemovePage(RAGPage) + fileTable.Select(0, 0). + SetFixed(1, 1). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX { + pages.RemovePage(RAGPage) + return + } + }).SetSelectedFunc(func(row int, column int) { + // If user selects a non-actionable column (0 or 1), move to first action column (2) + if column <= 1 { + if fileTable.GetColumnCount() > 2 { + fileTable.Select(row, 2) // Select first action column + } return } - if key == tcell.KeyEnter { - fileTable.SetSelectable(true, true) - } - }).SetSelectedFunc(func(row int, column int) { // defer pages.RemovePage(RAGPage) tc := fileTable.GetCell(row, column) - tc.SetTextColor(tcell.ColorRed) - fileTable.SetSelectable(false, false) - fpath := fileList[row] + // Check if the selected row is the exit row (row 0) - do this first to avoid index issues + if row == 0 { + pages.RemovePage(RAGPage) + return + } + // For file rows, get the filename (row index - 1 because of the exit row at index 0) + fpath := fileList[row-1] // -1 to account for the exit row at index 0 // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text) switch tc.Text { case "load": @@ -214,6 +290,7 @@ func makeRAGTable(fileList []string) *tview.Flex { go func() { if err := ragger.LoadRAG(fpath); err != nil { logger.Error("failed to embed file", "chat", fpath, "error", err) + _ = notifyUser("RAG", "failed to embed file; error: "+err.Error()) errCh <- err // pages.RemovePage(RAGPage) return @@ -231,68 +308,132 @@ func makeRAGTable(fileList []string) *tview.Flex { } return default: + pages.RemovePage(RAGPage) return } }) + // Add input capture to the flex container to handle 'x' key for closing + ragflex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(RAGPage) + return nil + } + return event + }) return ragflex } -func makeLoadedRAGTable(fileList []string) *tview.Table { +func makeLoadedRAGTable(fileList []string) *tview.Flex { actions := []string{"delete"} rows, cols := len(fileList), len(actions)+1 + // Add 1 extra row for the "exit" option at the top fileTable := tview.NewTable(). SetBorders(true) + longStatusView := tview.NewTextView() + longStatusView.SetText("Loaded RAG files list") + longStatusView.SetBorder(true).SetTitle("status") + longStatusView.SetChangedFunc(func() { + app.Draw() + }) + ragflex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(longStatusView, 0, 10, false). + AddItem(fileTable, 0, 60, true) + // Add the exit option as the first row (row 0) + fileTable.SetCell(0, 0, + tview.NewTableCell("Exit Loaded Files manager"). + SetTextColor(tcell.ColorWhite). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + fileTable.SetCell(0, 1, + tview.NewTableCell("(Close without action)"). + SetTextColor(tcell.ColorGray). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + fileTable.SetCell(0, 2, + tview.NewTableCell("exit"). + SetTextColor(tcell.ColorGray). + SetAlign(tview.AlignCenter)) + // Add the file rows starting from row 1 for r := 0; r < rows; r++ { for c := 0; c < cols; c++ { color := tcell.ColorWhite if c < 1 { - fileTable.SetCell(r, c, + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 tview.NewTableCell(fileList[r]). SetTextColor(color). - SetAlign(tview.AlignCenter)) - } else { - fileTable.SetCell(r, c, + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else if c == 1 { // Action description column - not selectable + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 + tview.NewTableCell("(Action)"). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else { // Action button column - selectable + fileTable.SetCell(r+1, c, // +1 to account for the exit row at index 0 tview.NewTableCell(actions[c-1]). SetTextColor(color). SetAlign(tview.AlignCenter)) } } } - fileTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEsc || key == tcell.KeyF1 { - pages.RemovePage(RAGPage) + fileTable.Select(0, 0). + SetFixed(1, 1). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') || key == tcell.KeyCtrlX { + pages.RemovePage(RAGLoadedPage) + return + } + }).SetSelectedFunc(func(row int, column int) { + // If user selects a non-actionable column (0 or 1), move to first action column (2) + if column <= 1 { + if fileTable.GetColumnCount() > 2 { + fileTable.Select(row, 2) // Select first action column + } return } - if key == tcell.KeyEnter { - fileTable.SetSelectable(true, true) - } - }).SetSelectedFunc(func(row int, column int) { - defer pages.RemovePage(RAGPage) + tc := fileTable.GetCell(row, column) - tc.SetTextColor(tcell.ColorRed) - fileTable.SetSelectable(false, false) - fpath := fileList[row] - // notification := fmt.Sprintf("chat: %s; action: %s", fpath, tc.Text) + + // Check if the selected row is the exit row (row 0) - do this first to avoid index issues + if row == 0 { + pages.RemovePage(RAGLoadedPage) + return + } + // For file rows, get the filename (row index - 1 because of the exit row at index 0) + fpath := fileList[row-1] // -1 to account for the exit row at index 0 switch tc.Text { case "delete": if err := ragger.RemoveFile(fpath); err != nil { - logger.Error("failed to delete file", "filename", fpath, "error", err) + logger.Error("failed to delete file from RAG", "filename", fpath, "error", err) + longStatusView.SetText(fmt.Sprintf("Error deleting file: %v", err)) return } - if err := notifyUser("chat deleted", fpath+" was deleted"); err != nil { + if err := notifyUser("RAG file deleted", fpath+" was deleted from RAG system"); err != nil { logger.Error("failed to send notification", "error", err) } + longStatusView.SetText(fpath + " was deleted from RAG system") return default: - // pages.RemovePage(RAGPage) + pages.RemovePage(RAGLoadedPage) return } }) - return fileTable + // Add input capture to the flex container to handle 'x' key for closing + ragflex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(RAGLoadedPage) + return nil + } + return event + }) + return ragflex } func makeAgentTable(agentList []string) *tview.Table { - actions := []string{"load"} + actions := []string{"filepath", "load"} rows, cols := len(agentList), len(actions)+1 chatActTable := tview.NewTable(). SetBorders(true) @@ -303,6 +444,24 @@ func makeAgentTable(agentList []string) *tview.Table { chatActTable.SetCell(r, c, tview.NewTableCell(agentList[r]). SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else if c == 1 { + if actions[c-1] == "filepath" { + cc, ok := sysMap[agentList[r]] + if !ok { + continue + } + chatActTable.SetCell(r, c, + tview.NewTableCell(cc.FilePath). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + continue + } + chatActTable.SetCell(r, c, + tview.NewTableCell(actions[c-1]). + SetTextColor(color). SetAlign(tview.AlignCenter)) } else { chatActTable.SetCell(r, c, @@ -312,18 +471,25 @@ func makeAgentTable(agentList []string) *tview.Table { } } } - chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEsc || key == tcell.KeyF1 { - pages.RemovePage(agentPage) + chatActTable.Select(0, 0). + SetFixed(1, 1). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') { + pages.RemovePage(agentPage) + return + } + }).SetSelectedFunc(func(row int, column int) { + // If user selects a non-actionable column (0 or 1), move to first action column (2) + if column <= 1 { + if chatActTable.GetColumnCount() > 2 { + chatActTable.Select(row, 2) // Select first action column + } return } - if key == tcell.KeyEnter { - chatActTable.SetSelectable(true, true) - } - }).SetSelectedFunc(func(row int, column int) { + tc := chatActTable.GetCell(row, column) - tc.SetTextColor(tcell.ColorRed) - chatActTable.SetSelectable(false, false) selected := agentList[row] // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text) switch tc.Text { @@ -365,6 +531,14 @@ func makeAgentTable(agentList []string) *tview.Table { return } }) + // Add input capture to handle 'x' key for closing the table + chatActTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(agentPage) + return nil + } + return event + }) return chatActTable } @@ -384,7 +558,8 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table { table.SetCell(r, c, tview.NewTableCell(codeBlocks[r][:previewLen]). SetTextColor(color). - SetAlign(tview.AlignCenter)) + SetAlign(tview.AlignCenter). + SetSelectable(false)) } else { table.SetCell(r, c, tview.NewTableCell(actions[c-1]). @@ -393,18 +568,25 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table { } } } - table.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { - if key == tcell.KeyEsc || key == tcell.KeyF1 { - pages.RemovePage(agentPage) + table.Select(0, 0). + SetFixed(1, 1). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') { + pages.RemovePage(codeBlockPage) + return + } + }).SetSelectedFunc(func(row int, column int) { + // If user selects a non-actionable column (0), move to first action column (1) + if column == 0 { + if table.GetColumnCount() > 1 { + table.Select(row, 1) // Select first action column + } return } - if key == tcell.KeyEnter { - table.SetSelectable(true, true) - } - }).SetSelectedFunc(func(row int, column int) { + tc := table.GetCell(row, column) - tc.SetTextColor(tcell.ColorRed) - table.SetSelectable(false, false) selected := codeBlocks[row] // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text) switch tc.Text { @@ -425,5 +607,333 @@ func makeCodeBlockTable(codeBlocks []string) *tview.Table { return } }) + // Add input capture to handle 'x' key for closing the table + table.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(codeBlockPage) + return nil + } + return event + }) return table } + +func makeImportChatTable(filenames []string) *tview.Table { + actions := []string{"load"} + rows, cols := len(filenames), len(actions)+1 + chatActTable := tview.NewTable(). + SetBorders(true) + for r := 0; r < rows; r++ { + for c := 0; c < cols; c++ { + color := tcell.ColorWhite + if c < 1 { + chatActTable.SetCell(r, c, + tview.NewTableCell(filenames[r]). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetSelectable(false)) + } else { + chatActTable.SetCell(r, c, + tview.NewTableCell(actions[c-1]). + SetTextColor(color). + SetAlign(tview.AlignCenter)) + } + } + } + chatActTable.Select(0, 0). + SetFixed(1, 1). + SetSelectable(true, false). + SetSelectedStyle(tcell.StyleDefault.Background(tcell.ColorGray).Foreground(tcell.ColorWhite)). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') { + pages.RemovePage(historyPage) + return + } + }).SetSelectedFunc(func(row int, column int) { + // If user selects a non-actionable column (0), move to first action column (1) + if column == 0 { + if chatActTable.GetColumnCount() > 1 { + chatActTable.Select(row, 1) // Select first action column + } + return + } + + tc := chatActTable.GetCell(row, column) + selected := filenames[row] + // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text) + switch tc.Text { + case "load": + if err := importChat(selected); err != nil { + logger.Warn("failed to import chat", "filename", selected) + pages.RemovePage(historyPage) + return + } + colorText() + updateStatusLine() + // redraw the text in text area + textView.SetText(chatToText(cfg.ShowSys)) + pages.RemovePage(historyPage) + app.SetFocus(textArea) + return + case "rename": + pages.RemovePage(historyPage) + pages.AddPage(renamePage, renameWindow, true, true) + return + case "delete": + sc, ok := chatMap[selected] + if !ok { + // no chat found + pages.RemovePage(historyPage) + return + } + if err := store.RemoveChat(sc.ID); err != nil { + logger.Error("failed to remove chat from db", "chat_id", sc.ID, "chat_name", sc.Name) + } + if err := notifyUser("chat deleted", selected+" was deleted"); err != nil { + logger.Error("failed to send notification", "error", err) + } + pages.RemovePage(historyPage) + return + default: + pages.RemovePage(historyPage) + return + } + }) + // Add input capture to handle 'x' key for closing the table + chatActTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(historyPage) + return nil + } + return event + }) + return chatActTable +} + +func makeFilePicker() *tview.Flex { + // Initialize with directory from config or current directory + startDir := cfg.FilePickerDir + if startDir == "" { + startDir = "." + } + // If startDir is ".", resolve it to the actual current working directory + if startDir == "." { + wd, err := os.Getwd() + if err == nil { + startDir = wd + } + } + // Track navigation history + dirStack := []string{startDir} + currentStackPos := 0 + // Track selected file + var selectedFile string + // Track currently displayed directory (changes as user navigates) + currentDisplayDir := startDir + // Helper function to check if a file has an allowed extension from config + hasAllowedExtension := func(filename string) bool { + // If no allowed extensions are specified in config, allow all files + if cfg.FilePickerExts == "" { + return true + } + // Split the allowed extensions from the config string + allowedExts := strings.Split(cfg.FilePickerExts, ",") + lowerFilename := strings.ToLower(strings.TrimSpace(filename)) + for _, ext := range allowedExts { + ext = strings.TrimSpace(ext) // Remove any whitespace around the extension + if ext != "" && strings.HasSuffix(lowerFilename, "."+ext) { + return true + } + } + return false + } + // Helper function to check if a file is an image + isImageFile := func(filename string) bool { + imageExtensions := []string{".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".svg"} + lowerFilename := strings.ToLower(filename) + for _, ext := range imageExtensions { + if strings.HasSuffix(lowerFilename, ext) { + return true + } + } + return false + } + // Create UI elements + listView := tview.NewList() + listView.SetBorder(true).SetTitle("Files & Directories").SetTitleAlign(tview.AlignLeft) + // Status view for selected file information + statusView := tview.NewTextView() + statusView.SetBorder(true).SetTitle("Selected File").SetTitleAlign(tview.AlignLeft) + statusView.SetTextColor(tcell.ColorYellow) + // Layout - only include list view and status view + flex := tview.NewFlex().SetDirection(tview.FlexRow) + flex.AddItem(listView, 0, 3, true) + flex.AddItem(statusView, 3, 0, false) + // Refresh the file list + var refreshList func(string) + refreshList = func(dir string) { + listView.Clear() + // Update the current display directory + currentDisplayDir = dir // Update the current display directory + // Add exit option at the top + listView.AddItem("Exit file picker [gray](Close without selecting)[-]", "", 'x', func() { + pages.RemovePage(filePickerPage) + }) + // Add parent directory (..) if not at root + if dir != "/" { + parentDir := path.Dir(dir) + // Special handling for edge cases - only return if we're truly at a system root + // For Unix-like systems, path.Dir("/") returns "/" which would cause parentDir == dir + if parentDir == dir && dir == "/" { + // We're at the root ("/") and trying to go up, just don't add the parent item + } else { + listView.AddItem("../ [gray](Parent Directory)[-]", "", 'p', func() { + refreshList(parentDir) + dirStack = append(dirStack, parentDir) + currentStackPos = len(dirStack) - 1 + }) + } + } + // Read directory contents + files, err := os.ReadDir(dir) + if err != nil { + statusView.SetText("Error reading directory: " + err.Error()) + return + } + // Add directories and files to the list + for _, file := range files { + name := file.Name() + // Skip hidden files and directories (those starting with a dot) + if strings.HasPrefix(name, ".") { + continue + } + if file.IsDir() { + // Capture the directory name for the closure to avoid loop variable issues + dirName := name + listView.AddItem(dirName+"/ [gray](Directory)[-]", "", 0, func() { + newDir := path.Join(dir, dirName) + refreshList(newDir) + dirStack = append(dirStack, newDir) + currentStackPos = len(dirStack) - 1 + statusView.SetText("Current: " + newDir) + }) + } else { + // Only show files that have allowed extensions (from config) + if hasAllowedExtension(name) { + // Capture the file name for the closure to avoid loop variable issues + fileName := name + fullFilePath := path.Join(dir, fileName) + listView.AddItem(fileName+" [gray](File)[-]", "", 0, func() { + selectedFile = fullFilePath + statusView.SetText("Selected: " + selectedFile) + // Check if the file is an image + if isImageFile(fileName) { + // For image files, offer to attach to the next LLM message + statusView.SetText("Selected image: " + selectedFile) + } else { + // For non-image files, display as before + statusView.SetText("Selected: " + selectedFile) + } + }) + } + } + } + statusView.SetText("Current: " + dir) + } + // Initialize the file list + refreshList(startDir) + // Set up keyboard navigation + flex.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + switch event.Key() { + case tcell.KeyEsc: + pages.RemovePage(filePickerPage) + return nil + case tcell.KeyBackspace2: // Backspace to go to parent directory + if currentStackPos > 0 { + currentStackPos-- + prevDir := dirStack[currentStackPos] + refreshList(prevDir) + // Trim the stack to current position to avoid deep history + dirStack = dirStack[:currentStackPos+1] + } + return nil + case tcell.KeyEnter: + // Get the currently highlighted item in the list + itemIndex := listView.GetCurrentItem() + if itemIndex >= 0 && itemIndex < listView.GetItemCount() { + // We need to get the text of the currently selected item to determine if it's a directory + // Since we can't directly get the item text, we'll keep track of items differently + // Let's improve the approach by tracking the currently selected item + itemText, _ := listView.GetItemText(itemIndex) + logger.Info("choosing dir", "itemText", itemText) + // Check for the exit option first (should be the first item) + if strings.HasPrefix(itemText, "Exit file picker") { + pages.RemovePage(filePickerPage) + return nil + } + // Extract the actual filename/directory name by removing the type info in brackets + // Format is "name [gray](type)[-]" + actualItemName := itemText + if bracketPos := strings.Index(itemText, " ["); bracketPos != -1 { + actualItemName = itemText[:bracketPos] + } + // Check if it's a directory (ends with /) + if strings.HasSuffix(actualItemName, "/") { + // This is a directory, we need to get the full path + // Since the item text ends with "/" and represents a directory + var targetDir string + if strings.HasPrefix(actualItemName, "../") { + // Parent directory - need to go up from current directory + targetDir = path.Dir(currentDisplayDir) + // Avoid going above root - if parent is same as current and it's system root + if targetDir == currentDisplayDir && currentDisplayDir == "/" { + // We're at root, don't navigate + logger.Warn("went to root", "dir", targetDir) + return nil + } + } else { + // Regular subdirectory + dirName := strings.TrimSuffix(actualItemName, "/") + targetDir = path.Join(currentDisplayDir, dirName) + } + // Navigate to the selected directory + logger.Info("going to the dir", "dir", targetDir) + refreshList(targetDir) + dirStack = append(dirStack, targetDir) + currentStackPos = len(dirStack) - 1 + statusView.SetText("Current: " + targetDir) + return nil + } else { + // It's a file - construct the full path from current directory and the actual item name + // We can't rely only on the selectedFile variable since Enter key might be pressed + // without having clicked the file first + filePath := path.Join(currentDisplayDir, actualItemName) + // Verify it's actually a file (not just lacking a directory suffix) + if info, err := os.Stat(filePath); err == nil && !info.IsDir() { + // Check if the file is an image + if isImageFile(actualItemName) { + // For image files, set it as an attachment for the next LLM message + // Use the version without UI updates to avoid hangs in event handlers + logger.Info("setting image", "file", actualItemName) + SetImageAttachment(filePath) + logger.Info("after setting image", "file", actualItemName) + statusView.SetText("Image attached: " + filePath + " (will be sent with next message)") + logger.Info("after setting text", "file", actualItemName) + pages.RemovePage(filePickerPage) + logger.Info("after update drawn", "file", actualItemName) + } else { + // For non-image files, update the text area with file path + textArea.SetText(filePath, true) + app.SetFocus(textArea) + pages.RemovePage(filePickerPage) + } + } + return nil + } + } + return nil + } + return event + }) + return flex +} @@ -1,19 +1,32 @@ package main import ( - "elefant/models" + "context" + "encoding/json" "fmt" + "gf-lt/extra" + "gf-lt/models" + "io" + "os" + "os/exec" "regexp" + "strconv" "strings" "time" ) var ( - toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`) - quotesRE = regexp.MustCompile(`(".*?")`) - starRE = regexp.MustCompile(`(\*.*?\*)`) - thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`) - codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`) + toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`) + quotesRE = regexp.MustCompile(`(".*?")`) + starRE = regexp.MustCompile(`(\*.*?\*)`) + thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`) + codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`) + roleRE = regexp.MustCompile(`^(\w+):`) + rpDefenitionSysMsg = ` +For this roleplay immersion is at most importance. +Every character thinks and acts based on their personality and setting of the roleplay. +Meta discussions outside of roleplay is allowed if clearly labeled as out of character, for example: (ooc: {msg}) or <ooc>{msg}</ooc>. +` basicSysMsg = `Large Language Model that helps user with any of his requests.` toolSysMsg = `You can do functions call if needed. Your current tools: @@ -21,18 +34,68 @@ Your current tools: [ { "name":"recall", -"args": "topic", +"args": ["topic"], "when_to_use": "when asked about topic that user previously asked to memorise" }, { "name":"memorise", -"args": ["topic", "info"], -"when_to_use": "when asked to memorise something" +"args": ["topic", "data"], +"when_to_use": "when asked to memorise information under a topic" }, { "name":"recall_topics", -"args": null, +"args": [], "when_to_use": "to see what topics are saved in memory" +}, +{ +"name":"websearch", +"args": ["query", "limit"], +"when_to_use": "when asked to search the web for information; limit is optional (default 3)" +}, +{ +"name":"read_url", +"args": ["url"], +"when_to_use": "when asked to get content for spicific webpage or url" +}, +{ +"name":"file_create", +"args": ["path", "content"], +"when_to_use": "when asked to create a new file with optional content" +}, +{ +"name":"file_read", +"args": ["path"], +"when_to_use": "when asked to read the content of a file" +}, +{ +"name":"file_write", +"args": ["path", "content", "mode"], +"when_to_use": "when asked to write content to a file; mode is optional (overwrite or append, default: overwrite)" +}, +{ +"name":"file_delete", +"args": ["path"], +"when_to_use": "when asked to delete a file" +}, +{ +"name":"file_move", +"args": ["src", "dst"], +"when_to_use": "when asked to move a file from source to destination" +}, +{ +"name":"file_copy", +"args": ["src", "dst"], +"when_to_use": "when asked to copy a file from source to destination" +}, +{ +"name":"file_list", +"args": ["path"], +"when_to_use": "when asked to list files in a directory; path is optional (default: current directory)" +}, +{ +"name":"execute_command", +"args": ["command", "args"], +"when_to_use": "when asked to execute a system command; args is optional; allowed commands: grep, sed, awk, find, cat, head, tail, sort, uniq, wc, ls, echo, cut, tr, cp, mv, rm, mkdir, rmdir, pwd, df, free, ps, top, du, whoami, date, uname" } ] </tools> @@ -41,7 +104,15 @@ To make a function call return a json object within __tool_call__ tags; __tool_call__ { "name":"recall", -"args": "Adam's number" +"args": {"topic": "Adam's number"} +} +__tool_call__ +</example_request> +<example_request> +__tool_call__ +{ +"name":"execute_command", +"args": {"command": "ls", "args": "-la /home"} } __tool_call__ </example_request> @@ -60,17 +131,75 @@ After that you are free to respond to the user. Role: "", FilePath: "", } - toolCard = &models.CharCard{ - SysPrompt: toolSysMsg, - FirstMsg: defaultFirstMsg, - Role: "", - FilePath: "", - } + // toolCard = &models.CharCard{ + // SysPrompt: toolSysMsg, + // FirstMsg: defaultFirstMsg, + // Role: "", + // FilePath: "", + // } // sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg} - sysMap = map[string]*models.CharCard{"basic_sys": basicCard, "tool_sys": toolCard} - sysLabels = []string{"basic_sys", "tool_sys"} + sysMap = map[string]*models.CharCard{"basic_sys": basicCard} + sysLabels = []string{"basic_sys"} ) +// web search (depends on extra server) +func websearch(args map[string]string) []byte { + // make http request return bytes + query, ok := args["query"] + if !ok || query == "" { + msg := "query not provided to web_search tool" + logger.Error(msg) + return []byte(msg) + } + limitS, ok := args["limit"] + if !ok || limitS == "" { + limitS = "3" + } + limit, err := strconv.Atoi(limitS) + if err != nil || limit == 0 { + logger.Warn("websearch limit; passed bad value; setting to default (3)", + "limit_arg", limitS, "error", err) + limit = 3 + } + resp, err := extra.WebSearcher.Search(context.Background(), query, limit) + if err != nil { + msg := "search tool failed; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + data, err := json.Marshal(resp) + if err != nil { + msg := "failed to marshal search result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return data +} + +// retrieves url content (text) +func readURL(args map[string]string) []byte { + // make http request return bytes + link, ok := args["url"] + if !ok || link == "" { + msg := "link not provided to read_url tool" + logger.Error(msg) + return []byte(msg) + } + resp, err := extra.WebSearcher.RetrieveFromLink(context.Background(), link) + if err != nil { + msg := "search tool failed; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + data, err := json.Marshal(resp) + if err != nil { + msg := "failed to marshal search result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return data +} + /* consider cases: - append mode (treat it like a journal appendix) @@ -79,7 +208,7 @@ also: - some writing can be done without consideration of previous data; - others do; */ -func memorise(args ...string) []byte { +func memorise(args map[string]string) []byte { agent := cfg.AssistantRole if len(args) < 2 { msg := "not enough args to call memorise tool; need topic and data to remember" @@ -88,35 +217,35 @@ func memorise(args ...string) []byte { } memory := &models.Memory{ Agent: agent, - Topic: args[0], - Mind: args[1], + Topic: args["topic"], + Mind: args["data"], UpdatedAt: time.Now(), } if _, err := store.Memorise(memory); err != nil { logger.Error("failed to save memory", "err", err, "memoory", memory) return []byte("failed to save info") } - msg := "info saved under the topic:" + args[0] + msg := "info saved under the topic:" + args["topic"] return []byte(msg) } -func recall(args ...string) []byte { +func recall(args map[string]string) []byte { agent := cfg.AssistantRole if len(args) < 1 { logger.Warn("not enough args to call recall tool") return nil } - mind, err := store.Recall(agent, args[0]) + mind, err := store.Recall(agent, args["topic"]) if err != nil { msg := fmt.Sprintf("failed to recall; error: %v; args: %v", err, args) logger.Error(msg) return []byte(msg) } - answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args[0], mind) + answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args["topic"], mind) return []byte(answer) } -func recallTopics(args ...string) []byte { +func recallTopics(args map[string]string) []byte { agent := cfg.AssistantRole topics, err := store.RecallTopics(agent) if err != nil { @@ -127,12 +256,947 @@ func recallTopics(args ...string) []byte { return []byte(joinedS) } -// func fullMemoryLoad() {} +// File Manipulation Tools + +func fileCreate(args map[string]string) []byte { + path, ok := args["path"] + if !ok || path == "" { + msg := "path not provided to file_create tool" + logger.Error(msg) + return []byte(msg) + } + + content, ok := args["content"] + if !ok { + content = "" + } + + if err := writeStringToFile(path, content); err != nil { + msg := "failed to create file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + msg := "file created successfully at " + path + return []byte(msg) +} + +func fileRead(args map[string]string) []byte { + path, ok := args["path"] + if !ok || path == "" { + msg := "path not provided to file_read tool" + logger.Error(msg) + return []byte(msg) + } + + content, err := readStringFromFile(path) + if err != nil { + msg := "failed to read file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + result := map[string]string{ + "content": content, + "path": path, + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult +} + +func fileWrite(args map[string]string) []byte { + path, ok := args["path"] + if !ok || path == "" { + msg := "path not provided to file_write tool" + logger.Error(msg) + return []byte(msg) + } + + content, ok := args["content"] + if !ok { + content = "" + } + + mode, ok := args["mode"] + if !ok || mode == "" { + mode = "overwrite" + } + + switch mode { + case "overwrite": + if err := writeStringToFile(path, content); err != nil { + msg := "failed to write to file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + case "append": + if err := appendStringToFile(path, content); err != nil { + msg := "failed to append to file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + default: + msg := "invalid mode; use 'overwrite' or 'append'" + logger.Error(msg) + return []byte(msg) + } + + msg := "file written successfully at " + path + return []byte(msg) +} + +func fileDelete(args map[string]string) []byte { + path, ok := args["path"] + if !ok || path == "" { + msg := "path not provided to file_delete tool" + logger.Error(msg) + return []byte(msg) + } + + if err := removeFile(path); err != nil { + msg := "failed to delete file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + msg := "file deleted successfully at " + path + return []byte(msg) +} + +func fileMove(args map[string]string) []byte { + src, ok := args["src"] + if !ok || src == "" { + msg := "source path not provided to file_move tool" + logger.Error(msg) + return []byte(msg) + } + + dst, ok := args["dst"] + if !ok || dst == "" { + msg := "destination path not provided to file_move tool" + logger.Error(msg) + return []byte(msg) + } + + if err := moveFile(src, dst); err != nil { + msg := "failed to move file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + msg := fmt.Sprintf("file moved successfully from %s to %s", src, dst) + return []byte(msg) +} + +func fileCopy(args map[string]string) []byte { + src, ok := args["src"] + if !ok || src == "" { + msg := "source path not provided to file_copy tool" + logger.Error(msg) + return []byte(msg) + } + + dst, ok := args["dst"] + if !ok || dst == "" { + msg := "destination path not provided to file_copy tool" + logger.Error(msg) + return []byte(msg) + } + + if err := copyFile(src, dst); err != nil { + msg := "failed to copy file; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + msg := fmt.Sprintf("file copied successfully from %s to %s", src, dst) + return []byte(msg) +} + +func fileList(args map[string]string) []byte { + path, ok := args["path"] + if !ok || path == "" { + path = "." // default to current directory + } + + files, err := listDirectory(path) + if err != nil { + msg := "failed to list directory; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + result := map[string]interface{}{ + "directory": path, + "files": files, + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult +} + +// Helper functions for file operations + +func readStringFromFile(filename string) (string, error) { + data, err := os.ReadFile(filename) + if err != nil { + return "", err + } + return string(data), nil +} + +func writeStringToFile(filename string, data string) error { + return os.WriteFile(filename, []byte(data), 0644) +} + +func appendStringToFile(filename string, data string) error { + file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer file.Close() + + _, err = file.WriteString(data) + return err +} + +func removeFile(filename string) error { + return os.Remove(filename) +} + +func moveFile(src, dst string) error { + // First try with os.Rename (works within same filesystem) + if err := os.Rename(src, dst); err == nil { + return nil + } + // If that fails (e.g., cross-filesystem), copy and delete + return copyAndRemove(src, dst) +} + +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.Create(dst) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +func copyAndRemove(src, dst string) error { + // Copy the file + if err := copyFile(src, dst); err != nil { + return err + } + // Remove the source file + return os.Remove(src) +} + +func listDirectory(path string) ([]string, error) { + entries, err := os.ReadDir(path) + if err != nil { + return nil, err + } + + var files []string + for _, entry := range entries { + if entry.IsDir() { + files = append(files, entry.Name()+"/") // Add "/" to indicate directory + } else { + files = append(files, entry.Name()) + } + } + + return files, nil +} + +// Command Execution Tool + +func executeCommand(args map[string]string) []byte { + command, ok := args["command"] + if !ok || command == "" { + msg := "command not provided to execute_command tool" + logger.Error(msg) + return []byte(msg) + } + + if !isCommandAllowed(command) { + msg := fmt.Sprintf("command '%s' is not allowed", command) + logger.Error(msg) + return []byte(msg) + } + + // Get arguments - handle both single arg and multiple args + var cmdArgs []string + if args["args"] != "" { + // If args is provided as a single string, split by spaces + cmdArgs = strings.Fields(args["args"]) + } else { + // If individual args are provided, collect them + argNum := 1 + for { + argKey := fmt.Sprintf("arg%d", argNum) + if argValue, exists := args[argKey]; exists && argValue != "" { + cmdArgs = append(cmdArgs, argValue) + } else { + break + } + argNum++ + } + } + + // Execute with timeout for safety + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, command, cmdArgs...) + + output, err := cmd.CombinedOutput() + if err != nil { + msg := fmt.Sprintf("command '%s' failed; error: %v; output: %s", command, err, string(output)) + logger.Error(msg) + return []byte(msg) + } + + return output +} + +// Helper functions for command execution + +// Todo structure +type TodoItem struct { + ID string `json:"id"` + Task string `json:"task"` + Status string `json:"status"` // "pending", "in_progress", "completed" +} + +type TodoList struct { + Items []TodoItem `json:"items"` +} + +// Global todo list storage +var globalTodoList = TodoList{ + Items: []TodoItem{}, +} + + +// Todo Management Tools +func todoCreate(args map[string]string) []byte { + task, ok := args["task"] + if !ok || task == "" { + msg := "task not provided to todo_create tool" + logger.Error(msg) + return []byte(msg) + } + + // Generate simple ID + id := fmt.Sprintf("todo_%d", len(globalTodoList.Items)+1) + + newItem := TodoItem{ + ID: id, + Task: task, + Status: "pending", + } + + globalTodoList.Items = append(globalTodoList.Items, newItem) + + result := map[string]string{ + "message": "todo created successfully", + "id": id, + "task": task, + "status": "pending", + } + + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult +} + +func todoRead(args map[string]string) []byte { + id, ok := args["id"] + if ok && id != "" { + // Find specific todo by ID + for _, item := range globalTodoList.Items { + if item.ID == id { + result := map[string]interface{}{ + "todo": item, + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return jsonResult + } + } + // ID not found + result := map[string]string{ + "error": fmt.Sprintf("todo with id %s not found", id), + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return jsonResult + } + + // Return all todos if no ID specified + result := map[string]interface{}{ + "todos": globalTodoList.Items, + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult +} + +func todoUpdate(args map[string]string) []byte { + id, ok := args["id"] + if !ok || id == "" { + msg := "id not provided to todo_update tool" + logger.Error(msg) + return []byte(msg) + } -type fnSig func(...string) []byte + task, taskOk := args["task"] + status, statusOk := args["status"] + + if !taskOk && !statusOk { + msg := "neither task nor status provided to todo_update tool" + logger.Error(msg) + return []byte(msg) + } + + // Find and update the todo + for i, item := range globalTodoList.Items { + if item.ID == id { + if taskOk { + globalTodoList.Items[i].Task = task + } + if statusOk { + // Validate status + if status == "pending" || status == "in_progress" || status == "completed" { + globalTodoList.Items[i].Status = status + } else { + result := map[string]string{ + "error": "status must be one of: pending, in_progress, completed", + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return jsonResult + } + } + + result := map[string]string{ + "message": "todo updated successfully", + "id": id, + } + + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult + } + } + + // ID not found + result := map[string]string{ + "error": fmt.Sprintf("todo with id %s not found", id), + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return jsonResult +} + +func todoDelete(args map[string]string) []byte { + id, ok := args["id"] + if !ok || id == "" { + msg := "id not provided to todo_delete tool" + logger.Error(msg) + return []byte(msg) + } + + // Find and remove the todo + for i, item := range globalTodoList.Items { + if item.ID == id { + // Remove item from slice + globalTodoList.Items = append(globalTodoList.Items[:i], globalTodoList.Items[i+1:]...) + + result := map[string]string{ + "message": "todo deleted successfully", + "id": id, + } + + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + + return jsonResult + } + } + + // ID not found + result := map[string]string{ + "error": fmt.Sprintf("todo with id %s not found", id), + } + jsonResult, err := json.Marshal(result) + if err != nil { + msg := "failed to marshal result; error: " + err.Error() + logger.Error(msg) + return []byte(msg) + } + return jsonResult +} + +func isCommandAllowed(command string) bool { + allowedCommands := map[string]bool{ + "grep": true, + "sed": true, + "awk": true, + "find": true, + "cat": true, + "head": true, + "tail": true, + "sort": true, + "uniq": true, + "wc": true, + "ls": true, + "echo": true, + "cut": true, + "tr": true, + "cp": true, + "mv": true, + "rm": true, + "mkdir": true, + "rmdir": true, + "pwd": true, + "df": true, + "free": true, + "ps": true, + "top": true, + "du": true, + "whoami": true, + "date": true, + "uname": true, + } + return allowedCommands[command] +} + +type fnSig func(map[string]string) []byte var fnMap = map[string]fnSig{ - "recall": recall, - "recall_topics": recallTopics, - "memorise": memorise, + "recall": recall, + "recall_topics": recallTopics, + "memorise": memorise, + "websearch": websearch, + "read_url": readURL, + "file_create": fileCreate, + "file_read": fileRead, + "file_write": fileWrite, + "file_delete": fileDelete, + "file_move": fileMove, + "file_copy": fileCopy, + "file_list": fileList, + "execute_command": executeCommand, + "todo_create": todoCreate, + "todo_read": todoRead, + "todo_update": todoUpdate, + "todo_delete": todoDelete, +} + +// openai style def +var baseTools = []models.Tool{ + // websearch + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "websearch", + Description: "Search web given query, limit of sources (default 3).", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"query", "limit"}, + Properties: map[string]models.ToolArgProps{ + "query": models.ToolArgProps{ + Type: "string", + Description: "search query", + }, + "limit": models.ToolArgProps{ + Type: "string", + Description: "limit of the website results", + }, + }, + }, + }, + }, + // read_url + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "read_url", + Description: "Retrieves text content of given link.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"url"}, + Properties: map[string]models.ToolArgProps{ + "url": models.ToolArgProps{ + Type: "string", + Description: "link to the webpage to read text from", + }, + }, + }, + }, + }, + // memorise + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "memorise", + Description: "Save topic-data in key-value cache. Use when asked to remember something/keep in mind.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"topic", "data"}, + Properties: map[string]models.ToolArgProps{ + "topic": models.ToolArgProps{ + Type: "string", + Description: "topic is the key under which data is saved", + }, + "data": models.ToolArgProps{ + Type: "string", + Description: "data is the value that is saved under the topic-key", + }, + }, + }, + }, + }, + // recall + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "recall", + Description: "Recall topic-data from key-value cache. Use when precise info about the topic is needed.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"topic"}, + Properties: map[string]models.ToolArgProps{ + "topic": models.ToolArgProps{ + Type: "string", + Description: "topic is the key to recall data from", + }, + }, + }, + }, + }, + // recall_topics + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "recall_topics", + Description: "Recall all topics from key-value cache. Use when need to know what topics are currently stored in memory.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{}, + }, + }, + }, + + // file_create + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_create", + Description: "Create a new file with specified content. Use when you need to create a new file.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"path"}, + Properties: map[string]models.ToolArgProps{ + "path": models.ToolArgProps{ + Type: "string", + Description: "path where the file should be created", + }, + "content": models.ToolArgProps{ + Type: "string", + Description: "content to write to the file (optional, defaults to empty string)", + }, + }, + }, + }, + }, + + // file_read + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_read", + Description: "Read the content of a file. Use when you need to see the content of a file.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"path"}, + Properties: map[string]models.ToolArgProps{ + "path": models.ToolArgProps{ + Type: "string", + Description: "path of the file to read", + }, + }, + }, + }, + }, + + // file_write + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_write", + Description: "Write content to a file. Use when you want to create or modify a file (overwrite or append).", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"path", "content"}, + Properties: map[string]models.ToolArgProps{ + "path": models.ToolArgProps{ + Type: "string", + Description: "path of the file to write to", + }, + "content": models.ToolArgProps{ + Type: "string", + Description: "content to write to the file", + }, + "mode": models.ToolArgProps{ + Type: "string", + Description: "write mode: 'overwrite' to replace entire file content, 'append' to add to the end (defaults to 'overwrite')", + }, + }, + }, + }, + }, + + // file_delete + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_delete", + Description: "Delete a file. Use when you need to remove a file.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"path"}, + Properties: map[string]models.ToolArgProps{ + "path": models.ToolArgProps{ + Type: "string", + Description: "path of the file to delete", + }, + }, + }, + }, + }, + + // file_move + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_move", + Description: "Move a file from one location to another. Use when you need to relocate a file.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"src", "dst"}, + Properties: map[string]models.ToolArgProps{ + "src": models.ToolArgProps{ + Type: "string", + Description: "source path of the file to move", + }, + "dst": models.ToolArgProps{ + Type: "string", + Description: "destination path where the file should be moved", + }, + }, + }, + }, + }, + + // file_copy + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_copy", + Description: "Copy a file from one location to another. Use when you need to duplicate a file.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"src", "dst"}, + Properties: map[string]models.ToolArgProps{ + "src": models.ToolArgProps{ + Type: "string", + Description: "source path of the file to copy", + }, + "dst": models.ToolArgProps{ + Type: "string", + Description: "destination path where the file should be copied", + }, + }, + }, + }, + }, + + // file_list + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "file_list", + Description: "List files and directories in a directory. Use when you need to see what files are in a directory.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "path": models.ToolArgProps{ + Type: "string", + Description: "path of the directory to list (optional, defaults to current directory)", + }, + }, + }, + }, + }, + + // execute_command + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "execute_command", + Description: "Execute a shell command safely. Use when you need to run system commands like grep sed awk find cat head tail sort uniq wc ls echo cut tr cp mv rm mkdir rmdir pwd df free ps top du whoami date uname", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"command"}, + Properties: map[string]models.ToolArgProps{ + "command": models.ToolArgProps{ + Type: "string", + Description: "command to execute (only commands from whitelist are allowed: grep sed awk find cat head tail sort uniq wc ls echo cut tr cp mv rm mkdir rmdir pwd df free ps top du whoami date uname", + }, + "args": models.ToolArgProps{ + Type: "string", + Description: "command arguments as a single string (e.g., '-la {path}')", + }, + }, + }, + }, + }, + // todo_create + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "todo_create", + Description: "Create a new todo item with a task. Returns the created todo with its ID.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"task"}, + Properties: map[string]models.ToolArgProps{ + "task": models.ToolArgProps{ + Type: "string", + Description: "the task description to add to the todo list", + }, + }, + }, + }, + }, + // todo_read + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "todo_read", + Description: "Read todo items. Without ID returns all todos, with ID returns specific todo.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "id": models.ToolArgProps{ + Type: "string", + Description: "optional id of the specific todo item to read", + }, + }, + }, + }, + }, + // todo_update + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "todo_update", + Description: "Update a todo item by ID with new task or status. Status must be one of: pending, in_progress, completed.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"id"}, + Properties: map[string]models.ToolArgProps{ + "id": models.ToolArgProps{ + Type: "string", + Description: "id of the todo item to update", + }, + "task": models.ToolArgProps{ + Type: "string", + Description: "new task description (optional)", + }, + "status": models.ToolArgProps{ + Type: "string", + Description: "new status for the todo: pending, in_progress, or completed (optional)", + }, + }, + }, + }, + }, + // todo_delete + models.Tool{ + Type: "function", + Function: models.ToolFunc{ + Name: "todo_delete", + Description: "Delete a todo item by ID. Returns success message.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"id"}, + Properties: map[string]models.ToolArgProps{ + "id": models.ToolArgProps{ + Type: "string", + Description: "id of the todo item to delete", + }, + }, + }, + }, + }, } @@ -1,13 +1,15 @@ package main import ( - "elefant/models" - "elefant/pngmeta" "fmt" + "gf-lt/extra" + "gf-lt/models" "image" _ "image/jpeg" _ "image/png" "os" + "os/exec" + "path" "strconv" "strings" @@ -28,17 +30,26 @@ var ( defaultImage = "sysprompts/llama.png" indexPickWindow *tview.InputField renameWindow *tview.InputField + roleEditWindow *tview.InputField + fullscreenMode bool // pages - historyPage = "historyPage" - agentPage = "agentPage" - editMsgPage = "editMsgPage" - indexPage = "indexPage" - helpPage = "helpPage" - renamePage = "renamePage" - RAGPage = "RAGPage " - propsPage = "propsPage" - codeBlockPage = "codeBlockPage" - imgPage = "imgPage" + historyPage = "historyPage" + agentPage = "agentPage" + editMsgPage = "editMsgPage" + roleEditPage = "roleEditPage" + helpPage = "helpPage" + renamePage = "renamePage" + RAGPage = "RAGPage" + RAGLoadedPage = "RAGLoadedPage" + propsPage = "propsPage" + codeBlockPage = "codeBlockPage" + imgPage = "imgPage" + filePickerPage = "filePicker" + exportDir = "chat_exports" + + // For overlay search functionality + searchField *tview.InputField + searchPageName = "searchOverlay" // help text helpText = ` [yellow]Esc[white]: send msg @@ -47,208 +58,394 @@ var ( [yellow]F2[white]: regen last [yellow]F3[white]: delete last msg [yellow]F4[white]: edit msg -[yellow]F5[white]: toggle system +[yellow]F5[white]: toggle fullscreen for input/chat window [yellow]F6[white]: interrupt bot resp [yellow]F7[white]: copy last msg to clipboard (linux xclip) [yellow]F8[white]: copy n msg to clipboard (linux xclip) [yellow]F9[white]: table to copy from; with all code blocks -[yellow]F10[white]: manage loaded rag files (that already in vector db) -[yellow]F11[white]: switch RAGEnabled boolean +[yellow]F10[white]: switch if LLM will respond on this message (for user to write multiple messages in a row) +[yellow]F11[white]: import json chat file [yellow]F12[white]: show this help page [yellow]Ctrl+w[white]: resume generation on the last msg [yellow]Ctrl+s[white]: load new char/agent [yellow]Ctrl+e[white]: export chat to json file -[yellow]Ctrl+n[white]: start a new chat [yellow]Ctrl+c[white]: close programm +[yellow]Ctrl+n[white]: start a new chat +[yellow]Ctrl+o[white]: open image file picker [yellow]Ctrl+p[white]: props edit form (min-p, dry, etc.) [yellow]Ctrl+v[white]: switch between /completion and /chat api (if provided in config) -[yellow]Ctrl+r[white]: menu of files that can be loaded in vector db (RAG) +[yellow]Ctrl+r[white]: start/stop recording from your microphone (needs stt server or whisper binary) [yellow]Ctrl+t[white]: remove thinking (<think>) and tool messages from context (delete from chat) -[yellow]Ctrl+l[white]: update connected model name (llamacpp) +[yellow]Ctrl+l[white]: rotate through free OpenRouter models (if openrouter api) or update connected model name (llamacpp) [yellow]Ctrl+k[white]: switch tool use (recommend tool use to llm after user msg) [yellow]Ctrl+j[white]: if chat agent is char.png will show the image; then any key to return +[yellow]Ctrl+a[white]: interrupt tts (needs tts server) +[yellow]Ctrl+g[white]: open RAG file manager (load files for context retrieval) +[yellow]Ctrl+y[white]: list loaded RAG files (view and manage loaded files) +[yellow]Ctrl+q[white]: cycle through mentioned chars in chat, to pick persona to send next msg as +[yellow]Ctrl+x[white]: cycle through mentioned chars in chat, to pick persona to send next msg as (for llm) +[yellow]Alt+1[white]: toggle shell mode (execute commands locally) +[yellow]Alt+4[white]: edit msg role +[yellow]Alt+5[white]: toggle system and tool messages display + +=== scrolling chat window (some keys similar to vim) === +[yellow]arrows up/down and j/k[white]: scroll up and down +[yellow]gg/G[white]: jump to the begging / end of the chat +[yellow]/[white]: start searching for text +[yellow]n[white]: go to next search result +[yellow]N[white]: go to previous search result -Press Enter to go back +=== tables (chat history, agent pick, file pick, properties) === +[yellow]x[white]: to exit the table page + +=== status line === +%s + +Press <Enter> or 'x' to return ` + colorschemes = map[string]tview.Theme{ + "default": tview.Theme{ + PrimitiveBackgroundColor: tcell.ColorDefault, + ContrastBackgroundColor: tcell.ColorGray, + MoreContrastBackgroundColor: tcell.ColorSteelBlue, + BorderColor: tcell.ColorGray, + TitleColor: tcell.ColorRed, + GraphicsColor: tcell.ColorBlue, + PrimaryTextColor: tcell.ColorLightGray, + SecondaryTextColor: tcell.ColorYellow, + TertiaryTextColor: tcell.ColorOrange, + InverseTextColor: tcell.ColorPurple, + ContrastSecondaryTextColor: tcell.ColorLime, + }, + "gruvbox": tview.Theme{ + PrimitiveBackgroundColor: tcell.ColorBlack, // Matches #1e1e2e + ContrastBackgroundColor: tcell.ColorDarkGoldenrod, // Selected option: warm yellow (#b57614) + MoreContrastBackgroundColor: tcell.ColorDarkSlateGray, // Non-selected options: dark grayish-blue (#32302f) + BorderColor: tcell.ColorLightGray, // Light gray (#a89984) + TitleColor: tcell.ColorRed, // Red (#fb4934) + GraphicsColor: tcell.ColorDarkCyan, // Cyan (#689d6a) + PrimaryTextColor: tcell.ColorLightGray, // Light gray (#d5c4a1) + SecondaryTextColor: tcell.ColorYellow, // Yellow (#fabd2f) + TertiaryTextColor: tcell.ColorOrange, // Orange (#fe8019) + InverseTextColor: tcell.ColorWhite, // White (#f9f5d7) for selected text + ContrastSecondaryTextColor: tcell.ColorLightGreen, // Light green (#b8bb26) + }, + "solarized": tview.Theme{ + PrimitiveBackgroundColor: tcell.NewHexColor(0x1e1e2e), // #1e1e2e for main dropdown box + ContrastBackgroundColor: tcell.ColorDarkCyan, // Selected option: cyan (#2aa198) + MoreContrastBackgroundColor: tcell.ColorDarkSlateGray, // Non-selected options: dark blue (#073642) + BorderColor: tcell.ColorLightBlue, // Light blue (#839496) + TitleColor: tcell.ColorRed, // Red (#dc322f) + GraphicsColor: tcell.ColorBlue, // Blue (#268bd2) + PrimaryTextColor: tcell.ColorWhite, // White (#fdf6e3) + SecondaryTextColor: tcell.ColorYellow, // Yellow (#b58900) + TertiaryTextColor: tcell.ColorOrange, // Orange (#cb4b16) + InverseTextColor: tcell.ColorWhite, // White (#eee8d5) for selected text + ContrastSecondaryTextColor: tcell.ColorLightCyan, // Light cyan (#93a1a1) + }, + "dracula": tview.Theme{ + PrimitiveBackgroundColor: tcell.NewHexColor(0x1e1e2e), // #1e1e2e for main dropdown box + ContrastBackgroundColor: tcell.ColorDarkMagenta, // Selected option: magenta (#bd93f9) + MoreContrastBackgroundColor: tcell.ColorDarkGray, // Non-selected options: dark gray (#44475a) + BorderColor: tcell.ColorLightGray, // Light gray (#f8f8f2) + TitleColor: tcell.ColorRed, // Red (#ff5555) + GraphicsColor: tcell.ColorDarkCyan, // Cyan (#8be9fd) + PrimaryTextColor: tcell.ColorWhite, // White (#f8f8f2) + SecondaryTextColor: tcell.ColorYellow, // Yellow (#f1fa8c) + TertiaryTextColor: tcell.ColorOrange, // Orange (#ffb86c) + InverseTextColor: tcell.ColorWhite, // White (#f8f8f2) for selected text + ContrastSecondaryTextColor: tcell.ColorLightGreen, // Light green (#50fa7b) + }, + } ) -func loadImage() { - filepath := defaultImage - cc, ok := sysMap[cfg.AssistantRole] - if ok { - if strings.HasSuffix(cc.FilePath, ".png") { - filepath = cc.FilePath - } +func toggleShellMode() { + shellMode = !shellMode + if shellMode { + // Update input placeholder to indicate shell mode + textArea.SetPlaceholder("SHELL MODE: Enter command and press <Esc> to execute") + } else { + // Reset to normal mode + textArea.SetPlaceholder("input is multiline; press <Enter> to start the next line;\npress <Esc> to send the message. Alt+1 to exit shell mode") } - file, err := os.Open(filepath) - if err != nil { - panic(err) + updateStatusLine() +} + +func executeCommandAndDisplay(cmdText string) { + // Parse the command (split by spaces, but handle quoted arguments) + cmdParts := parseCommand(cmdText) + if len(cmdParts) == 0 { + fmt.Fprintf(textView, "\n[red]Error: No command provided[-:-:-]\n") + textView.ScrollToEnd() + colorText() + return + } + command := cmdParts[0] + args := []string{} + if len(cmdParts) > 1 { + args = cmdParts[1:] } - defer file.Close() - img, _, err := image.Decode(file) + // Create the command execution + cmd := exec.Command(command, args...) + // Execute the command and get output + output, err := cmd.CombinedOutput() + // Add the command being executed to the chat + fmt.Fprintf(textView, "\n[yellow]$ %s[-:-:-]\n", cmdText) if err != nil { - panic(err) + // Include both output and error + fmt.Fprintf(textView, "[red]Error: %s[-:-:-]\n", err.Error()) + if len(output) > 0 { + fmt.Fprintf(textView, "[red]%s[-:-:-]\n", string(output)) + } + } else { + // Only output if successful + if len(output) > 0 { + fmt.Fprintf(textView, "[green]%s[-:-:-]\n", string(output)) + } else { + fmt.Fprintf(textView, "[green]Command executed successfully (no output)[-:-:-]\n") + } } - imgView.SetImage(img) + // Scroll to end and update colors + textView.ScrollToEnd() + colorText() } -func colorText() { - text := textView.GetText(false) - // Step 1: Extract code blocks and replace them with unique placeholders - var codeBlocks []string - placeholder := "__CODE_BLOCK_%d__" - counter := 0 - // thinking - var thinkBlocks []string - placeholderThink := "__THINK_BLOCK_%d__" - counterThink := 0 - // Replace code blocks with placeholders and store their styled versions - text = codeBlockRE.ReplaceAllStringFunc(text, func(match string) string { - // Style the code block and store it - styled := fmt.Sprintf("[red::i]%s[-:-:-]", match) - codeBlocks = append(codeBlocks, styled) - // Generate a unique placeholder (e.g., "__CODE_BLOCK_0__") - id := fmt.Sprintf(placeholder, counter) - counter++ - return id - }) - text = thinkRE.ReplaceAllStringFunc(text, func(match string) string { - // Style the code block and store it - styled := fmt.Sprintf("[red::i]%s[-:-:-]", match) - thinkBlocks = append(codeBlocks, styled) - // Generate a unique placeholder (e.g., "__CODE_BLOCK_0__") - id := fmt.Sprintf(placeholderThink, counterThink) - counter++ - return id - }) - // Step 2: Apply other regex styles to the non-code parts - text = quotesRE.ReplaceAllString(text, `[orange::-]$1[-:-:-]`) - text = starRE.ReplaceAllString(text, `[turquoise::i]$1[-:-:-]`) - // text = thinkRE.ReplaceAllString(text, `[yellow::i]$1[-:-:-]`) - // Step 3: Restore the styled code blocks from placeholders - for i, cb := range codeBlocks { - text = strings.Replace(text, fmt.Sprintf(placeholder, i), cb, 1) +// parseCommand splits command string handling quotes properly +func parseCommand(cmd string) []string { + var args []string + var current string + var inQuotes bool + var quoteChar rune + for _, r := range cmd { + switch r { + case '"', '\'': + if inQuotes { + if r == quoteChar { + inQuotes = false + } else { + current += string(r) + } + } else { + inQuotes = true + quoteChar = r + } + case ' ', '\t': + if inQuotes { + current += string(r) + } else if current != "" { + args = append(args, current) + current = "" + } + default: + current += string(r) + } } - for i, tb := range thinkBlocks { - text = strings.Replace(text, fmt.Sprintf(placeholderThink, i), tb, 1) + if current != "" { + args = append(args, current) } - textView.SetText(text) + return args } -func updateStatusLine() { - position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level())) -} +// Global variables for search state +var searchResults []int +var searchResultLengths []int // To store the length of each match in the formatted string +var searchIndex int +var searchText string +var originalTextForSearch string -func initSysCards() ([]string, error) { - labels := []string{} - labels = append(labels, sysLabels...) - cards, err := pngmeta.ReadDirCards(cfg.SysDir, cfg.UserRole, logger) - if err != nil { - logger.Error("failed to read sys dir", "error", err) - return nil, err +// performSearch searches for the given term in the textView content and highlights matches +func performSearch(term string) { + searchText = term + if searchText == "" { + searchResults = nil + searchResultLengths = nil + originalTextForSearch = "" + // Re-render text without highlights + textView.SetText(chatToText(cfg.ShowSys)) + colorText() + return } - for _, cc := range cards { - if cc.Role == "" { - logger.Warn("empty role", "file", cc.FilePath) - continue + // Get formatted text and search directly in it to avoid mapping issues + formattedText := textView.GetText(true) + originalTextForSearch = formattedText + searchTermLower := strings.ToLower(searchText) + formattedTextLower := strings.ToLower(formattedText) + // Find all occurrences of the search term in the formatted text directly + formattedSearchResults := []int{} + searchStart := 0 + for { + pos := strings.Index(formattedTextLower[searchStart:], searchTermLower) + if pos == -1 { + break } - sysMap[cc.Role] = cc - labels = append(labels, cc.Role) + absolutePos := searchStart + pos + formattedSearchResults = append(formattedSearchResults, absolutePos) + searchStart = absolutePos + len(searchText) } - return labels, nil + if len(formattedSearchResults) == 0 { + // No matches found + searchResults = nil + searchResultLengths = nil + notification := "Pattern not found: " + term + if err := notifyUser("search", notification); err != nil { + logger.Error("failed to send notification", "error", err) + } + return + } + // Store the formatted text positions and lengths for accurate highlighting + searchResults = formattedSearchResults + // Create lengths array - all matches have the same length as the search term + searchResultLengths = make([]int, len(formattedSearchResults)) + for i := range searchResultLengths { + searchResultLengths[i] = len(searchText) + } + searchIndex = 0 + highlightCurrentMatch() } -func startNewChat() { - id, err := store.ChatGetMaxID() - if err != nil { - logger.Error("failed to get chat id", "error", err) +// highlightCurrentMatch highlights the current search match and scrolls to it +func highlightCurrentMatch() { + if len(searchResults) == 0 || searchIndex >= len(searchResults) { + return } - if ok := charToStart(cfg.AssistantRole); !ok { - logger.Warn("no such sys msg", "name", cfg.AssistantRole) + // Get the stored formatted text + formattedText := originalTextForSearch + // For tview to properly support highlighting and scrolling, we need to work with its region system + // Instead of just applying highlights, we need to add region tags to the text + highlightedText := addRegionTags(formattedText, searchResults, searchResultLengths, searchIndex, searchText) + // Update the text view with the text that includes region tags + textView.SetText(highlightedText) + // Highlight the current region and scroll to it + // Need to identify which position in the results array corresponds to the current match + // The region ID will be search_<position>_<index> + currentRegion := fmt.Sprintf("search_%d_%d", searchResults[searchIndex], searchIndex) + textView.Highlight(currentRegion).ScrollToHighlight() + // Send notification about which match we're at + notification := fmt.Sprintf("Match %d of %d", searchIndex+1, len(searchResults)) + if err := notifyUser("search", notification); err != nil { + logger.Error("failed to send notification", "error", err) } - // set chat body - chatBody.Messages = chatBody.Messages[:2] - textView.SetText(chatToText(cfg.ShowSys)) - newChat := &models.Chat{ - ID: id + 1, - Name: fmt.Sprintf("%d_%s", id+1, cfg.AssistantRole), - Msgs: string(defaultStarterBytes), - Agent: cfg.AssistantRole, +} + +// showSearchBar shows the search input field as an overlay +func showSearchBar() { + // Create a temporary flex to combine search and main content + updatedFlex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(searchField, 3, 0, true). // Search field at top + AddItem(flex, 0, 1, false) // Main flex layout below + + // Add the search overlay as a page + pages.AddPage(searchPageName, updatedFlex, true, true) + app.SetFocus(searchField) +} + +// hideSearchBar hides the search input field +func hideSearchBar() { + pages.RemovePage(searchPageName) + // Return focus to the text view + app.SetFocus(textView) + // Clear the search field + searchField.SetText("") +} + +// Global variables for index overlay functionality +var indexPageName = "indexOverlay" + +// showIndexBar shows the index input field as an overlay at the top +func showIndexBar() { + // Create a temporary flex to combine index input and main content + updatedFlex := tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(indexPickWindow, 3, 0, true). // Index field at top + AddItem(flex, 0, 1, false) // Main flex layout below + + // Add the index overlay as a page + pages.AddPage(indexPageName, updatedFlex, true, true) + app.SetFocus(indexPickWindow) +} + +// hideIndexBar hides the index input field +func hideIndexBar() { + pages.RemovePage(indexPageName) + // Return focus to the text view + app.SetFocus(textView) + // Clear the index field + indexPickWindow.SetText("") +} + +// addRegionTags adds region tags to search matches in the text for tview highlighting +func addRegionTags(text string, positions []int, lengths []int, currentIdx int, searchTerm string) string { + if len(positions) == 0 { + return text } - activeChatName = newChat.Name - chatMap[newChat.Name] = newChat - updateStatusLine() - colorText() + var result strings.Builder + lastEnd := 0 + for i, pos := range positions { + endPos := pos + lengths[i] + // Add text before this match + if pos > lastEnd { + result.WriteString(text[lastEnd:pos]) + } + // The matched text, which may contain its own formatting tags + actualText := text[pos:endPos] + // Add region tag and highlighting for this match + // Use a unique region id that includes the match index to avoid conflicts + regionId := fmt.Sprintf("search_%d_%d", pos, i) // position + index to ensure uniqueness + var highlightStart, highlightEnd string + if i == currentIdx { + // Current match - use different highlighting + highlightStart = fmt.Sprintf(`["%s"][yellow:blue:b]`, regionId) // Current match with region and special highlight + highlightEnd = `[-:-:-][""]` // Reset formatting and close region + } else { + // Other matches - use regular highlighting + highlightStart = fmt.Sprintf(`["%s"][gold:red:u]`, regionId) // Other matches with region and highlight + highlightEnd = `[-:-:-][""]` // Reset formatting and close region + } + result.WriteString(highlightStart) + result.WriteString(actualText) + result.WriteString(highlightEnd) + lastEnd = endPos + } + // Add the rest of the text after the last processed match + if lastEnd < len(text) { + result.WriteString(text[lastEnd:]) + } + return result.String() } -func setLogLevel(sl string) { - switch sl { - case "Debug": - logLevel.Set(-4) - case "Info": - logLevel.Set(0) - case "Warn": - logLevel.Set(4) +// searchNext finds the next occurrence of the search term +func searchNext() { + if len(searchResults) == 0 { + if err := notifyUser("search", "No search results to navigate"); err != nil { + logger.Error("failed to send notification", "error", err) + } + return } + searchIndex = (searchIndex + 1) % len(searchResults) + highlightCurrentMatch() } -func makePropsForm(props map[string]float32) *tview.Form { - // https://github.com/rivo/tview/commit/0a18dea458148770d212d348f656988df75ff341 - // no way to close a form by a key press; a shame. - form := tview.NewForm(). - AddTextView("Notes", "Props for llamacpp completion call", 40, 2, true, false). - AddCheckbox("Insert <think> (/completion only)", cfg.ThinkUse, func(checked bool) { - cfg.ThinkUse = checked - }).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1, - func(option string, optionIndex int) { - setLogLevel(option) - }). - AddButton("Quit", func() { - pages.RemovePage(propsPage) - }) - form.AddButton("Save", func() { - defer updateStatusLine() - defer pages.RemovePage(propsPage) - for pn := range props { - propField, ok := form.GetFormItemByLabel(pn).(*tview.InputField) - if !ok { - logger.Warn("failed to convert to inputfield", "prop_name", pn) - continue - } - val, err := strconv.ParseFloat(propField.GetText(), 32) - if err != nil { - logger.Warn("failed parse to float", "value", propField.GetText()) - continue - } - props[pn] = float32(val) +// searchPrev finds the previous occurrence of the search term +func searchPrev() { + if len(searchResults) == 0 { + if err := notifyUser("search", "No search results to navigate"); err != nil { + logger.Error("failed to send notification", "error", err) } - }) - for propName, value := range props { - form.AddInputField(propName, fmt.Sprintf("%v", value), 20, tview.InputFieldFloat, nil) + return + } + if searchIndex == 0 { + searchIndex = len(searchResults) - 1 + } else { + searchIndex-- } - form.SetBorder(true).SetTitle("Enter some data").SetTitleAlign(tview.AlignLeft) - return form + highlightCurrentMatch() } func init() { - theme := tview.Theme{ - PrimitiveBackgroundColor: tcell.ColorDefault, - ContrastBackgroundColor: tcell.ColorGray, - MoreContrastBackgroundColor: tcell.ColorNavy, - BorderColor: tcell.ColorGray, - TitleColor: tcell.ColorRed, - GraphicsColor: tcell.ColorBlue, - PrimaryTextColor: tcell.ColorLightGray, - SecondaryTextColor: tcell.ColorYellow, - TertiaryTextColor: tcell.ColorOrange, - InverseTextColor: tcell.ColorPurple, - ContrastSecondaryTextColor: tcell.ColorLime, - } - tview.Styles = theme + tview.Styles = colorschemes["default"] app = tview.NewApplication() pages = tview.NewPages() textArea = tview.NewTextArea(). - SetPlaceholder("Type your prompt...") + SetPlaceholder("input is multiline; press <Enter> to start the next line;\npress <Esc> to send the message.") textArea.SetBorder(true).SetTitle("input") textView = tview.NewTextView(). SetDynamicColors(true). @@ -256,34 +453,83 @@ func init() { SetChangedFunc(func() { app.Draw() }) + // + flex = tview.NewFlex().SetDirection(tview.FlexRow). + AddItem(textView, 0, 40, false). + AddItem(textArea, 0, 10, true). // Restore original height + AddItem(position, 0, 2, false) // textView.SetBorder(true).SetTitle("chat") textView.SetDoneFunc(func(key tcell.Key) { - currentSelection := textView.GetHighlights() if key == tcell.KeyEnter { - if len(currentSelection) > 0 { - textView.Highlight() + if len(searchResults) > 0 { // Check if a search is active + hideSearchBar() // Hide the search bar if visible + searchResults = nil // Clear search results + searchResultLengths = nil // Clear search result lengths + originalTextForSearch = "" + textView.SetText(chatToText(cfg.ShowSys)) // Reset text without search regions + colorText() // Apply normal chat coloring } else { - textView.Highlight("0").ScrollToHighlight() + // Original logic if no search is active + currentSelection := textView.GetHighlights() + if len(currentSelection) > 0 { + textView.Highlight() + } else { + textView.Highlight("0").ScrollToHighlight() + } } } }) + textView.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + // Handle vim-like navigation in TextView + switch event.Key() { + case tcell.KeyRune: + switch event.Rune() { + case 'j': + // For line down + return event + case 'k': + // For line up + return event + case 'g': + // Go to beginning + textView.ScrollToBeginning() + return nil + case 'G': + // Go to end + textView.ScrollToEnd() + return nil + case '/': + // Search functionality - show search bar + showSearchBar() + return nil + case 'n': + // Next search result + searchNext() + return nil + case 'N': + // Previous search result + searchPrev() + return nil + } + } + return event + }) focusSwitcher[textArea] = textView focusSwitcher[textView] = textArea position = tview.NewTextView(). SetDynamicColors(true). SetTextAlign(tview.AlignCenter) - position.SetChangedFunc(func() { - app.Draw() - }) + // Initially set up flex without search bar flex = tview.NewFlex().SetDirection(tview.FlexRow). AddItem(textView, 0, 40, false). - AddItem(textArea, 0, 10, true). + AddItem(textArea, 0, 10, true). // Restore original height AddItem(position, 0, 2, false) editArea = tview.NewTextArea(). SetPlaceholder("Replace msg...") editArea.SetBorder(true).SetTitle("input") editArea.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { - if event.Key() == tcell.KeyEscape && editMode { + // if event.Key() == tcell.KeyEscape && editMode { + if event.Key() == tcell.KeyEscape { defer colorText() editedMsg := editArea.GetText() if editedMsg == "" { @@ -291,7 +537,6 @@ func init() { logger.Error("failed to send notification", "error", err) } pages.RemovePage(editMsgPage) - editMode = false return nil } chatBody.Messages[selectedIndex].Content = editedMsg @@ -308,55 +553,93 @@ func init() { SetFieldWidth(4). SetAcceptanceFunc(tview.InputFieldInteger). SetDoneFunc(func(key tcell.Key) { - defer indexPickWindow.SetText("") - pages.RemovePage(indexPage) - colorText() - updateStatusLine() + hideIndexBar() + // colorText() + // updateStatusLine() + }) + + roleEditWindow = tview.NewInputField(). + SetLabel("Enter new role: "). + SetPlaceholder("e.g., user, assistant, system, tool"). + SetDoneFunc(func(key tcell.Key) { + switch key { + case tcell.KeyEnter: + newRole := roleEditWindow.GetText() + if newRole == "" { + if err := notifyUser("edit", "no role provided"); err != nil { + logger.Error("failed to send notification", "error", err) + } + pages.RemovePage(roleEditPage) + return + } + if selectedIndex >= 0 && selectedIndex < len(chatBody.Messages) { + chatBody.Messages[selectedIndex].Role = newRole + textView.SetText(chatToText(cfg.ShowSys)) + colorText() + pages.RemovePage(roleEditPage) + } + case tcell.KeyEscape: + pages.RemovePage(roleEditPage) + } }) indexPickWindow.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { switch event.Key() { case tcell.KeyBackspace: return event + case tcell.KeyEscape: + // Hide the index overlay when Escape is pressed + hideIndexBar() + return nil case tcell.KeyEnter: si := indexPickWindow.GetText() siInt, err := strconv.Atoi(si) if err != nil { logger.Error("failed to convert provided index", "error", err, "si", si) - if err := notifyUser("cancel", "no index provided"); err != nil { + if err := notifyUser("cancel", "no index provided, copying user input"); err != nil { logger.Error("failed to send notification", "error", err) } - pages.RemovePage(indexPage) - return event + if err := copyToClipboard(textArea.GetText()); err != nil { + logger.Error("failed to copy to clipboard", "error", err) + } + hideIndexBar() // Hide overlay instead of removing page directly + return nil } selectedIndex = siInt - if len(chatBody.Messages)+1 < selectedIndex || selectedIndex < 0 { - msg := "chosen index is out of bounds" + if len(chatBody.Messages)-1 < selectedIndex || selectedIndex < 0 { + msg := "chosen index is out of bounds, will copy user input" logger.Warn(msg, "index", selectedIndex) if err := notifyUser("error", msg); err != nil { logger.Error("failed to send notification", "error", err) } - pages.RemovePage(indexPage) - return event + if err := copyToClipboard(textArea.GetText()); err != nil { + logger.Error("failed to copy to clipboard", "error", err) + } + hideIndexBar() // Hide overlay instead of removing page directly + return nil } m := chatBody.Messages[selectedIndex] - if editMode && event.Key() == tcell.KeyEnter { + if roleEditMode { + hideIndexBar() // Hide overlay first + // Set the current role as the default text in the input field + roleEditWindow.SetText(m.Role) + pages.AddPage(roleEditPage, roleEditWindow, true, true) + roleEditMode = false // Reset the flag + } else if editMode { + hideIndexBar() // Hide overlay first pages.AddPage(editMsgPage, editArea, true, true) editArea.SetText(m.Content, true) - } - if !editMode && event.Key() == tcell.KeyEnter { + } else { if err := copyToClipboard(m.Content); err != nil { logger.Error("failed to copy to clipboard", "error", err) } - previewLen := 30 - if len(m.Content) < 30 { - previewLen = len(m.Content) - } + previewLen := min(30, len(m.Content)) notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen]) if err := notifyUser("copied", notification); err != nil { logger.Error("failed to send notification", "error", err) } + hideIndexBar() // Hide overlay after copying } - return event + return nil default: return event } @@ -392,13 +675,45 @@ func init() { return event }) // - helpView = tview.NewTextView().SetDynamicColors(true).SetText(helpText).SetDoneFunc(func(key tcell.Key) { - pages.RemovePage(helpPage) - }) + searchField = tview.NewInputField(). + SetPlaceholder("Search... (Enter: search)"). + SetDoneFunc(func(key tcell.Key) { + if key == tcell.KeyEnter { + term := searchField.GetText() + if term == "" { + // If the search term is empty, cancel the search + hideSearchBar() + searchResults = nil + searchResultLengths = nil + originalTextForSearch = "" + textView.SetText(chatToText(cfg.ShowSys)) + colorText() + return + } else { + performSearch(term) + // Keep focus on textView after search + app.SetFocus(textView) + hideSearchBar() + } + } + }) + searchField.SetBorder(true).SetTitle("Search") + // Note: Initially hide the search field (handled by not showing it in the layout) + // + helpView = tview.NewTextView().SetDynamicColors(true). + SetText(fmt.Sprintf(helpText, makeStatusLine())). + SetDoneFunc(func(key tcell.Key) { + pages.RemovePage(helpPage) + }) helpView.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { switch event.Key() { - case tcell.KeyEsc, tcell.KeyEnter: + case tcell.KeyEnter: return event + default: + if event.Key() == tcell.KeyRune && event.Rune() == 'x' { + pages.RemovePage(helpPage) + return nil + } } return nil }) @@ -428,6 +743,12 @@ func init() { logger.Error("failed to init sys cards", "error", err) } app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyRune && event.Rune() == '5' && event.Modifiers()&tcell.ModAlt != 0 { + // switch cfg.ShowSys + cfg.ShowSys = !cfg.ShowSys + textView.SetText(chatToText(cfg.ShowSys)) + colorText() + } if event.Key() == tcell.KeyF1 { // chatList, err := loadHistoryChats() chatList, err := store.GetChatByChar(cfg.AssistantRole) @@ -435,6 +756,14 @@ func init() { logger.Error("failed to load chat history", "error", err) return nil } + // Check if there are no chats for this agent + if len(chatList) == 0 { + notification := "no chats found for agent: " + cfg.AssistantRole + if err := notifyUser("info", notification); err != nil { + logger.Error("failed to send notification", "error", err) + } + return nil + } chatMap := make(map[string]models.Chat) // nameList := make([]string, len(chatList)) for _, chat := range chatList { @@ -449,7 +778,15 @@ func init() { } if event.Key() == tcell.KeyF2 { // regen last msg + if len(chatBody.Messages) == 0 { + if err := notifyUser("info", "no messages to regenerate"); err != nil { + logger.Error("failed to send notification", "error", err) + } + return nil + } chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] + // there is no case where user msg is regenerated + // lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role textView.SetText(chatToText(cfg.ShowSys)) go chatRound("", cfg.UserRole, textView, true, false) return nil @@ -465,22 +802,57 @@ func init() { colorText() return nil } + if len(chatBody.Messages) == 0 { + if err := notifyUser("info", "no messages to delete"); err != nil { + logger.Error("failed to send notification", "error", err) + } + return nil + } chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] textView.SetText(chatToText(cfg.ShowSys)) colorText() return nil } if event.Key() == tcell.KeyF4 { - // edit msg + // edit msg - show index input as overlay at top editMode = true - pages.AddPage(indexPage, indexPickWindow, true, true) + showIndexBar() + return nil + } + if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '4' { + // edit msg role - show index input as overlay at top + editMode = false // Reset edit mode to false to handle role editing + showIndexBar() + // Set a flag to indicate we're in role edit mode + roleEditMode = true return nil } if event.Key() == tcell.KeyF5 { - // switch cfg.ShowSys - cfg.ShowSys = !cfg.ShowSys - textView.SetText(chatToText(cfg.ShowSys)) - colorText() + // toggle fullscreen + fullscreenMode = !fullscreenMode + focused := app.GetFocus() + if fullscreenMode { + if focused == textArea || focused == textView { + flex.Clear() + flex.AddItem(focused, 0, 1, true) + } else { + // if focus is not on textarea or textview, cancel fullscreen + fullscreenMode = false + } + } else { + // focused is the fullscreened widget here + flex.Clear(). + AddItem(textView, 0, 40, false). + AddItem(textArea, 0, 10, false). + AddItem(position, 0, 2, false) + + if focused == textView { + app.SetFocus(textView) + } else { // default to textArea + app.SetFocus(textArea) + } + } + return nil } if event.Key() == tcell.KeyF6 { interruptResp = true @@ -494,10 +866,7 @@ func init() { if err := copyToClipboard(m.Content); err != nil { logger.Error("failed to copy to clipboard", "error", err) } - previewLen := 30 - if len(m.Content) < 30 { - previewLen = len(m.Content) - } + previewLen := min(30, len(m.Content)) notification := fmt.Sprintf("msg '%s' was copied to the clipboard", m.Content[:previewLen]) if err := notifyUser("copied", notification); err != nil { logger.Error("failed to send notification", "error", err) @@ -507,7 +876,7 @@ func init() { if event.Key() == tcell.KeyF8 { // copy msg to clipboard editMode = false - pages.AddPage(indexPage, indexPickWindow, true, true) + showIndexBar() return nil } if event.Key() == tcell.KeyF9 { @@ -522,29 +891,32 @@ func init() { } table := makeCodeBlockTable(cb) pages.AddPage(codeBlockPage, table, true, true) - // updateStatusLine() return nil } if event.Key() == tcell.KeyF10 { - // list rag loaded in db - loadedFiles, err := ragger.ListLoaded() + cfg.SkipLLMResp = !cfg.SkipLLMResp + updateStatusLine() + } + if event.Key() == tcell.KeyF11 { + // read files in chat_exports + filelist, err := os.ReadDir(exportDir) if err != nil { - logger.Error("failed to list regfiles in db", "error", err) - return nil - } - if len(loadedFiles) == 0 { - if err := notifyUser("loaded RAG", "no files in db"); err != nil { + if err := notifyUser("failed to load exports", err.Error()); err != nil { logger.Error("failed to send notification", "error", err) } return nil } - dbRAGTable := makeLoadedRAGTable(loadedFiles) - pages.AddPage(RAGPage, dbRAGTable, true, true) - return nil - } - if event.Key() == tcell.KeyF11 { - // xor - cfg.RAGEnabled = !cfg.RAGEnabled + fli := []string{} + for _, f := range filelist { + if f.IsDir() || !strings.HasSuffix(f.Name(), ".json") { + continue + } + fpath := path.Join(exportDir, f.Name()) + fli = append(fli, fpath) + } + // check error + exportsTable := makeImportChatTable(fli) + pages.AddPage(historyPage, exportsTable, true, true) updateStatusLine() return nil } @@ -565,19 +937,41 @@ func init() { return nil } if event.Key() == tcell.KeyCtrlP { - propsForm := makePropsForm(defaultLCPProps) - pages.AddPage(propsPage, propsForm, true, true) + propsTable := makePropsTable(defaultLCPProps) + pages.AddPage(propsPage, propsTable, true, true) return nil } if event.Key() == tcell.KeyCtrlN { startNewChat() return nil } + if event.Key() == tcell.KeyCtrlO { + // open file picker + filePicker := makeFilePicker() + pages.AddPage(filePickerPage, filePicker, true, true) + return nil + } if event.Key() == tcell.KeyCtrlL { - go func() { - fetchModelName() // blocks + // Check if the current API is an OpenRouter API + if strings.Contains(cfg.CurrentAPI, "openrouter.ai/api/v1/") { + // Rotate through OpenRouter free models + if len(ORFreeModels) > 0 { + currentORModelIndex = (currentORModelIndex + 1) % len(ORFreeModels) + chatBody.Model = ORFreeModels[currentORModelIndex] + } updateStatusLine() - }() + } else { + if len(LocalModels) > 0 { + currentLocalModelIndex = (currentLocalModelIndex + 1) % len(LocalModels) + chatBody.Model = LocalModels[currentLocalModelIndex] + } + updateStatusLine() + // // For non-OpenRouter APIs, use the old logic + // go func() { + // fetchLCPModelName() // blocks + // updateStatusLine() + // }() + } return nil } if event.Key() == tcell.KeyCtrlT { @@ -589,16 +983,28 @@ func init() { return nil } if event.Key() == tcell.KeyCtrlV { - // switch between /chat and /completion api - prevAPI := cfg.CurrentAPI - newAPI := cfg.APIMap[cfg.CurrentAPI] - if newAPI == "" { - // do not switch + // switch between API links using index-based rotation + if len(cfg.ApiLinks) == 0 { + // No API links to rotate through return nil } - cfg.APIMap[newAPI] = prevAPI - cfg.CurrentAPI = newAPI - initChunkParser() + // Find current API in the list to get the current index + currentIndex := -1 + for i, api := range cfg.ApiLinks { + if api == cfg.CurrentAPI { + currentIndex = i + break + } + } + // If current API is not in the list, start from beginning + // Otherwise, advance to next API in the list (with wrap-around) + if currentIndex == -1 { + currentAPIIndex = 0 + } else { + currentAPIIndex = (currentIndex + 1) % len(cfg.ApiLinks) + } + cfg.CurrentAPI = cfg.ApiLinks[currentAPIIndex] + choseChunkParser() updateStatusLine() return nil } @@ -626,18 +1032,153 @@ func init() { return nil } if event.Key() == tcell.KeyCtrlJ { - // show image - loadImage() + // show image - check for attached image first, then fall back to agent image + if lastImg != "" { + // Load the attached image + file, err := os.Open(lastImg) + if err != nil { + logger.Error("failed to open attached image", "path", lastImg, "error", err) + // Fall back to showing agent image + loadImage() + } else { + defer file.Close() + img, _, err := image.Decode(file) + if err != nil { + logger.Error("failed to decode attached image", "path", lastImg, "error", err) + // Fall back to showing agent image + loadImage() + } else { + imgView.SetImage(img) + } + } + } else { + // No attached image, show agent image as before + loadImage() + } pages.AddPage(imgPage, imgView, true, true) return nil } - if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" { + if event.Key() == tcell.KeyCtrlR && cfg.STT_ENABLED { + defer updateStatusLine() + if asr.IsRecording() { + userSpeech, err := asr.StopRecording() + if err != nil { + msg := "failed to inference user speech; error:" + err.Error() + logger.Error(msg) + if err := notifyUser("stt error", msg); err != nil { + logger.Error("failed to notify user", "error", err) + } + return nil + } + if userSpeech != "" { + // append indtead of replacing + prevText := textArea.GetText() + textArea.SetText(prevText+userSpeech, true) + } else { + logger.Warn("empty user speech") + } + return nil + } + if err := asr.StartRecording(); err != nil { + logger.Error("failed to start recording user speech", "error", err) + return nil + } + } + // I need keybind for tts to shut up + if event.Key() == tcell.KeyCtrlA { + // textArea.SetText("pressed ctrl+A", true) + if cfg.TTS_ENABLED { + // audioStream.TextChan <- chunk + extra.TTSDoneChan <- true + } + } + if event.Key() == tcell.KeyCtrlW { + // INFO: continue bot/text message + // without new role + lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role + go chatRound("", lastRole, textView, false, true) + return nil + } + if event.Key() == tcell.KeyCtrlQ { + persona := cfg.UserRole + if cfg.WriteNextMsgAs != "" { + persona = cfg.WriteNextMsgAs + } + roles := listRolesWithUser() + logger.Info("list roles", "roles", roles) + for i, role := range roles { + if strings.EqualFold(role, persona) { + if i == len(roles)-1 { + cfg.WriteNextMsgAs = roles[0] // reached last, get first + break + } + cfg.WriteNextMsgAs = roles[i+1] // get next role + logger.Info("picked role", "roles", roles, "index", i+1) + break + } + } + updateStatusLine() + return nil + } + if event.Key() == tcell.KeyCtrlX { + persona := cfg.AssistantRole + if cfg.WriteNextMsgAsCompletionAgent != "" { + persona = cfg.WriteNextMsgAsCompletionAgent + } + roles := chatBody.ListRoles() + if len(roles) == 0 { + logger.Warn("empty roles in chat") + } + if !strInSlice(cfg.AssistantRole, roles) { + roles = append(roles, cfg.AssistantRole) + } + for i, role := range roles { + if strings.EqualFold(role, persona) { + if i == len(roles)-1 { + cfg.WriteNextMsgAsCompletionAgent = roles[0] // reached last, get first + break + } + cfg.WriteNextMsgAsCompletionAgent = roles[i+1] // get next role + logger.Info("picked role", "roles", roles, "index", i+1) + break + } + } + updateStatusLine() + return nil + } + if event.Key() == tcell.KeyCtrlG { + // cfg.RAGDir is the directory with files to use with RAG // rag load // menu of the text files from defined rag directory files, err := os.ReadDir(cfg.RAGDir) if err != nil { - logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err) - return nil + // Check if the error is because the directory doesn't exist + if os.IsNotExist(err) { + // Create the RAG directory if it doesn't exist + if mkdirErr := os.MkdirAll(cfg.RAGDir, 0755); mkdirErr != nil { + logger.Error("failed to create RAG directory", "dir", cfg.RAGDir, "error", mkdirErr) + if notifyerr := notifyUser("failed to create RAG directory", mkdirErr.Error()); notifyerr != nil { + logger.Error("failed to send notification", "error", notifyerr) + } + return nil + } + // Now try to read the directory again after creating it + files, err = os.ReadDir(cfg.RAGDir) + if err != nil { + logger.Error("failed to read dir after creating it", "dir", cfg.RAGDir, "error", err) + if notifyerr := notifyUser("failed to read RAG directory", err.Error()); notifyerr != nil { + logger.Error("failed to send notification", "error", notifyerr) + } + return nil + } + } else { + // Other error (permissions, etc.) + logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err) + if notifyerr := notifyUser("failed to open RAG files dir", err.Error()); notifyerr != nil { + logger.Error("failed to send notification", "error", notifyerr) + } + return nil + } } fileList := []string{} for _, f := range files { @@ -650,34 +1191,73 @@ func init() { pages.AddPage(RAGPage, chatRAGTable, true, true) return nil } - if event.Key() == tcell.KeyCtrlW { - // INFO: continue bot/text message - // without new role - lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role - go chatRound("", lastRole, textView, false, true) + if event.Key() == tcell.KeyCtrlY { // Use Ctrl+Y to list loaded RAG files + // List files already loaded into the RAG system + fileList, err := ragger.ListLoaded() + if err != nil { + logger.Error("failed to list loaded RAG files", "error", err) + if notifyerr := notifyUser("failed to list RAG files", err.Error()); notifyerr != nil { + logger.Error("failed to send notification", "error", notifyerr) + } + return nil + } + chatLoadedRAGTable := makeLoadedRAGTable(fileList) + pages.AddPage(RAGLoadedPage, chatLoadedRAGTable, true, true) + return nil + } + if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '1' { + // Toggle shell mode: when enabled, commands are executed locally instead of sent to LLM + toggleShellMode() return nil } // cannot send msg in editMode or botRespMode if event.Key() == tcell.KeyEscape && !editMode && !botRespMode { - // read all text into buffer msgText := textArea.GetText() - nl := "\n" - prevText := textView.GetText(true) - // strings.LastIndex() - // newline is not needed is prev msg ends with one - if strings.HasSuffix(prevText, nl) { - nl = "" - } - if msgText != "" { - // add user icon before user msg - fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", - nl, len(chatBody.Messages), cfg.UserRole, msgText) - textArea.SetText("", true) - textView.ScrollToEnd() - colorText() + + if shellMode && msgText != "" { + // In shell mode, execute command instead of sending to LLM + executeCommandAndDisplay(msgText) + textArea.SetText("", true) // Clear the input area + return nil + } else if !shellMode { + // Normal mode - send to LLM + nl := "\n" + prevText := textView.GetText(true) + persona := cfg.UserRole + // strings.LastIndex() + // newline is not needed is prev msg ends with one + if strings.HasSuffix(prevText, nl) { + nl = "" + } + if msgText != "" { + // as what char user sends msg? + if cfg.WriteNextMsgAs != "" { + persona = cfg.WriteNextMsgAs + } + // check if plain text + if !injectRole { + matches := roleRE.FindStringSubmatch(msgText) + if len(matches) > 1 { + persona = matches[1] + msgText = strings.TrimLeft(msgText[len(matches[0]):], " ") + } + } + // add user icon before user msg + fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", + nl, len(chatBody.Messages), persona, msgText) + textArea.SetText("", true) + textView.ScrollToEnd() + colorText() + } + go chatRound(msgText, persona, textView, false, false) + // Also clear any image attachment after sending the message + go func() { + // Wait a short moment for the message to be processed, then clear the image attachment + // This allows the image to be sent with the current message if it was attached + // But clears it for the next message + ClearImageAttachment() + }() } - // update statue line - go chatRound(msgText, cfg.UserRole, textView, false, false) return nil } if event.Key() == tcell.KeyPgUp || event.Key() == tcell.KeyPgDn { @@ -685,6 +1265,7 @@ func init() { app.SetFocus(focusSwitcher[currentF]) return nil } + if isASCII(string(event.Rune())) && !botRespMode { return event } diff --git a/vec0.so b/vec0.so Binary files differdeleted file mode 100755 index bd4c3ca..0000000 --- a/vec0.so +++ /dev/null |
