Skip to content

Commit

Permalink
[Prim] Add exp_double_grad, log_double_grad, abs_triple_grad, min_grad (
Browse files Browse the repository at this point in the history
PaddlePaddle#63245)

* add min_grad composite

* support exp_double_grad, abs_triple_grad, log_double_grad composite OPs

* add infermeta items

* set_prim_all for TestLogDoubleGradCheck

* change 2022 to 2024 in copyright header

* remove 'log_grad' from vjp_interface_black_list
  • Loading branch information
HydrogenSulfate authored and co63oc committed Apr 10, 2024
1 parent 066ae39 commit 23e6586
Show file tree
Hide file tree
Showing 12 changed files with 459 additions and 3 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
"tanh_triple_grad",
"minimum_double_grad",
"maximum_double_grad",
"abs_triple_grad",
"exp_double_grad",
"log_double_grad",
]

# white ops list whose kernel can automatically do type promotion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
# remove this file and support Vjp methods
# code gen.

# Operators which only has composite implementation should be added below.
# For example
# * `silu_double_grad` only has composite implementation, so `silu_grad` was added below.
# * `log_double_grad` has both composite and kernel implementation, so `log_grad` should not be added below.

vjp_interface_black_list = [
'silu_grad',
'exp_grad',
'abs_double_grad',
]
52 changes: 52 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,58 @@ void max_grad(const Tensor& x,
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void min_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
auto zero_tensor = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
auto out_tmp = out.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
auto out_ = reshape<T>(out, out_grad_shape);
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_tmp = out_.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
}
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void assign_grad(const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,5 +784,54 @@ void subtract_double_grad(const Tensor& y,
}
}

template <typename T>
void exp_double_grad(const Tensor& out,
const Tensor& grad_out,
const Tensor& grad_x_grad,
Tensor* out_grad,
Tensor* grad_out_grad) {
// dout = dout_old * ddx
if (out_grad) {
auto out_grad_tmp = grad_out * grad_x_grad;
set_output<T>(out_grad_tmp, out_grad);
}

// ddout = out * ddx
if (grad_out_grad) {
auto grad_out_grad_tmp = out * grad_x_grad;
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}

template <typename T>
void log_double_grad(const Tensor& x,
const Tensor& grad_out,
const Tensor& grad_x_grad,
Tensor* x_grad,
Tensor* grad_out_grad) {
// dx = -dout/x^2 * ddx
if (x_grad) {
auto x_grad_tmp = -grad_out / (x * x) * grad_x_grad;
set_output<T>(x_grad_tmp, x_grad);
}

// ddout = ddx / x
if (grad_out_grad) {
auto grad_out_grad_tmp = grad_x_grad / x;
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}

template <typename T>
void abs_triple_grad(const Tensor& x,
const Tensor& grad_out_grad_grad,
Tensor* grad_grad_x_grad) {
// dddx = sign(x) * dddout
if (grad_grad_x_grad) {
auto grad_grad_x_grad_tmp = sign<T>(x) * grad_out_grad_grad;
set_output<T>(grad_grad_x_grad_tmp, grad_grad_x_grad);
}
}

} // namespace prim
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
'relu_grad',
'sigmoid_grad',
'silu_grad',
'exp_grad',
'log_grad',
'abs_double_grad',
'softmax_grad',
'sqrt_grad',
] # custom vjp list of composite op
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
kernel :
func : abs_double_grad
data_type : grad_x_grad
backward : abs_triple_grad

- backward_op : abs_grad
forward : abs (Tensor x) -> Tensor(out)
Expand All @@ -27,6 +28,17 @@
composite : abs_grad(x, out_grad, x_grad)
backward : abs_double_grad

- backward_op : abs_triple_grad
forward : abs_double_grad (Tensor x, Tensor grad_x_grad) -> Tensor(grad_out_grad)
args : (Tensor x, Tensor grad_out_grad_grad)
output : Tensor(grad_x_grad_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
data_transform :
support_trans_dtype : x
composite : abs_triple_grad(x, grad_out_grad_grad, grad_x_grad_grad)

- backward_op : acos_grad
forward : acos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down Expand Up @@ -717,6 +729,16 @@
kernel :
func : erfinv_grad

- backward_op : exp_double_grad
forward : exp_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
composite : exp_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad)
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : exp_grad
forward : exp (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
Expand All @@ -728,6 +750,7 @@
kernel :
func : exp_grad
inplace : (out_grad -> x_grad)
backward : exp_double_grad
composite : exp_grad(out, out_grad, x_grad)

- backward_op : expand_as_grad
Expand Down Expand Up @@ -1434,6 +1457,7 @@
param : [x, x]
kernel :
func : log_double_grad
composite : log_double_grad(x, grad_out, grad_x_grad, x_grad, grad_out_grad)
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : log_grad
Expand Down Expand Up @@ -2734,6 +2758,9 @@
forward: silu_grad (Tensor x, Tensor out, Tensor grad_out) -> Tensor(grad_x)
args: (Tensor x, Tensor out, Tensor grad_out, Tensor grad_x_grad)
output: Tensor(x_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, x]
composite: silu_double_grad(x, out, grad_out, grad_x_grad, x_grad, grad_out_grad)

- backward_op: unpool3d_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@
param: [x]
kernel :
func : min_grad
composite : min_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad)

- backward_op : minimum_grad
forward : minimum(Tensor x, Tensor y) -> Tensor(out)
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# attrs : [bool is_test = false]

- op : abs
backward : abs_grad
backward : abs_grad, abs_double_grad, abs_triple_grad
inputs :
x : X
outputs :
Expand Down Expand Up @@ -1031,7 +1031,7 @@
out : Out

- op : exp
backward : exp_grad
backward : exp_grad, exp_double_grad
inputs :
x : X
outputs :
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/gradient_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def fail_test(msg):
n = numerical[x_idx][y_idx]
if not np.allclose(a, n, rtol, atol):
msg = (
f'Jacobian mismatch for output {y_idx} in y'
f'Jacobian mismatch for output {y_idx} in y '
f'with respect to input {x_idx} in x on {str(place)},\n'
f'numerical:{n}\nanalytical:{a}\n'
)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_activation_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,14 @@ def func(self, place):

x_arr = np.random.uniform(0.1, 1, shape).astype(dtype)

core._set_prim_all_enabled(True)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
gradient_checker.double_grad_check_for_dygraph(
self.log_wrapper, [x], y, x_init=x_arr, place=place
)
core._set_prim_all_enabled(False)

def test_grad(self):
paddle.enable_static()
Expand Down
71 changes: 71 additions & 0 deletions test/prim/prim/vjp/eager/test_comp_eager_min_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import parameterized as param

import paddle
from paddle.base import core

core.set_prim_eager_enabled(True)


@param.parameterized_class(
('primal', 'axis', 'cotangent', 'dtype'),
[
(np.random.rand(16, 32), [1], np.random.rand(16, 32), np.float32),
(np.random.rand(16, 32), [0], np.random.rand(16, 32), np.float32),
],
)
class TestMinGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)

def test_min_grad_comp(self):
def actual(primal0, axis):
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.min(x, axis)
res = paddle.grad(out, [x], create_graph=False)
return res[0].numpy()

def desired(primal0, axis):
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.min(x, axis)
res = paddle.grad(out, [x], create_graph=False)
return res[0].numpy()

dx = actual(self.primal, self.axis)

ddx = desired(self.primal, self.axis)

np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
core.set_prim_eager_enabled(False)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 23e6586

Please sign in to comment.