summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-06 09:11:25 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-06 09:11:25 +0300
commitd2caebdb4fd3ad148aad20866503b7d46d546404 (patch)
tree8f59ef20824764b471d4633f044049b779df0e9c /rag
parente1f2a8cd7be487a3b4284ca70cc5a2a64b50f5d1 (diff)
Enha (onnx): use gpu
Diffstat (limited to 'rag')
-rw-r--r--rag/embedder.go68
1 files changed, 67 insertions, 1 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 59dbfd2..13f6a6e 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -156,14 +156,22 @@ type ONNXEmbedder struct {
var onnxInitOnce sync.Once
var onnxReady bool
var onnxLibPath string
+var cudaLibPath string
var onnxLibPaths = []string{
"/usr/lib/libonnxruntime.so",
+ "/usr/lib/libonnxruntime.so.1.24.2",
"/usr/local/lib/libonnxruntime.so",
"/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
}
+var cudaLibPaths = []string{
+ "/usr/lib/libonnxruntime_providers_cuda.so",
+ "/usr/local/lib/libonnxruntime_providers_cuda.so",
+ "/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so",
+}
+
func findONNXLibrary() string {
for _, path := range onnxLibPaths {
if _, err := os.Stat(path); err == nil {
@@ -173,6 +181,15 @@ func findONNXLibrary() string {
return ""
}
+func findCUDALibrary() string {
+ for _, path := range cudaLibPaths {
+ if _, err := os.Stat(path); err == nil {
+ return path
+ }
+ }
+ return ""
+}
+
func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
// Check if model and tokenizer files exist
if _, err := os.Stat(modelPath); err != nil {
@@ -188,6 +205,12 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
return nil, errors.New("ONNX runtime library not found in standard locations")
}
+ // Find CUDA provider library (optional)
+ cudaLibPath = findCUDALibrary()
+ if cudaLibPath == "" {
+ fmt.Println("WARNING: CUDA provider library not found, will use CPU")
+ }
+
emb := &ONNXEmbedder{
tokenizerPath: tokenizerPath,
dims: dims,
@@ -223,16 +246,56 @@ func (e *ONNXEmbedder) ensureInitialized() error {
onnxReady = false
return
}
+ // Register CUDA provider if available
+ if cudaLibPath != "" {
+ if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil {
+ e.logger.Warn("failed to register CUDA provider", "error", err)
+ }
+ }
onnxReady = true
})
if !onnxReady {
return errors.New("ONNX runtime not ready")
}
+
+ // Create session options
+ opts, err := onnxruntime_go.NewSessionOptions()
+ if err != nil {
+ return fmt.Errorf("failed to create session options: %w", err)
+ }
+ defer opts.Destroy()
+
+ // Try to add CUDA provider
+ useCUDA := cudaLibPath != ""
+ if useCUDA {
+ cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions()
+ if err != nil {
+ e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
+ useCUDA = false
+ } else {
+ defer cudaOpts.Destroy()
+ if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
+ e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
+ useCUDA = false
+ } else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil {
+ e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err)
+ useCUDA = false
+ }
+ }
+ }
+
+ if useCUDA {
+ e.logger.Info("Using CUDA for ONNX inference")
+ } else {
+ e.logger.Info("Using CPU for ONNX inference")
+ }
+
+ // Create session with options
session, err := onnxruntime_go.NewDynamicAdvancedSession(
e.getModelPath(),
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
- nil,
+ opts,
)
if err != nil {
return fmt.Errorf("failed to create ONNX session: %w", err)
@@ -304,6 +367,9 @@ func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
}
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
+ if err := e.ensureInitialized(); err != nil {
+ return nil, err
+ }
encodings := make([]*tokenizer.Encoding, len(texts))
maxLen := 0
for i, txt := range texts {