Skip to content

Commit

Permalink
fix avgpool ir (#188)
Browse files Browse the repository at this point in the history
* fix avgpool ir

* add ir template

* rename temp to mediate in ir

Co-authored-by: yuqing <yuqxia@microsoft.com>
  • Loading branch information
xiayuqing0622 and xiayuqing0622 authored Dec 22, 2020
1 parent 7fe3c7c commit 61418e9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

REGISTER_OP(AvgPool)
.infershape(nnfusion::op::infershape::unimplemented_and_not_used)
/*
.translate([](std::shared_ptr<graph::GNode> curr) -> std::string {
auto _op = static_pointer_cast<nnfusion::op::AvgPool>(curr->get_op_ptr());
NNFUSION_CHECK_NOT_NULLPTR(_op) << "Node type is not " << curr->get_op_ptr()->get_op_type();
Expand All @@ -24,24 +25,27 @@ REGISTER_OP(AvgPool)
{"stride", vector_to_string(stride)},
{"padding", vector_to_string(padding)}});
})
*/
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {

auto _op = static_pointer_cast<nnfusion::op::AvgPool>(curr->get_op_ptr());
NNFUSION_CHECK_NOT_NULLPTR(_op) << "Node type is not " << curr->get_op_ptr()->get_op_type();

auto ir_template =
R"( temp0@output0_layout@ +=! @input0@@input0_layout@@conditions@; @output0@@output0_layout@ = temp0@output0_layout@ * @div@;)";

auto ir_template1 =
R"( mediate0@output0_layout@ +=! @input0@@input0_layout@ @where_condition@; @output0@@output0_layout@ = mediate0@output0_layout@ / @div@;)";
auto ir_template2 =
R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@when_condition@ / ((HO * @stride_h@ + @KH_top@ - @pad_h@).call('min', [@H_top@]) - (HO * @stride_h@ - @pad_h@).call('max', [0])) / ((WO * @stride_w@ + @KW_top@ - @pad_w@).call('min', [@W_top@]) - (WO * @stride_w@ - @pad_w@).call('max', [0])) @where_condition@;)";
const auto& input0_shape = curr->get_input_shape(0);
const auto& output0_shape = curr->get_output_shape(0);
const auto& kernel = _op->get_window_shape();
const auto& stride = _op->get_window_movement_strides();
const auto& padding_below = _op->get_padding_below();
const auto& padding_above = _op->get_padding_above();

std::string input0_layout;
std::string output0_layout;
std::string conditions;
std::string when_condition;
std::string when_condition_template;
std::string where_condition;

NNFUSION_CHECK(input0_shape.size() == output0_shape.size());
Expand All @@ -64,16 +68,18 @@ REGISTER_OP(AvgPool)
if (padding_below[0] > 0)
{
H_in += " - " + to_string(padding_below[0]);
when_condition += (when_condition.empty() ? "" : " , ") + H_in + " >= 0, " + H_in +
" < " + to_string(input0_shape[input0_shape.size() - 2]);
when_condition_template += (when_condition_template.empty() ? "" : " , ") + H_in +
" >= 0, " + H_in + " < " +
to_string(input0_shape[input0_shape.size() - 2]);
}

std::string W_in = "WO * " + to_string(stride[1]) + " + KW";
if (padding_below[1] > 0)
{
W_in += " - " + to_string(padding_below[1]);
when_condition += (when_condition.empty() ? "" : " , ") + W_in + " >= 0, " + W_in +
" < " + to_string(input0_shape[input0_shape.size() - 1]);
when_condition_template += (when_condition_template.empty() ? "" : " , ") + W_in +
" >= 0, " + W_in + " < " +
to_string(input0_shape[input0_shape.size() - 1]);
}

input0_layout += H_in + " , " + W_in + "]";
Expand All @@ -83,18 +89,38 @@ REGISTER_OP(AvgPool)
", " + "KH in " + to_string(kernel[0]) + ", " + "KW in " +
to_string(kernel[0]);

if (!when_condition.empty())
if (!when_condition_template.empty())
{
when_condition = ".when([" + when_condition + "], 0.0)";
when_condition_template = ".when([" + when_condition_template +
"], const(0.0).cast(@input0@@input0_layout@.dtype()))";
}

conditions = when_condition + " " + where_condition;

op::OpConfig::any op_config;
op_config["input0_layout"] = input0_layout;
op_config["output0_layout"] = output0_layout;
op_config["conditions"] = conditions;
op_config["div"] = 1.0 / (kernel[0] * kernel[1]);

return op::create_code_from_template(ir_template, op_config);
op_config["div"] = kernel[0] * kernel[1];
op_config["stride_h"] = stride[0];
op_config["stride_w"] = stride[1];
op_config["H_top"] = input0_shape[input0_shape.size() - 2];
op_config["W_top"] = input0_shape[input0_shape.size() - 1];
op_config["KH_top"] = kernel[0];
op_config["KW_top"] = kernel[1];
op_config["pad_h"] = padding_below[0];
op_config["pad_w"] = padding_below[1];
op_config["where_condition"] = where_condition;
op_config["when_condition"] =
op::create_code_from_template(when_condition_template, op_config);

if (padding_below[0] == 0 && padding_below[1] == 0 && padding_above[0] == 0 &&
padding_above[1] == 0)
{
return op::create_code_from_template(ir_template1, op_config);
}
else
{
// For ir_template2, divide operation goes before add operation, which may
// cause precision issue.
// return op::create_code_from_template(ir_template2, op_config);
return "";
}
});
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ REGISTER_OP(DepthToSpace)
})
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
auto expression_template =
R"( temp0@mediate0_layout@ = @input0@@input0_layout@ @cond0@; temp1@mediate1_layout@ = temp0@mediate0_layout@; @output0@@output0_layout@ = temp1@mediate1o_layout@ @cond1@;)";
R"( mediate0@mediate0_layout@ = @input0@@input0_layout@ @cond0@; mediate1@mediate1_layout@ = mediate0@mediate0_layout@; @output0@@output0_layout@ = mediate1@mediate1o_layout@ @cond1@;)";

auto input_shape = curr->get_input_shape(0);
auto _op = std::dynamic_pointer_cast<nnfusion::op::GenericOp>(curr->get_op_ptr());
Expand Down

0 comments on commit 61418e9

Please sign in to comment.