Skip to content

Commit

Permalink
Update code to support negative static indices
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Aug 9, 2024
1 parent 0bd42bd commit 21b21dd
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 395 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
return {{"Gather", {}},
{"GatherElements", {}},
{"GatherND", {}},
{"Reshape", {}},
{"Expand", {}},
{"Flatten", {}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreateGatherOpBuilder("Gather", *this);
CreateGatherOpBuilder("GatherElements", *this);
CreateGatherOpBuilder("GatherND", *this);
}

{
Expand Down
218 changes: 96 additions & 122 deletions onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cassert>
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
Expand Down Expand Up @@ -51,117 +52,121 @@ class GatherElementsOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

class GatherNDOpBuilder : public BaseOpBuilder {
public:
GatherNDOpBuilder() : BaseOpBuilder("GatherNDOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherNDOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

// Converts int64 indices to another integer type (typically int32 or uint32).
// Makes negative indices positive and converts int64 indices to another integer type (typically int32 or uint32).
// The input and output are both represented as byte arrays.
template <typename T>
static void ConvertInt64IndicesBytes(const std::vector<uint8_t>& onnx_bytes, std::vector<uint8_t>& qnn_bytes) {
const size_t num_elems = onnx_bytes.size() / sizeof(uint64_t);
gsl::span<const uint64_t> onnx_indices{reinterpret_cast<const uint64_t*>(onnx_bytes.data()), num_elems};
template <typename SrcType, typename DstType>
static void FixStaticIndices(const std::vector<uint8_t>& onnx_bytes,
int64_t input0_axis_dim,
/*out*/ std::vector<uint8_t>& qnn_bytes) {
const size_t num_elems = onnx_bytes.size() / sizeof(SrcType);
gsl::span<const SrcType> onnx_indices{reinterpret_cast<const SrcType*>(onnx_bytes.data()), num_elems};

qnn_bytes.resize(num_elems * sizeof(DstType));
DstType* qnn_indices_ptr = reinterpret_cast<DstType*>(qnn_bytes.data());

std::transform(onnx_indices.begin(), onnx_indices.end(), qnn_indices_ptr, [&input0_axis_dim](SrcType index) {
if (index < 0) {
index += static_cast<SrcType>(input0_axis_dim);
}
return SafeInt<DstType>(index);
});
}

// Gets the size of input0 on the axis dimension.
static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
int64_t default_axis_value,
/*out*/ int64_t& axis_dim_value) {
const auto& input0 = node_unit.Inputs()[0];
std::vector<uint32_t> input0_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape),
"Cannot get shape for ", node_unit.OpType(), " input[0] ", input0.node_arg.Name());

int64_t rank = static_cast<int64_t>(input0_shape.size());
NodeAttrHelper node_helper(node_unit);
int64_t onnx_axis = node_helper.Get("axis", default_axis_value);
if (onnx_axis < 0) {
onnx_axis += rank;
}
ORT_RETURN_IF_NOT((onnx_axis >= 0 && onnx_axis < static_cast<int64_t>(input0_shape.size())),
"QNN requires axis range [0, rank-1] for ", node_unit.OpType());

qnn_bytes.resize(num_elems * sizeof(T));
T* qnn_indices_ptr = reinterpret_cast<T*>(qnn_bytes.data());
axis_dim_value = static_cast<int64_t>(input0_shape[onnx_axis]);

std::transform(onnx_indices.begin(), onnx_indices.end(), qnn_indices_ptr,
[](int64_t index) { return SafeInt<T>(index); });
return Status::OK();
}

// Processes the indices input to Gather operators.
//
// Gather ops on the QNN CPU backend require int32 indices, so this function will either add a Cast operator
// to dynamic indices or transform static indices to int32/uint32.
// In general, QNN only supports int32/uint32 indices. QNN EP has to add Cast for dynamic int64 indices or
// convert static int64 indices to int32.
//
// The HTP backend does not support int64, so this function returns an error status if dynamic indices are of
// type int64. If the indices are static, then this function will convert them to int32/uint32.
// The HTP backend only supports dynamic int64 indices if they are a graph input.
static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
const NodeUnitIODef& indices_input,
bool int32_type_is_signed,
int64_t input0_axis_dim,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) {
Qnn_DataType_t desired_data_type = int32_type_is_signed ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32;

const auto& input_name = indices_input.node_arg.Name();
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
input_names.push_back(input_name);
return Status::OK();
}

std::string indices_input_name(input_name);
Qnn_DataType_t qnn_data_type = desired_data_type;
const auto* type_proto = indices_input.node_arg.TypeAsProto();
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));

