diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index bc25dc9935830..7ed2196e07d86 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -2414,60 +2414,6 @@ void group_norm_grad(const Tensor& x, PADDLE_THROW(common::errors::Unimplemented( "Only support NCHW and NHWC format in rank {3, 4, 5}.")); } - int N = x_dims[0]; - int C; - int hw = 1; - std::vector reduce_axis; - - if (data_layout_ == DataLayout::kNCHW) { - C = x_dims[1]; - for (int i = 2; i < rank; ++i) { - hw *= x_dims[i]; - reduce_axis.push_back(i); - } - } else if (data_layout_ == DataLayout::kNHWC) { - C = x_dims[rank - 1]; - for (int i = 1; i < (rank - 1); ++i) { - hw *= x_dims[i]; - reduce_axis.push_back(i); - } - } else { - PADDLE_THROW(common::errors::InvalidArgument( - "Unsupported storage order: %s", data_layout)); - } - - int g_num = C / groups; - - Tensor x_data = x; - Tensor out_grad_data = out_grad; - - if (x.dtype() == phi::DataType::FLOAT16 || - x.dtype() == phi::DataType::BFLOAT16) { - x_data = cast(x, phi::DataType::FLOAT32); - } - - if (out_grad.dtype() == phi::DataType::FLOAT16 || - out_grad.dtype() == phi::DataType::BFLOAT16) { - out_grad_data = cast(out_grad, phi::DataType::FLOAT32); - } - - auto shape_group = std::vector({N, groups, g_num}); - - std::vector whole_group_shape; - if (data_layout_ == DataLayout::kNCHW) { - whole_group_shape = std::vector({N, groups, g_num, -1}); - } else { - whole_group_shape = std::vector({N, -1, groups, g_num}); - } - auto var_eps = variance + epsilon; - - auto inv_std = rsqrt(var_eps); - - auto inv_std_mul_s = inv_std / hw / g_num; - auto dtype = x_data.dtype(); - auto sum_y_grad_mul_x = - sum(out_grad_data * x_data, reduce_axis, dtype, false); - auto sum_y_grad = sum(out_grad_data, reduce_axis, dtype, false); Tensor scale_data; if (scale) { @@ -2478,73 +2424,254 @@ void group_norm_grad(const Tensor& x, bias_data = bias.get(); } - if (x_grad) { - Tensor d1; - Tensor d2; - Tensor p1; - if (scale) { - if (scale_data.dtype() == phi::DataType::FLOAT16 || - scale_data.dtype() == phi::DataType::BFLOAT16) { - scale_data = cast(scale_data, phi::DataType::FLOAT32); - } - d1 = (reshape(sum_y_grad_mul_x * scale_data, shape_group)) - .sum(std::vector({2}), dtype, false); - d2 = (reshape(sum_y_grad * scale_data, shape_group)) - .sum(std::vector({2}), dtype, false); - p1 = reshape(inv_std, std::vector({N, groups, 1})) * - reshape(scale_data, std::vector({1, groups, g_num})); + if (has_dynamic_shape(x_dims)) { + Tensor x_shape = shape(x); + Tensor ones = full({1}, 1, x_shape.dtype()); + Tensor N = get_slice(x_shape, 0); + Tensor C; + Tensor hw = ones; + std::vector reduce_axis; + if (data_layout_ == DataLayout::kNCHW) { + C = get_slice(x_shape, 1); + for (int i = 2; i < rank; ++i) { + hw = hw * get_slice(x_shape, i); + reduce_axis.push_back(i); + } + } else if (data_layout_ == DataLayout::kNHWC) { + C = get_slice(x_shape, rank - 1); + for (int i = 1; i < (rank - 1); ++i) { + hw = hw * get_slice(x_shape, i); + reduce_axis.push_back(i); + } } else { - d1 = (reshape(sum_y_grad_mul_x, shape_group)).sum({2}, dtype, false); - d2 = (reshape(sum_y_grad, shape_group)).sum({2}, dtype, false); - p1 = (reshape(inv_std, {N, groups, 1})) - .expand(shape_group); // [n, g, g_n] + PADDLE_THROW(common::errors::InvalidArgument( + "Unsupported storage order: %s", data_layout)); } + Tensor group_tensor = full({1}, groups, x_shape.dtype()); + Tensor g_num = cast(C / group_tensor, x_shape.dtype()); + + Tensor x_data = x; + Tensor out_grad_data = out_grad; - auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g] - auto p3 = -p2 * mean - d2 * inv_std_mul_s; - std::vector first_shape; - std::vector second_shape; + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + x_data = cast(x, phi::DataType::FLOAT32); + } + if (out_grad.dtype() == phi::DataType::FLOAT16 || + out_grad.dtype() == phi::DataType::BFLOAT16) { + out_grad_data = cast(out_grad, phi::DataType::FLOAT32); + } + auto shape_group = concat({N, group_tensor, g_num}); + Tensor minus_one = full({1}, -1, x_shape.dtype()); + Tensor whole_group_shape; if (data_layout_ == DataLayout::kNCHW) { - first_shape = get_unsqueeze_dims(p1, {3}); // [n, g, g_n, 1] - second_shape = get_unsqueeze_dims(p2, {2, 3}); // [n, g, 1, 1] + whole_group_shape = concat({N, group_tensor, g_num, minus_one}); } else { - first_shape = get_unsqueeze_dims(p1, {1}); // [n, 1, g, g_n] - second_shape = get_unsqueeze_dims(p2, {1, 3}); // [n, 1, g, 1] - } - - p1 = reshape(p1, first_shape); - p2 = reshape(p2, second_shape); - p3 = reshape(p3, second_shape); - auto tmp_1 = - reshape(out_grad_data, whole_group_shape) * p1; // [n, hw, g, g_n] - auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; - auto x_grad_data = tmp_1 + tmp_2; - x_grad_data = reshape(x_grad_data, x.shape()); + whole_group_shape = concat({N, minus_one, group_tensor, g_num}); + } + auto var_eps = + variance + backend::full_with_tensor( + shape(variance), epsilon, variance.dtype()); + auto inv_std = rsqrt(var_eps); + auto inv_std_mul_s = inv_std / cast(hw, inv_std.dtype()) / + cast(g_num, inv_std.dtype()); + auto dtype = x_data.dtype(); + auto sum_y_grad_mul_x = + sum(out_grad_data * x_data, reduce_axis, dtype, false); + auto sum_y_grad = sum(out_grad_data, reduce_axis, dtype, false); + if (x_grad) { + Tensor d1, d2, p1; + if (scale) { + if (scale_data.dtype() == phi::DataType::FLOAT16 || + scale_data.dtype() == phi::DataType::BFLOAT16) { + scale_data = cast(scale_data, phi::DataType::FLOAT32); + } + d1 = (backend::reshape(sum_y_grad_mul_x * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + d2 = (backend::reshape(sum_y_grad * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + p1 = backend::reshape(inv_std, concat({N, group_tensor, ones})) * + backend::reshape(scale_data, + concat({ones, group_tensor, g_num})); + } else { + d1 = (backend::reshape(sum_y_grad_mul_x, shape_group)) + .sum({2}, dtype, false); + d2 = (backend::reshape(sum_y_grad, shape_group)) + .sum({2}, dtype, false); + p1 = backend::reshape(inv_std, concat({N, group_tensor, ones})); + p1 = backend::expand_with_tensor(p1, shape_group); // [n, g, g_n] + } + + auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g] + auto p3 = -p2 * mean - d2 * inv_std_mul_s; + Tensor first_shape, second_shape; + if (data_layout_ == DataLayout::kNCHW) { + first_shape = + get_unsqueeze_dims(shape(p1), {3}); // [n, g, g_n, 1] + second_shape = + get_unsqueeze_dims(shape(p2), {2, 3}); // [n, g, 1, 1] + } else { + first_shape = + get_unsqueeze_dims(shape(p1), {1}); // [n, 1, g, g_n] + second_shape = + get_unsqueeze_dims(shape(p2), {1, 3}); // [n, 1, g, 1] + } + p1 = backend::reshape(p1, first_shape); + p2 = backend::reshape(p2, second_shape); + p3 = backend::reshape(p3, second_shape); + auto tmp_1 = backend::reshape(out_grad_data, whole_group_shape) * + p1; // [n, hw, g, g_n] + auto tmp_2 = backend::reshape(x_data, whole_group_shape) * p2 + p3; + auto x_grad_data = tmp_1 + tmp_2; + x_grad_data = backend::reshape(x_grad_data, x_shape); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + x_grad_data = cast(x_grad_data, x.dtype()); + } + set_output(x_grad_data, x_grad); + } + if (scale_grad) { + if (scale) { + auto third_shape = get_unsqueeze_dims(shape(mean), {2}); + auto tmp1 = (backend::reshape(sum_y_grad_mul_x, shape_group) - + backend::reshape(sum_y_grad, shape_group) * + backend::reshape(mean, third_shape)) * + backend::reshape(inv_std, third_shape); + auto scale_grad_tmp = + backend::reshape(tmp1.sum({0}, scale->dtype(), false), C); + set_output(scale_grad_tmp, scale_grad); + } + } + if (bias_grad) { + if (bias) { + auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false); + set_output(bias_grad_tmp, bias_grad); + } + } + } else { + int N = x_dims[0]; + int C; + int hw = 1; + std::vector reduce_axis; + + if (data_layout_ == DataLayout::kNCHW) { + C = x_dims[1]; + for (int i = 2; i < rank; ++i) { + hw *= x_dims[i]; + reduce_axis.push_back(i); + } + } else if (data_layout_ == DataLayout::kNHWC) { + C = x_dims[rank - 1]; + for (int i = 1; i < (rank - 1); ++i) { + hw *= x_dims[i]; + reduce_axis.push_back(i); + } + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "Unsupported storage order: %s", data_layout)); + } + + int g_num = C / groups; + + Tensor x_data = x; + Tensor out_grad_data = out_grad; + if (x.dtype() == phi::DataType::FLOAT16 || x.dtype() == phi::DataType::BFLOAT16) { - x_grad_data = cast(x_grad_data, x.dtype()); + x_data = cast(x, phi::DataType::FLOAT32); } - set_output(x_grad_data, x_grad); - } + if (out_grad.dtype() == phi::DataType::FLOAT16 || + out_grad.dtype() == phi::DataType::BFLOAT16) { + out_grad_data = cast(out_grad, phi::DataType::FLOAT32); + } - if (scale_grad) { - if (scale) { - auto third_shape = get_unsqueeze_dims(mean, {2}); - auto tmp1 = (reshape(sum_y_grad_mul_x, shape_group) - - reshape(sum_y_grad, shape_group) * - reshape(mean, third_shape)) * - reshape(inv_std, third_shape); - auto scale_grad_tmp = - reshape(tmp1.sum({0}, scale->dtype(), false), {C}); - set_output(scale_grad_tmp, scale_grad); + auto shape_group = std::vector({N, groups, g_num}); + + std::vector whole_group_shape; + if (data_layout_ == DataLayout::kNCHW) { + whole_group_shape = std::vector({N, groups, g_num, -1}); + } else { + whole_group_shape = std::vector({N, -1, groups, g_num}); + } + auto var_eps = variance + epsilon; + + auto inv_std = rsqrt(var_eps); + + auto inv_std_mul_s = inv_std / hw / g_num; + auto dtype = x_data.dtype(); + auto sum_y_grad_mul_x = + sum(out_grad_data * x_data, reduce_axis, dtype, false); + auto sum_y_grad = sum(out_grad_data, reduce_axis, dtype, false); + + if (x_grad) { + Tensor d1; + Tensor d2; + Tensor p1; + if (scale) { + if (scale_data.dtype() == phi::DataType::FLOAT16 || + scale_data.dtype() == phi::DataType::BFLOAT16) { + scale_data = cast(scale_data, phi::DataType::FLOAT32); + } + d1 = (reshape(sum_y_grad_mul_x * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + d2 = (reshape(sum_y_grad * scale_data, shape_group)) + .sum(std::vector({2}), dtype, false); + p1 = reshape(inv_std, std::vector({N, groups, 1})) * + reshape(scale_data, std::vector({1, groups, g_num})); + } else { + d1 = (reshape(sum_y_grad_mul_x, shape_group)).sum({2}, dtype, false); + d2 = (reshape(sum_y_grad, shape_group)).sum({2}, dtype, false); + p1 = (reshape(inv_std, {N, groups, 1})) + .expand(shape_group); // [n, g, g_n] + } + + auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g] + auto p3 = -p2 * mean - d2 * inv_std_mul_s; + std::vector first_shape; + std::vector second_shape; + if (data_layout_ == DataLayout::kNCHW) { + first_shape = get_unsqueeze_dims(p1, {3}); // [n, g, g_n, 1] + second_shape = get_unsqueeze_dims(p2, {2, 3}); // [n, g, 1, 1] + } else { + first_shape = get_unsqueeze_dims(p1, {1}); // [n, 1, g, g_n] + second_shape = get_unsqueeze_dims(p2, {1, 3}); // [n, 1, g, 1] + } + + p1 = reshape(p1, first_shape); + p2 = reshape(p2, second_shape); + p3 = reshape(p3, second_shape); + auto tmp_1 = + reshape(out_grad_data, whole_group_shape) * p1; // [n, hw, g, g_n] + auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; + auto x_grad_data = tmp_1 + tmp_2; + x_grad_data = reshape(x_grad_data, x.shape()); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + x_grad_data = cast(x_grad_data, x.dtype()); + } + + set_output(x_grad_data, x_grad); } - } - if (bias_grad) { - if (bias) { - auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false); - set_output(bias_grad_tmp, bias_grad); + if (scale_grad) { + if (scale) { + auto third_shape = get_unsqueeze_dims(mean, {2}); + auto tmp1 = (reshape(sum_y_grad_mul_x, shape_group) - + reshape(sum_y_grad, shape_group) * + reshape(mean, third_shape)) * + reshape(inv_std, third_shape); + auto scale_grad_tmp = + reshape(tmp1.sum({0}, scale->dtype(), false), {C}); + set_output(scale_grad_tmp, scale_grad); + } + } + + if (bias_grad) { + if (bias) { + auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false); + set_output(bias_grad_tmp, bias_grad); + } } } } diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 974ae63bc113c..6befa83a55362 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -48,6 +48,7 @@ "pd_op.gather", "pd_op.gather_nd", "pd_op.gelu", + "pd_op.group_norm", "pd_op.hardswish", "pd_op.leaky_relu", "pd_op.log", diff --git a/test/prim/pir_prim/test_prim_sub_graph_fghij_backward_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_fghij_backward_dynamic_shape.py index 3b056c1c45e72..59644c9612b7b 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_fghij_backward_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_fghij_backward_dynamic_shape.py @@ -17,6 +17,7 @@ import numpy as np from test_prim_sub_graph_backward_dynamic_shape import ( TestPrimBaseWithGrad, + TestPrimThreeWithGrad, TestPrimTwoWithGrad, apply_to_static, ) @@ -46,6 +47,22 @@ def gelu_net2(x): return paddle.nn.functional.gelu(x, approximate=False) +def group_norm_net1(x, y, z, epsilon=1e-5, num_groups=10): + return paddle._C_ops.group_norm(x, y, z, epsilon, num_groups, "NCHW") + + +def group_norm_net2(x, epsilon=1e-5, num_groups=10): + return paddle._C_ops.group_norm(x, None, None, epsilon, num_groups, "NCHW") + + +def group_norm_net3(x, y, z, epsilon=1e-5, num_groups=10): + return paddle._C_ops.group_norm(x, y, z, epsilon, num_groups, "NHWC") + + +def group_norm_net4(x, epsilon=1e-5, num_groups=10): + return paddle._C_ops.group_norm(x, None, None, epsilon, num_groups, "NHWC") + + def hardswish_net(x): return paddle.nn.functional.hardswish(x) @@ -222,6 +239,70 @@ def test_prim_all_dynamic(self): np.testing.assert_allclose(dr, d, rtol=self.rtol, atol=self.atol) +class TestPrimGroupNormWithGrad1(TestPrimThreeWithGrad): + def setUp(self): + np.random.seed(2023) + self.op_name = "pd_op.group_norm_grad" + self.dtype = "float32" + self.x_shape = [30, 60, 50, 60] + self.init_x_shape = [None, None, None, 60] + self.y_shape = [60] + self.init_y_shape = [None] + self.z_shape = [60] + self.init_z_shape = [None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.z = np.random.random(self.z_shape).astype(self.dtype) + self.net = group_norm_net1 + self.enable_cinn = False + self.tol = 7e-4 + + +class TestPrimGroupNormWithGrad2(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2023) + self.op_name = "pd_op.group_norm_grad" + self.dtype = "float32" + self.x_shape = [30, 60, 50, 60] + self.init_x_shape = [None, None, None, 60] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net2 + self.enable_cinn = False + self.tol = 1e-5 + + +class TestPrimGroupNormWithGrad3(TestPrimThreeWithGrad): + def setUp(self): + np.random.seed(2023) + self.op_name = "pd_op.group_norm_grad" + self.dtype = "float32" + self.x_shape = [30, 60, 50, 60] + self.init_x_shape = [None, 60, None, None] + self.y_shape = [60] + self.init_y_shape = [None] + self.z_shape = [60] + self.init_z_shape = [None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.z = np.random.random(self.z_shape).astype(self.dtype) + self.net = group_norm_net3 + self.enable_cinn = False + self.tol = 1e-5 + + +class TestPrimGroupNormWithGrad4(TestPrimBaseWithGrad): + def setUp(self): + np.random.seed(2023) + self.op_name = "pd_op.group_norm_grad" + self.dtype = "float32" + self.x_shape = [30, 60, 50, 60] + self.init_x_shape = [None, 60, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net4 + self.enable_cinn = False + self.tol = 1e-5 + + class TestPrimHardswishWithGrad(TestPrimBaseWithGrad): def setUp(self): np.random.seed(2024)