Skip to content

Commit

Permalink
Fixed GPU support - fixes #9
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 13, 2023
1 parent b6319a9 commit abf1124
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.8 (unreleased)

- Fixed GPU support

## 0.7.7 (2023-07-24)

- Updated ONNX Runtime to 1.15.1
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ To enable GPU support on Linux and Windows, download the appropriate [GPU releas
OnnxRuntime.ffi_lib = "path/to/lib/libonnxruntime.so" # onnxruntime.dll for Windows
```

and use: [unreleased]

```ruby
model = OnnxRuntime::Model.new("model.onnx", providers: ["CUDAExecutionProvider"])
```

## History

View the [changelog](https://github.com/ankane/onnxruntime-ruby/blob/master/CHANGELOG.md)
Expand Down
10 changes: 5 additions & 5 deletions lib/onnxruntime/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ class Api < ::FFI::Struct
:SetGlobalCustomJoinThreadFn, callback(%i[], :pointer),
:SynchronizeBoundInputs, callback(%i[], :pointer),
:SynchronizeBoundOutputs, callback(%i[], :pointer),
:SessionOptionsAppendExecutionProvider_CUDA_V2, callback(%i[], :pointer),
:CreateCUDAProviderOptions, callback(%i[], :pointer),
:UpdateCUDAProviderOptions, callback(%i[], :pointer),
:GetCUDAProviderOptionsAsString, callback(%i[], :pointer),
:ReleaseCUDAProviderOptions, callback(%i[], :pointer),
:SessionOptionsAppendExecutionProvider_CUDA_V2, callback(%i[pointer pointer], :pointer),
:CreateCUDAProviderOptions, callback(%i[pointer], :pointer),
:UpdateCUDAProviderOptions, callback(%i[pointer pointer pointer size_t], :pointer),
:GetCUDAProviderOptionsAsString, callback(%i[pointer pointer pointer], :pointer),
:ReleaseCUDAProviderOptions, callback(%i[pointer], :void),
:SessionOptionsAppendExecutionProvider_MIGraphX, callback(%i[], :pointer)
end

Expand Down
12 changes: 11 additions & 1 deletion lib/onnxruntime/inference_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OnnxRuntime
class InferenceSession
attr_reader :inputs, :outputs

def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil)
def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: [])
# session options
session_options = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateSessionOptions].call(session_options)
Expand Down Expand Up @@ -54,6 +54,16 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr
check_status api[:AddSessionConfigEntry].call(session_options.read_pointer, k.to_s, v.to_s)
end
end
providers.each do |provider|
if provider == "CUDAExecutionProvider"
cuda_options = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CreateCUDAProviderOptions].call(cuda_options)
check_status api[:SessionOptionsAppendExecutionProvider_CUDA_V2].call(session_options.read_pointer, cuda_options.read_pointer)
release :CUDAProviderOptions, cuda_options
else
raise ArgumentError, "Provider not supported: #{provider}"
end
end

@session = load_session(path_or_bytes, session_options)
ObjectSpace.define_finalizer(@session, self.class.finalize(read_pointer.to_i))
Expand Down
8 changes: 8 additions & 0 deletions test/model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def test_providers
assert_includes sess.providers, "CPUExecutionProvider"
end

def test_providers_cuda
# TODO fallback to CPU without error
error = assert_raises(OnnxRuntime::Error) do
OnnxRuntime::InferenceSession.new("test/support/model.onnx", providers: ["CUDAExecutionProvider"])
end
assert_equal "CUDA execution provider is not enabled in this build.", error.message
end

def test_profiling
sess = OnnxRuntime::InferenceSession.new("test/support/model.onnx", enable_profiling: true)
file = sess.end_profiling
Expand Down

0 comments on commit abf1124

Please sign in to comment.