summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--storage/storage_test.go171
-rw-r--r--tables.go13
-rw-r--r--tools.go15
3 files changed, 102 insertions, 97 deletions
diff --git a/storage/storage_test.go b/storage/storage_test.go
index ff3b5e6..07dd3e7 100644
--- a/storage/storage_test.go
+++ b/storage/storage_test.go
@@ -3,16 +3,13 @@ package storage
import (
"elefant/models"
"fmt"
- "log"
"log/slog"
"os"
"testing"
"time"
- sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
- "github.com/ncruces/go-sqlite3"
)
func TestMemories(t *testing.T) {
@@ -177,87 +174,87 @@ func TestChatHistory(t *testing.T) {
}
}
-func TestVecTable(t *testing.T) {
- // healthcheck
- db, err := sqlite3.Open(":memory:")
- if err != nil {
- t.Fatal(err)
- }
- stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
- if err != nil {
- t.Fatal(err)
- }
- stmt.Step()
- log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
- stmt.Close()
- // migration
- err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
- if err != nil {
- t.Fatal(err)
- }
- // data prep and insert
- items := map[int][]float32{
- 1: {0.1, 0.1, 0.1, 0.1},
- 2: {0.2, 0.2, 0.2, 0.2},
- 3: {0.3, 0.3, 0.3, 0.3},
- 4: {0.4, 0.4, 0.4, 0.4},
- 5: {0.5, 0.5, 0.5, 0.5},
- }
- q := []float32{0.28, 0.3, 0.3, 0.3}
- stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
- if err != nil {
- t.Fatal(err)
- }
- for id, values := range items {
- v, err := sqlite_vec.SerializeFloat32(values)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindInt(1, id)
- stmt.BindBlob(2, v)
- stmt.BindText(3, "some_chat")
- err = stmt.Exec()
- if err != nil {
- t.Fatal(err)
- }
- stmt.Reset()
- }
- stmt.Close()
- // select | vec search
- stmt, _, err = db.Prepare(`
- SELECT
- rowid,
- distance,
- embedding
- FROM vec_items
- WHERE embedding MATCH ?
- ORDER BY distance
- LIMIT 3
- `)
- if err != nil {
- t.Fatal(err)
- }
- query, err := sqlite_vec.SerializeFloat32(q)
- if err != nil {
- t.Fatal(err)
- }
- stmt.BindBlob(1, query)
- for stmt.Step() {
- rowid := stmt.ColumnInt64(0)
- distance := stmt.ColumnFloat(1)
- emb := stmt.ColumnRawText(2)
- floats := decodeUnsafe(emb)
- log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
- }
- if err := stmt.Err(); err != nil {
- t.Fatal(err)
- }
- err = stmt.Close()
- if err != nil {
- t.Fatal(err)
- }
- err = db.Close()
- if err != nil {
- t.Fatal(err)
- }
-}
+// func TestVecTable(t *testing.T) {
+// // healthcheck
+// db, err := sqlite3.Open(":memory:")
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Step()
+// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
+// stmt.Close()
+// // migration
+// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// // data prep and insert
+// items := map[int][]float32{
+// 1: {0.1, 0.1, 0.1, 0.1},
+// 2: {0.2, 0.2, 0.2, 0.2},
+// 3: {0.3, 0.3, 0.3, 0.3},
+// 4: {0.4, 0.4, 0.4, 0.4},
+// 5: {0.5, 0.5, 0.5, 0.5},
+// }
+// q := []float32{0.4, 0.3, 0.3, 0.3}
+// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// for id, values := range items {
+// v, err := sqlite_vec.SerializeFloat32(values)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindInt(1, id)
+// stmt.BindBlob(2, v)
+// stmt.BindText(3, "some_chat")
+// err = stmt.Exec()
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Reset()
+// }
+// stmt.Close()
+// // select | vec search
+// stmt, _, err = db.Prepare(`
+// SELECT
+// rowid,
+// distance,
+// embedding
+// FROM vec_items
+// WHERE embedding MATCH ?
+// ORDER BY distance
+// LIMIT 3
+// `)
+// if err != nil {
+// t.Fatal(err)
+// }
+// query, err := sqlite_vec.SerializeFloat32(q)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindBlob(1, query)
+// for stmt.Step() {
+// rowid := stmt.ColumnInt64(0)
+// distance := stmt.ColumnFloat(1)
+// emb := stmt.ColumnRawText(2)
+// floats := decodeUnsafe(emb)
+// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
+// }
+// if err := stmt.Err(); err != nil {
+// t.Fatal(err)
+// }
+// err = stmt.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// err = db.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// }
diff --git a/tables.go b/tables.go
index 4dc36d9..16b90ee 100644
--- a/tables.go
+++ b/tables.go
@@ -16,7 +16,7 @@ import (
)
func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
- actions := []string{"load", "rename", "delete", "update card"}
+ actions := []string{"load", "rename", "delete", "update card", "move sysprompt onto 1st msg"}
chatList := make([]string, len(chatMap))
i := 0
for name := range chatMap {
@@ -26,9 +26,7 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
rows, cols := len(chatMap), len(actions)+2
chatActTable := tview.NewTable().
SetBorders(true)
- // for chatName, chat := range chatMap {
for r := 0; r < rows; r++ {
- // r := 0
for c := 0; c < cols; c++ {
color := tcell.ColorWhite
switch c {
@@ -49,7 +47,6 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
SetAlign(tview.AlignCenter))
}
}
- // r++
}
chatActTable.Select(0, 0).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEsc || key == tcell.KeyF1 {
@@ -65,7 +62,6 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
chatActTable.SetSelectable(false, false)
selectedChat := chatList[row]
defer pages.RemovePage(historyPage)
- // notification := fmt.Sprintf("chat: %s; action: %s", selectedChat, tc.Text)
switch tc.Text {
case "load":
history, err := loadHistoryChat(selectedChat)
@@ -128,6 +124,13 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table {
"error", err)
}
return
+ case "move sysprompt onto 1st msg":
+ chatBody.Messages[1].Content = chatBody.Messages[0].Content + chatBody.Messages[1].Content
+ chatBody.Messages[0].Content = rpDefenitionSysMsg
+ textView.SetText(chatToText(cfg.ShowSys))
+ activeChatName = selectedChat
+ pages.RemovePage(historyPage)
+ return
default:
return
}
diff --git a/tools.go b/tools.go
index a380bf5..50b3d24 100644
--- a/tools.go
+++ b/tools.go
@@ -9,11 +9,16 @@ import (
)
var (
- toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`)
- quotesRE = regexp.MustCompile(`(".*?")`)
- starRE = regexp.MustCompile(`(\*.*?\*)`)
- thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`)
- codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`)
+ toolCallRE = regexp.MustCompile(`__tool_call__\s*([\s\S]*?)__tool_call__`)
+ quotesRE = regexp.MustCompile(`(".*?")`)
+ starRE = regexp.MustCompile(`(\*.*?\*)`)
+ thinkRE = regexp.MustCompile(`(<think>\s*([\s\S]*?)</think>)`)
+ codeBlockRE = regexp.MustCompile(`(?s)\x60{3}(?:.*?)\n(.*?)\n\s*\x60{3}\s*`)
+ rpDefenitionSysMsg = `
+For this roleplay immersion is at most importance.
+Every character thinks and acts based on their personality and setting of the roleplay.
+Meta discussions outside of roleplay is allowed if clearly labeled as out of character, for example: (ooc: {msg}) or <ooc>{msg}</ooc>.
+`
basicSysMsg = `Large Language Model that helps user with any of his requests.`
toolSysMsg = `You can do functions call if needed.
Your current tools: