diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 73dd6b735775..83972bd08b41 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -215,6 +215,24 @@ def qnn_mul_pattern(): input_is_right = gen_mul_inputs(is_constant(), wildcard()) return input_is_left | input_is_right + def qnn_add_pattern(): + add_op = is_op("qnn.add") + gen_add_inputs = lambda x, y: add_op( + x, + y, + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + two_inputs = gen_add_inputs(wildcard(), wildcard()) + input_is_left = gen_add_inputs(wildcard(), is_constant()) + input_is_right = gen_add_inputs(is_constant(), wildcard()) + + return input_is_left | input_is_right | two_inputs + def check_conv2d(extract): """Check if a conv2d is supported by Ethos-N.""" if not ethosn_available(): @@ -289,8 +307,24 @@ def check_resize(extract): return _ethosn.resize(extract) + def check_add(extract): + """Check if an addition is supported by Ethos-N.""" + if not ethosn_available(): + return False + # Do not support scalar constants for now + check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0 + if check_scalar(extract.args[0]) or check_scalar(extract.args[1]): + return False + + inputs = extract.args[0:2] + if any([isinstance(i, tvm.relay.Constant) for i in inputs]): + extract = _ethosn.ConvertQnnAdd(extract) + return _ethosn.conv2d(extract) + return _ethosn.addition(extract) + return [ ("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul), + ("ethos-n.qnn_add", qnn_add_pattern(), check_add), ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d), ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d), ("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid), @@ -332,15 +366,6 @@ def reshape(expr): return _ethosn.reshape(expr) -@tvm.ir.register_op_attr("qnn.add", "target.ethos-n") -def qnn_add(expr): - """Check if an addition is supported by Ethos-N.""" - if not ethosn_available(): - return False - - return _ethosn.addition(expr) - - @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n") def qnn_concatenate(expr): """Check if a concatenate is supported by Ethos-N.""" diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index bc4613b80155..69672a143585 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -104,9 +104,9 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) { params.input_info = GetTensorInfo(tensor_table_, call); err += EthosnAPI::Reshape(call, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; - } else if (IsEthosnOp(call, "qnn.add")) { + } else if (IsEthosnFunc(call, "ethos-n.qnn_add")) { AdditionParams params; - err += EthosnAPI::Addition(call, ¶ms); + err += EthosnAPI::Addition(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.lhs_info}; tensor_table_[cn->args[1]] = {params.rhs_info}; } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) { @@ -296,7 +296,7 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) { } else if (IsEthosnOp(call, "reshape")) { if ((err = MakeReshapeLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); - } else if (IsEthosnOp(call, "qnn.add")) { + } else if (IsEthosnFunc(call, "ethos-n.qnn_add")) { if ((err = MakeAdditionLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) { @@ -468,7 +468,7 @@ EthosnError ConstructNetworkVisitor::MakeReshapeLayer(const Call& call, EthosnError ConstructNetworkVisitor::MakeAdditionLayer(const Call& call, sl::TensorAndId* out) { AdditionParams params; - if (auto err = EthosnAPI::Addition(call, ¶ms)) { + if (auto err = EthosnAPI::Addition(call->op.as()->body, ¶ms)) { return err; } diff --git a/src/relay/backend/contrib/ethosn/convert_equivalent.cc b/src/relay/backend/contrib/ethosn/convert_equivalent.cc index 6b64467047f4..12b5a12afb35 100644 --- a/src/relay/backend/contrib/ethosn/convert_equivalent.cc +++ b/src/relay/backend/contrib/ethosn/convert_equivalent.cc @@ -38,6 +38,20 @@ namespace relay { namespace contrib { namespace ethosn { +/*! + * \brief Apply constant folding on an expression. + * + * \param expr The expression to fold. + * \param fold_qnn Whether to fold constants for QNN operations. + * \returns The new folded expression. + */ +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true) { + auto mod = IRModule::FromExpr(expr); + mod = transform::FoldConstant(fold_qnn)(mod); + auto entry_func = Downcast(mod->Lookup("main")); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + /*! * \brief Converts qnn.mul to mathematically equivalent * qnn.conv2d depthwise operation. @@ -65,7 +79,9 @@ Expr ConvertQnnMultiply(const Expr& expr) { const auto* input_constant = input2.as(); ICHECK(input_constant) << "Expected ConstantNode but got " << input2->GetTypeKey(); - const auto* input_constant_tt = input_constant->checked_type().as(); + Type input_constant_type = input_constant->checked_type(); + const auto* input_constant_tt = input_constant_type.as(); + ICHECK(input_constant) << "Expected TensorTypeNode but got " << input_constant_type->GetTypeKey(); int channels = input_constant_tt->shape.back().as()->value; runtime::NDArray input_data = input_constant->data; @@ -93,6 +109,83 @@ Expr ConvertQnnMultiply(const Expr& expr) { TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertQnnMultiply") .set_body_typed(ConvertQnnMultiply); +/*! + * \brief Converts qnn.add to a mathematically equivalent + * qnn.conv2d depthwise operation. + */ +Expr ConvertQnnAdd(const Expr& expr) { + Call call = Downcast(expr); + + Expr input1 = call->args[0]; + Expr input2 = call->args[1]; + Expr input1_scale = call->args[2]; + Expr input1_zero_point = call->args[3]; + Expr input2_scale = call->args[4]; + Expr input2_zero_point = call->args[5]; + // Reverse the inputs if the constant is first input + if (call->args[0]->IsInstance()) { + input1 = call->args[1]; + input2 = call->args[0]; + input1_scale = call->args[4]; + input1_zero_point = call->args[5]; + input2_scale = call->args[2]; + input2_zero_point = call->args[3]; + } + Expr output_scale = call->args[6]; + Expr output_zero_point = call->args[7]; + + const auto* input_constant = input2.as(); + ICHECK(input_constant) << "Expected ConstantNode but got " << input2->GetTypeKey(); + Type input_constant_type = input_constant->checked_type(); + const auto* input_constant_tt = input_constant_type.as(); + ICHECK(input_constant) << "Expected TensorTypeNode but got " << input_constant_type->GetTypeKey(); + int channels = input_constant_tt->shape.back().as()->value; + + // Create the identity kernel. The kernel data is constructed such that it produces an identity + // operation in the quantized space. Therefore, the input is not scaled in any way which allows + // us to later use the bias to perform the addition. + float input_scale_value = GetScalarFromConstant(input1_scale); + float output_scale_value = GetScalarFromConstant(output_scale); + float identity_kernel_scale_ub = std::min(output_scale_value / input_scale_value, 1.f); + float identity_kernel_scale_lb = (1.f / 255.f); + float identity_kernel_scale_target = (identity_kernel_scale_ub + identity_kernel_scale_lb) / 2.f; + float identity_kernel_scale_recip_rounded = std::round(1.f / identity_kernel_scale_target); + float identity_kernel_scale_value = 1.f / identity_kernel_scale_recip_rounded; + Constant identity_kernel_scale = + MakeConstantScalar(DataType::Float(32), identity_kernel_scale_value); + Constant identity_kernel_zero_point = MakeConstantScalar(DataType::Int(32), 0); + float identity_kernel_quantized_data = identity_kernel_scale_recip_rounded; + std::vector identity_kernel_data(channels, + static_cast(identity_kernel_quantized_data)); + Constant identity_kernel = + MakeConstantTensor(input_constant_tt->dtype, {1, 1, channels, 1}, identity_kernel_data); + + // Calculate the bias, this is where the addition happens. The bias values are calculated by + // scaling the constant input to input_scale * identity_kernel_scale. + Constant bias_scale = + MakeConstantScalar(DataType::Float(32), input_scale_value * identity_kernel_scale_value); + Constant bias_zero_point = MakeConstantScalar(DataType::Int(32), 0); + Expr requantize_bias = + qnn::MakeRequantize(input2, input2_scale, input2_zero_point, bias_scale, bias_zero_point, -1, + "None", "None", DataType::Int(32)); + Expr reshape_bias = MakeReshape(requantize_bias, {channels}); + Constant bias = Downcast(FoldConstantExpr(reshape_bias)); + + // Make depthwise conv2d operation + Expr conv2d = + qnn::MakeQnnConv2D(input1, identity_kernel, input1_zero_point, identity_kernel_zero_point, + input1_scale, identity_kernel_scale, {1, 1}, {0, 0, 0, 0}, {1, 1}, + channels, channels, {1, 1}, "NHWC", "HWOI", "NHWC", DataType::Int(32)); + Expr bias_add = MakeBiasAdd(conv2d, bias, 3); + Expr requantize = + qnn::MakeRequantize(bias_add, input1_scale, input1_zero_point, output_scale, + output_zero_point, -1, "None", "None", input_constant_tt->dtype); + + return InferType(requantize); +} + +TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertQnnAdd").set_body_typed(ConvertQnnAdd); + class ConvertEquivalentsMutator : public MixedModeMutator { public: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -108,11 +201,25 @@ class ConvertEquivalentsMutator : public MixedModeMutator { Expr new_func_body = ConvertQnnMultiply(func->body); new_func = WithFields(func, func->params, new_func_body); new_func = WithAttr(std::move(new_func), attr::kComposite, String("ethos-n.qnn_conv2d")); + } else if (composite_name == "ethos-n.qnn_add" && CheckCanConvertAdd(func->body)) { + Expr new_func_body = ConvertQnnAdd(func->body); + new_func = WithFields(func, func->params, new_func_body); + new_func = WithAttr(std::move(new_func), attr::kComposite, String("ethos-n.qnn_conv2d")); } Call new_call = WithFields(call, new_func); return Downcast(new_call); } + + private: + /*! + * \brief Check whether add can be converted to depthwise, or whether + * it should be offloaded as a normal add operation. + */ + bool CheckCanConvertAdd(const Expr& expr) { + Call call = Downcast(expr); + return call->args[0]->IsInstance() || call->args[1]->IsInstance(); + } }; tvm::transform::Pass ConvertEquivalents() { diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index c227ef5c3aea..a1c8ca0a32d2 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -83,7 +83,8 @@ def make_module(func, params): def make_ethosn_composite(ethosn_expr, name): vars = relay.analysis.free_vars(ethosn_expr) - func = relay.Function([relay.Var("a")], ethosn_expr) + inner_vars = [relay.Var(v.name_hint, v.type_annotation) for v in vars] + func = relay.Function(inner_vars, ethosn_expr) func = func.with_attr("Composite", name) call = relay.Call(func, vars) return call diff --git a/tests/python/contrib/test_ethosn/test_addition.py b/tests/python/contrib/test_ethosn/test_addition.py index cc8e030d372d..72981182e17f 100644 --- a/tests/python/contrib/test_ethosn/test_addition.py +++ b/tests/python/contrib/test_ethosn/test_addition.py @@ -25,11 +25,37 @@ from . import infrastructure as tei -def _get_model(input_shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype): +def _get_model( + lhs_shape, + rhs_shape, + lhs_zp, + lhs_sc, + rhs_zp, + rhs_sc, + out_zp, + out_sc, + dtype, + lhs_is_constant=False, + rhs_is_constant=False, +): """Return a model and any parameters it may have""" - a = relay.var("a", shape=input_shape, dtype=dtype) - b = relay.var("b", shape=input_shape, dtype=dtype) + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + + if lhs_is_constant: + a_data = np.random.randint(data_min, data_max + 1, size=lhs_shape, dtype=dtype) + a = relay.const(a_data, dtype=dtype) + else: + a = relay.var("a", shape=lhs_shape, dtype=dtype) + + if rhs_is_constant: + b_data = np.random.randint(data_min, data_max + 1, size=rhs_shape, dtype=dtype) + b = relay.const(b_data, dtype=dtype) + else: + b = relay.var("b", shape=rhs_shape, dtype=dtype) + model = relay.qnn.op.add( lhs=a, rhs=b, @@ -43,74 +69,156 @@ def _get_model(input_shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtyp return model -def _get_addition_qnn_params(dtype, input1_zp, input1_sc, input2_zp, input2_sc): - input1_max = input1_sc * (255 - input1_zp) - input1_min = -input1_sc * input1_zp - input2_max = input2_sc * (255 - input2_zp) - input2_min = -input2_sc * input2_zp +def _get_addition_qnn_params(dtype): + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + lhs_zp = np.random.randint(data_min, data_max) + lhs_sc = np.random.random() * 2 + rhs_zp = np.random.randint(data_min, data_max) + rhs_sc = np.random.random() * 2 + + input1_max = lhs_sc * (255 - lhs_zp) + input1_min = -lhs_sc * lhs_zp + input2_max = rhs_sc * (255 - rhs_zp) + input2_min = -rhs_sc * rhs_zp output_max = input1_max + input2_max output_min = input1_min + input2_min output_sc = (output_max - output_min) / 255 output_zp = -int(output_min / output_sc) - return output_zp, output_sc + return lhs_zp, lhs_sc, rhs_zp, rhs_sc, output_zp, output_sc + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize("shape", [(1, 22, 9, 9), (1, 27, 21, 16)]) +def test_addition(dtype, shape): + """Compare Addition output with TVM.""" + np.random.seed(0) + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) + + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype)), + "b": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype)), + } + model = _get_model(shape, shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype) + for npu in [False, True]: + mod = tei.make_module(model, []) + outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu)) + + tei.verify(outputs, dtype, 1) + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize( + "lhs_shape,rhs_shape", + [ + ((1, 4, 4, 8), (1, 1, 1, 8)), + ((1, 16, 12, 4), (4,)), + ], +) +def test_addition_to_depthwise_rhs_constant(dtype, lhs_shape, rhs_shape): + """Compare addition to depthwise with TVM.""" + np.random.seed(0) + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) + + model = _get_model( + lhs_shape, + rhs_shape, + lhs_zp, + lhs_sc, + rhs_zp, + rhs_sc, + out_zp, + out_sc, + dtype, + lhs_is_constant=False, + rhs_is_constant=True, + ) + inputs = { + "a": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=lhs_shape, dtype=dtype)) + } + outputs = [] + for npu in [False, True]: + mod = tei.make_module(model, {}) + outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu)) + tei.verify(outputs, dtype, 1) @requires_ethosn @pytest.mark.parametrize("dtype", ["uint8", "int8"]) -def test_addition(dtype): - zp_min = np.iinfo(dtype).min - zp_max = np.iinfo(dtype).max - trials = [ - ((1, 22, 9, 9), zp_min + 24, 1.057, zp_max - 3, 0.452), - ((1, 27, 21, 16), zp_min + 79, 0.850, 24, 0.380), - ((1, 7, 12, 28), zp_min + 125, 1.293, zp_max - 16, 0.320), - ((1, 14, 9, 6), zp_min + 14, 0.942, zp_max - 28, 1.562), - ((1, 13, 16, 22), zp_min + 15, 0.727, zp_max - 75, 0.461), - ] +@pytest.mark.parametrize( + "lhs_shape,rhs_shape", + [ + ((1, 8), (1, 20, 15, 8)), + ], +) +def test_addition_to_depthwise_lhs_constant(dtype, lhs_shape, rhs_shape): + """Compare addition to depthwise with TVM.""" np.random.seed(0) - for shape, rhs_zp, rhs_sc, lhs_zp, lhs_sc in trials: - outputs = [] - inputs = { - "a": tvm.nd.array(np.random.randint(zp_min, zp_max + 1, size=shape, dtype=dtype)), - "b": tvm.nd.array(np.random.randint(zp_min, zp_max + 1, size=shape, dtype=dtype)), - } - out_zp, out_sc = _get_addition_qnn_params(dtype, lhs_zp, lhs_sc, rhs_zp, rhs_sc) - model = _get_model(shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype) - for npu in [False, True]: - mod = tei.make_module(model, []) - outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu)) - tei.verify(outputs, dtype, 2) + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) + + model = _get_model( + lhs_shape, + rhs_shape, + lhs_zp, + lhs_sc, + rhs_zp, + rhs_sc, + out_zp, + out_sc, + dtype, + lhs_is_constant=True, + rhs_is_constant=False, + ) + inputs = { + "b": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=rhs_shape, dtype=dtype)) + } + outputs = [] + for npu in [False, True]: + mod = tei.make_module(model, {}) + outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu)) + tei.verify(outputs, dtype, 1) @requires_ethosn -def test_addition_failure(): - trials = [ +@pytest.mark.parametrize( + "dtype,shape,err_msg", + [ ( - (2, 4, 4, 4), "uint8", - 0, - 1, - 0, - 1, - 0, - 1, + (2, 4, 4, 4), "batch size=2, batch size must = 1; batch size=2, batch size must = 1", ), ( - (1, 4, 4, 4), "int16", - 0, - 1, - 0, - 1, - 0, - 1, - "dtype='int16', dtype must be either uint8, int8 or int32; dtype='int16', dtype must be either uint8, int8 or int32", + (1, 4, 4, 4), + "dtype='int16', dtype must be either uint8, int8 or int32; dtype='int16', " + "dtype must be either uint8, int8 or int32", ), - ] + ], +) +def test_addition_failure(dtype, shape, err_msg): + """Check addition error messages.""" + np.random.seed(0) + + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) - for shape, dtype, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, err_msg in trials: - model = _get_model(shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype) - mod = tei.make_ethosn_partition(model) - tei.test_error(mod, {}, err_msg) + model = _get_model(shape, shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype) + model = tei.make_ethosn_composite(model, "ethos-n.qnn_add") + mod = tei.make_ethosn_partition(model) + tei.test_error(mod, {}, err_msg) diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py b/tests/python/contrib/test_ethosn/test_convert_equivalents.py index 570009422067..fe9b346691b6 100644 --- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py +++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py @@ -24,8 +24,10 @@ from tvm import relay from tvm.testing import requires_ethosn from tvm.relay.op.contrib.ethosn import ConvertEquivalents +from tvm.relay import ExprVisitor from . import infrastructure as tei +from .test_addition import _get_addition_qnn_params def _assert_structural_equal(a, b): @@ -38,35 +40,6 @@ def _assert_structural_equal(a, b): assert tvm.ir.structural_equal(a, b), reason -def _create_npu_module(inputs, expr, composite_name, ext_func_name): - """Wraps an operator as an NPU module.""" - gen_vars = lambda prefix, vars: [ - relay.var( - prefix + var.name_hint, shape=var.type_annotation.shape, dtype=var.type_annotation.dtype - ) - for var in vars - ] - - mod = tvm.ir.IRModule() - - func = relay.Function(relay.analysis.free_vars(expr), expr) - func = func.with_attr("Composite", composite_name) - inner_vars = gen_vars("inner_", inputs) - call = relay.Call(func, inner_vars) - - func2 = relay.Function(relay.analysis.free_vars(call), call) - func2 = func2.with_attr("Compiler", "ethos-n") - func2 = func2.with_attr("global_symbol", ext_func_name) - mod[ext_func_name] = func2 - mod = relay.transform.InferType()(mod) - - outer_vars = gen_vars("outer_", inputs) - out = relay.Call(mod.get_global_var(ext_func_name), outer_vars) - mod["main"] = relay.Function(relay.analysis.free_vars(out), out) - mod = relay.transform.InferType()(mod) - return mod - - @requires_ethosn @pytest.mark.parametrize("dtype", ["uint8", "int8"]) @pytest.mark.parametrize("shape,channels", [((1, 4, 4, 8), 8), ((1, 16, 12, 4), 4)]) @@ -101,7 +74,8 @@ def before(): relay.const(output_sc, "float32"), relay.const(output_zp, "int32"), ) - return _create_npu_module([x], expr, "ethos-n.qnn_mul", "ext_func") + composite = tei.make_ethosn_composite(expr, "ethos-n.qnn_mul") + return tei.make_ethosn_partition(composite) def expected(): constant_shape_hwoi = (1, 1, channels, 1) @@ -134,9 +108,70 @@ def expected(): relay.const(output_zp, "int32"), out_dtype=dtype, ) - return _create_npu_module([x], expr, "ethos-n.qnn_conv2d", "ext_func") + composite = tei.make_ethosn_composite(expr, "ethos-n.qnn_conv2d") + return tei.make_ethosn_partition(composite) mod = before() mod = ConvertEquivalents()(mod) expected_mod = expected() - _assert_structural_equal(mod["ext_func"], expected_mod["ext_func"]) + _assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) + + +@requires_ethosn +@pytest.mark.parametrize("reverse_inputs", [True, False]) +def test_add_to_depthwise(reverse_inputs): + """ + Check that add is converted correctly. + """ + dtype = "uint8" + lhs_shape = (1, 2, 4, 8) + rhs_shape = (1, 1, 1, 8) + np.random.seed(0) + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc = _get_addition_qnn_params(dtype) + + x = relay.var("x", shape=lhs_shape, dtype=dtype) + y_data = np.random.randint(data_min, data_max + 1, size=rhs_shape, dtype=dtype) + + def before(): + y = relay.const(y_data) + expr = relay.qnn.op.add( + lhs=y if reverse_inputs else x, + rhs=x if reverse_inputs else y, + lhs_scale=relay.const(lhs_sc, "float32"), + lhs_zero_point=relay.const(lhs_zp, "int32"), + rhs_scale=relay.const(rhs_sc, "float32"), + rhs_zero_point=relay.const(rhs_zp, "int32"), + output_scale=relay.const(out_sc, "float32"), + output_zero_point=relay.const(out_zp, "int32"), + ) + composite = tei.make_ethosn_composite(expr, "ethos-n.qnn_add") + return tei.make_ethosn_partition(composite) + + class ConversionChecker(ExprVisitor): + """ + Pass to check the new composite function is in the expected format. + """ + + sequence = ["qnn.conv2d", "nn.bias_add", "qnn.requantize"] + + def visit_function(self, fn): + composite_name = fn.attrs["Composite"] + expected = "ethos-n.qnn_conv2d" + assert ( + composite_name == expected + ), f"Expected Composite attribute {expected} but got {composite_name}" + super().visit_function(fn) + + def visit_call(self, call): + op_name = call.op.name + expected_name = self.sequence.pop() + assert op_name == expected_name, f"Got operator {op_name} but expected {expected_name}" + super().visit_call(call) + + mod = before() + mod = ConvertEquivalents()(mod) + mod = ConversionChecker().visit(mod["ethos-n_0"].body.op) diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index abc4d37a7359..d16bf5bf325c 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -143,7 +143,7 @@ def test_resnet_50_int8(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"60404ad60fc2bfbb68464d8a14cc0452", "4225fa951c145bb1e48e28cad6a3bdd4"} + _compile_hash = {"9245965b2c01e7f3d9b478e38a186eb4", "4225fa951c145bb1e48e28cad6a3bdd4"} _test_image_network( model_url="https://raw.githubusercontent.com/dmlc/web-data/main/tensorflow/" "models/Quantized/resnet_50_quantized.tflite",