Skip to content

Commit

Permalink
add group norm backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Yin Hongyun committed Nov 29, 2024
1 parent 6e71c12 commit fe30ff7
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
42 changes: 42 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5275,6 +5275,48 @@ def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_
GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd
return out


def group_norm_GB_backward(
input,
grad_outputs,
num_groups,
weight=None,
bias=None,
eps=1e-05,
reduced_axes=[2, 3],
channel_axis=1,
**kwargs,
) -> Tensor:
assert len(grad_outputs) == 1, "only accept 1 gradient to do backward"
save_mean = GLOBAL_STATE.pop("group_norm_GB_save_mean")
save_invstd = GLOBAL_STATE.pop("group_norm_GB_save_invstd")
grad_input = raw_like(input)
grad_weight = raw_like(weight)
grad_bias = raw_like(bias)
weight = None if weight is None else weight
bias = None if bias is None else bias

out = {"input": grad_input, "weight": grad_weight, "bias": grad_bias}
func = check_function("diopiGroupNormGBBackward")
reduced_axes = Sizes(reduced_axes)
ret = func(
input.context(),
grad_input,
grad_weight,
grad_bias,
grad_outputs[0],
input,
weight,
save_mean,
save_invstd,
num_groups,
reduced_axes,
channel_axis,
)
check_returncode(ret)
return {k: v for k, v in out.items() if v.requires_grad}


def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
dim = list(input.size().data)
save_mean = Tensor((dim[0], num_groups), input.get_dtype())
Expand Down
66 changes: 66 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4251,6 +4251,72 @@ diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out,
return diopiSuccess;
}

diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias,
diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis) {
impl::aten::setCurStream(ctx);
auto atGradOutput = impl::aten::buildATen(grad_output);
auto atInput = impl::aten::buildATen(input);
auto atWeight = impl::aten::buildATen(weight);
auto atSaveMean = impl::aten::buildATen(mean);
auto atSaveVar = impl::aten::buildATen(rstd);
auto atGradWeight = impl::aten::buildATen(grad_weight);
auto atGradBias = impl::aten::buildATen(grad_bias);
std::vector<int64_t> dims;
int64_t N = 1;
for (int i = 0; i < atInput.dim(); i++) {
if (i == channel_axis) {
continue;
} else {
bool is_reduced_axis = false;
for (int m = 0; m < reduced_axes.len; m++) {
if (i == reduced_axes.data[m]) {
is_reduced_axis = true;
break;
}
}
if (is_reduced_axis) {
continue;
} else {
dims.push_back(i);
N *= atInput.size(i);
}
}
}
dims.push_back(channel_axis);
int64_t HxW = 1;
for(auto i = 0; i < reduced_axes.len; i++) {
dims.push_back(reduced_axes.data[i]);
HxW *= atInput.size(reduced_axes.data[i]);
}
auto C = atInput.size(channel_axis);
auto permutedInput = atInput.permute(dims);
auto permutedShape = permutedInput.sizes();
auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous();

std::vector<int64_t> reverse_order(dims.size());
for (auto i = 0; i < atInput.dim(); i++) {
reverse_order[dims[i]] = i;
}

if (grad_weight && grad_bias) {
auto atGradInput = impl::aten::buildATen(grad_input).permute(dims).reshape({N, C, HxW, 1});

at::native_group_norm_backward_out(
atGradInput, atGradWeight, atGradBias, atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, true, true});
atGradInput = atGradInput.reshape(permutedShape).permute(reverse_order);
impl::aten::updateATen2Tensor(ctx, atGradInput, grad_input);
} else {
auto atOuts = at::native_group_norm_backward(
atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, grad_weight != nullptr, grad_bias != nullptr});
impl::aten::updateATen2Tensor(ctx, std::get<0>(atOuts).reshape(permutedShape).permute(reverse_order), grad_input);
impl::aten::updateATen2Tensor(ctx, std::get<1>(atOuts), grad_weight);
impl::aten::updateATen2Tensor(ctx, std::get<2>(atOuts), grad_bias);
}

return diopiSuccess;
}

diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps) {
Expand Down
7 changes: 7 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3607,6 +3607,13 @@ DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHan
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
double eps, diopiSize_t reduced_axes, const int64_t channel_axis);

/**
* @brief Compute the backward pass of diopiGroupNorm().
*/
DIOPI_API diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight,
diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input,
diopiConstTensorHandle_t weight, diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd,
int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis);
/**
* @brief Compute the backward pass of diopiGroupNorm().
* @param[in] ctx Context environment.
Expand Down

0 comments on commit fe30ff7

Please sign in to comment.