Skip to content

Commit

Permalink
【prim】New layer_norm grad (#51750)
Browse files Browse the repository at this point in the history
* Add flatten composite rule

* get the right xshape and pass func test

* add cinn unit test

* Remove cinn test, wait for it to be added after repair

* add comp test to test_flatten_contiguous_range_op.py

* remove func test on composite_ops

* Add comments to maybe_wrap_dim func

* remove commented code

* fix the problem with 0D tensor case

* add flatten split rule comment

* fix syntax issues

* block flatten on resnet_prim_cinn

* init change

* tmp commit

* add layer_norm InferMeta check

* cast type modify

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* recover

* big tol

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* Cxx prim custom vjp (#8)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [Prim] enable whitelist and blacklist for custom_vjp

* debug log

* clear log

* fix

* nothing

* less memory

* recover utils

* fix

* modify threshold value

* skip layer_norm for test_bert

* back to bert success state

* add epsion

* delete unnecessary compute

* modify amp dtype

* modify * order

* delete sqrt check and fp16

---------

Co-authored-by: xuyongsheng <xuyongsheng@baidu.com>
Co-authored-by: xysheng-baidu <121540080+xysheng-baidu@users.noreply.github.com>
Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
Co-authored-by: xiongkun <807377414@qq.com>
  • Loading branch information
7 people committed Mar 20, 2023
1 parent b81188f commit 802a81d
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 49 deletions.
73 changes: 71 additions & 2 deletions paddle/fluid/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ limitations under the License. */
#include <memory>
#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -253,15 +259,78 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
"Bias");

class LayerNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor mean = this->GetSingleForwardOutput("Mean");
paddle::Tensor var = this->GetSingleForwardOutput("Variance");
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::optional<paddle::Tensor> scale =
this->GetOptionalSingleForwardInput("Scale");
paddle::optional<paddle::Tensor> bias =
this->GetOptionalSingleForwardInput("Bias");

// get Attrs
auto epsilon = this->Attr<float>("epsilon");
auto begin_norm_axis = this->Attr<int>("begin_norm_axis");

// get outputs
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");

auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
std::string dscale_name = this->GetOutputName(scale_grad);
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
std::string dbias_name = this->GetOutputName(bias_grad);

VLOG(6) << "Runing layer_norm_grad composite func";
prim::layer_norm_grad<prim::DescTensor>(x,
scale,
bias,
mean,
var,
y_grad,
epsilon,
begin_norm_axis,
dx_ptr,
dscale_ptr,
dbias_ptr);

this->RecoverOutputName(x_grad, dx_name);
this->RecoverOutputName(scale_grad, dscale_name);
this->RecoverOutputName(bias_grad, dbias_name);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(layer_norm,
LayerNormInferShapeFunctor,
PD_INFER_META(phi::LayerNormInferMeta));

REGISTER_OPERATOR(layer_norm,
ops::LayerNormOp,
ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>,
ops::LayerNormCompositeGradOpMaker,
LayerNormInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(layer_norm_grad,
LayerNormGradInferShapeFunctor,
PD_INFER_META(phi::LayerNormGradInferMeta));

REGISTER_OPERATOR(layer_norm_grad,
ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer);
ops::LayerNormGradNoNeedBufferVarInferer,
LayerNormGradInferShapeFunctor);
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- tile
- transpose
- pad
- sqrt
- cumsum
- put_along_axis
- greater_than
Expand Down
95 changes: 95 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,101 @@ void slice_grad(const Tensor& input,
}
}

template <typename T>
void layer_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
const Tensor& mean,
const Tensor& variance,
const Tensor& out_grad,
float epsilon,
int begin_norm_axis,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
auto x_dims = x.dims();
auto shape_1 = 1; // front part
auto shape_2 = 1; // back part
for (int i = 0; i < begin_norm_axis; ++i) {
shape_1 *= x_dims[i];
}
for (int i = begin_norm_axis; i < x.dims().size(); ++i) {
shape_2 *= x_dims[i];
}
auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();

// cast dtype to float32 if dtype =float16
Tensor x_cast = x;
Tensor out_grad_cast = out_grad;
Tensor scale_cast;
if (scale_ptr) {
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
}
if (x.dtype() == phi::DataType::FLOAT16) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
if (scale_ptr) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
}
}

x_cast = reshape<T>(x_cast, std::vector<int64_t>({shape_1, shape_2}));
out_grad_cast =
reshape<T>(out_grad_cast, std::vector<int64_t>({shape_1, shape_2}));
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
}
}
auto x_sub_mean = x_cast - mean_;
auto tmp = (1.0 / (variance_ + epsilon));
auto sqrt_var_1 = sqrt<T>(tmp);
if (scale_grad) {
if (scale_ptr) {
auto scale_grad_tmp =
(x_sub_mean * sqrt_var_1 * out_grad_cast)
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
set_output<T>(scale_grad_tmp, scale_grad);
} else {
scale_grad = nullptr;
}
}

if (x_grad) {
if (!scale_ptr) {
scale_cast =
full<T>(std::vector<int64_t>({1, shape_2}), 1.0, x_cast.dtype());
}
auto out_grad_scale = out_grad_cast * scale_cast;
auto dx_end = (sqrt_var_1 * out_grad_scale);
auto d_mean_0 =
(-dx_end).sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_mean = (1.0 / shape_2) * d_mean_0;
auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_std_2 = (1.0 / shape_2) * sqrt_var_1;
d_std_2 = reshape<T>(d_std_2, std::vector<int64_t>({shape_1, 1}));
d_std_2 = d_std_2 * x_sub_mean;
auto d_std = d_std_1 * d_std_2;

auto x_grad_tmp = dx_end + d_mean + d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
}
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void cumsum_grad(const Tensor& x,
const Tensor& out_grad,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@
kernel :
func : layer_norm_grad
data_type : out_grad
composite : layer_norm_grad(x, scale, bias, mean,varience, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad)
no_need_buffer : bias
optional : scale, bias

Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,23 @@ void LayerNormInferMeta(const MetaTensor& x,
right));
}

phi::DataType x_dtype = x.dtype();
out->set_dims(x_dim);
out->set_dtype(x_dtype);
out->share_lod(x);

phi::DataType param_type =
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
? phi::DataType::FLOAT32
: x_dtype;
if (mean) {
mean->set_dims({left});
mean->set_dtype(param_type);
}
if (variance) {
variance->set_dims({left});
variance->set_dtype(param_type);
}
out->share_lod(x);
}

void LayerNormGradInferMeta(const MetaTensor& x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ def test_train(self):

def test_train_composite(self):
core._set_prim_backward_enabled(True)
# core._add_skip_comp_ops("layer_norm")
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
core._set_prim_backward_enabled(False)
# core._add_skip_comp_ops("layer_norm")
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
Expand Down
Loading

0 comments on commit 802a81d

Please sign in to comment.