diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 527616c46b2a..d2240b5d5019 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -34,7 +34,6 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"GatherElements", {}}, - {"GatherND", {}}, {"Reshape", {}}, {"Expand", {}}, {"Flatten", {}}, diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 12cda8748fdd..dd5c6a5a79cd 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -111,7 +111,6 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateGatherOpBuilder("Gather", *this); CreateGatherOpBuilder("GatherElements", *this); - CreateGatherOpBuilder("GatherND", *this); } { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index 35dfddfd3f24..992e1243dacf 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -51,54 +52,62 @@ class GatherElementsOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; }; -class GatherNDOpBuilder : public BaseOpBuilder { - public: - GatherNDOpBuilder() : BaseOpBuilder("GatherNDOpBuilder") {} - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherNDOpBuilder); - - protected: - Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, - std::vector& input_names, - bool do_op_validation) const override ORT_MUST_USE_RESULT; - - Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const override ORT_MUST_USE_RESULT; -}; - -// Converts int64 indices to another integer type (typically int32 or uint32). +// Makes negative indices positive and converts int64 indices to another integer type (typically int32 or uint32). // The input and output are both represented as byte arrays. -template -static void ConvertInt64IndicesBytes(const std::vector& onnx_bytes, std::vector& qnn_bytes) { - const size_t num_elems = onnx_bytes.size() / sizeof(uint64_t); - gsl::span onnx_indices{reinterpret_cast(onnx_bytes.data()), num_elems}; +template +static void FixStaticIndices(const std::vector& onnx_bytes, + int64_t input0_axis_dim, + /*out*/ std::vector& qnn_bytes) { + const size_t num_elems = onnx_bytes.size() / sizeof(SrcType); + gsl::span onnx_indices{reinterpret_cast(onnx_bytes.data()), num_elems}; + + qnn_bytes.resize(num_elems * sizeof(DstType)); + DstType* qnn_indices_ptr = reinterpret_cast(qnn_bytes.data()); + + std::transform(onnx_indices.begin(), onnx_indices.end(), qnn_indices_ptr, [&input0_axis_dim](SrcType index) { + if (index < 0) { + index += static_cast(input0_axis_dim); + } + return SafeInt(index); + }); +} + +// Gets the size of input0 on the axis dimension. +static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + int64_t default_axis_value, + /*out*/ int64_t& axis_dim_value) { + const auto& input0 = node_unit.Inputs()[0]; + std::vector input0_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape), + "Cannot get shape for ", node_unit.OpType(), " input[0] ", input0.node_arg.Name()); + + int64_t rank = static_cast(input0_shape.size()); + NodeAttrHelper node_helper(node_unit); + int64_t onnx_axis = node_helper.Get("axis", default_axis_value); + if (onnx_axis < 0) { + onnx_axis += rank; + } + ORT_RETURN_IF_NOT((onnx_axis >= 0 && onnx_axis < static_cast(input0_shape.size())), + "QNN requires axis range [0, rank-1] for ", node_unit.OpType()); - qnn_bytes.resize(num_elems * sizeof(T)); - T* qnn_indices_ptr = reinterpret_cast(qnn_bytes.data()); + axis_dim_value = static_cast(input0_shape[onnx_axis]); - std::transform(onnx_indices.begin(), onnx_indices.end(), qnn_indices_ptr, - [](int64_t index) { return SafeInt(index); }); + return Status::OK(); } // Processes the indices input to Gather operators. // -// Gather ops on the QNN CPU backend require int32 indices, so this function will either add a Cast operator -// to dynamic indices or transform static indices to int32/uint32. +// In general, QNN only supports int32/uint32 indices. QNN EP has to add Cast for dynamic int64 indices or +// convert static int64 indices to int32. // -// The HTP backend does not support int64, so this function returns an error status if dynamic indices are of -// type int64. If the indices are static, then this function will convert them to int32/uint32. +// The HTP backend only supports dynamic int64 indices if they are a graph input. static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& indices_input, - bool int32_type_is_signed, + int64_t input0_axis_dim, const logging::Logger& logger, std::vector& input_names, bool do_op_validation) { - Qnn_DataType_t desired_data_type = int32_type_is_signed ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32; - const auto& input_name = indices_input.node_arg.Name(); if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name; @@ -106,62 +115,58 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } - std::string indices_input_name(input_name); - Qnn_DataType_t qnn_data_type = desired_data_type; - const auto* type_proto = indices_input.node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type)); - - std::vector gather_indices; - bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); - - // Gather input 0 is quantized tensor, input 1 (indices) is int64, this is not supported by QNN - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - ORT_RETURN_IF(is_npu_backend && qnn_data_type == QNN_DATATYPE_INT_64 && !is_initializer_input, - "HTP backend doesn't support any int64 data type."); - - if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); - std::vector unpacked_tensor; - - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); - - if (qnn_data_type == QNN_DATATYPE_INT_64) { - if (desired_data_type == QNN_DATATYPE_INT_32) { - ConvertInt64IndicesBytes(unpacked_tensor, gather_indices); - } else { - ConvertInt64IndicesBytes(unpacked_tensor, gather_indices); - } + TensorInfo indices_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info)); + + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name); + ORT_RETURN_IF(is_npu_backend && + (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) && + !(indices_info.is_initializer || is_graph_input), + "HTP backend doesn't support a Gather* op with a dynamic int64 input activation ", + "unless it is a graph input."); + + std::vector qnn_indices_bytes; + + // Get raw bytes for static indices. + // If indices are int64, convert them to int32 and update indices_info.qnn_data_type. + if (indices_info.is_initializer) { + std::vector onnx_indices_bytes; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*indices_info.initializer_tensor, onnx_indices_bytes)); + + if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) { + FixStaticIndices(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes); + indices_info.qnn_data_type = QNN_DATATYPE_INT_32; + } else if (indices_info.qnn_data_type == QNN_DATATYPE_INT_32) { + FixStaticIndices(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes); } else { - gather_indices = std::move(unpacked_tensor); + qnn_indices_bytes = std::move(onnx_indices_bytes); } - qnn_data_type = desired_data_type; } Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name); - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(indices_input.node_arg, input_shape), "Cannot get shape"); - std::vector cast_output_shape(input_shape); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(), - std::move(input_shape), std::move(gather_indices)); + std::vector cast_output_shape(indices_info.shape); + QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(), + std::move(indices_info.shape), std::move(qnn_indices_bytes)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) { - // Insert cast node int64 -> int32/uint32 - if (qnn_data_type == QNN_DATATYPE_INT_64) { - // Add Cast node for indices - indices_input_name = input_name + "_ort_qnn_ep_cast"; - QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, desired_data_type, QnnQuantParamsWrapper(), - std::move(cast_output_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - "Cast", - {input_name}, - {indices_input_name}, - {}, - do_op_validation), - "Failed to add node."); - } + // Insert QNN Cast op to convert dynamic indices from int64 to int32. + std::string indices_input_name(input_name); + if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) { + assert(!indices_info.is_initializer); + + indices_input_name = input_name + "_ort_qnn_ep_cast"; + QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, + QnnQuantParamsWrapper(), std::move(cast_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Cast", + {input_name}, + {indices_input_name}, + {}, + do_op_validation), + "Failed to add node."); } input_names.push_back(indices_input_name); @@ -178,7 +183,10 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF(inputs.size() != 2, "QNN EP: Gather operator must have two inputs"); ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); - return ProcessIndicesInput(qnn_model_wrapper, inputs[1], true, logger, input_names, do_op_validation); + int64_t input0_axis_dim = 0; + ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim)); + + return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation); } Status GatherElementsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -190,19 +198,10 @@ Status GatherElementsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherElements operator must have two inputs"); ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); - return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, input_names, do_op_validation); -} + int64_t input0_axis_dim = 0; + ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim)); -Status GatherNDOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, - std::vector& input_names, - bool do_op_validation) const { - const auto& inputs = node_unit.Inputs(); - ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherND operator must have two inputs"); - ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); - - return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, input_names, do_op_validation); + return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation); } Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -319,36 +318,11 @@ Status GatherElementsOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn logger, do_op_validation, GetQnnOpType(node_unit.OpType())); } -Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { - std::vector param_tensor_names; - NodeAttrHelper node_attr_helper(node_unit); - int32_t onnx_batch_dims = node_attr_helper.Get("batch_dims", 0); - - Qnn_Scalar_t qnn_batch_dims_scalar = QNN_SCALAR_INIT; - qnn_batch_dims_scalar.dataType = QNN_DATATYPE_UINT_32; - qnn_batch_dims_scalar.uint32Value = SafeInt(onnx_batch_dims); - - QnnParamWrapper batch_dims_param(node_unit.Index(), node_unit.Name(), QNN_OP_GATHER_ND_PARAM_BATCH_DIMS, - qnn_batch_dims_scalar); - - param_tensor_names.push_back(batch_dims_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(batch_dims_param)); - - return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), - logger, do_op_validation, GetQnnOpType(node_unit.OpType())); -} - void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_type == "Gather") { op_registrations.AddOpBuilder(op_type, std::make_unique()); } else if (op_type == "GatherElements") { op_registrations.AddOpBuilder(op_type, std::make_unique()); - } else if (op_type == "GatherND") { - op_registrations.AddOpBuilder(op_type, std::make_unique()); } } diff --git a/onnxruntime/test/providers/qnn/gather_elems_op_test.cc b/onnxruntime/test/providers/qnn/gather_elems_op_test.cc index f737b688a60b..92037291b0fb 100644 --- a/onnxruntime/test/providers/qnn/gather_elems_op_test.cc +++ b/onnxruntime/test/providers/qnn/gather_elems_op_test.cc @@ -148,10 +148,7 @@ TEST_F(QnnCPUBackendTests, GatherElems_f32_IndicesInt32) { // Test GatherElements op on CPU backend: // positive, dynamic, int32 indices. -// -// TODO: Enable when fix QNN GatherElements bug. -// Expected output: [[ [3], [3] ]], actual (incorrect) output: [[ [2], [2] ]] -TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_IndicesInt32_3D) { +TEST_F(QnnCPUBackendTests, GatherElems_f32_IndicesInt32_3D) { RunCPUGatherElemsOpTest( {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}, {1, 2, 1}, {1, 1}, false, 1, @@ -176,19 +173,9 @@ TEST_F(QnnCPUBackendTests, GatherElems_f32_StaticIndicesInt32) { ExpectedEPNodeAssignment::All); } -// TODO: Enable when QNN CPU backend supports negative indices. -// Test GatherElements op on CPU backend: -// negative, dynamic, int32 indices. -TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_NegIndicesInt32) { - RunCPUGatherElemsOpTest( - {2, 2}, {1.f, 2.f, 3.f, 4.f}, {2, 2}, {-1, 0, -1, 0}, true, 1, - ExpectedEPNodeAssignment::All); -} - -// TODO: Enable when QNN CPU backend supports negative indices. // Test GatherElements op on CPU backend: // negative, static, int32 indices. -TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_StaticNegIndicesInt32) { +TEST_F(QnnCPUBackendTests, GatherElems_f32_StaticNegIndicesInt32) { RunCPUGatherElemsOpTest( {2, 2}, {1.f, 2.f, 3.f, 4.f}, {2, 2}, {-1, 0, -1, 0}, true, 1, ExpectedEPNodeAssignment::All); @@ -238,23 +225,21 @@ TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticIndicesInt64) { ExpectedEPNodeAssignment::All); } -// TODO: Enable when QNN HTP backend supports negative indices. // Test GatherElements op on HTP backend: -// negative, positive, int32 indices. -TEST_F(QnnHTPBackendTests, DISABLED_GatherElems_u8_NegIndicesInt32) { +// negative, static, int32 indices. +TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticNegIndicesInt32) { RunHTPGatherElemsOpTest( {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, - {2, 3}, {1, 2, 0, 2, 0, 0}, false, 1, + {2, 3}, {1, 2, -3, 2, 0, 0}, true, 1, ExpectedEPNodeAssignment::All); } -// TODO: Enable when QNN HTP backend supports negative indices. // Test GatherElements op on HTP backend: -// negative, static, int32 indices. -TEST_F(QnnHTPBackendTests, DISABLED_GatherElems_u8_StaticNegIndicesInt32) { - RunHTPGatherElemsOpTest( +// negative, static, int64 indices. +TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticNegIndicesInt64) { + RunHTPGatherElemsOpTest( {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, - {2, 3}, {1, 2, 0, 2, 0, 0}, true, 1, + {2, 3}, {1, -1, -3, 2, 0, 0}, true, 1, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/gather_nd_op_test.cc b/onnxruntime/test/providers/qnn/gather_nd_op_test.cc deleted file mode 100644 index fa11b94a4369..000000000000 --- a/onnxruntime/test/providers/qnn/gather_nd_op_test.cc +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#if !defined(ORT_MINIMAL_BUILD) - -#include -#include - -#include "test/optimizer/qdq_test_utils.h" -#include "test/providers/qnn/qnn_test_utils.h" - -#include "onnx/onnx_pb.h" - -#include "gtest/gtest.h" - -namespace onnxruntime { -namespace test { - -// Creates a graph with a single GatherND operator. Used for testing CPU backend. -template -static GetTestModelFn BuildGatherNDTestCase(const std::vector& data_shape, - const std::vector& data, - const std::vector& indices_shape, - const std::vector& indices, - bool indices_are_static, - int64_t batch_dims) { - return [data_shape, data, indices_shape, indices, - indices_are_static, batch_dims](ModelTestBuilder& builder) { - auto* data_input = builder.MakeInput(data_shape, data); - auto* indices_input = (indices_are_static ? builder.MakeInitializer(indices_shape, indices) - : builder.MakeInput(indices_shape, indices)); - auto* output = builder.MakeOutput(); - - Node& gather_nd_node = builder.AddNode("GatherND", {data_input, indices_input}, {output}); - gather_nd_node.AddAttribute("batch_dims", batch_dims); - }; -} - -// Creates a graph with a single Q/DQ GatherND operator. Used for testing HTP backend. -template -static GetTestModelFn BuildQDQGatherNDTestCase(const std::vector& data_shape, - const std::vector& data, - const std::vector& indices_shape, - const std::vector& indices, - bool indices_are_static, - int64_t batch_dims) { - return [data_shape, data, indices_shape, indices, - indices_are_static, batch_dims](ModelTestBuilder& builder) { - constexpr float qdq_scale = 0.0038f; - const DataQType zero_point = (std::numeric_limits::max() - std::numeric_limits::min()) / 2; - - auto* data_input = builder.MakeInput(data_shape, data); - auto* indices_input = (indices_are_static ? builder.MakeInitializer(indices_shape, indices) - : builder.MakeInput(indices_shape, indices)); - auto* output = builder.MakeOutput(); - - // data_input -> Q -> DQ -> GatherND - auto* qdq_output = AddQDQNodePair(builder, data_input, qdq_scale, zero_point); - auto* gather_output = builder.MakeIntermediate(); - - Node& gather_nd_node = builder.AddNode("GatherND", {qdq_output, indices_input}, {gather_output}); - gather_nd_node.AddAttribute("batch_dims", batch_dims); - - // -> Q -> DQ -> output - auto* q_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(gather_output, qdq_scale, zero_point, q_output); - builder.AddDequantizeLinearNode(q_output, qdq_scale, zero_point, output); - }; -} - -// Runs an GatherND model on the QNN CPU backend. Checks the graph node assignment, and that inference -// outputs for QNN EP and CPU EP match. -template -static void RunCPUGatherNDOpTest(const std::vector& data_shape, - const std::vector& data, - const std::vector& indices_shape, - const std::vector& indices, - bool indices_are_static, - int64_t batch_dims, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { - ProviderOptions provider_options; - float fp32_abs_err = 1e-5f; // default tolerance - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnCpu.dll"; -#else - provider_options["backend_path"] = "libQnnCpu.so"; -#endif - - RunQnnModelTest(BuildGatherNDTestCase(data_shape, data, indices_shape, indices, - indices_are_static, batch_dims), - provider_options, - opset, - expected_ep_assignment, - fp32_abs_err); -} - -// Runs an GatherND model on the QNN HTP backend. Checks the graph node assignment, and that inference -// outputs for QNN EP and CPU EP match. -template -static void RunHTPGatherNDOpTest(const std::vector& data_shape, - const std::vector& data, - const std::vector& indices_shape, - const std::vector& indices, - bool indices_are_static, - int64_t batch_dims, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { - ProviderOptions provider_options; - float fp32_abs_err = 1e-5f; // default tolerance - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - RunQnnModelTest(BuildQDQGatherNDTestCase(data_shape, data, indices_shape, indices, - indices_are_static, batch_dims), - provider_options, - opset, - expected_ep_assignment, - fp32_abs_err); -} - -// -// CPU tests: -// - -// Test GatherND op on CPU backend: -// positive, dynamic indices. -// QNN EP should support by adding a Cast operator (to int32) after the indices input. -TEST_F(QnnCPUBackendTests, GatherND_f32_DynamicIndices_BatchDim0) { - RunCPUGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1, 2}, // indices_shape - {0, 1, 1, 0}, // indices - false, // indices_are_static - 0, // batch_dims - ExpectedEPNodeAssignment::All); -} - -// Test GatherND op on CPU backend: -// positive, static indices. -// QNN EP should support by converting static weights to int32_t. -TEST_F(QnnCPUBackendTests, GatherND_f32_StaticIndices_BatchDim0) { - RunCPUGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1, 2}, // indices_shape - {0, 1, 1, 0}, // indices - true, // indices_are_static - 0, // batch_dims - ExpectedEPNodeAssignment::All); -} - -// Test GatherND op on CPU backend: -// - positive, dynamic indices. -// - batch_dims = 1 -// QNN EP should support by adding a Cast operator (to int32) after the indices input. -// -// TODO: Enable when QNN fixes GatherNd with batch_dims != 0 -// QNN graph fails to finalized. -TEST_F(QnnCPUBackendTests, DISABLED_GatherND_f32_DynamicIndices_BatchDim1) { - RunCPUGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1}, // indices_shape - {1, 0}, // indices - false, // indices_are_static - 1, // batch_dims - ExpectedEPNodeAssignment::All); -} - -// Test GatherND op on CPU backend: -// - positive, static indices. -// - batch_dims = 1 -// QNN EP should support by converting static weights to int32_t. -// -// TODO: Enable when QNN fixes GatherNd with batch_dims != 0 -// QNN graph fails to finalized. -TEST_F(QnnCPUBackendTests, DISABLED_GatherND_f32_StaticIndices_BatchDim1) { - RunCPUGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1}, // indices_shape - {1, 0}, // indices - true, // indices_are_static - 1, // batch_dims - ExpectedEPNodeAssignment::All); -} - -#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// -// HTP tests: -// - -// Test GatherND op on CPU backend: -// positive, dynamic indices. -// QNN EP's HTP backend does not support int64 data types. -// Thefore, HTP does not support Dynamic int64_t indices at all. -TEST_F(QnnHTPBackendTests, GatherND_f32_DynamicIndices_BatchDim0) { - RunHTPGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1, 2}, // indices_shape - {0, 1, 1, 0}, // indices - false, // indices_are_static - 0, // batch_dims - ExpectedEPNodeAssignment::Some); // QDQ GatherND not assigned to QNN EP. -} - -// Test GatherND op on HTP backend: -// positive, static indices. -// HTP does not support int64, but QNN EP converts static int64 indices into int32. -TEST_F(QnnHTPBackendTests, GatherND_u8_StaticIndices_BatchDim0) { - RunHTPGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1, 2}, // indices_shape - {0, 1, 1, 0}, // indices - true, // indices_are_static - 0, // batch_dims - ExpectedEPNodeAssignment::All); -} - -// Test GatherND op on HTP backend: -// - positive, static indices. -// - batch_dims = 1 -// QNN EP should support by converting static weights to int32_t. -// -// TODO: Enable when QNN fixes GatherNd with batch_dims != 0 -// Expected value: [[0.2, 0.3],[0.4, 0.5]], Actual (incorrect) output: [[0.2, 0.3], [0.0, 0.1]] -TEST_F(QnnHTPBackendTests, DISABLED_GatherND_f32_StaticIndices_BatchDim1) { - RunHTPGatherNDOpTest({2, 2, 2}, // data_shape - {0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}, // data - {2, 1}, // indices_shape - {1, 0}, // indices - true, // indices_are_static - 1, // batch_dims - ExpectedEPNodeAssignment::All); -} - -#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) - -} // namespace test -} // namespace onnxruntime - -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 48cd5ad99540..e3f09e92593d 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -97,13 +97,14 @@ TEST_F(QnnHTPBackendTests, GatherOp_U16_IndicesStaticInt64_Axis0) { true); // Use 'com.microsoft' Q/DQ ops } -// Tests that dynamic int64 indices are not supported on HTP backend. +// Tests that dynamic int64 indices are supported on HTP backend if the indices are a graph input. +// QNN SDK 2.23 added support for Cast from int64 to int32. TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), {utils::MakeAttribute("axis", static_cast(0))}, 13, - ExpectedEPNodeAssignment::None); + ExpectedEPNodeAssignment::All); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all