diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-03-06 09:11:25 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-03-06 09:11:25 +0300 |
| commit | d2caebdb4fd3ad148aad20866503b7d46d546404 (patch) | |
| tree | 8f59ef20824764b471d4633f044049b779df0e9c /rag/embedder.go | |
| parent | e1f2a8cd7be487a3b4284ca70cc5a2a64b50f5d1 (diff) | |
Enha (onnx): use gpu
Diffstat (limited to 'rag/embedder.go')
| -rw-r--r-- | rag/embedder.go | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index 59dbfd2..13f6a6e 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -156,14 +156,22 @@ type ONNXEmbedder struct { var onnxInitOnce sync.Once var onnxReady bool var onnxLibPath string +var cudaLibPath string var onnxLibPaths = []string{ "/usr/lib/libonnxruntime.so", + "/usr/lib/libonnxruntime.so.1.24.2", "/usr/local/lib/libonnxruntime.so", "/usr/lib/x86_64-linux-gnu/libonnxruntime.so", "/opt/onnxruntime/lib/libonnxruntime.so", } +var cudaLibPaths = []string{ + "/usr/lib/libonnxruntime_providers_cuda.so", + "/usr/local/lib/libonnxruntime_providers_cuda.so", + "/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so", +} + func findONNXLibrary() string { for _, path := range onnxLibPaths { if _, err := os.Stat(path); err == nil { @@ -173,6 +181,15 @@ func findONNXLibrary() string { return "" } +func findCUDALibrary() string { + for _, path := range cudaLibPaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} + func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { // Check if model and tokenizer files exist if _, err := os.Stat(modelPath); err != nil { @@ -188,6 +205,12 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log return nil, errors.New("ONNX runtime library not found in standard locations") } + // Find CUDA provider library (optional) + cudaLibPath = findCUDALibrary() + if cudaLibPath == "" { + fmt.Println("WARNING: CUDA provider library not found, will use CPU") + } + emb := &ONNXEmbedder{ tokenizerPath: tokenizerPath, dims: dims, @@ -223,16 +246,56 @@ func (e *ONNXEmbedder) ensureInitialized() error { onnxReady = false return } + // Register CUDA provider if available + if cudaLibPath != "" { + if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil { + e.logger.Warn("failed to register CUDA provider", "error", err) + } + } onnxReady = true }) if !onnxReady { return errors.New("ONNX runtime not ready") } + + // Create session options + opts, err := onnxruntime_go.NewSessionOptions() + if err != nil { + return fmt.Errorf("failed to create session options: %w", err) + } + defer opts.Destroy() + + // Try to add CUDA provider + useCUDA := cudaLibPath != "" + if useCUDA { + cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions() + if err != nil { + e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err) + useCUDA = false + } else { + defer cudaOpts.Destroy() + if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil { + e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err) + useCUDA = false + } else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil { + e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err) + useCUDA = false + } + } + } + + if useCUDA { + e.logger.Info("Using CUDA for ONNX inference") + } else { + e.logger.Info("Using CPU for ONNX inference") + } + + // Create session with options session, err := onnxruntime_go.NewDynamicAdvancedSession( e.getModelPath(), []string{"input_ids", "attention_mask"}, []string{"sentence_embedding"}, - nil, + opts, ) if err != nil { return fmt.Errorf("failed to create ONNX session: %w", err) @@ -304,6 +367,9 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { } func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { + if err := e.ensureInitialized(); err != nil { + return nil, err + } encodings := make([]*tokenizer.Encoding, len(texts)) maxLen := 0 for i, txt := range texts { |
