Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【prim】New layer_norm grad #51750

Merged
merged 68 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
ba39d24
Add flatten composite rule
xysheng-baidu Feb 20, 2023
a80f705
get the right xshape and pass func test
xysheng-baidu Feb 21, 2023
4943544
add cinn unit test
xysheng-baidu Feb 21, 2023
d9ffe5e
Remove cinn test, wait for it to be added after repair
xysheng-baidu Feb 21, 2023
d3f8af7
add comp test to test_flatten_contiguous_range_op.py
xysheng-baidu Feb 22, 2023
83deca7
Merge branch 'PaddlePaddle:develop' into composite_rule_flatten
xysheng-baidu Feb 22, 2023
4e43a73
remove func test on composite_ops
xysheng-baidu Feb 22, 2023
3c906bb
Add comments to maybe_wrap_dim func
xysheng-baidu Feb 22, 2023
c569f59
remove commented code
xysheng-baidu Feb 22, 2023
48547ab
fix the problem with 0D tensor case
xysheng-baidu Feb 22, 2023
d384698
add flatten split rule comment
xysheng-baidu Feb 22, 2023
c11d340
fix conflicts
xysheng-baidu Feb 22, 2023
28af93d
Merge branch 'PaddlePaddle:develop' into composite_rule_flatten
xysheng-baidu Feb 23, 2023
e09e5f1
fix syntax issues
xysheng-baidu Feb 23, 2023
bb4c836
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xysheng-baidu Feb 24, 2023
70d7453
block flatten on resnet_prim_cinn
xysheng-baidu Feb 24, 2023
593d1a2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Feb 27, 2023
55e9f2a
init change
xiaoguoguo626807 Feb 28, 2023
9697474
fix_conflict
xiaoguoguo626807 Feb 28, 2023
3e4e1cf
tmp commit
xiaoguoguo626807 Mar 2, 2023
e4aaa86
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Mar 2, 2023
1b4f8a1
add layer_norm InferMeta check
xiaoguoguo626807 Mar 2, 2023
3d48002
cast type modify
xiaoguoguo626807 Mar 2, 2023
81fbf68
[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)
Aurelius84 Feb 24, 2023
6f71cd9
[prim] enable dygraph_to_static to support custom_vjp
cxxly Feb 24, 2023
99e7bd8
Pr 50885 (#7)
2742195759 Feb 28, 2023
28f8d74
[prim] enable dygraph_to_static to support custom_vjp
cxxly Feb 24, 2023
342abb2
fix cast prim and vjp dtype mapping error bug
cxxly Mar 2, 2023
26fc165
recover
xiaoguoguo626807 Mar 2, 2023
3ad5919
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Mar 2, 2023
ef1a424
fix_conflict
xiaoguoguo626807 Mar 2, 2023
120867e
big tol
xiaoguoguo626807 Mar 3, 2023
42c64de
[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)
Aurelius84 Feb 24, 2023
baa82d4
[prim] enable dygraph_to_static to support custom_vjp
cxxly Feb 24, 2023
9e584bd
Pr 50885 (#7)
2742195759 Feb 28, 2023
8d50dd9
[prim] enable dygraph_to_static to support custom_vjp
cxxly Feb 24, 2023
80c4ee9
fix cast prim and vjp dtype mapping error bug
cxxly Mar 2, 2023
d44eb19
Cxx prim custom vjp (#8)
2742195759 Mar 2, 2023
74fd37a
[Prim] enable whitelist and blacklist for custom_vjp
cxxly Mar 5, 2023
f037d4b
fix_conflict
xiaoguoguo626807 Mar 6, 2023
907647d
fix_conflict
xiaoguoguo626807 Mar 6, 2023
ccaa51f
fix_conflict with 50885
xiaoguoguo626807 Mar 6, 2023
97066c3
debug log
xiaoguoguo626807 Mar 7, 2023
a5c60a4
clear log
xiaoguoguo626807 Mar 7, 2023
1e79731
Merge branch 'develop' into layer_norm_grad
xiaoguoguo626807 Mar 8, 2023
3ca656c
fix_conflict
xiaoguoguo626807 Mar 8, 2023
ccce35a
fix
xiaoguoguo626807 Mar 8, 2023
f6e017c
Merge branch 'layer_norm_grad' of https://github.com/xiaoguoguo626807…
xiaoguoguo626807 Mar 8, 2023
ade18e4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Mar 9, 2023
6f5e584
nothing
xiaoguoguo626807 Mar 10, 2023
18cd8c0
conflict
xiaoguoguo626807 Mar 10, 2023
3b1f6b4
less memory
xiaoguoguo626807 Mar 10, 2023
abab32b
recover utils
xiaoguoguo626807 Mar 13, 2023
e05dd44
Merge branch 'layer_norm_grad' of https://github.com/xiaoguoguo626807…
xiaoguoguo626807 Mar 13, 2023
2d55f26
conflict
xiaoguoguo626807 Mar 14, 2023
c7166ea
fix
xiaoguoguo626807 Mar 14, 2023
983931d
Merge branch 'develop' into layer_norm_grad
xiaoguoguo626807 Mar 15, 2023
00b7f54
modify threshold value
xiaoguoguo626807 Mar 15, 2023
b2a16b6
Merge branch 'layer_norm_grad' of https://github.com/xiaoguoguo626807…
xiaoguoguo626807 Mar 15, 2023
595c6fd
skip layer_norm for test_bert
xiaoguoguo626807 Mar 16, 2023
7c968ab
Merge commit 'refs/pull/51135/head' of https://github.com/PaddlePaddl…
xiaoguoguo626807 Mar 16, 2023
9c1d5f2
back to bert success state
xiaoguoguo626807 Mar 16, 2023
9065511
add epsion
xiaoguoguo626807 Mar 16, 2023
78ec3dc
delete unnecessary compute
xiaoguoguo626807 Mar 16, 2023
afbf4d2
modify amp dtype
xiaoguoguo626807 Mar 17, 2023
9ceb78d
modify * order
xiaoguoguo626807 Mar 17, 2023
dc72787
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Mar 20, 2023
f43e433
delete sqrt check and fp16
xiaoguoguo626807 Mar 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions paddle/fluid/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ limitations under the License. */
#include <memory>
#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -253,15 +259,78 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
"Bias");

class LayerNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor mean = this->GetSingleForwardOutput("Mean");
paddle::Tensor var = this->GetSingleForwardOutput("Variance");
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::optional<paddle::Tensor> scale =
this->GetOptionalSingleForwardInput("Scale");
paddle::optional<paddle::Tensor> bias =
this->GetOptionalSingleForwardInput("Bias");

// get Attrs
auto epsilon = this->Attr<float>("epsilon");
auto begin_norm_axis = this->Attr<int>("begin_norm_axis");

// get outputs
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");

auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
std::string dscale_name = this->GetOutputName(scale_grad);
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
std::string dbias_name = this->GetOutputName(bias_grad);

VLOG(6) << "Runing layer_norm_grad composite func";
prim::layer_norm_grad<prim::DescTensor>(x,
scale,
bias,
mean,
var,
y_grad,
epsilon,
begin_norm_axis,
dx_ptr,
dscale_ptr,
dbias_ptr);

this->RecoverOutputName(x_grad, dx_name);
this->RecoverOutputName(scale_grad, dscale_name);
this->RecoverOutputName(bias_grad, dbias_name);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(layer_norm,
LayerNormInferShapeFunctor,
PD_INFER_META(phi::LayerNormInferMeta));

REGISTER_OPERATOR(layer_norm,
ops::LayerNormOp,
ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>,
ops::LayerNormCompositeGradOpMaker,
LayerNormInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(layer_norm_grad,
LayerNormGradInferShapeFunctor,
PD_INFER_META(phi::LayerNormGradInferMeta));

REGISTER_OPERATOR(layer_norm_grad,
ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer);
ops::LayerNormGradNoNeedBufferVarInferer,
LayerNormGradInferShapeFunctor);
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- tile
- transpose
- pad
- sqrt
- cumsum
- put_along_axis
- greater_than
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,101 @@ void slice_grad(const Tensor& input,
}
}

template <typename T>
void layer_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
const Tensor& mean,
const Tensor& variance,
const Tensor& out_grad,
float epsilon,
int begin_norm_axis,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
auto x_dims = x.dims();
auto shape_1 = 1; // front part
auto shape_2 = 1; // back part
for (int i = 0; i < begin_norm_axis; ++i) {
shape_1 *= x_dims[i];
}
for (int i = begin_norm_axis; i < x.dims().size(); ++i) {
shape_2 *= x_dims[i];
}
auto scale_ptr = scale.get_ptr();
auto bias_ptr = bias.get_ptr();

// cast dtype to float32 if dtype =float16
Tensor x_cast = x;
Tensor out_grad_cast = out_grad;
Tensor scale_cast;
if (scale_ptr) {
scale_cast = reshape<T>(*scale_ptr, std::vector<int64_t>({1, shape_2}));
}
if (x.dtype() == phi::DataType::FLOAT16) {
x_cast = cast<T>(x, phi::DataType::FLOAT32);
out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
if (scale_ptr) {
scale_cast = cast<T>(scale_cast, phi::DataType::FLOAT32);
}
}

x_cast = reshape<T>(x_cast, std::vector<int64_t>({shape_1, shape_2}));
out_grad_cast =
reshape<T>(out_grad_cast, std::vector<int64_t>({shape_1, shape_2}));
auto mean_ = reshape<T>(mean, std::vector<int64_t>({shape_1, 1}));
auto variance_ = reshape<T>(variance, std::vector<int64_t>({shape_1, 1}));
if (bias_grad) {
if (bias_ptr) {
auto bias_grad_tmp =
out_grad_cast.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
bias_grad_tmp = reshape<T>(bias_grad_tmp, bias_ptr->shape());
set_output<T>(bias_grad_tmp, bias_grad);
} else {
bias_grad = nullptr;
}
}
auto x_sub_mean = x_cast - mean_;
auto tmp = (1.0 / (variance_ + epsilon));
auto sqrt_var_1 = sqrt<T>(tmp);
if (scale_grad) {
if (scale_ptr) {
auto scale_grad_tmp =
(x_sub_mean * sqrt_var_1 * out_grad_cast)
.sum(std::vector<int64_t>({0}), x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, scale_ptr->shape());
set_output<T>(scale_grad_tmp, scale_grad);
} else {
scale_grad = nullptr;
}
}

if (x_grad) {
if (!scale_ptr) {
scale_cast =
full<T>(std::vector<int64_t>({1, shape_2}), 1.0, x_cast.dtype());
}
auto out_grad_scale = out_grad_cast * scale_cast;
auto dx_end = (sqrt_var_1 * out_grad_scale);
auto d_mean_0 =
(-dx_end).sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_mean = (1.0 / shape_2) * d_mean_0;
auto d_std_1 = (-tmp * x_sub_mean * out_grad_scale)
.sum(std::vector<int64_t>({1}), x_cast.dtype(), true);
auto d_std_2 = (1.0 / shape_2) * sqrt_var_1;
d_std_2 = reshape<T>(d_std_2, std::vector<int64_t>({shape_1, 1}));
d_std_2 = d_std_2 * x_sub_mean;
auto d_std = d_std_1 * d_std_2;

auto x_grad_tmp = dx_end + d_mean + d_std;
x_grad_tmp = reshape<T>(x_grad_tmp, phi::vectorize(x.dims()));
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_tmp = cast<T>(x_grad_tmp, x.dtype());
}
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void cumsum_grad(const Tensor& x,
const Tensor& out_grad,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@
kernel :
func : layer_norm_grad
data_type : out_grad
composite : layer_norm_grad(x, scale, bias, mean,varience, out_grad, epsilon, begin_norm_axis, x_grad, scale_grad, bias_grad)
no_need_buffer : bias
optional : scale, bias

Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,23 @@ void LayerNormInferMeta(const MetaTensor& x,
right));
}

phi::DataType x_dtype = x.dtype();
out->set_dims(x_dim);
out->set_dtype(x_dtype);
out->share_lod(x);

phi::DataType param_type =
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
? phi::DataType::FLOAT32
: x_dtype;
if (mean) {
mean->set_dims({left});
mean->set_dtype(param_type);
}
if (variance) {
variance->set_dims({left});
variance->set_dtype(param_type);
}
out->share_lod(x);
}

void LayerNormGradInferMeta(const MetaTensor& x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ def test_train(self):

def test_train_composite(self):
core._set_prim_backward_enabled(True)
# core._add_skip_comp_ops("layer_norm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify in next PR

static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
core._set_prim_backward_enabled(False)
# core._add_skip_comp_ops("layer_norm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
Expand Down
Loading