Skip to content

Commit

Permalink
support dynamic shape for group_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Sep 17, 2024
1 parent 9baee85 commit 864953b
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 109 deletions.
345 changes: 236 additions & 109 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -2414,60 +2414,6 @@ void group_norm_grad(const Tensor& x,
PADDLE_THROW(common::errors::Unimplemented(
"Only support NCHW and NHWC format in rank {3, 4, 5}."));
}
int N = x_dims[0];
int C;
int hw = 1;
std::vector<int64_t> reduce_axis;

if (data_layout_ == DataLayout::kNCHW) {
C = x_dims[1];
for (int i = 2; i < rank; ++i) {
hw *= x_dims[i];
reduce_axis.push_back(i);
}
} else if (data_layout_ == DataLayout::kNHWC) {
C = x_dims[rank - 1];
for (int i = 1; i < (rank - 1); ++i) {
hw *= x_dims[i];
reduce_axis.push_back(i);
}
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"Unsupported storage order: %s", data_layout));
}

int g_num = C / groups;

Tensor x_data = x;
Tensor out_grad_data = out_grad;

if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}

if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}

auto shape_group = std::vector<int64_t>({N, groups, g_num});

std::vector<int64_t> whole_group_shape;
if (data_layout_ == DataLayout::kNCHW) {
whole_group_shape = std::vector<int64_t>({N, groups, g_num, -1});
} else {
whole_group_shape = std::vector<int64_t>({N, -1, groups, g_num});
}
auto var_eps = variance + epsilon;

auto inv_std = rsqrt<T>(var_eps);

auto inv_std_mul_s = inv_std / hw / g_num;
auto dtype = x_data.dtype();
auto sum_y_grad_mul_x =
sum<T>(out_grad_data * x_data, reduce_axis, dtype, false);
auto sum_y_grad = sum<T>(out_grad_data, reduce_axis, dtype, false);

