summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-05 20:02:46 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-05 20:02:46 +0300
commitefc92d884c36498220e2b8d5ad9e02f84e42d953 (patch)
tree6361de7107d077e39c1aeb312013041d40f73167 /rag
parentac8c8bb0558a00cf0d025ab8522aaa57b8cba7de (diff)
Chore: onnx library lookup
Diffstat (limited to 'rag')
-rw-r--r--rag/embedder.go113
-rw-r--r--rag/rag.go39
2 files changed, 113 insertions, 39 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index b0a3226..59dbfd2 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -9,6 +9,7 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
+ "os"
"sync"
"github.com/sugarme/tokenizer"
@@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
type ONNXEmbedder struct {
- session *onnxruntime_go.DynamicAdvancedSession
- tokenizer *tokenizer.Tokenizer
- dims int // embedding dimension (e.g., 768)
- logger *slog.Logger
+ session *onnxruntime_go.DynamicAdvancedSession
+ tokenizer *tokenizer.Tokenizer
+ tokenizerPath string
+ dims int
+ logger *slog.Logger
+ mu sync.Mutex
+ modelPath string
}
var onnxInitOnce sync.Once
+var onnxReady bool
+var onnxLibPath string
+
+var onnxLibPaths = []string{
+ "/usr/lib/libonnxruntime.so",
+ "/usr/local/lib/libonnxruntime.so",
+ "/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
+ "/opt/onnxruntime/lib/libonnxruntime.so",
+}
+
+func findONNXLibrary() string {
+ for _, path := range onnxLibPaths {
+ if _, err := os.Stat(path); err == nil {
+ return path
+ }
+ }
+ return ""
+}
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
- // Initialize ONNX runtime environment once
- onnxInitOnce.Do(func() {
- onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so")
- err := onnxruntime_go.InitializeEnvironment()
+ // Check if model and tokenizer files exist
+ if _, err := os.Stat(modelPath); err != nil {
+ return nil, fmt.Errorf("ONNX model not found: %w", err)
+ }
+ if _, err := os.Stat(tokenizerPath); err != nil {
+ return nil, fmt.Errorf("tokenizer not found: %w", err)
+ }
+
+ // Find ONNX library
+ onnxLibPath = findONNXLibrary()
+ if onnxLibPath == "" {
+ return nil, errors.New("ONNX runtime library not found in standard locations")
+ }
+
+ emb := &ONNXEmbedder{
+ tokenizerPath: tokenizerPath,
+ dims: dims,
+ logger: logger,
+ modelPath: modelPath,
+ }
+ return emb, nil
+}
+
+func (e *ONNXEmbedder) ensureInitialized() error {
+ if e.session != nil {
+ return nil
+ }
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.session != nil {
+ return nil
+ }
+
+ // Load tokenizer lazily
+ if e.tokenizer == nil {
+ tok, err := pretrained.FromFile(e.tokenizerPath)
if err != nil {
- logger.Error("failed to initialize ONNX runtime", "error", err)
+ return fmt.Errorf("failed to load tokenizer: %w", err)
+ }
+ e.tokenizer = tok
+ }
+
+ onnxInitOnce.Do(func() {
+ onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
+ if err := onnxruntime_go.InitializeEnvironment(); err != nil {
+ e.logger.Error("failed to initialize ONNX runtime", "error", err)
+ onnxReady = false
+ return
}
+ onnxReady = true
})
- // Load tokenizer using sugarme/tokenizer
- tok, err := pretrained.FromFile(tokenizerPath)
- if err != nil {
- return nil, fmt.Errorf("failed to load tokenizer: %w", err)
+ if !onnxReady {
+ return errors.New("ONNX runtime not ready")
}
- // Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession(
- modelPath, // onnx/embedgemma/model_q4.onnx
+ e.getModelPath(),
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
- nil, // optional options
+ nil,
)
if err != nil {
- return nil, fmt.Errorf("failed to create ONNX session: %w", err)
- }
- return &ONNXEmbedder{
- session: session,
- tokenizer: tok,
- dims: dims,
- logger: logger,
- }, nil
+ return fmt.Errorf("failed to create ONNX session: %w", err)
+ }
+ e.session = session
+ return nil
+}
+
+func (e *ONNXEmbedder) getModelPath() string {
+ return e.modelPath
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
+ if err := e.ensureInitialized(); err != nil {
+ return nil, err
+ }
// 1. Tokenize
encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil {
diff --git a/rag/rag.go b/rag/rag.go
index 654afde..fa30303 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -25,20 +25,23 @@ var (
)
type RAG struct {
- logger *slog.Logger
- store storage.FullRepo
- cfg *config.Config
- embedder Embedder
- storage *VectorStorage
- mu sync.Mutex
+ logger *slog.Logger
+ store storage.FullRepo
+ cfg *config.Config
+ embedder Embedder
+ storage *VectorStorage
+ mu sync.Mutex
+ fallbackMsg string
}
-func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
+func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
var embedder Embedder
+ var fallbackMsg string
if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
if err != nil {
l.Error("failed to create ONNX embedder, falling back to API", "error", err)
+ fallbackMsg = err.Error()
embedder = NewAPIEmbedder(l, cfg)
} else {
embedder = emb
@@ -49,16 +52,17 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
l.Info("using API embedder", "url", cfg.EmbedURL)
}
rag := &RAG{
- logger: l,
- store: s,
- cfg: cfg,
- embedder: embedder,
- storage: NewVectorStorage(l, s),
+ logger: l,
+ store: s,
+ cfg: cfg,
+ embedder: embedder,
+ storage: NewVectorStorage(l, s),
+ fallbackMsg: fallbackMsg,
}
// Note: Vector tables are created via database migrations, not at runtime
- return rag
+ return rag, nil
}
func wordCounter(sentence string) int {
@@ -449,14 +453,19 @@ var (
ragOnce sync.Once
)
+func (r *RAG) FallbackMessage() string {
+ return r.fallbackMsg
+}
+
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
+ var err error
ragOnce.Do(func() {
if c == nil || l == nil || s == nil {
return
}
- ragInstance = New(l, s, c)
+ ragInstance, err = New(l, s, c)
})
- return nil
+ return err
}
func GetInstance() *RAG {