From 52779f1273b05d53d8213e23e70d9b0ac82fd0b9 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Tue, 23 Aug 2022 10:00:34 +0100 Subject: [PATCH] [CMSIS-NN] Pad fusion with QNN Conv2D (#12353) Pass that fuses nn.pad and qnn.conv2d for CMSIS-NN target. --- python/tvm/relay/op/contrib/cmsisnn.py | 50 ++- .../backend/contrib/cmsisnn/fuse_pads.cc | 209 +++++++++++ .../contrib/test_cmsisnn/test_conv2d.py | 277 ++++++++++++-- .../contrib/test_cmsisnn/test_fuse_pads.py | 340 ++++++++++++++++++ tests/python/contrib/test_cmsisnn/utils.py | 45 ++- 5 files changed, 886 insertions(+), 35 deletions(-) create mode 100644 src/relay/backend/contrib/cmsisnn/fuse_pads.cc create mode 100644 tests/python/contrib/test_cmsisnn/test_fuse_pads.py diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 8d714b7269d9..b887fafd7e00 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -59,6 +59,7 @@ def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts): transform.AnnotateTarget("cmsis-nn"), transform.PartitionGraph(mod_name=mod_name), GenerateCMSISNNConstants(), + CMSISNNFusePads(), ScalarToTensorConstants(), ExtractConstantsFromPartitionedFunction(), transform.InferType(), @@ -91,10 +92,18 @@ def check_qnn_softmax(pattern): and dequantize_call.args[0].checked_type.dtype == "int8" ) - def qnn_conv2d_pattern(): - """Create pattern for qnn.conv2D with optional fused relu.""" + def qnn_conv2d_pattern(with_pad): + """Create pattern for qnn.conv2D with optional pad and/or optional fused relu.""" + conv2d_input = wildcard() + if with_pad: + conv2d_input = is_op("nn.pad")(wildcard(), is_constant()) qnn_conv2d = is_op("qnn.conv2d")( - wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + conv2d_input, + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), ) bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) req = is_op("qnn.requantize")( @@ -136,7 +145,7 @@ def check_qnn_conv2d(pattern): ): is_depthwise = True - return ( + ret = ( conv2d.attrs.out_dtype == "int32" and conv2d_input.checked_type.dtype == "int8" and conv2d_weight.checked_type.dtype == "int8" @@ -145,6 +154,36 @@ def check_qnn_conv2d(pattern): and all([zp == 0 for zp in kernel_zp]) and (not is_depthwise or bias_add is not None) ) + return ret + + def check_qnn_conv2d_pad(pattern): + """Check if the Pad followed by Conv2D is supported by CMSIS-NN.""" + if str(pattern.op.name) == "clip": + relu = pattern + requantize = relu.args[0] + else: + requantize = pattern + requantize_input = requantize.args[0] + if str(requantize_input.op.name) == "nn.bias_add": + bias_add = requantize_input + conv2d = bias_add.args[0] + else: + conv2d = requantize_input + conv2d_input = conv2d.args[0] + + # check if sum of paddings from pad() and conv2d() satisfies CMSIS-NN constraints + can_pad_be_fused = True + if isinstance(conv2d_input, tvm.relay.expr.Call) and str(conv2d_input.op.name) == "nn.pad": + pad_top, pad_left, pad_bottom, pad_right = GetEffectiveConv2DPadding( + conv2d, conv2d_input + ) + # check if difference in the side paddings is 1 along each dimension + pad_w_diff = int(pad_right - pad_left) + pad_h_diff = int(pad_bottom - pad_top) + can_pad_be_fused = pad_w_diff in [0, 1] and pad_h_diff in [0, 1] + + ret = check_qnn_conv2d(pattern) and can_pad_be_fused + return ret def qnn_fully_connected_pattern(): """Create pattern for qnn.dense with optional Relu.""" @@ -275,7 +314,8 @@ def check_qnn_binary_op(pattern): ) return [ - ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d), + ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad), + ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=False), check_qnn_conv2d), ("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected), ("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d), ("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d), diff --git a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc new file mode 100644 index 000000000000..71c31c303588 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relay/backend/contrib/cmsisnn/fuse_pads.cc + * \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions. + */ + +#include +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" +#include "convolutions.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +inline IntImm ToIntImm(int32_t value) { return IntImm(DataType::Int(32), value); } + +/*! + * \brief From padding attributes of nn.pad and qnn.conv2d, calculates effective padding along H + * and W dimensions. + */ +Array GetEffectiveConv2DPadding(Expr conv2d, Expr pad) { + // pad_width: ((), (top, bottom), (left, right), ()) for NHWC layout + // conv2d_attrs->padding: (top, left, bottom, right) + auto* conv2d_call = conv2d.as(); + auto* conv2d_attrs = conv2d_call->attrs.as(); + std::string data_layout = conv2d_attrs->data_layout.c_str(); + int pos_h = data_layout.find("H"); + int pos_w = data_layout.find("W"); + + auto* pad_call = pad.as(); + Array> pad_width = pad_call->attrs.as()->pad_width; + int pad_top = + qnn::get_const_int(conv2d_attrs->padding[0]) + qnn::get_const_int(pad_width[pos_h][0]); + int pad_left = + qnn::get_const_int(conv2d_attrs->padding[1]) + qnn::get_const_int(pad_width[pos_w][0]); + int pad_bottom = + qnn::get_const_int(conv2d_attrs->padding[2]) + qnn::get_const_int(pad_width[pos_h][1]); + int pad_right = + qnn::get_const_int(conv2d_attrs->padding[3]) + qnn::get_const_int(pad_width[pos_w][1]); + + return {ToIntImm(pad_top), ToIntImm(pad_left), ToIntImm(pad_bottom), ToIntImm(pad_right)}; +} + +/*! + * \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D. + * Then, it will fuse preceding pads with qnn.conv2d. + */ +class FusePadsMutator : public MixedModeMutator { + public: + explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {} + + private: + /*! + * \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto + * convolution layer to update Conv2DAttrs's padding attribute. */ + void UpdateConv2DPadding(const CallNode* conv2d_call, const CallNode* pad_call, + Attrs* new_attrs) { + Array effective_padding = + GetEffectiveConv2DPadding(GetRef(conv2d_call), GetRef(pad_call)); + int pad_top = effective_padding[0]->value; + int pad_left = effective_padding[1]->value; + int pad_bottom = effective_padding[2]->value; + int pad_right = effective_padding[3]->value; + int pad_diff_w = pad_right - pad_left; + int pad_diff_h = pad_bottom - pad_top; + bool can_pad_be_fused = + ((pad_diff_w == 0 || pad_diff_w == 1) && (pad_diff_h == 0 || pad_diff_h == 1)); + std::string error = "Difference on each side of a dimension should be either 0 or 1. "; + error += "Effective padding in this case: (pad_top, pad_left, pad_bottom, pad_right)=("; + error += std::to_string(pad_top); + error += ", "; + error += std::to_string(pad_left); + error += ", "; + error += std::to_string(pad_bottom); + error += ", "; + error += std::to_string(pad_right); + error += ")"; + ICHECK(can_pad_be_fused) << error; + + // Prepare new attrs as padding has changed + auto* conv2d_attrs = conv2d_call->attrs.as(); + auto attrs = make_object(); + attrs->strides = std::move(conv2d_attrs->strides); + attrs->dilation = std::move(conv2d_attrs->dilation); + attrs->groups = conv2d_attrs->groups; + attrs->channels = std::move(conv2d_attrs->channels); + attrs->kernel_size = std::move(conv2d_attrs->kernel_size); + attrs->data_layout = std::move(conv2d_attrs->data_layout); + attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout); + attrs->out_layout = std::move(conv2d_attrs->out_layout); + attrs->out_dtype = std::move(conv2d_attrs->out_dtype); + attrs->padding = {pad_top, pad_left, pad_bottom, pad_right}; + *new_attrs = tvm::Attrs{attrs}; + } + + /*! + * \brief Identifies the sequence for qnn.conv2D and fuses the preceding nn.pad present within the + * CMSIS-NN partitioned function. */ + Expr FusePadConv2d(const CallNode* conv2d_call) { + // create new paddings for qnn.conv2d + tvm::Attrs new_conv2d_attrs = conv2d_call->attrs; + Expr new_conv2d_input = conv2d_call->args[0]; + if (auto* pad_call = conv2d_call->args[0].as()) { + if (auto* pad_call_op = pad_call->op.as()) { + if (pad_call_op->name == "nn.pad") { + new_conv2d_input = pad_call->args[0]; + UpdateConv2DPadding(conv2d_call, pad_call, &new_conv2d_attrs); + } + } + } + + // Conv2D arguments: pad's input + rest of the origin args + auto new_conv2d_args = conv2d_call->args; + new_conv2d_args.erase(new_conv2d_args.begin()); + new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input); + Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}); + return std::move(ret_call); + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr ret_call = post; + auto* post_call = post.as(); + + // Fuse nn.pad and qnn.conv2d + if (auto* conv2d_op = post_call->op.as()) { + if (conv2d_op->name == "qnn.conv2d") { + ret_call = FusePadConv2d(post_call); + } + } + + // Identify qnn.conv2d partitioned function + if (post_call->op.as()) { + auto* func = call->op.as(); + auto func_name = func->GetAttr(attr::kComposite); + if (func_name.defined() && func_name == "cmsis-nn.qnn_conv2d") { + Expr new_body = VisitExpr(func->body); + Function new_func = Function(FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + ret_call = Call(new_func, post_call->args); + } + } + + return ret_call; + } + + private: + IRModule mod_; +}; + +IRModule FusePads(const IRModule& mod) { + for (auto gv : mod->GetGlobalVars()) { + Function func = Downcast(mod->Lookup(gv)); + + // only mutate CMSIS-NN partitioned functions + auto compiler_name = func->GetAttr(attr::kCompiler); + if (!compiler_name.defined() || compiler_name != "cmsis-nn") { + continue; + } + + auto fuse_pads_mutator = FusePadsMutator(mod); + auto new_func_body = fuse_pads_mutator.VisitExpr(func->body); + if (!new_func_body.same_as(func->body)) { + Function new_func = + Function(func->params, new_func_body, func->ret_type, func->type_params, func->attrs); + mod->Update(gv, new_func); + } + } + return mod; +} + +transform::Pass CMSISNNFusePads() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, transform::PassContext pc) { return FusePads(m); }; + return tvm::transform::CreateModulePass(pass_func, 0, "CMSISNNFusePads", {}); +} + +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.CMSISNNFusePads").set_body_typed(CMSISNNFusePads); +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GetEffectiveConv2DPadding") + .set_body_typed(GetEffectiveConv2DPadding); + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 502743387bfa..d33d71261613 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -40,6 +40,7 @@ assert_partitioned_function, assert_no_external_function, create_test_runner, + CheckForPadsWithinCompositeFunc, ) @@ -62,23 +63,21 @@ def make_model( weight_format, enable_bias, relu_type, + input_op=None, ): """Return a model and any parameters it may have""" + if input_op: + op = input_op + else: + op = relay.var("input", shape=shape, dtype=dtype) + h_index = weight_format.index("H") w_index = weight_format.index("W") kernel_h = kernel_shape[h_index] kernel_w = kernel_shape[w_index] - invar = relay.var("input", shape=shape, dtype=dtype) p = (0, 0, 0, 0) if padding == "SAME": p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) - invar = relay.nn.pad( - invar, - pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], - pad_value=input_zero_point, - pad_mode="constant", - ) - shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) rng = np.random.default_rng(12321) weight = tvm.nd.array( @@ -92,7 +91,7 @@ def make_model( weight_const = relay.const(weight, kernel_dtype) conv2d_kernel_sc = kernel_scale[0] if out_channels == 1 else kernel_scale conv = relay.qnn.op.conv2d( - invar, + op, weight_const, input_zero_point=relay.const(input_zero_point, "int32"), kernel_zero_point=relay.const(kernel_zero_point, "int32"), @@ -165,9 +164,9 @@ def test_conv2d_number_primfunc_args( input_zero_point, kernel_scale, kernel_zero_point, - dtype, - dtype, - dtype, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, ) model, params = make_model( @@ -265,9 +264,9 @@ def test_conv2d_symmetric_padding_int8( input_zero_point, kernel_scale, kernel_zero_point, - dtype, - dtype, - dtype, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, ) model, params = make_model( @@ -355,9 +354,110 @@ def test_conv2d_asymmetric_padding_int8( input_zero_point, kernel_scale, kernel_zero_point, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, dtype, dtype, - dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("ifm_shape", [(1, 25, 25, 12), (1, 64, 100, 4)]) +@pytest.mark.parametrize( + "pad_width", + [ + ((0, 0), (0, 1), (1, 2), (0, 0)), + ((0, 0), (1, 1), (1, 1), (0, 0)), + ((0, 0), (2, 2), (3, 4), (0, 0)), + ], +) +def test_pad_conv2d_fusion_int8( + ifm_shape, + pad_width, +): + """Tests QNN Conv2D where the padding is asymmetric on different sides of input""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + ifm_shape = (1, 25, 25, 12) + kernel_size = (5, 5) + strides = (2, 2) + dilation = (1, 1) + padding = "SAME" + dtype = "int8" + enable_bias = True + relu_type = "NONE" + input_zero_point = 10 + input_scale = 0.0128 + kernel_scale = [0.11, 0.22] + out_channels = 2 + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, + ) + + invar = relay.var("input", shape=ifm_shape, dtype=dtype) + pad = relay.nn.pad( + invar, + pad_width=pad_width, # ((), (top, bottom), (left, right), ()) + pad_value=input_zero_point, + pad_mode="constant", ) model, params = make_model( @@ -379,12 +479,139 @@ def test_conv2d_asymmetric_padding_int8( weight_format, enable_bias, relu_type, + input_op=pad, ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod, False) + + # check pad is not present inside CMSIS-NN partitioned function + cmsisnn_func = None + for var in cmsisnn_mod.get_global_vars(): + if "cmsis_nn_main_0" in var.name_hint: + cmsisnn_func = cmsisnn_mod[var] + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(cmsisnn_func) + pad_verifier.assert_no_pads_within_func() + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize( + "ifm_shape, pad_width, conv2d_padding", + [ + [(1, 25, 25, 12), ((0, 0), (0, 2), (1, 2), (0, 0)), "SAME"], + [(1, 64, 100, 4), ((0, 0), (1, 3), (1, 1), (0, 0)), "VALID"], + [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 5), (0, 0)), "SAME"], + ], +) +def test_invalid_pad_conv2d_fusion_int8( + ifm_shape, + pad_width, + conv2d_padding, +): + """Tests QNN Conv2D where the padding is asymmetric on different sides of input""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + ifm_shape = (1, 25, 25, 12) + kernel_size = (5, 5) + strides = (2, 2) + dilation = (1, 1) + dtype = "int8" + enable_bias = True + relu_type = "NONE" + input_zero_point = 10 + input_scale = 0.0128 + kernel_scale = [0.11, 0.22] + out_channels = 2 + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, + ) + + invar = relay.var("input", shape=ifm_shape, dtype=dtype) + pad = relay.nn.pad( + invar, + pad_width=pad_width, # ((), (top, bottom), (left, right), ()) + pad_value=input_zero_point, + pad_mode="constant", + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + conv2d_padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + input_op=pad, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + # validate pattern matching assert_partitioned_function(orig_mod, cmsisnn_mod) + # check pad is only present inside main function + cmsisnn_func = None + for var in cmsisnn_mod.get_global_vars(): + if "cmsis_nn_main_0" in var.name_hint: + cmsisnn_func = cmsisnn_mod[var] + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(cmsisnn_func) + pad_verifier.assert_no_pads_within_func() + else: + main_func = cmsisnn_mod[var] + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(main_func) + pad_verifier.assert_pads_within_func() + # validate the output rng = np.random.default_rng(12345) inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} @@ -506,10 +733,10 @@ def test_depthwise_int8( input_zero_point, kernel_scale, kernel_zero_point, - dtype, - dtype, - dtype, - True, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, + is_depthwise=True, ) model, params = make_model( @@ -611,10 +838,10 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( input_zero_point, kernel_scale, kernel_zero_point, - dtype, - dtype, - dtype, - True, + input_dtype=dtype, + weights_dtype=dtype, + output_dtype=dtype, + is_depthwise=True, ) model, params = make_model( @@ -729,7 +956,7 @@ def test_invalid_parameters( in_dtype, kernel_dtype, in_dtype, - False, + is_depthwise=False, ) model, params = make_model( shape=ifm_shape, diff --git a/tests/python/contrib/test_cmsisnn/test_fuse_pads.py b/tests/python/contrib/test_cmsisnn/test_fuse_pads.py new file mode 100644 index 000000000000..f57dc5cd5bab --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_fuse_pads.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: fuse_pads pass""" +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import relay +from .utils import CheckForPadsWithinCompositeFunc + +tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) + + +def set_external_func_attr(func, compiler, ext_symbol): + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compiler) + func = func.with_attr("global_symbol", ext_symbol) + return func + + +def set_composite_func_attr(func, name): + func = func.with_attr("Composite", name) + return func + + +@pytest.mark.parametrize( + "ifm_shape, pad_width, conv2d_padding, ofm_shape", + [ + [(1, 25, 25, 12), ((0, 0), (0, 2), (1, 2), (0, 0)), (1, 1, 1, 1), (1, 26, 28, 2)], + [(1, 64, 100, 4), ((0, 0), (1, 3), (1, 1), (0, 0)), (0, 0, 0, 0), (1, 64, 100, 2)], + [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 5), (0, 0)), (0, 0, 1, 1), (1, 57, 59, 2)], + ], +) +def test_invalid_padding_for_fusion(ifm_shape, pad_width, conv2d_padding, ofm_shape): + """Negative tests for pads preceding Conv2D that cannot be fused.""" + dtype = "int8" + kernel_size = (3, 3) + ofm_channels = 2 + local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype) + pad = relay.nn.pad( + local_input, + pad_width=pad_width, # ((), (top, bottom), (left, right), ()) + pad_value=10, + pad_mode="constant", + ) + rng = np.random.default_rng(12321) + local_weight = tvm.nd.array( + rng.integers( + np.iinfo(dtype).min, + high=np.iinfo(dtype).max, + size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), + dtype=dtype, + ) + ) + local_weight = relay.const(local_weight, dtype) + conv2d = relay.qnn.op.conv2d( + pad, + local_weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + data_layout="NHWC", + kernel_layout="OHWI", + channels=ofm_channels, + kernel_size=(3, 3), + padding=conv2d_padding, + out_dtype="int32", + ) + requantize = relay.qnn.op.requantize( + conv2d, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=0, + out_dtype=dtype, + ) + local_func = relay.Function(relay.analysis.free_vars(requantize), requantize) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d") + + mod = tvm.IRModule() + ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype) + call_local_func = relay.Call(local_func, [ext_input]) + extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func) + extern_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint) + mod[extern_var] = extern_func + + main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype) + call_extern_func = relay.Call(extern_var, [main_input]) + main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype)) + main_var = relay.GlobalVar("main") + mod[main_var] = main_func + + mod = relay.transform.InferType()(mod) + + error_regex = r"Difference on each side of a dimension should be either 0 or 1" + + with pytest.raises(tvm.TVMError, match=error_regex): + mod = CMSISNNFusePads()(mod) + + +@pytest.mark.parametrize( + "ifm_shape, pad_width, conv2d_padding, ofm_shape", + [ + [(1, 25, 25, 12), ((0, 0), (0, 1), (1, 2), (0, 0)), (1, 1, 1, 1), (1, 26, 28, 2)], + [(1, 64, 100, 4), ((0, 0), (1, 1), (1, 1), (0, 0)), (0, 0, 0, 0), (1, 64, 100, 2)], + [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 2), (0, 0)), (0, 0, 1, 1), (1, 57, 59, 2)], + ], +) +def test_pad_conv2d_fusion_noncmsisnn_target(ifm_shape, pad_width, conv2d_padding, ofm_shape): + """Tests the pads and conv2d fusion for non-cmsisnn targets. + It is expected that pad will not be fused with Conv2D in this case. + """ + dtype = "int8" + kernel_size = (3, 3) + ofm_channels = 2 + local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype) + pad = relay.nn.pad( + local_input, + pad_width=pad_width, # ((), (top, bottom), (left, right), ()) + pad_value=10, + pad_mode="constant", + ) + rng = np.random.default_rng(12321) + local_weight = tvm.nd.array( + rng.integers( + np.iinfo(dtype).min, + high=np.iinfo(dtype).max, + size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), + dtype=dtype, + ) + ) + local_weight = relay.const(local_weight, dtype) + conv2d = relay.qnn.op.conv2d( + pad, + local_weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + data_layout="NHWC", + kernel_layout="OHWI", + channels=ofm_channels, + kernel_size=(3, 3), + padding=conv2d_padding, + out_dtype="int32", + ) + requantize = relay.qnn.op.requantize( + conv2d, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=0, + out_dtype=dtype, + ) + local_func = relay.Function(relay.analysis.free_vars(requantize), requantize) + local_func = set_composite_func_attr(local_func, "noncmsis-nn.qnn_conv2d") + + mod = tvm.IRModule() + ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype) + call_local_func = relay.Call(local_func, [ext_input]) + extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func) + extern_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "noncmsis-nn", extern_var.name_hint) + mod[extern_var] = extern_func + + main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype) + call_extern_func = relay.Call(extern_var, [main_input]) + main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype)) + main_var = relay.GlobalVar("main") + mod[main_var] = main_func + + mod = relay.transform.InferType()(mod) + + mod = CMSISNNFusePads()(mod) + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(mod[extern_var]) + pad_verifier.assert_pads_within_func() + + +@pytest.mark.parametrize( + "ifm_shape, pad_width, conv2d_padding, ofm_shape", + [ + [(1, 25, 25, 12), ((0, 0), (0, 1), (1, 2), (0, 0)), (1, 1, 1, 1), (1, 26, 28, 2)], + [(1, 64, 100, 4), ((0, 0), (1, 1), (1, 1), (0, 0)), (0, 0, 0, 0), (1, 64, 100, 2)], + [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 2), (0, 0)), (0, 0, 1, 1), (1, 57, 59, 2)], + ], +) +def test_pad_conv2d_fusion(ifm_shape, pad_width, conv2d_padding, ofm_shape): + """Tests the pads and conv2d fusion.""" + dtype = "int8" + kernel_size = (3, 3) + ofm_channels = 2 + local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype) + pad = relay.nn.pad( + local_input, + pad_width=pad_width, # ((), (top, bottom), (left, right), ()) + pad_value=10, + pad_mode="constant", + ) + rng = np.random.default_rng(12321) + local_weight = tvm.nd.array( + rng.integers( + np.iinfo(dtype).min, + high=np.iinfo(dtype).max, + size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]), + dtype=dtype, + ) + ) + local_weight = relay.const(local_weight, dtype) + conv2d = relay.qnn.op.conv2d( + pad, + local_weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + data_layout="NHWC", + kernel_layout="OHWI", + channels=ofm_channels, + kernel_size=(3, 3), + padding=conv2d_padding, + out_dtype="int32", + ) + requantize = relay.qnn.op.requantize( + conv2d, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=0, + out_dtype=dtype, + ) + local_func = relay.Function(relay.analysis.free_vars(requantize), requantize) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d") + + mod = tvm.IRModule() + ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype) + call_local_func = relay.Call(local_func, [ext_input]) + extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func) + extern_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint) + mod[extern_var] = extern_func + + main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype) + call_extern_func = relay.Call(extern_var, [main_input]) + main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype)) + main_var = relay.GlobalVar("main") + mod[main_var] = main_func + + mod = relay.transform.InferType()(mod) + + mod = CMSISNNFusePads()(mod) + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(mod[extern_var]) + pad_verifier.assert_no_pads_within_func() + + +def test_without_preceding_pad(): + """Tests the pass FusePads when padding is not present before qnn.conv2d.""" + dtype = "int8" + ifm_shape = (1, 56, 56, 64) + ofm_shape = (1, 56, 56, 64) + local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype) + rng = np.random.default_rng(12321) + local_weight = tvm.nd.array( + rng.integers( + np.iinfo(dtype).min, + high=np.iinfo(dtype).max, + size=(64, 3, 3, 64), + dtype=dtype, + ) + ) + local_weight = relay.const(local_weight, dtype) + conv2d = relay.qnn.op.conv2d( + local_input, + local_weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + data_layout="NHWC", + kernel_layout="OHWI", + channels=64, + kernel_size=(3, 3), + padding=(1, 1, 1, 1), + out_dtype="int32", + ) + requantize = relay.qnn.op.requantize( + conv2d, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=0, + out_dtype=dtype, + ) + relu = relay.nn.relu(requantize) + local_func = relay.Function(relay.analysis.free_vars(relu), relu) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d") + + mod = tvm.IRModule() + ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype) + call_local_func = relay.Call(local_func, [ext_input]) + extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func) + extern_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint) + mod[extern_var] = extern_func + + main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype) + call_extern_func = relay.Call(extern_var, [main_input]) + main_func = relay.Function(relay.analysis.free_vars(call_extern_func), call_extern_func) + main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype)) + main_var = relay.GlobalVar("main") + mod[main_var] = main_func + + mod = relay.transform.InferType()(mod) + + mod = CMSISNNFusePads()(mod) + pad_verifier = CheckForPadsWithinCompositeFunc() + pad_verifier.visit_function(mod[extern_var]) + pad_verifier.assert_no_pads_within_func() diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index d36ec4219a0e..9fdb89289aff 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -50,8 +50,19 @@ def visit_call(self, call): return counter.count -def assert_partitioned_function(orig_mod, cmsisnn_mod): - """If kCompiler attribute is missing, this function raises assertion""" +def assert_partitioned_function(orig_mod, cmsisnn_mod, expected_ops_unchanged=True): + """ + if KCompiler attribute is missing, this function raises an assertion. + + Parameters + ---------- + orig_mod : IRModule + Pre-partitioning module + cmsisnn_mod : IRModule + Post-partitioning module + is_num_calls_same: bool + Are number of CallNode(s) before and after partitioning expected to be the same + """ attrs = [ cmsisnn_mod[var.name_hint].attrs for var in cmsisnn_mod.get_global_vars() @@ -64,9 +75,10 @@ def assert_partitioned_function(orig_mod, cmsisnn_mod): ] assert any(compilers), "Module does not contain function for cmsisnn target." - assert count_num_calls(orig_mod) == count_num_calls( - cmsisnn_mod - ), "Number of calls changed during partitioning" + if expected_ops_unchanged: + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" def assert_no_external_function(mod): @@ -228,6 +240,29 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): raise ValueError("Invalid argument provided with fused_activation_fn") +class CheckForPadsWithinCompositeFunc(tvm.relay.ExprVisitor): + """Provides method to test number of pads present inside the function being visited.""" + + def __init__(self): + super().__init__() + self.num_pads_ = 0 + + def visit_call(self, call): + super().visit_call(call) + if ( + isinstance(call, tvm.relay.Call) + and isinstance(call.op, tvm.ir.op.Op) + and call.op.name == "nn.pad" + ): + self.num_pads_ += 1 + + def assert_no_pads_within_func(self): + assert self.num_pads_ == 0, "CMSIS-NN composite function should not have pads." + + def assert_pads_within_func(self): + assert self.num_pads_ > 0, "Composite function should have pads within it." + + def create_test_runner(compiler_cpu="cortex-m55", cpu_flags=""): """ Creates AOT test runner for CMSIS-NN tests.