Skip to content

Commit

Permalink
[CMSIS-NN] Pad fusion with QNN Conv2D (#12353)
Browse files Browse the repository at this point in the history
Pass that fuses nn.pad and qnn.conv2d for CMSIS-NN target.
  • Loading branch information
ashutosh-arm authored Aug 23, 2022
1 parent 383bd41 commit 52779f1
Show file tree
Hide file tree
Showing 5 changed files with 886 additions and 35 deletions.
50 changes: 45 additions & 5 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
transform.AnnotateTarget("cmsis-nn"),
transform.PartitionGraph(mod_name=mod_name),
GenerateCMSISNNConstants(),
CMSISNNFusePads(),
ScalarToTensorConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
Expand Down Expand Up @@ -91,10 +92,18 @@ def check_qnn_softmax(pattern):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
def qnn_conv2d_pattern(with_pad):
"""Create pattern for qnn.conv2D with optional pad and/or optional fused relu."""
conv2d_input = wildcard()
if with_pad:
conv2d_input = is_op("nn.pad")(wildcard(), is_constant())
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
conv2d_input,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
Expand Down Expand Up @@ -136,7 +145,7 @@ def check_qnn_conv2d(pattern):
):
is_depthwise = True

return (
ret = (
conv2d.attrs.out_dtype == "int32"
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
Expand All @@ -145,6 +154,36 @@ def check_qnn_conv2d(pattern):
and all([zp == 0 for zp in kernel_zp])
and (not is_depthwise or bias_add is not None)
)
return ret

def check_qnn_conv2d_pad(pattern):
"""Check if the Pad followed by Conv2D is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]

# check if sum of paddings from pad() and conv2d() satisfies CMSIS-NN constraints
can_pad_be_fused = True
if isinstance(conv2d_input, tvm.relay.expr.Call) and str(conv2d_input.op.name) == "nn.pad":
pad_top, pad_left, pad_bottom, pad_right = GetEffectiveConv2DPadding(
conv2d, conv2d_input
)
# check if difference in the side paddings is 1 along each dimension
pad_w_diff = int(pad_right - pad_left)
pad_h_diff = int(pad_bottom - pad_top)
can_pad_be_fused = pad_w_diff in [0, 1] and pad_h_diff in [0, 1]

ret = check_qnn_conv2d(pattern) and can_pad_be_fused
return ret

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional Relu."""
Expand Down Expand Up @@ -275,7 +314,8 @@ def check_qnn_binary_op(pattern):
)

return [
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=False), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d),
("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d),
Expand Down
209 changes: 209 additions & 0 deletions src/relay/backend/contrib/cmsisnn/fuse_pads.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/backend/contrib/cmsisnn/fuse_pads.cc
* \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions.
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../op/make_op.h"
#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"
#include "convolutions.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

inline IntImm ToIntImm(int32_t value) { return IntImm(DataType::Int(32), value); }

/*!
* \brief From padding attributes of nn.pad and qnn.conv2d, calculates effective padding along H
* and W dimensions.
*/
Array<IntImm> GetEffectiveConv2DPadding(Expr conv2d, Expr pad) {
// pad_width: ((), (top, bottom), (left, right), ()) for NHWC layout
// conv2d_attrs->padding: (top, left, bottom, right)
auto* conv2d_call = conv2d.as<CallNode>();
auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
std::string data_layout = conv2d_attrs->data_layout.c_str();
int pos_h = data_layout.find("H");
int pos_w = data_layout.find("W");

auto* pad_call = pad.as<CallNode>();
Array<Array<Integer>> pad_width = pad_call->attrs.as<PadAttrs>()->pad_width;
int pad_top =
qnn::get_const_int(conv2d_attrs->padding[0]) + qnn::get_const_int(pad_width[pos_h][0]);
int pad_left =
qnn::get_const_int(conv2d_attrs->padding[1]) + qnn::get_const_int(pad_width[pos_w][0]);
int pad_bottom =
qnn::get_const_int(conv2d_attrs->padding[2]) + qnn::get_const_int(pad_width[pos_h][1]);
int pad_right =
qnn::get_const_int(conv2d_attrs->padding[3]) + qnn::get_const_int(pad_width[pos_w][1]);

return {ToIntImm(pad_top), ToIntImm(pad_left), ToIntImm(pad_bottom), ToIntImm(pad_right)};
}

/*!
* \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
* Then, it will fuse preceding pads with qnn.conv2d.
*/
class FusePadsMutator : public MixedModeMutator {
public:
explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}

private:
/*!
* \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
* convolution layer to update Conv2DAttrs's padding attribute. */
void UpdateConv2DPadding(const CallNode* conv2d_call, const CallNode* pad_call,
Attrs* new_attrs) {
Array<IntImm> effective_padding =
GetEffectiveConv2DPadding(GetRef<Call>(conv2d_call), GetRef<Call>(pad_call));
int pad_top = effective_padding[0]->value;
int pad_left = effective_padding[1]->value;
int pad_bottom = effective_padding[2]->value;
int pad_right = effective_padding[3]->value;
int pad_diff_w = pad_right - pad_left;
int pad_diff_h = pad_bottom - pad_top;
bool can_pad_be_fused =
((pad_diff_w == 0 || pad_diff_w == 1) && (pad_diff_h == 0 || pad_diff_h == 1));
std::string error = "Difference on each side of a dimension should be either 0 or 1. ";
error += "Effective padding in this case: (pad_top, pad_left, pad_bottom, pad_right)=(";
error += std::to_string(pad_top);
error += ", ";
error += std::to_string(pad_left);
error += ", ";
error += std::to_string(pad_bottom);
error += ", ";
error += std::to_string(pad_right);
error += ")";
ICHECK(can_pad_be_fused) << error;

// Prepare new attrs as padding has changed
auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(conv2d_attrs->strides);
attrs->dilation = std::move(conv2d_attrs->dilation);
attrs->groups = conv2d_attrs->groups;
attrs->channels = std::move(conv2d_attrs->channels);
attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
attrs->data_layout = std::move(conv2d_attrs->data_layout);
attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout);
attrs->out_layout = std::move(conv2d_attrs->out_layout);
attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
attrs->padding = {pad_top, pad_left, pad_bottom, pad_right};
*new_attrs = tvm::Attrs{attrs};
}