Tensor scale_data;
if (scale) {
Expand All @@ -2478,73 +2424,254 @@ void group_norm_grad(const Tensor& x,
bias_data = bias.get();
}

if (x_grad) {
Tensor d1;
Tensor d2;
Tensor p1;
if (scale) {
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
}
d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
d2 = (reshape<T>(sum_y_grad * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
p1 = reshape<T>(inv_std, std::vector<int64_t>({N, groups, 1})) *
reshape<T>(scale_data, std::vector<int64_t>({1, groups, g_num}));
if (has_dynamic_shape(x_dims)) {
Tensor x_shape = shape<T>(x);
Tensor ones = full<T>({1}, 1, x_shape.dtype());
Tensor N = get_slice<T>(x_shape, 0);
Tensor C;
Tensor hw = ones;
std::vector<int64_t> reduce_axis;
if (data_layout_ == DataLayout::kNCHW) {
C = get_slice<T>(x_shape, 1);
for (int i = 2; i < rank; ++i) {
hw = hw * get_slice<T>(x_shape, i);
reduce_axis.push_back(i);
}
} else if (data_layout_ == DataLayout::kNHWC) {
C = get_slice<T>(x_shape, rank - 1);
for (int i = 1; i < (rank - 1); ++i) {
hw = hw * get_slice<T>(x_shape, i);
reduce_axis.push_back(i);
}
} else {
d1 = (reshape<T>(sum_y_grad_mul_x, shape_group)).sum({2}, dtype, false);
d2 = (reshape<T>(sum_y_grad, shape_group)).sum({2}, dtype, false);
p1 = (reshape<T>(inv_std, {N, groups, 1}))
.expand(shape_group); // [n, g, g_n]
PADDLE_THROW(common::errors::InvalidArgument(
"Unsupported storage order: %s", data_layout));
}
Tensor group_tensor = full<T>({1}, groups, x_shape.dtype());
Tensor g_num = cast<T>(C / group_tensor, x_shape.dtype());

Tensor x_data = x;
Tensor out_grad_data = out_grad;

auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g]
auto p3 = -p2 * mean - d2 * inv_std_mul_s;
std::vector<int64_t> first_shape;
std::vector<int64_t> second_shape;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}
if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}
auto shape_group = concat<T>({N, group_tensor, g_num});
Tensor minus_one = full<T>({1}, -1, x_shape.dtype());
Tensor whole_group_shape;
if (data_layout_ == DataLayout::kNCHW) {
first_shape = get_unsqueeze_dims(p1, {3}); // [n, g, g_n, 1]
second_shape = get_unsqueeze_dims(p2, {2, 3}); // [n, g, 1, 1]
whole_group_shape = concat<T>({N, group_tensor, g_num, minus_one});
} else {
first_shape = get_unsqueeze_dims(p1, {1}); // [n, 1, g, g_n]
second_shape = get_unsqueeze_dims(p2, {1, 3}); // [n, 1, g, 1]
}

p1 = reshape<T>(p1, first_shape);
p2 = reshape<T>(p2, second_shape);
p3 = reshape<T>(p3, second_shape);
auto tmp_1 =
reshape<T>(out_grad_data, whole_group_shape) * p1; // [n, hw, g, g_n]
auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = reshape<T>(x_grad_data, x.shape());
whole_group_shape = concat<T>({N, minus_one, group_tensor, g_num});
}
auto var_eps =
variance + backend::full_with_tensor<T>(
shape<T>(variance), epsilon, variance.dtype());
auto inv_std = rsqrt<T>(var_eps);
auto inv_std_mul_s = inv_std / cast<T>(hw, inv_std.dtype()) /
cast<T>(g_num, inv_std.dtype());
auto dtype = x_data.dtype();
auto sum_y_grad_mul_x =
sum<T>(out_grad_data * x_data, reduce_axis, dtype, false);
auto sum_y_grad = sum<T>(out_grad_data, reduce_axis, dtype, false);
if (x_grad) {
Tensor d1, d2, p1;
if (scale) {
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
}
d1 = (backend::reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
d2 = (backend::reshape<T>(sum_y_grad * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
p1 = backend::reshape<T>(inv_std, concat<T>({N, group_tensor, ones})) *
backend::reshape<T>(scale_data,
concat<T>({ones, group_tensor, g_num}));
} else {
d1 = (backend::reshape<T>(sum_y_grad_mul_x, shape_group))
.sum({2}, dtype, false);
d2 = (backend::reshape<T>(sum_y_grad, shape_group))
.sum({2}, dtype, false);
p1 = backend::reshape<T>(inv_std, concat<T>({N, group_tensor, ones}));
p1 = backend::expand_with_tensor<T>(p1, shape_group); // [n, g, g_n]
}

auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g]
auto p3 = -p2 * mean - d2 * inv_std_mul_s;
Tensor first_shape, second_shape;
if (data_layout_ == DataLayout::kNCHW) {
first_shape =
get_unsqueeze_dims<T>(shape<T>(p1), {3}); // [n, g, g_n, 1]
second_shape =
get_unsqueeze_dims<T>(shape<T>(p2), {2, 3}); // [n, g, 1, 1]
} else {
first_shape =
get_unsqueeze_dims<T>(shape<T>(p1), {1}); // [n, 1, g, g_n]
second_shape =
get_unsqueeze_dims<T>(shape<T>(p2), {1, 3}); // [n, 1, g, 1]
}
p1 = backend::reshape<T>(p1, first_shape);
p2 = backend::reshape<T>(p2, second_shape);
p3 = backend::reshape<T>(p3, second_shape);
auto tmp_1 = backend::reshape<T>(out_grad_data, whole_group_shape) *
p1; // [n, hw, g, g_n]
auto tmp_2 = backend::reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = backend::reshape<T>(x_grad_data, x_shape);
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad);
}
if (scale_grad) {
if (scale) {
auto third_shape = get_unsqueeze_dims<T>(shape<T>(mean), {2});
auto tmp1 = (backend::reshape<T>(sum_y_grad_mul_x, shape_group) -
backend::reshape<T>(sum_y_grad, shape_group) *
backend::reshape<T>(mean, third_shape)) *
backend::reshape<T>(inv_std, third_shape);
auto scale_grad_tmp =
backend::reshape<T>(tmp1.sum({0}, scale->dtype(), false), C);
set_output<T>(scale_grad_tmp, scale_grad);
}
}
if (bias_grad) {
if (bias) {
auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false);
set_output<T>(bias_grad_tmp, bias_grad);
}
}
} else {
int N = x_dims[0];
int C;
int hw = 1;
std::vector<int64_t> reduce_axis;

if (data_layout_ == DataLayout::kNCHW) {
C = x_dims[1];
for (int i = 2; i < rank; ++i) {
hw *= x_dims[i];
reduce_axis.push_back(i);
}
} else if (data_layout_ == DataLayout::kNHWC) {
C = x_dims[rank - 1];
for (int i = 1; i < (rank - 1); ++i) {
hw *= x_dims[i];
reduce_axis.push_back(i);
}
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"Unsupported storage order: %s", data_layout));
}

int g_num = C / groups;

Tensor x_data = x;
Tensor out_grad_data = out_grad;

if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
x_data = cast<T>(x, phi::DataType::FLOAT32);
}

set_output<T>(x_grad_data, x_grad);
}
if (out_grad.dtype() == phi::DataType::FLOAT16 ||
out_grad.dtype() == phi::DataType::BFLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}

