diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h index 9a09d91ae5d0..438495870166 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h @@ -36,7 +36,8 @@ static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) { op_name == "sqrt" || op_name == "exp" || op_name == "abs" || - op_name == "clip") { + op_name == "clip" || + op_name == "LeakyReLU") { return true; } else { return false; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index e2b1807b6559..fd1e156fcd5e 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -286,8 +286,14 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, if (fuse_requantize || mkldnn_param.enable_float_output) { float tmp_scale_ = 1.0f; if (fuse_requantize) { - tmp_scale_ = - GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_) / data_scale_; + if (mkldnn_param.with_eltwise) { + tmp_scale_ = 1.0 / data_scale_; + full_param_.eltwise_param.scale = + GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_); + } else { + tmp_scale_ = + GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_) / data_scale_; + } } else { tmp_scale_ = 1.0 / data_scale_; } @@ -405,6 +411,10 @@ static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) { if (op_name == "Activation") { const ActivationParam act_param = nnvm::get(node->attrs.parsed); full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param); + } else if (op_name == "LeakyReLU") { + const auto act_param = nnvm::get(node->attrs.parsed); + full_param.eltwise_param.alpha = act_param.slope; + full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param); } else if (op_name == "clip") { const ClipParam clip_param = nnvm::get(node->attrs.parsed); full_param.eltwise_param.alg = mkldnn::algorithm::eltwise_bounded_relu; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index aecb3a7a8477..432772d36298 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -102,6 +102,16 @@ class SgMKLDNNFCSelector : public SubgraphSelector { return true; } } + if (new_node.op() == Op::Get("LeakyReLU")) { + const LeakyReLUParam ¶m = nnvm::get(new_node.attrs.parsed); + if (param.act_type == leakyrelu::kLeakyReLU || + param.act_type == leakyrelu::kELU || + param.act_type == leakyrelu::kGELU) { + matched_list_.push_back(&new_node); + status_ = kSuccess; + return true; + } + } if (!quantized_ && (new_node.op() == Op::Get("square") || new_node.op() == Op::Get("sqrt") || new_node.op() == Op::Get("exp"))) { diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 65b73e438ea6..f4d421c1df39 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -49,7 +49,7 @@ } DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] -fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', +fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'gelu', 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] def check_qsym_calibrated(qsym, out_type, name='conv'): @@ -654,6 +654,17 @@ def single_fc(no_bias, data_shape, flatten=True): no_bias=no_bias, flatten=flatten) return fc, attr +def fc_gelu(no_bias, data_shape, flatten): + attrs = {'fc': {'with_eltwise': 'true'}} + data, weight = head_symbol(data_shape) + weight_2 = mx.symbol.Variable('2nd_weight', dtype='float32') + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=6, + no_bias=no_bias, flatten=flatten) + sym = mx.symbol.LeakyReLU(data=fc, act_type='gelu') + sym = mx.symbol.FullyConnected(name='fc2', data=sym, weight=weight_2, num_hidden=6, + no_bias=no_bias, flatten=flatten) + return sym, attrs + # fc + eltwise fusion case def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): assert alg in fc_post_ops_list @@ -664,6 +675,8 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): no_bias=no_bias, flatten=flatten) if alg in ['relu', 'sigmoid', 'tanh', 'softrelu']: sym = mx.symbol.Activation(data=fc, name='act', act_type=alg) + elif alg == "gelu": + sym = mx.symbol.LeakyReLU(data=fc, act_type='gelu') elif alg == 'square': sym = mx.symbol.square(data=fc, name='square') elif alg == 'square_root': @@ -865,6 +878,12 @@ def test_fc_eltwise(): else: check_fusion(syms, dshape, attrs, check_quantization=False) +@with_seed() +def test_fc_gelu(): + for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): + sym, attrs = fc_gelu(no_bias, dshape, flatten) + check_fusion(sym, dshape, attrs, check_quantization=True) + @with_seed() def test_neg_fc_relu(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):