std::vector<uint8_t> gather_indices;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);

// Gather input 0 is quantized tensor, input 1 (indices) is int64, this is not supported by QNN
bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
ORT_RETURN_IF(is_npu_backend && qnn_data_type == QNN_DATATYPE_INT_64 && !is_initializer_input,
"HTP backend doesn't support any int64 data type.");

if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
std::vector<uint8_t> unpacked_tensor;

ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));

if (qnn_data_type == QNN_DATATYPE_INT_64) {
if (desired_data_type == QNN_DATATYPE_INT_32) {
ConvertInt64IndicesBytes<int32_t>(unpacked_tensor, gather_indices);
} else {
ConvertInt64IndicesBytes<uint32_t>(unpacked_tensor, gather_indices);
}
TensorInfo indices_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));

const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
ORT_RETURN_IF(is_npu_backend &&
(indices_info.qnn_data_type == QNN_DATATYPE_INT_64) &&
!(indices_info.is_initializer || is_graph_input),
"HTP backend doesn't support a Gather* op with a dynamic int64 input activation ",
"unless it is a graph input.");

std::vector<uint8_t> qnn_indices_bytes;

// Get raw bytes for static indices.
// If indices are int64, convert them to int32 and update indices_info.qnn_data_type.
if (indices_info.is_initializer) {
std::vector<uint8_t> onnx_indices_bytes;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*indices_info.initializer_tensor, onnx_indices_bytes));

if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
FixStaticIndices<int64_t, int32_t>(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes);
indices_info.qnn_data_type = QNN_DATATYPE_INT_32;
} else if (indices_info.qnn_data_type == QNN_DATATYPE_INT_32) {
FixStaticIndices<int32_t, int32_t>(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes);
} else {
gather_indices = std::move(unpacked_tensor);
qnn_indices_bytes = std::move(onnx_indices_bytes);
}
qnn_data_type = desired_data_type;
}

Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name);
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(indices_input.node_arg, input_shape), "Cannot get shape");
std::vector<uint32_t> cast_output_shape(input_shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(),
std::move(input_shape), std::move(gather_indices));
std::vector<uint32_t> cast_output_shape(indices_info.shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape), std::move(qnn_indices_bytes));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");

if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) {
// Insert cast node int64 -> int32/uint32
if (qnn_data_type == QNN_DATATYPE_INT_64) {
// Add Cast node for indices
indices_input_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, desired_data_type, QnnQuantParamsWrapper(),
std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{indices_input_name},
{},
do_op_validation),
"Failed to add node.");
}
// Insert QNN Cast op to convert dynamic indices from int64 to int32.
std::string indices_input_name(input_name);
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
assert(!indices_info.is_initializer);

indices_input_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{indices_input_name},
{},
do_op_validation),
"Failed to add node.");
}

input_names.push_back(indices_input_name);
Expand All @@ -178,7 +183,10 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: Gather operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], true, logger, input_names, do_op_validation);
int64_t input0_axis_dim = 0;
ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation);
}

Status GatherElementsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
Expand All @@ -190,19 +198,10 @@ Status GatherElementsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherElements operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, input_names, do_op_validation);
}
int64_t input0_axis_dim = 0;
ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim));

Status GatherNDOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherND operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, input_names, do_op_validation);
return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation);
}

Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
Expand Down Expand Up @@ -319,36 +318,11 @@ Status GatherElementsOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<std::string> param_tensor_names;
NodeAttrHelper node_attr_helper(node_unit);
int32_t onnx_batch_dims = node_attr_helper.Get("batch_dims", 0);

Qnn_Scalar_t qnn_batch_dims_scalar = QNN_SCALAR_INIT;
qnn_batch_dims_scalar.dataType = QNN_DATATYPE_UINT_32;
qnn_batch_dims_scalar.uint32Value = SafeInt<uint32_t>(onnx_batch_dims);