/*!
* \brief Identifies the sequence for qnn.conv2D and fuses the preceding nn.pad present within the
* CMSIS-NN partitioned function. */
Expr FusePadConv2d(const CallNode* conv2d_call) {
// create new paddings for qnn.conv2d
tvm::Attrs new_conv2d_attrs = conv2d_call->attrs;
Expr new_conv2d_input = conv2d_call->args[0];
if (auto* pad_call = conv2d_call->args[0].as<CallNode>()) {
if (auto* pad_call_op = pad_call->op.as<OpNode>()) {
if (pad_call_op->name == "nn.pad") {
new_conv2d_input = pad_call->args[0];
UpdateConv2DPadding(conv2d_call, pad_call, &new_conv2d_attrs);
}
}
}

// Conv2D arguments: pad's input + rest of the origin args
auto new_conv2d_args = conv2d_call->args;
new_conv2d_args.erase(new_conv2d_args.begin());
new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input);
Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {});
return std::move(ret_call);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr ret_call = post;
auto* post_call = post.as<CallNode>();

// Fuse nn.pad and qnn.conv2d
if (auto* conv2d_op = post_call->op.as<OpNode>()) {
if (conv2d_op->name == "qnn.conv2d") {
ret_call = FusePadConv2d(post_call);
}
}

// Identify qnn.conv2d partitioned function
if (post_call->op.as<FunctionNode>()) {
auto* func = call->op.as<FunctionNode>();
auto func_name = func->GetAttr<String>(attr::kComposite);
if (func_name.defined() && func_name == "cmsis-nn.qnn_conv2d") {
Expr new_body = VisitExpr(func->body);
Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
ret_call = Call(new_func, post_call->args);
}
}

return ret_call;
}

private:
IRModule mod_;
};

IRModule FusePads(const IRModule& mod) {
for (auto gv : mod->GetGlobalVars()) {
Function func = Downcast<Function>(mod->Lookup(gv));

// only mutate CMSIS-NN partitioned functions
auto compiler_name = func->GetAttr<String>(attr::kCompiler);
if (!compiler_name.defined() || compiler_name != "cmsis-nn") {
continue;
}

auto fuse_pads_mutator = FusePadsMutator(mod);
auto new_func_body = fuse_pads_mutator.VisitExpr(func->body);
if (!new_func_body.same_as(func->body)) {
Function new_func =
Function(func->params, new_func_body, func->ret_type, func->type_params, func->attrs);
mod->Update(gv, new_func);
}
}
return mod;
}

transform::Pass CMSISNNFusePads() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return FusePads(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "CMSISNNFusePads", {});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.CMSISNNFusePads").set_body_typed(CMSISNNFusePads);
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GetEffectiveConv2DPadding")
.set_body_typed(GetEffectiveConv2DPadding);

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 52779f1

Please sign in to comment.