Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference Zero-Dim] Support trt 0dim of gelu, hard_swish, hard_sigmoid and leaky_relu #53714

Merged
merged 3 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 12 additions & 58 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ struct SimpleOpTypeSetTeller : public Teller {
"erf", "floor", "round",
"sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt", "swish"};
"rsqrt", "swish", "hard_sigmoid",
"hard_swish", "leaky_relu"};
std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "silu", "logical_not", "reciprocal", "tanh_shrink",
"logsigmoid", "erf", "bitwise_not", "equal", "not_equal",
"rsqrt"};
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
"asin", "acos", "atan", "asinh", "acosh",
"atanh", "ceil", "celu", "floor", "round",
"sign", "logical_not", "reciprocal", "tanh_shrink", "logsigmoid",
"erf", "bitwise_not", "equal", "not_equal", "rsqrt"};

// Static shape does not support 0 or 1 dim's input.
if (!with_dynamic_shape) {
Expand Down Expand Up @@ -962,20 +962,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "hard_swish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "HardSwish op has only 1 input, but got "
<< desc.Input("X").size();
return false;
}

if (desc.Output("Out").size() != 1) {
VLOG(3) << "HardSwish op has only 1 output, but got "
<< desc.Output("Out").size();
return false;
}
}

if (op_type == "squeeze2") {
// If Attribute is Variable(s), HasAttr() will return False
if (!desc.HasAttr("axes", /*with_attr_var=*/false)) {
Expand Down Expand Up @@ -1642,8 +1628,10 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "gelu op does not support input's dim is 1 in tensorrt.";
if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< "gelu op does not support input's dim is 1 or 0 in tensorrt "
"static shape mode.";
return false;
}
}
Expand Down Expand Up @@ -1733,20 +1721,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "leaky_relu") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid number of TRT leaky_relu op converter "
"inputs. Expected 1, but received "
<< desc.Input("X").size();
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "output of leaky_relu op converter should be 1, got "
<< desc.Output("Out").size();
return false;
}
}

if (op_type == "pad") {
if (!desc.HasAttr("pad_value") || !desc.HasAttr("paddings")) return false;
const float pad_value =
Expand Down Expand Up @@ -2388,26 +2362,6 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "hard_sigmoid") {
if (!with_dynamic_shape) {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "Hard sigmoid does not support 1-dimensional input in "
"tensorrt";
return false;
}
}
}

if (op_type == "cast") {
// trt 6015 result in Windows ppyolo_mbv3 TRT fp32 diff
#if !IS_TRT_VERSION_GE(7000)
Expand Down
15 changes: 15 additions & 0 deletions test/ir/inference/test_trt_convert_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
"logsigmoid",
"tanh_shrink",
"softplus",
"hard_swish",
"hard_sigmoid",
"leaky_relu",
]:
# few samples to reduce time
# for beta in [-0.2, 0.5, 0.67, 3]:
Expand All @@ -80,6 +83,18 @@ def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
dics = [{"threshold": alpha}]
if op_type == "softplus":
dics = [{"beta": beta}]
if op_type == "hard_swish":
dics = [
{
"threshold": 6.0,
"scale": 6.0,
"offset": 3.0,
}
]
if op_type == "hard_sigmoid":
dics = [{"slope": beta, "offset": alpha}]
if op_type == "leaky_relu":
dics = [{"alpha": alpha}]

ops_config = [
{
Expand Down
14 changes: 10 additions & 4 deletions test/ir/inference/test_trt_convert_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:

def sample_program_configs(self):
def generate_input1(dims, attrs: List[Dict[str, Any]]):
if dims == 1:
if dims == 0:
return np.ones([]).astype(np.float32)
elif dims == 1:
return np.ones([32]).astype(np.float32)
elif dims == 2:
return np.ones([3, 32]).astype(np.float32)
Expand All @@ -38,7 +40,7 @@ def generate_input1(dims, attrs: List[Dict[str, Any]]):
else:
return np.ones([1, 3, 32, 32]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for approximate in [True, False]:
self.dims = dims
dics = [{"approximate": approximate}]
Expand Down Expand Up @@ -70,7 +72,11 @@ def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 1:
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [64]}
self.dynamic_shape.opt_input_shape = {"input_data": [32]}
Expand Down Expand Up @@ -104,7 +110,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
runtime_version = paddle_infer.get_trt_runtime_version()
self.assertTrue(compile_version == runtime_version)
# Dimension one only runs on Paddle OP
if self.dims == 1:
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
if compile_version >= valid_version:
return 1, 2
Expand Down