QnnParamWrapper batch_dims_param(node_unit.Index(), node_unit.Name(), QNN_OP_GATHER_ND_PARAM_BATCH_DIMS,
qnn_batch_dims_scalar);

param_tensor_names.push_back(batch_dims_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(batch_dims_param));

return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_type == "Gather") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherOpBuilder>());
} else if (op_type == "GatherElements") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherElementsOpBuilder>());
} else if (op_type == "GatherND") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherNDOpBuilder>());
}
}

Expand Down
33 changes: 9 additions & 24 deletions onnxruntime/test/providers/qnn/gather_elems_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ TEST_F(QnnCPUBackendTests, GatherElems_f32_IndicesInt32) {

// Test GatherElements op on CPU backend:
// positive, dynamic, int32 indices.
//
// TODO: Enable when fix QNN GatherElements bug.
// Expected output: [[ [3], [3] ]], actual (incorrect) output: [[ [2], [2] ]]
TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_IndicesInt32_3D) {
TEST_F(QnnCPUBackendTests, GatherElems_f32_IndicesInt32_3D) {
RunCPUGatherElemsOpTest<float, int32_t>(
{2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f},
{1, 2, 1}, {1, 1}, false, 1,
Expand All @@ -176,19 +173,9 @@ TEST_F(QnnCPUBackendTests, GatherElems_f32_StaticIndicesInt32) {
ExpectedEPNodeAssignment::All);
}

// TODO: Enable when QNN CPU backend supports negative indices.
// Test GatherElements op on CPU backend:
// negative, dynamic, int32 indices.
TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_NegIndicesInt32) {
RunCPUGatherElemsOpTest<float, int64_t>(
{2, 2}, {1.f, 2.f, 3.f, 4.f}, {2, 2}, {-1, 0, -1, 0}, true, 1,
ExpectedEPNodeAssignment::All);
}

// TODO: Enable when QNN CPU backend supports negative indices.
// Test GatherElements op on CPU backend:
// negative, static, int32 indices.
TEST_F(QnnCPUBackendTests, DISABLED_GatherElems_f32_StaticNegIndicesInt32) {
TEST_F(QnnCPUBackendTests, GatherElems_f32_StaticNegIndicesInt32) {
RunCPUGatherElemsOpTest<float, int32_t>(
{2, 2}, {1.f, 2.f, 3.f, 4.f}, {2, 2}, {-1, 0, -1, 0}, true, 1,
ExpectedEPNodeAssignment::All);
Expand Down Expand Up @@ -238,23 +225,21 @@ TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticIndicesInt64) {
ExpectedEPNodeAssignment::All);
}

// TODO: Enable when QNN HTP backend supports negative indices.
// Test GatherElements op on HTP backend:
// negative, positive, int32 indices.
TEST_F(QnnHTPBackendTests, DISABLED_GatherElems_u8_NegIndicesInt32) {
// negative, static, int32 indices.
TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticNegIndicesInt32) {
RunHTPGatherElemsOpTest<float, uint8_t, int32_t>(
{3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f},
{2, 3}, {1, 2, 0, 2, 0, 0}, false, 1,
{2, 3}, {1, 2, -3, 2, 0, 0}, true, 1,
ExpectedEPNodeAssignment::All);
}

// TODO: Enable when QNN HTP backend supports negative indices.
// Test GatherElements op on HTP backend:
// negative, static, int32 indices.
TEST_F(QnnHTPBackendTests, DISABLED_GatherElems_u8_StaticNegIndicesInt32) {
RunHTPGatherElemsOpTest<float, uint8_t, int32_t>(
// negative, static, int64 indices.
TEST_F(QnnHTPBackendTests, GatherElems_u8_StaticNegIndicesInt64) {
RunHTPGatherElemsOpTest<float, uint8_t, int64_t>(
{3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f},
{2, 3}, {1, 2, 0, 2, 0, 0}, true, 1,
{2, 3}, {1, -1, -3, 2, 0, 0}, true, 1,
ExpectedEPNodeAssignment::All);
}

Expand Down
Loading

0 comments on commit 21b21dd

Please sign in to comment.