Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mish FP32/BF16 kernel, conv and fc fuse passes #38623

Merged
merged 19 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.IsType<float>()
.End();
}

Conv2DMishFusePass::Conv2DMishFusePass() {
AddOpCompat(OpCompat("mish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
Expand Down Expand Up @@ -311,6 +319,14 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
.LE("conv2d", 1)
.EQ("hard_swish", 0));

REGISTER_PASS(conv_mish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DMishFusePass);
REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("mish", 1));

REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass {
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
/*
* Fuse Conv and Mish class
*/
class Conv2DMishFusePass : public ConvActivationFusePass {
public:
Conv2DMishFusePass();
std::string activation_type() const { return "mish"; }
};
/*
* Fuse Conv and HardSigmoid class
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
MainTest("hard_swish");
}
TEST(ConvActivationFusePass, conv_mish_fuse_pass) { MainTest("mish"); }
TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) {
MainTest("hard_sigmoid");
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;

void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid",
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid", "mish",
"hard_swish"};

for (std::string act_type : act_types) FuseFCAct(graph, act_type);
Expand Down Expand Up @@ -99,5 +99,6 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.LE("fc", 0)
.LE("gelu", 0)
.LE("sigmoid", 0)
.LE("mish", 1)
.LE("hard_swish", 0)
.LE("tanh", 0));
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace ir {
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid and tanh are supported as an
* activation function.
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,36 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
}
}

TEST(FuseFCActOneDNNPass, FuseWithMish) {
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
test::CreateOp(&prog, "fc",
{
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);

Graph graph(prog);
constexpr int removed_nodes_count = 2;

EXPECT_TRUE(test::RunPassAndAssert(&graph, "fc_act_mkldnn_fuse_pass", "x",
"act_y", removed_nodes_count));
EXPECT_TRUE(test::AssertOpsCount(graph, {{"fc", 1}, {"mish", 0}}));

for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
auto act_type =
BOOST_GET_CONST(std::string, op->GetAttr("activation_type"));
EXPECT_EQ(act_type.compare("mish"), 0);
}
}
}

TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
auto prog =
test::BuildProgramDesc({"x", "fc_y", "act_y"}, {"weights", "bias"});
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", //
"conv_mish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", //
// TODO(baoachun) fix int8 accuracy
"conv_gelu_mkldnn_fuse_pass",
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ template <typename T>
using HardSwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;

template <typename T>
using MishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;

template <typename T>
using SigmoidMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
Expand Down Expand Up @@ -272,6 +276,10 @@ template <typename T>
using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;

template <typename T>
using MishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;

template <typename T>
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
Expand Down Expand Up @@ -339,6 +347,8 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
MishMKLDNNGradFunctor);

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ class ConvMKLDNNHandlerT
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "mish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_mish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,11 @@ class FCPrimitiveFactory {
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "mish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_mish,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def set_params(self):
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'


class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "mish"
self.pass_name = 'conv_mish_mkldnn_fuse_pass'


class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,27 @@ def test_check_output(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))


class FCMishOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 128, 768], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=3072, num_flatten_dims=2)
mish_out = fluid.layers.mish(fc_out)

self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}

self.fetch_list = [mish_out]
self.enable_mkldnn = True

def set_params(self):
self.pass_name = "fc_act_mkldnn_fuse_pass"

def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,19 @@ def op_grad(self, dout, x):
return dout


class TestMKLDNNMishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "mish"

def op_forward(self, x):
return x * np.tanh(np.log(1 + np.exp(x)))

def op_grad(self, dout, x):
omega = np.exp(3 * x) + 4 * np.exp(2 * x) + np.exp(x) * (4 * x + 6
) + 4 * (x + 1)
delta = np.exp(2 * x) + 2 * np.exp(x) + 2
return dout * ((np.exp(x) * omega) / delta**2)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,19 @@ def init_dtype(self):
self.dtype = np.float32


class TestMKLDNNMish(TestActivation):
def setUp(self):
self.op_type = "mish"
self.dtype = np.float32

x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype(self.dtype)
out = x * np.tanh(np.log(1 + np.exp(x)))

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}


class TestMKLDNNSigmoidDim4(TestSigmoid):
def setUp(self):
super(TestMKLDNNSigmoidDim4, self).setUp()
Expand Down