diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index 2e2ca87..814038c 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -264,7 +264,9 @@ def create_input_tensor(input_feed) inp = @inputs.find { |i| i[:name] == input_name.to_s } raise Error, "Unknown input: #{input_name}" unless inp - if inp[:type] == "tensor(string)" + if input.is_a?(OrtValue) + input + elsif inp[:type] == "tensor(string)" OrtValue.ortvalue_from_array(input, element_type: :string) elsif (tensor_type = tensor_types[inp[:type]]) OrtValue.ortvalue_from_array(input, element_type: tensor_type) diff --git a/test/inference_session_test.rb b/test/inference_session_test.rb index 045fa68..77d2bf9 100644 --- a/test/inference_session_test.rb +++ b/test/inference_session_test.rb @@ -22,6 +22,13 @@ def test_run_with_ort_values_invalid_type assert_equal "Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(float))", error.message end + def test_run_ort_value_input + sess = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx") + x = OnnxRuntime::OrtValue.ortvalue_from_numo(Numo::SFloat.cast([[5.8, 2.8]])) + output = sess.run(nil, {input: x}) + assert_equal [1], output[0] + end + def test_providers sess = OnnxRuntime::InferenceSession.new("test/support/model.onnx") assert_includes sess.providers, "CPUExecutionProvider"