Skip to content

Commit

Permalink
fix code format error
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Dec 30, 2021
1 parent 3f4e54d commit 8a6c259
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 12 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace ir {
using string::PrettyLogDetail;

void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid",
"mish", "hard_swish"};
std::vector<std::string> act_types = {"gelu", "tanh", "sigmoid", "mish",
"hard_swish"};

for (std::string act_type : act_types) FuseFCAct(graph, act_type);
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace ir {
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported as an
* activation function.
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
Expand All @@ -42,4 +42,4 @@ class FuseFCActOneDNNPass : public FusePassBase {

} // namespace ir
} // namespace framework
} // namespace paddlea
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
{"Input", "x"}, {"Weights", "weights"}, {"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}},
false);
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);

Graph graph(prog);
constexpr int removed_nodes_count = 2;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ using HardSwishMKLDNNFunctor =

template <typename T>
using MishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;

template <typename T>
using SigmoidMKLDNNFunctor =
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class ConvMKLDNNHandlerT
fuse_alpha, fuse_beta);
} else if (fuse_activation == "mish") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_mish,
fuse_alpha, fuse_beta);
fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_linear,
fuse_alpha, fuse_beta);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,6 @@ def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ def op_forward(self, x):
return x * np.tanh(np.log(1 + np.exp(x)))

def op_grad(self, dout, x):
omega = np.exp(3 * x) + 4 * np.exp(2 * x) + np.exp(x) * (4 * x + 6) + 4 * (x + 1)
omega = np.exp(3 * x) + 4 * np.exp(2 * x) + np.exp(x) * (4 * x + 6
) + 4 * (x + 1)
delta = np.exp(2 * x) + 2 * np.exp(x) + 2
return dout * ((np.exp(x) * omega) / delta ** 2)
return dout * ((np.exp(x) * omega) / delta**2)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.fluid.tests.unittests.test_gelu_op import gelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd


class TestMKLDNNReluDim2(TestRelu):
def setUp(self):
super(TestMKLDNNReluDim2, self).setUp()
Expand Down Expand Up @@ -313,11 +314,12 @@ def setUp(self):
def init_dtype(self):
self.dtype = np.float32


class TestMKLDNNMish(TestActivation):
def setUp(self):
self.op_type = "mish"
self.dtype = np.float32

x = np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype(self.dtype)
out = x * np.tanh(np.log(1 + np.exp(x)))

Expand Down

0 comments on commit 8a6c259

Please sign in to comment.