From 23e658687570777745f44e567dca61f3a23b402e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 8 Apr 2024 12:55:31 +0800 Subject: [PATCH] [Prim] Add exp_double_grad, log_double_grad, abs_triple_grad, min_grad (#63245) * add min_grad composite * support exp_double_grad, abs_triple_grad, log_double_grad composite OPs * add infermeta items * set_prim_all for TestLogDoubleGradCheck * change 2022 to 2024 in copyright header * remove 'log_grad' from vjp_interface_black_list --- .../generator/eager_gen.py | 3 + .../op_generator/vjp_interface_black_list.py | 6 + .../composite_backward_api.h | 52 ++++ .../composite_double_backward_api.h | 49 ++++ paddle/fluid/primitive/codegen/gen.py | 3 + paddle/phi/api/yaml/backward.yaml | 27 ++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + paddle/phi/api/yaml/op_compat.yaml | 4 +- test/legacy_test/gradient_checker.py | 2 +- test/legacy_test/test_activation_nn_grad.py | 2 + .../vjp/eager/test_comp_eager_min_grad.py | 71 +++++ test/prim/prim/vjp/test_comp_high_grad.py | 242 ++++++++++++++++++ 12 files changed, 459 insertions(+), 3 deletions(-) create mode 100644 test/prim/prim/vjp/eager/test_comp_eager_min_grad.py diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 253cca37dcbfac..c272e09a9579fd 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -75,6 +75,9 @@ "tanh_triple_grad", "minimum_double_grad", "maximum_double_grad", + "abs_triple_grad", + "exp_double_grad", + "log_double_grad", ] # white ops list whose kernel can automatically do type promotion. diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 4b2bbc3c549993..c0620d4dbdc43e 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -22,7 +22,13 @@ # remove this file and support Vjp methods # code gen. +# Operators which only has composite implementation should be added below. +# For example +# * `silu_double_grad` only has composite implementation, so `silu_grad` was added below. +# * `log_double_grad` has both composite and kernel implementation, so `log_grad` should not be added below. vjp_interface_black_list = [ 'silu_grad', + 'exp_grad', + 'abs_double_grad', ] diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 169d41d9763e52..0b08706d24a8d1 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1123,6 +1123,58 @@ void max_grad(const Tensor& x, set_output(x_grad_tmp, x_grad); } +template +void min_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + auto zero_tensor = full(common::vectorize(x.dims()), 0.0, x.dtype()); + std::vector x_dim = common::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 0 || x_dim_size == 1 || keepdim) { + auto out_grad_tmp = out_grad.expand(IntArray(x_dim)); + auto out_tmp = out.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } else { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); + auto out_ = reshape(out, out_grad_shape); + auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); + auto out_tmp = out_.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } + set_output(x_grad_tmp, x_grad); +} + template void assign_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 7e7ccfaf170b30..67feb640c9f7a4 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -784,5 +784,54 @@ void subtract_double_grad(const Tensor& y, } } +template +void exp_double_grad(const Tensor& out, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* out_grad, + Tensor* grad_out_grad) { + // dout = dout_old * ddx + if (out_grad) { + auto out_grad_tmp = grad_out * grad_x_grad; + set_output(out_grad_tmp, out_grad); + } + + // ddout = out * ddx + if (grad_out_grad) { + auto grad_out_grad_tmp = out * grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + +template +void log_double_grad(const Tensor& x, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* x_grad, + Tensor* grad_out_grad) { + // dx = -dout/x^2 * ddx + if (x_grad) { + auto x_grad_tmp = -grad_out / (x * x) * grad_x_grad; + set_output(x_grad_tmp, x_grad); + } + + // ddout = ddx / x + if (grad_out_grad) { + auto grad_out_grad_tmp = grad_x_grad / x; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + +template +void abs_triple_grad(const Tensor& x, + const Tensor& grad_out_grad_grad, + Tensor* grad_grad_x_grad) { + // dddx = sign(x) * dddout + if (grad_grad_x_grad) { + auto grad_grad_x_grad_tmp = sign(x) * grad_out_grad_grad; + set_output(grad_grad_x_grad_tmp, grad_grad_x_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index e4d0e50e608778..152773f7e95fe9 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -117,6 +117,9 @@ 'relu_grad', 'sigmoid_grad', 'silu_grad', + 'exp_grad', + 'log_grad', + 'abs_double_grad', 'softmax_grad', 'sqrt_grad', ] # custom vjp list of composite op diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 603b65c8b4c53a..49f6874982b139 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -13,6 +13,7 @@ kernel : func : abs_double_grad data_type : grad_x_grad + backward : abs_triple_grad - backward_op : abs_grad forward : abs (Tensor x) -> Tensor(out) @@ -27,6 +28,17 @@ composite : abs_grad(x, out_grad, x_grad) backward : abs_double_grad +- backward_op : abs_triple_grad + forward : abs_double_grad (Tensor x, Tensor grad_x_grad) -> Tensor(grad_out_grad) + args : (Tensor x, Tensor grad_out_grad_grad) + output : Tensor(grad_x_grad_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + data_transform : + support_trans_dtype : x + composite : abs_triple_grad(x, grad_out_grad_grad, grad_x_grad_grad) + - backward_op : acos_grad forward : acos (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -717,6 +729,16 @@ kernel : func : erfinv_grad +- backward_op : exp_double_grad + forward : exp_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor out, Tensor grad_out, Tensor grad_x_grad) + output : Tensor(out_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [out, out] + composite : exp_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) + inplace : (grad_x_grad -> grad_out_grad) + - backward_op : exp_grad forward : exp (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -728,6 +750,7 @@ kernel : func : exp_grad inplace : (out_grad -> x_grad) + backward : exp_double_grad composite : exp_grad(out, out_grad, x_grad) - backward_op : expand_as_grad @@ -1434,6 +1457,7 @@ param : [x, x] kernel : func : log_double_grad + composite : log_double_grad(x, grad_out, grad_x_grad, x_grad, grad_out_grad) inplace : (grad_x_grad -> grad_out_grad) - backward_op : log_grad @@ -2734,6 +2758,9 @@ forward: silu_grad (Tensor x, Tensor out, Tensor grad_out) -> Tensor(grad_x) args: (Tensor x, Tensor out, Tensor grad_out, Tensor grad_x_grad) output: Tensor(x_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, x] composite: silu_double_grad(x, out, grad_out, grad_x_grad, x_grad, grad_out_grad) - backward_op: unpool3d_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 8478e3caec98c8..b24b3a20c37eb0 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -411,6 +411,7 @@ param: [x] kernel : func : min_grad + composite : min_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) - backward_op : minimum_grad forward : minimum(Tensor x, Tensor y) -> Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f1db7cb97191b2..ef6d69c734d156 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -15,7 +15,7 @@ # attrs : [bool is_test = false] - op : abs - backward : abs_grad + backward : abs_grad, abs_double_grad, abs_triple_grad inputs : x : X outputs : @@ -1031,7 +1031,7 @@ out : Out - op : exp - backward : exp_grad + backward : exp_grad, exp_double_grad inputs : x : X outputs : diff --git a/test/legacy_test/gradient_checker.py b/test/legacy_test/gradient_checker.py index 00a561bcaa9603..210db283b979a2 100644 --- a/test/legacy_test/gradient_checker.py +++ b/test/legacy_test/gradient_checker.py @@ -446,7 +446,7 @@ def fail_test(msg): n = numerical[x_idx][y_idx] if not np.allclose(a, n, rtol, atol): msg = ( - f'Jacobian mismatch for output {y_idx} in y' + f'Jacobian mismatch for output {y_idx} in y ' f'with respect to input {x_idx} in x on {str(place)},\n' f'numerical:{n}\nanalytical:{a}\n' ) diff --git a/test/legacy_test/test_activation_nn_grad.py b/test/legacy_test/test_activation_nn_grad.py index 7bdcc6fcf30346..56daaac30a3c7f 100644 --- a/test/legacy_test/test_activation_nn_grad.py +++ b/test/legacy_test/test_activation_nn_grad.py @@ -467,12 +467,14 @@ def func(self, place): x_arr = np.random.uniform(0.1, 1, shape).astype(dtype) + core._set_prim_all_enabled(True) gradient_checker.double_grad_check( [x], y, x_init=x_arr, place=place, eps=eps ) gradient_checker.double_grad_check_for_dygraph( self.log_wrapper, [x], y, x_init=x_arr, place=place ) + core._set_prim_all_enabled(False) def test_grad(self): paddle.enable_static() diff --git a/test/prim/prim/vjp/eager/test_comp_eager_min_grad.py b/test/prim/prim/vjp/eager/test_comp_eager_min_grad.py new file mode 100644 index 00000000000000..48408b2d33b6c8 --- /dev/null +++ b/test/prim/prim/vjp/eager/test_comp_eager_min_grad.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.base import core + +core.set_prim_eager_enabled(True) + + +@param.parameterized_class( + ('primal', 'axis', 'cotangent', 'dtype'), + [ + (np.random.rand(16, 32), [1], np.random.rand(16, 32), np.float32), + (np.random.rand(16, 32), [0], np.random.rand(16, 32), np.float32), + ], +) +class TestMinGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal = cls.primal.astype(cls.dtype) + + def test_min_grad_comp(self): + def actual(primal0, axis): + core.set_prim_eager_enabled(True) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + x.stop_gradient = False + out = paddle.min(x, axis) + res = paddle.grad(out, [x], create_graph=False) + return res[0].numpy() + + def desired(primal0, axis): + core.set_prim_eager_enabled(False) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + x.stop_gradient = False + out = paddle.min(x, axis) + res = paddle.grad(out, [x], create_graph=False) + return res[0].numpy() + + dx = actual(self.primal, self.axis) + + ddx = desired(self.primal, self.axis) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + core.set_prim_eager_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index f1f2d02887a369..29c907296addd2 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -411,6 +411,248 @@ def test_high_grad(self): self.func_triple(p) +@param.parameterized_class( + ('shape1'), + [ + ([2],), + ([2, 3],), + ([2, 3, 4],), + ([2, 3, 3, 4],), + ], +) +class TestExpHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + + def exp_wrapper(self, x): + return paddle.exp(x[0]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.exp(x) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + + # exp double grad only has CompositeOpMaker, don't need set prim_flag + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.exp_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.exp(x) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.exp_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + +@param.parameterized_class( + ('shape1'), + [ + ([2],), + ([2, 3],), + ([2, 3, 4],), + ([2, 3, 3, 4],), + ], +) +class TestLogHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + + def log_wrapper(self, x): + return paddle.log(x[0]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.log(x) + x_arr = np.random.uniform(0.0, 10.0, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + + # log double grad only has CompositeOpMaker,don't need set prim_flag + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.log_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.log(x) + x_arr = np.random.uniform(0.0, 10.0, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.log_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + +@param.parameterized_class( + ('shape1'), + [ + ([2],), + ([2, 3],), + ([2, 3, 4],), + ([2, 3, 3, 4],), + ], +) +class TestAbsHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + + def abs_wrapper(self, x): + return paddle.abs(x[0]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.abs(x) + x_arr = np.random.uniform(0.0, 10.0, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.abs_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + x.stop_gradient = False + x.persistable = True + out = paddle.abs(x) + x_arr = np.random.uniform(0.0, 10.0, shape1).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.base import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x], y=out, x_init=[x_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.abs_wrapper, + [x], + y=out, + x_init=[x_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + @param.parameterized_class( ('shape1', 'shape2'), [