Skip to content

Commit

Permalink
adaptive pool2d pass fix (#39600)
Browse files Browse the repository at this point in the history
* first commit

* teller fix

* bug fix

* enable for pool2d only

* fix global_pooling issue

* pooling_type

* fix test
  • Loading branch information
b3602sss authored Feb 17, 2022
1 parent db43b54 commit c1c5c1f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
13 changes: 12 additions & 1 deletion paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,18 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->HasAttr("adaptive") && op->HasAttr("ksize")) {
if (op->Type() == "pool2d" && op->HasAttr("adaptive") &&
op->HasAttr("ksize")) {
if (op->HasAttr("global_pooling")) {
bool global_pooling =
BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
if (global_pooling) return;
}
if (!op->HasAttr("pooling_type")) return;
std::string type =
BOOST_GET_CONST(std::string, op->GetAttr("pooling_type"));
// adaptive has no effect on max pooling
if (type == "max") return;
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ TEST(AdaptivePool2dConvertGlobalPass, basic) {
AttributeMap attrs;
attrs["adaptive"] = true;
attrs["ksize"] = std::vector<int>{1, 1};
attrs["pooling_type"] =
std::string("avg"); // adaptive has no effect on max pooling
layers.pool2d(x, false, &attrs);

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< desc.Output("Out").size();
return false;
}
if (desc.HasAttr("data_format")) {
std::string data_format =
BOOST_GET_CONST(std::string, desc.GetAttr("data_format"));
if (data_format == "NHWC" || data_format == "NDHWC") {
return false;
}
}
if (!desc.HasAttr("pooling_type")) {
return false;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ def sample_program_config(self, draw):
st.integers(
min_value=1, max_value=4), min_size=2, max_size=2))

paddings = [0, 0] # only 0 0 is right
paddings = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=2, max_size=2))

ceil_mode = draw(st.booleans())
exclusive = draw(st.booleans())
global_pooling = False #only false is right
global_pooling = draw(st.booleans())
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VAILD"]))

pool_op = OpConfig(
Expand Down Expand Up @@ -83,29 +87,6 @@ def sample_predictor_configs(self, program_config):
use_calib_mode=False)
yield config, ['pool2d'], (1e-5, 1e-5)

def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs["pooling_type"] == "max":
x_shape = list(program_config.inputs["input_data"].shape)
if x_shape[-1] != 1 or x_shape[-2] != 1:
return True
return False

def teller2(program_config, predictor_config):
if program_config.ops[0].attrs["padding_algorithm"] == "SAME":
return True
return False

self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"max pooling has diff if H or W is not equals to 1", )
self.add_ignore_check_case(
teller2,
IgnoreReasons.PASS_ACCURACY_ERROR,
"output has wrong result if padding_algorithm equals to SAME", )

def test(self):
self.run_and_statis(
quant=False,
Expand Down

0 comments on commit c1c5c1f

Please sign in to comment.