Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for BF16 datatype for all oneDNN activation kernels #40721

Merged
merged 5 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions paddle/fluid/operators/abs_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ namespace operators {
class AbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -72,8 +87,17 @@ class AbsGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

Expand Down
43 changes: 16 additions & 27 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,7 @@ using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor, \
grad_functor) \
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL( \
act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>, \
Expand All @@ -339,30 +331,27 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>);

#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(mish, MishMKLDNNFunctor, MishMKLDNNGradFunctor); \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);

REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
ReluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
MishMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
softplus, MKLDNN, paddle::platform::CPUPlace,
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>,
ops::MKLDNNActivationKernel<
ops::SoftplusMKLDNNFunctor<paddle::platform::bfloat16>>);
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def setUp(self):
self.dtype = np.uint16
self.init_data()
self.config()
self.set_attrs()
self.out = self.op_forward(self.x)

self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': self.out}
self.set_attrs()

def calculate_grads(self):
self.dx = self.op_grad(self.out, self.x)
Expand Down Expand Up @@ -162,5 +162,110 @@ def op_grad(self, dout, x):
return dout * ((np.exp(x) * omega) / delta**2)


class TestMKLDNNRelu6BF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "relu6"

def op_forward(self, x):
return np.clip(x, 0, 6)

def op_grad(self, dout, x):
return np.where((x > 0) & (x <= 6), dout, 0)


class TestMKLDNNLeakyReluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "leaky_relu"

def op_forward(self, x):
return np.where(x > 0, x, self.alpha * x)

def op_grad(self, dout, x):
return np.where(x > 0, dout, self.alpha * dout)

def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}


class TestMKLDNNSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "swish"

def expit(self, val):
return 1 / (1 + np.exp(-self.beta * val))

def op_forward(self, x):
return x * self.expit(x)

def op_grad(self, dout, x):
return dout * self.expit(x) * (1 + self.beta * x * (1 - self.expit(x)))

def set_attrs(self):
self.beta = 0.2
self.attrs = {"use_mkldnn": True, "beta": self.beta}


class TestMKLDNNHardSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "hard_swish"

def op_forward(self, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, result, result * (result + 3) / 6)

def op_grad(self, dout, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, dout, dout * (2 * x + 3) / 6)


class TestMKLDNNTanhBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "tanh"

def op_forward(self, x):
return np.tanh(x)

def op_grad(self, dout, x):
return dout * (1 - np.tanh(x)**2)


class TestMKLDNNAbsBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "abs"

def op_forward(self, x):
return np.absolute(x)

def op_grad(self, dout, x):
return dout * np.sign(x)


class TestMKLDNNEluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "elu"

def op_forward(self, x):
return np.where(x > 0, x, self.alpha * (np.exp(x) - 1))

def op_grad(self, dout, x):
return np.where(x > 0, dout, dout * self.alpha * np.exp(x))

def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}


class TestMKLDNNExpBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "exp"

def op_forward(self, x):
return np.exp(x)

def op_grad(self, dout, x):
return dout * np.exp(x)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
Expand All @@ -30,23 +30,32 @@ def ref_softplus(x, beta, threshold):
return out


@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported")
@OpTestTool.skip_if_not_cpu_bf16()
class TestSoftplusOneDNNOp(OpTest):
def setUp(self):
self.op_type = "softplus"
self.beta = 1
self.threshold = 20
self.config()
self.set_dtype()
self.attrs = {'use_mkldnn': True, 'beta': self.beta}
self.inputs = {'X': np.random.random(self.x_shape).astype(np.float32)}
self.x = np.random.random(self.x_shape)
self.out = ref_softplus(self.x, self.beta, self.threshold)

if self.dtype != np.float32:
self.x = convert_float_to_uint16(self.x)

self.inputs = {'X': self.out}
self.outputs = {
'Out': ref_softplus(self.inputs['X'], self.beta, self.threshold)
'Out': ref_softplus(self.out, self.beta, self.threshold)
}

def config(self):
self.x_shape = (10, 10)

def set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()

Expand All @@ -73,6 +82,27 @@ def config(self):
self.beta = 0.4


class TestSoftplusBF16OneDNNOp(TestSoftplusOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus4DBF16OneDNNOp(TestSoftplus4DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus6DBF16OneDNNOp(TestSoftplus6DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


class TestSoftplus3DExtendedFunctorBF16OneDNNOp(
TestSoftplus3DExtendedFunctorOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16


if __name__ == "__main__":
paddle.enable_static()
unittest.main()