Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bf16] add bf16 kernel: layer_norm p_norm reduce_sum #39843

Merged
merged 14 commits into from
Mar 1, 2022
Merged
4 changes: 0 additions & 4 deletions paddle/fluid/operators/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,5 @@ using CUDA = paddle::platform::CUDADeviceContext;
ops::CastOpKernel<CUDA, plat::complex<float>>, \
ops::CastOpKernel<CUDA, plat::complex<double>>, ##__VA_ARGS__);

#if !defined(PADDLE_WITH_HIP)
// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel<CUDA, plat::bfloat16>)
#else
REGISTER_CAST_CUDA_BASE(transfer_dtype)
#endif
6 changes: 3 additions & 3 deletions paddle/fluid/operators/layer_norm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U x_tmp = x[it][jt];
U x_tmp = static_cast<U>(x[it][jt]);
U y_tmp = var_cur_row * (x_tmp - mean_cur_row);
U dy_tmp = static_cast<U>(gamma[it][jt]) *
static_cast<U>(dout[it][jt]); // scale * dy
U dout_tmp = dout[it][jt]; // dy
static_cast<U>(dout[it][jt]); // scale * dy
U dout_tmp = static_cast<U>(dout[it][jt]); // dy

// used for get dx (row reduction)
sum_loss1 += dy_tmp; // scale * dy, sum_1
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/operators/layer_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ REGISTER_OP_CUDA_KERNEL(
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#elif CUDNN_VERSION_MIN(8, 1, 0)
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(
layer_norm,
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/p_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ __device__ __forceinline__ int sgn(T val) {
__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) {
return static_cast<platform::float16>(abs(static_cast<float>(x)));
}

__device__ __forceinline__ platform::bfloat16 inline_abs(platform::bfloat16 x) {
return static_cast<platform::bfloat16>(abs(static_cast<float>(x)));
}

__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }

Expand All @@ -53,6 +58,11 @@ __device__ __forceinline__ platform::float16 inline_pow(
return static_cast<platform::float16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ platform::bfloat16 inline_pow(
platform::bfloat16 base, platform::bfloat16 exponent) {
return static_cast<platform::bfloat16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ float inline_pow(float base, float exponent) {
return pow(base, exponent);
}
Expand Down Expand Up @@ -202,9 +212,11 @@ using CUDA = paddle::platform::CUDADeviceContext;

REGISTER_OP_CUDA_KERNEL(p_norm,
ops::PnormCUDAKernel<CUDA, paddle::platform::float16>,
ops::PnormCUDAKernel<CUDA, paddle::platform::bfloat16>,
ops::PnormCUDAKernel<CUDA, float>,
ops::PnormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
p_norm_grad, ops::PnormGradCUDAKernel<CUDA, paddle::platform::float16>,
ops::PnormGradCUDAKernel<CUDA, paddle::platform::bfloat16>,
ops::PnormGradCUDAKernel<CUDA, float>,
ops::PnormGradCUDAKernel<CUDA, double>);
1 change: 1 addition & 0 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>, CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<paddle::platform::float16>,
CUDAReduceSumGradKernel<paddle::platform::bfloat16>,
CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex<float>>,
CUDAReduceSumGradKernel<paddle::platform::complex<double>>);
4 changes: 0 additions & 4 deletions paddle/phi/kernels/gpu/cast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,4 @@ void CastKernel(const Context& dev_ctx,
paddle::experimental::DataType::UNDEFINED); \
}

#if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16)
#else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/math_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ PD_REGISTER_KERNEL(sum_raw,
float,
double,
float16,
bfloat16,
int16_t,
int,
int64_t,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/math_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ PD_REGISTER_KERNEL(sum,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int16_t,
int,
int64_t,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def check_output_with_place(self,
else:
atol = 2
else:
atol = 1e-2
atol = 1e-1

if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
Expand Down
47 changes: 47 additions & 0 deletions python/paddle/fluid/tests/unittests/test_layer_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,53 @@ def assert_equal(x, y):
assert_equal(b_g_np_1, b_g_np_2)


class TestBF16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()

x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)

if dtype == "bfloat16":
x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)

x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False

y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])

y_np = y.cast('float32').numpy()
x_g_np = x_g.cast('float32').numpy()
w_g_np = w_g.cast('float32').numpy()
b_g_np = b_g.cast('float32').numpy()

paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np

def test_main(self):
if (not core.is_compiled_with_cuda()) or (core.cudnn_version() < 8100):
return
x_np = np.random.random([10, 20]).astype('float32')
weight_np = np.random.random([20]).astype('float32')
bias_np = np.random.random([20]).astype('float32')

y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float32')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'bfloat16')

def assert_equal(x, y):
self.assertTrue(np.allclose(x, y, atol=1.e-1))

assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)


class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
def test_main(self):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
Expand Down
76 changes: 75 additions & 1 deletion python/paddle/fluid/tests/unittests/test_norm_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
Expand Down Expand Up @@ -282,6 +282,80 @@ def init_test_case(self):
self.asvector = True


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPnormBF16Op(OpTest):
def setUp(self):
self.op_type = "p_norm"
self.init_test_case()
self.x = (np.random.random(self.shape) + 0.5).astype(np.float32)
self.norm = p_norm(self.x, self.axis, self.porder, self.keepdim,
self.asvector)
self.gradient = self.calc_gradient()
self.inputs = {'X': convert_float_to_uint16(self.x)}
self.attrs = {
'epsilon': self.epsilon,
'axis': self.axis,
'keepdim': self.keepdim,
'porder': float(self.porder),
'asvector': self.asvector
}
self.outputs = {'Out': convert_float_to_uint16(self.norm)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', user_defined_grads=self.gradient)

def init_test_case(self):
self.shape = [2, 3, 4, 5]
self.axis = 1
self.epsilon = 1e-12
self.porder = 2.0
self.keepdim = False
self.dtype = np.uint16
self.asvector = False

def calc_gradient(self):
self.attrs = {
'epsilon': self.epsilon,
'axis': self.axis,
'keepdim': self.keepdim,
'porder': float(self.porder),
'asvector': self.asvector
}
x = self.x
porder = self.attrs["porder"]
axis = self.attrs["axis"]
asvector = self.attrs["asvector"]
x_dtype = x.dtype
x = x.astype(np.float32) if x.dtype == np.float16 else x
if porder == 0:
grad = np.zeros(x.shape).astype(x.dtype)
elif porder in [float("inf"), float("-inf")]:
norm = p_norm(
x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector)
x_abs = np.abs(x)
grad = np.sign(x)
grad[x_abs != norm] = 0.0
else:
norm = p_norm(
x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector)
grad = np.power(norm, 1 - porder) * np.power(
np.abs(x), porder - 1) * np.sign(x)

numel = 1
for s in x.shape:
numel *= s
divisor = numel if asvector else x.shape[axis]
numel /= divisor
return [grad.astype(x_dtype) * 1 / numel]


def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
Expand Down
33 changes: 32 additions & 1 deletion python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
Expand Down Expand Up @@ -61,6 +61,37 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSumOp_bf16(OpTest):
def setUp(self):
np.random.seed(100)
self.op_type = "reduce_sum"
self.dtype = np.uint16
self.x = np.random.uniform(0, 0.1, (2, 5, 10)).astype(np.float32)
self.attrs = {'dim': [0, 1, 2]}
self.out = self.x.sum(axis=tuple(self.attrs['dim']))
self.gradient = self.calc_gradient()

self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
self.gradient = self.calc_gradient()

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', user_defined_grads=self.gradient)

def calc_gradient(self):
x = self.x
grad = np.ones(x.shape, dtype=x.dtype)
return [grad]


class TestSumOp_fp16_withInt(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
Expand Down