From fe30ff7444b8cfb8fe2c490df2d5368de1d1c77e Mon Sep 17 00:00:00 2001 From: Yin Hongyun Date: Fri, 29 Nov 2024 22:20:42 +0800 Subject: [PATCH] add group norm backward --- .../python/conformance/diopi_functions.py | 42 ++++++++++++ impl/torch/functions/functions.cpp | 66 +++++++++++++++++++ proto/include/diopi/functions.h | 7 ++ 3 files changed, 115 insertions(+) diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 8cc29cbaa..8c9456c22 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -5275,6 +5275,48 @@ def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_ GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd return out + +def group_norm_GB_backward( + input, + grad_outputs, + num_groups, + weight=None, + bias=None, + eps=1e-05, + reduced_axes=[2, 3], + channel_axis=1, + **kwargs, +) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + save_mean = GLOBAL_STATE.pop("group_norm_GB_save_mean") + save_invstd = GLOBAL_STATE.pop("group_norm_GB_save_invstd") + grad_input = raw_like(input) + grad_weight = raw_like(weight) + grad_bias = raw_like(bias) + weight = None if weight is None else weight + bias = None if bias is None else bias + + out = {"input": grad_input, "weight": grad_weight, "bias": grad_bias} + func = check_function("diopiGroupNormGBBackward") + reduced_axes = Sizes(reduced_axes) + ret = func( + input.context(), + grad_input, + grad_weight, + grad_bias, + grad_outputs[0], + input, + weight, + save_mean, + save_invstd, + num_groups, + reduced_axes, + channel_axis, + ) + check_returncode(ret) + return {k: v for k, v in out.items() if v.requires_grad} + + def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): dim = list(input.size().data) save_mean = Tensor((dim[0], num_groups), input.get_dtype()) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index b95f36575..b6caade61 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -4251,6 +4251,72 @@ diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, return diopiSuccess; } +diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias, + diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis) { + impl::aten::setCurStream(ctx); + auto atGradOutput = impl::aten::buildATen(grad_output); + auto atInput = impl::aten::buildATen(input); + auto atWeight = impl::aten::buildATen(weight); + auto atSaveMean = impl::aten::buildATen(mean); + auto atSaveVar = impl::aten::buildATen(rstd); + auto atGradWeight = impl::aten::buildATen(grad_weight); + auto atGradBias = impl::aten::buildATen(grad_bias); + std::vector dims; + int64_t N = 1; + for (int i = 0; i < atInput.dim(); i++) { + if (i == channel_axis) { + continue; + } else { + bool is_reduced_axis = false; + for (int m = 0; m < reduced_axes.len; m++) { + if (i == reduced_axes.data[m]) { + is_reduced_axis = true; + break; + } + } + if (is_reduced_axis) { + continue; + } else { + dims.push_back(i); + N *= atInput.size(i); + } + } + } + dims.push_back(channel_axis); + int64_t HxW = 1; + for(auto i = 0; i < reduced_axes.len; i++) { + dims.push_back(reduced_axes.data[i]); + HxW *= atInput.size(reduced_axes.data[i]); + } + auto C = atInput.size(channel_axis); + auto permutedInput = atInput.permute(dims); + auto permutedShape = permutedInput.sizes(); + auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous(); + + std::vector reverse_order(dims.size()); + for (auto i = 0; i < atInput.dim(); i++) { + reverse_order[dims[i]] = i; + } + + if (grad_weight && grad_bias) { + auto atGradInput = impl::aten::buildATen(grad_input).permute(dims).reshape({N, C, HxW, 1}); + + at::native_group_norm_backward_out( + atGradInput, atGradWeight, atGradBias, atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, true, true}); + atGradInput = atGradInput.reshape(permutedShape).permute(reverse_order); + impl::aten::updateATen2Tensor(ctx, atGradInput, grad_input); + } else { + auto atOuts = at::native_group_norm_backward( + atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, grad_weight != nullptr, grad_bias != nullptr}); + impl::aten::updateATen2Tensor(ctx, std::get<0>(atOuts).reshape(permutedShape).permute(reverse_order), grad_input); + impl::aten::updateATen2Tensor(ctx, std::get<1>(atOuts), grad_weight); + impl::aten::updateATen2Tensor(ctx, std::get<2>(atOuts), grad_bias); + } + + return diopiSuccess; +} + diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) { diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 6fbef871b..f7eb3d2f4 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -3607,6 +3607,13 @@ DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHan diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps, diopiSize_t reduced_axes, const int64_t channel_axis); +/** + * @brief Compute the backward pass of diopiGroupNorm(). + */ +DIOPI_API diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, + diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t weight, diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, + int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis); /** * @brief Compute the backward pass of diopiGroupNorm(). * @param[in] ctx Context environment.