if (scale_grad) {
if (scale) {
auto third_shape = get_unsqueeze_dims(mean, {2});
auto tmp1 = (reshape<T>(sum_y_grad_mul_x, shape_group) -
reshape<T>(sum_y_grad, shape_group) *
reshape<T>(mean, third_shape)) *
reshape<T>(inv_std, third_shape);
auto scale_grad_tmp =
reshape<T>(tmp1.sum({0}, scale->dtype(), false), {C});
set_output<T>(scale_grad_tmp, scale_grad);
auto shape_group = std::vector<int64_t>({N, groups, g_num});

std::vector<int64_t> whole_group_shape;
if (data_layout_ == DataLayout::kNCHW) {
whole_group_shape = std::vector<int64_t>({N, groups, g_num, -1});
} else {
whole_group_shape = std::vector<int64_t>({N, -1, groups, g_num});
}
auto var_eps = variance + epsilon;

auto inv_std = rsqrt<T>(var_eps);

auto inv_std_mul_s = inv_std / hw / g_num;
auto dtype = x_data.dtype();
auto sum_y_grad_mul_x =
sum<T>(out_grad_data * x_data, reduce_axis, dtype, false);
auto sum_y_grad = sum<T>(out_grad_data, reduce_axis, dtype, false);

if (x_grad) {
Tensor d1;
Tensor d2;
Tensor p1;
if (scale) {
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
scale_data = cast<T>(scale_data, phi::DataType::FLOAT32);
}
d1 = (reshape<T>(sum_y_grad_mul_x * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
d2 = (reshape<T>(sum_y_grad * scale_data, shape_group))
.sum(std::vector<int64_t>({2}), dtype, false);
p1 = reshape<T>(inv_std, std::vector<int64_t>({N, groups, 1})) *
reshape<T>(scale_data, std::vector<int64_t>({1, groups, g_num}));
} else {
d1 = (reshape<T>(sum_y_grad_mul_x, shape_group)).sum({2}, dtype, false);
d2 = (reshape<T>(sum_y_grad, shape_group)).sum({2}, dtype, false);
p1 = (reshape<T>(inv_std, {N, groups, 1}))
.expand(shape_group); // [n, g, g_n]
}

auto p2 = (d2 * mean - d1) * (inv_std_mul_s / var_eps); // [n, g]
auto p3 = -p2 * mean - d2 * inv_std_mul_s;
std::vector<int64_t> first_shape;
std::vector<int64_t> second_shape;
if (data_layout_ == DataLayout::kNCHW) {
first_shape = get_unsqueeze_dims(p1, {3}); // [n, g, g_n, 1]
second_shape = get_unsqueeze_dims(p2, {2, 3}); // [n, g, 1, 1]
} else {
first_shape = get_unsqueeze_dims(p1, {1}); // [n, 1, g, g_n]
second_shape = get_unsqueeze_dims(p2, {1, 3}); // [n, 1, g, 1]
}

p1 = reshape<T>(p1, first_shape);
p2 = reshape<T>(p2, second_shape);
p3 = reshape<T>(p3, second_shape);
auto tmp_1 =
reshape<T>(out_grad_data, whole_group_shape) * p1; // [n, hw, g, g_n]
auto tmp_2 = reshape<T>(x_data, whole_group_shape) * p2 + p3;
auto x_grad_data = tmp_1 + tmp_2;
x_grad_data = reshape<T>(x_grad_data, x.shape());
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}

set_output<T>(x_grad_data, x_grad);
}
}

if (bias_grad) {
if (bias) {
auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false);
set_output<T>(bias_grad_tmp, bias_grad);
if (scale_grad) {
if (scale) {
auto third_shape = get_unsqueeze_dims(mean, {2});
auto tmp1 = (reshape<T>(sum_y_grad_mul_x, shape_group) -
reshape<T>(sum_y_grad, shape_group) *
reshape<T>(mean, third_shape)) *
reshape<T>(inv_std, third_shape);
auto scale_grad_tmp =
reshape<T>(tmp1.sum({0}, scale->dtype(), false), {C});
set_output<T>(scale_grad_tmp, scale_grad);
}
}

if (bias_grad) {
if (bias) {
auto bias_grad_tmp = sum_y_grad.sum({0}, bias->dtype(), false);
set_output<T>(bias_grad_tmp, bias_grad);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"pd_op.gather",
"pd_op.gather_nd",
"pd_op.gelu",
"pd_op.group_norm",
"pd_op.hardswish",
"pd_op.leaky_relu",
"pd_op.log",
Expand Down
Loading

0 comments on commit 864953b

Please sign in to comment.