summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bot.go7
-rw-r--r--config.example.toml3
-rw-r--r--config/config.go7
-rw-r--r--rag/embedder.go23
-rw-r--r--rag/rag.go16
5 files changed, 43 insertions, 13 deletions
diff --git a/bot.go b/bot.go
index 13ee074..5463800 100644
--- a/bot.go
+++ b/bot.go
@@ -1393,12 +1393,13 @@ func updateModelLists() {
}
}
// if llama.cpp started after gf-lt?
- localModelsMu.Lock()
- LocalModels, err = fetchLCPModelsWithLoadStatus()
- localModelsMu.Unlock()
+ ml, err := fetchLCPModelsWithLoadStatus()
if err != nil {
logger.Warn("failed to fetch llama.cpp models", "error", err)
}
+ localModelsMu.Lock()
+ LocalModels = ml
+ localModelsMu.Unlock()
// set already loaded model in llama.cpp
if strings.Contains(cfg.CurrentAPI, "localhost") || strings.Contains(cfg.CurrentAPI, "127.0.0.1") {
localModelsMu.Lock()
diff --git a/config.example.toml b/config.example.toml
index 39a730b..f5820da 100644
--- a/config.example.toml
+++ b/config.example.toml
@@ -13,6 +13,9 @@ OpenRouterChatAPI = "https://openrouter.ai/api/v1/chat/completions"
# embeddings
EmbedURL = "http://localhost:8082/v1/embeddings"
HFToken = ""
+EmbedModelPath = "onnx/embedgemma/model_q4.onnx"
+EmbedTokenizerPath = "onnx/embedgemma/tokenizer.json"
+EmbedDims = 768
#
ShowSys = true
LogFile = "log.txt"
diff --git a/config/config.go b/config/config.go
index 412eaaa..84ec480 100644
--- a/config/config.go
+++ b/config/config.go
@@ -34,8 +34,11 @@ type Config struct {
ImagePreview bool `toml:"ImagePreview"`
EnableMouse bool `toml:"EnableMouse"`
// embeddings
- EmbedURL string `toml:"EmbedURL"`
- HFToken string `toml:"HFToken"`
+ EmbedURL string `toml:"EmbedURL"`
+ HFToken string `toml:"HFToken"`
+ EmbedModelPath string `toml:"EmbedModelPath"`
+ EmbedTokenizerPath string `toml:"EmbedTokenizerPath"`
+ EmbedDims int `toml:"EmbedDims"`
// rag settings
RAGEnabled bool `toml:"RAGEnabled"`
RAGDir string `toml:"RAGDir"`
diff --git a/rag/embedder.go b/rag/embedder.go
index 6903a5d..b0a3226 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -9,6 +9,7 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
+ "sync"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
@@ -148,7 +149,17 @@ type ONNXEmbedder struct {
logger *slog.Logger
}
+var onnxInitOnce sync.Once
+
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
+ // Initialize ONNX runtime environment once
+ onnxInitOnce.Do(func() {
+ onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
+ err := onnxruntime_go.InitializeEnvironment()
+ if err != nil {
+ logger.Error("failed to initialize ONNX runtime", "error", err)
+ }
+ })
// Load tokenizer using sugarme/tokenizer
tok, err := pretrained.FromFile(tokenizerPath)
if err != nil {
@@ -195,7 +206,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err != nil {
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
}
- defer inputIDsTensor.Destroy()
+ defer func() { _ = inputIDsTensor.Destroy() }()
maskTensor, err := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(1, seqLen),
attentionMask,
@@ -203,7 +214,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err != nil {
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
}
- defer maskTensor.Destroy()
+ defer func() { _ = maskTensor.Destroy() }()
// 4. Create output tensor
outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(1, int64(e.dims)),
@@ -211,7 +222,7 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err != nil {
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
- defer outputTensor.Destroy()
+ defer func() { _ = outputTensor.Destroy() }()
// 5. Run inference
err = e.session.Run(
[]onnxruntime_go.Value{inputIDsTensor, maskTensor},
@@ -257,16 +268,16 @@ func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
inputIDs,
)
- defer inputTensor.Destroy()
+ defer func() { _ = inputTensor.Destroy() }()
maskTensor, _ := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
attentionMask,
)
- defer maskTensor.Destroy()
+ defer func() { _ = maskTensor.Destroy() }()
outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
)
- defer outputTensor.Destroy()
+ defer func() { _ = outputTensor.Destroy() }()
err := e.session.Run(
[]onnxruntime_go.Value{inputTensor, maskTensor},
[]onnxruntime_go.Value{outputTensor},
diff --git a/rag/rag.go b/rag/rag.go
index 3d0f38f..654afde 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -34,8 +34,20 @@ type RAG struct {
}
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
- // Initialize with API embedder by default, could be configurable later
- embedder := NewAPIEmbedder(l, cfg)
+ var embedder Embedder
+ if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
+ emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
+ if err != nil {
+ l.Error("failed to create ONNX embedder, falling back to API", "error", err)
+ embedder = NewAPIEmbedder(l, cfg)
+ } else {
+ embedder = emb
+ l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims)
+ }
+ } else {
+ embedder = NewAPIEmbedder(l, cfg)
+ l.Info("using API embedder", "url", cfg.EmbedURL)
+ }
rag := &RAG{
logger: l,
store: s,