From 4932b991012d38de1071ab290e4699d9112196f6 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Mon, 6 Nov 2023 13:28:14 +0000 Subject: [PATCH 1/9] prim gelu op sink --- .../decomp_interface_gen_op_list.py | 9 +++- paddle/fluid/primitive/composite/composite.h | 30 +++++++++++++ python/paddle/decomposition/rules.py | 2 +- test/prim/pir_prim/test_sink_decomp.py | 42 +++++++++++++++++++ 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 2068d0917e299..496e2bbbcb29e 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -19,8 +19,13 @@ # come into effect in generated file pd_op.h # manual decomp interface declare are located in manual_op.h -decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n"] +decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n", "gelu"] # come into effect in generated file op_decomp.cc # manual decomp interface implementation are located in manual_op_decomp.cc -decomp_interface_implementation_gen_op_list = ["mean", "squeeze", "add_n"] +decomp_interface_implementation_gen_op_list = [ + "mean", + "squeeze", + "add_n", + "gelu", +] diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 36e2745b086da..aee0d1be648c8 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -81,6 +81,36 @@ Tensor add_n_decomp(const std::vector& x) { return res; } +// template +// Tensor pow_decomp(const Tensor& x, const bool approximate) { + +// } + +template +Tensor gelu_decomp(const Tensor& x, bool approximate) { + auto org_dtype = x.dtype(); + + auto half = full(phi::vectorize(x.dims()), 0.5, org_dtype); + auto one = full(phi::vectorize(x.dims()), 1, org_dtype); + if (approximate) { + // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) + auto kAlpha = + full(phi::vectorize(x.dims()), M_2_SQRTPI * M_SQRT1_2, org_dtype); + auto GELU_CONSTANT = full(phi::vectorize(x.dims()), 0.044715, org_dtype); + auto tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)); + + auto res = x * half * (one + tanh_out); + return res; + } else { + // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + auto M_SQRT1_2T = full(phi::vectorize(x.dims()), M_SQRT1_2, org_dtype); + auto res = x * full(phi::vectorize(x.dims()), 0.5, org_dtype) * + (one + erf(x * M_SQRT1_2T)); + + return res; + } +} + } // namespace details } // namespace primitive diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index d64cba8d657ba..89280cf83b826 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -37,7 +37,7 @@ def mean(x, axis, keepdim): return res -@register_decomp('pd_op.gelu') +# @register_decomp('pd_op.gelu') def gelu(x, approximate): """define composite rule of op gelu""" M_SQRT1_2 = ( diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index d1a14987123ee..d29fba2995a83 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -17,6 +17,7 @@ import numpy as np import paddle +import paddle.nn.functional as F from paddle.autograd.ir_backward import grad from paddle.base import core from paddle.decomposition import decompose @@ -109,5 +110,46 @@ def test_has_decomp(self): self.assertEqual(core.has_decomp(op), True) +class TestGeluSink(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.prog = None + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + x.stop_gradient = False + sum_out = F.gelu(x) + print(type(sum_out)) + [new_out] = decompose(main_program, [sum_out]) + print(type(new_out)) + gradients = grad(new_out, x) + + exe = paddle.static.Executor() + [fwd, dx] = exe.run( + feed={'x': self.x}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert 'pd_op.gelu' not in whole_ops + else: + assert 'pd_op.gelu' in whole_ops + return fwd, dx + + def test_relu_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + if __name__ == "__main__": unittest.main() From b12762a3c0f65b04b0071731bec9eba5ac2c6694 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Tue, 7 Nov 2023 12:00:57 +0000 Subject: [PATCH 2/9] prim gelu op sink --- paddle/fluid/primitive/composite/composite.h | 11 +++-- test/prim/pir_prim/test_sink_decomp.py | 45 ++++++++++++++++++-- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index aee0d1be648c8..c5d4792ebdfd1 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -97,16 +97,19 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { auto kAlpha = full(phi::vectorize(x.dims()), M_2_SQRTPI * M_SQRT1_2, org_dtype); auto GELU_CONSTANT = full(phi::vectorize(x.dims()), 0.044715, org_dtype); - auto tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)); + auto x_pow3 = + elementwise_pow(x, full(phi::vectorize(x.dims()), 3, org_dtype)); + auto multi_out = multiply(x_pow3, GELU_CONSTANT); + auto tanh_out = tanh(multiply(kAlpha, x + multi_out)); - auto res = x * half * (one + tanh_out); + auto res = multiply(multiply(x, half), one + tanh_out); return res; } else { // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) auto M_SQRT1_2T = full(phi::vectorize(x.dims()), M_SQRT1_2, org_dtype); - auto res = x * full(phi::vectorize(x.dims()), 0.5, org_dtype) * - (one + erf(x * M_SQRT1_2T)); + auto erf_out = add(one, erf(multiply(x, M_SQRT1_2T))); + auto res = multiply(multiply(x, half), erf_out); return res; } } diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index d29fba2995a83..601ad0ee5f471 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -124,10 +124,8 @@ def base_net(self, flag=None): with paddle.static.program_guard(main_program): x = paddle.static.data('x', self.shape_x, dtype='float32') x.stop_gradient = False - sum_out = F.gelu(x) - print(type(sum_out)) + sum_out = F.gelu(x, approximate=True) [new_out] = decompose(main_program, [sum_out]) - print(type(new_out)) gradients = grad(new_out, x) exe = paddle.static.Executor() @@ -151,5 +149,46 @@ def test_relu_forward(self): np.testing.assert_equal(ref, actual) +class TestSumMax: + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.shape_y = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + self.prog = None + + def base_net(self, flag=None): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + sum_out = paddle.sum(x, y) # 静态图下 测试输出, 含有axis的算子 max() + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, (x, y)) + + exe = paddle.static.Executor() + [fwd, dx, dy] = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[new_out, gradients] + ) + + return fwd, dx, dy + + def test_relu_forward(self): + res_ref = self.base_net() + # res = self.base_net("forward") + print(res_ref) + print("-----------------------------------") + # print(res) + # for ref, actual in zip(res_ref, res): + # np.testing.assert_equal(ref, actual) + + if __name__ == "__main__": unittest.main() + + # summax = TestSumMax() + # summax.setUp() + # summax.test_relu_forward() From 242b62cfcd757e265c92bc8a902101086955b485 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Tue, 14 Nov 2023 14:32:31 +0000 Subject: [PATCH 3/9] update code --- paddle/fluid/primitive/composite/composite.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 417c0c82765d5..b90f604aa0d2d 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -121,7 +121,7 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { auto org_dtype = x.dtype(); auto half = full(phi::vectorize(x.dims()), 0.5, org_dtype); - auto one = full(phi::vectorize(x.dims()), 1, org_dtype); + auto one = full(phi::vectorize(x.dims()), 1.0, org_dtype); if (approximate) { // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) auto kAlpha = From 80962d664f2dec47c2a7d16621be51fc3f74c6b4 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Wed, 15 Nov 2023 06:47:49 +0000 Subject: [PATCH 4/9] pir gelu sink c++ --- paddle/fluid/primitive/composite/composite.h | 2 +- python/paddle/decomposition/rules.py | 26 -------------------- test/prim/pir_prim/test_sink_decomp.py | 2 -- 3 files changed, 1 insertion(+), 29 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index b90f604aa0d2d..623288d231dc0 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -129,7 +129,7 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { auto GELU_CONSTANT = full(phi::vectorize(x.dims()), 0.044715, org_dtype); auto x_pow3 = elementwise_pow(x, full(phi::vectorize(x.dims()), 3, org_dtype)); - auto tanh_out = tanh(kAlpha * (x + x * x * x * GELU_CONSTANT)); + auto tanh_out = tanh(kAlpha * (x + x_pow3 * GELU_CONSTANT)); auto res = x * half * (one + tanh_out); return res; diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 79641af2024ae..85da591991d89 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _pir_ops from .primitives import * # noqa: F403 from .register import register_decomp @@ -37,31 +36,6 @@ def mean(x, axis, keepdim): return res -# @register_decomp('pd_op.gelu') -def gelu(x, approximate): - """define composite rule of op gelu""" - M_SQRT1_2 = ( - 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc - ) - M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ - full_shape = x.shape if len(x.shape) == 0 else [1] - one = ones(full_shape, x.dtype) - half = full(full_shape, 0.5, x.dtype) - # Todo(cz): after symbol overload, add and multiply will be replaced by "+" and "*" - if approximate: - # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) - kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) - GELU_CONSTANT = full(full_shape, 0.044715, x.dtype) - tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) - out = x * half * (one + tanh_out) - return out - else: - # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) - out = x * cdf - return out - - @register_decomp('pd_op.sqrt') def sqrt(x): """ diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index 0b0e44ddcafe1..0e5394e32dc9f 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -184,14 +184,12 @@ def base_net(self, approximate=True, flag=None): def test_gelu_forward_true(self): res_ref = self.base_net(approximate=True) res = self.base_net(approximate=True, flag="forward") - print("---------------gelu_true-----------------") for ref, actual in zip(res_ref, res): np.testing.assert_equal(ref, actual) def test_gelu_approximate_false(self): res_ref = self.base_net(approximate=False) res = self.base_net(approximate=False, flag="forward") - print("---------------gelu_false-----------------") for ref, actual in zip(res_ref, res): np.testing.assert_equal(ref, actual) From 8943b48f6fef3fbc65250f987c47ca19aec7f18d Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Wed, 15 Nov 2023 06:49:11 +0000 Subject: [PATCH 5/9] pir gelu sink c++ --- paddle/fluid/primitive/composite/composite.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 623288d231dc0..aaecf1fa7180a 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -111,11 +111,6 @@ Tensor add_n_decomp(const std::vector& x) { return res; } -// template -// Tensor pow_decomp(const Tensor& x, const bool approximate) { - -// } - template Tensor gelu_decomp(const Tensor& x, bool approximate) { auto org_dtype = x.dtype(); From 10ec34ae5db61739b843a57c4b502f769d0ff5c6 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Thu, 16 Nov 2023 13:10:32 +0000 Subject: [PATCH 6/9] process accuracy --- test/legacy_test/test_activation_op.py | 4 +++- test/prim/pir_prim/test_sink_decomp.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index ace22b1388040..6fc52a753c2f8 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2691,6 +2691,8 @@ def setUp(self): self.public_python_api = paddle.nn.functional.gelu self.init_dtype() self.init_shape() + # Todo: Under float64, only this accuracy is currently supported, for further processing + self.fw_comp_rtol = 1e-7 approximate = False np.random.seed(2048) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) @@ -2713,7 +2715,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_pir=True, check_prim_pir=False) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index 0e5394e32dc9f..b55ac33c485bd 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -185,13 +185,13 @@ def test_gelu_forward_true(self): res_ref = self.base_net(approximate=True) res = self.base_net(approximate=True, flag="forward") for ref, actual in zip(res_ref, res): - np.testing.assert_equal(ref, actual) + np.testing.assert_allclose(ref, actual, rtol=1e-6) def test_gelu_approximate_false(self): res_ref = self.base_net(approximate=False) res = self.base_net(approximate=False, flag="forward") for ref, actual in zip(res_ref, res): - np.testing.assert_equal(ref, actual) + np.testing.assert_allclose(ref, actual, rtol=1e-6) if __name__ == "__main__": From f5bd2c8857346c80f338d6312189af467df57085 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Fri, 17 Nov 2023 03:28:45 +0000 Subject: [PATCH 7/9] adapter windows --- paddle/fluid/primitive/composite/composite.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 31e545b1783a7..13e73d289e9fc 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -186,6 +186,9 @@ std::tuple layer_norm_decomp( template Tensor gelu_decomp(const Tensor& x, bool approximate) { + using PM_2_SQRTPI = 1.12837916709551257390; /* 2/sqrt(pi) */ + using PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */ + auto org_dtype = x.dtype(); auto half = full(phi::vectorize(x.dims()), 0.5, org_dtype); @@ -193,7 +196,7 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { if (approximate) { // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) auto kAlpha = - full(phi::vectorize(x.dims()), M_2_SQRTPI * M_SQRT1_2, org_dtype); + full(phi::vectorize(x.dims()), PM_2_SQRTPI * PM_SQRT1_2, org_dtype); auto GELU_CONSTANT = full(phi::vectorize(x.dims()), 0.044715, org_dtype); auto x_pow3 = elementwise_pow(x, full(phi::vectorize(x.dims()), 3, org_dtype)); From 59e960d2d48274df10a9f99763a732c2f78ae972 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Fri, 17 Nov 2023 05:50:13 +0000 Subject: [PATCH 8/9] adapter windows --- paddle/fluid/primitive/composite/composite.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 13e73d289e9fc..2be5e1c91b71b 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -186,11 +186,10 @@ std::tuple layer_norm_decomp( template Tensor gelu_decomp(const Tensor& x, bool approximate) { - using PM_2_SQRTPI = 1.12837916709551257390; /* 2/sqrt(pi) */ - using PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */ + const double PM_2_SQRTPI = 1.12837916709551257390; /* 2/sqrt(pi) */ + const double PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */ auto org_dtype = x.dtype(); - auto half = full(phi::vectorize(x.dims()), 0.5, org_dtype); auto one = full(phi::vectorize(x.dims()), 1.0, org_dtype); if (approximate) { From 8310c3f4b5edf88db945970a48d46e0fd5b305f7 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Fri, 17 Nov 2023 06:35:24 +0000 Subject: [PATCH 9/9] adapter windows --- paddle/fluid/primitive/composite/composite.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 2be5e1c91b71b..00e2a6218c260 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -205,7 +205,7 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { return res; } else { // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - auto M_SQRT1_2T = full(phi::vectorize(x.dims()), M_SQRT1_2, org_dtype); + auto M_SQRT1_2T = full(phi::vectorize(x.dims()), PM_SQRT1_2, org_dtype); auto erf_out = one + erf(x * M_SQRT1_2T); auto res = x * half * erf_out;