From f124c86fdddee953ad6300083fba4771da8df56b Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 15 Mar 2023 17:58:30 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Custom=20softmax=20grad?= =?UTF-8?q?=20(#51474)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * Cxx prim custom vjp (#8) * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly * [Prim] enable whitelist and blacklist for custom_vjp * support softmax grad * remove additional code * add test back --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly Co-authored-by: xiongkun <807377414@qq.com> --- paddle/fluid/operators/softmax_op.cc | 21 ++ .../composite_backward_api.h | 29 +++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../test_composite_softmax_custom_vjp.py | 200 ++++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_custom_vjp.py diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 99383363e65eb..ca523f084c2d9 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" @@ -156,6 +159,23 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker { } }; +class SoftmaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor out = this->GetSingleForwardOutput("Out"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor dx = this->GetSingleInputGrad("X"); + auto* dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + int axis = static_cast(this->Attr("axis")); + VLOG(6) << "Runing softmax_grad composite func"; + prim::softmax_grad(out, out_grad, axis, dx_ptr); + this->RecoverOutputName(dx, dx_name); + } +}; + DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"}); } // namespace operators @@ -172,6 +192,7 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker, ops::SoftmaxOpGradMaker, + ops::SoftmaxCompositeGradOpMaker, ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor); DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index bcd6f459b8dc3..7afd190069d53 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -30,6 +30,35 @@ using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h +template +void softmax_grad(const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* x_grad) { + if (x_grad) { + if (out_grad.dims().size() > 0) { + if (axis >= 0) { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = new_out_grad - + out * sum(new_out_grad, {axis}, out.dtype(), true); + set_output(tmp_x_grad, x_grad); + } else { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = + new_out_grad - out * sum(new_out_grad, + {out.dims().size() + axis}, + out.dtype(), + true); + set_output(tmp_x_grad, x_grad); + } + } else { + set_output( + full(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()), + x_grad); + } + } +} + template void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 53cf1945f2f1e..8eb3095933a45 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1144,6 +1144,7 @@ param : [out] kernel : func : softmax_grad + composite : softmax_grad(out, out_grad, axis, x_grad) - backward_op : spectral_norm_grad forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_custom_vjp.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_custom_vjp.py new file mode 100644 index 0000000000000..9a7c1c3619e6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_custom_vjp.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.axis = -1 + self.shape = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_axis(self, axis) -> None: + self.axis = axis + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype) + + +def expect_grad(inputs): + paddle.disable_static() + inputs.stop_gradient = False + res = fn(inputs) + + gradients = paddle.grad(res, inputs) + return gradients + + +class TestCompositeSoftmax(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.shapes = [[2, 3, 4], [2, 3]] + self.axes = [-1, 0, 1] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that softmax in original block + self.assertTrue('softmax' in fwd_ops) + + paddle.incubate.autograd.primapi.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that softmax is splitted into small ops + self.assertTrue('softmax' not in fwd_ops_new) + + z = paddle.static.gradients([y], x) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that softmax_grad not in grad block + + self.assertTrue('softmax_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + def test_backward(self): + for i in self.axes: + for j in self.dtypes: + for t in self.shapes: + attrs.set_axis(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +class TestCompositeSoftmaxPrimBackward(unittest.TestCase): + "test composite softmax and prim backward" + + def setUp(self): + core._set_prim_backward_enabled(True) + self.dtypes = ["float32", "float64"] + self.shapes = [[], [2, 3, 4], [2, 3]] + self.axes = [-1, 0, 1] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + z = paddle.static.gradients([y], x) + paddle.incubate.autograd.primapi.to_prim(blocks) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_backward(self): + if not attrs.shape and attrs.axis not in [-1, 0]: + # op softmax does not support both case + return + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_rtol("prim_backward"), + ) + + def test_prim_backward(self): + for i in self.axes: + for j in self.dtypes: + for t in self.shapes: + attrs.set_axis(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main()