diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index af490654b91b4..2a8b43fc09ab5 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -25,6 +25,7 @@ "relu", "softmax", "layer_norm", + "gelu", ] # come into effect in generated file op_decomp.cc @@ -36,6 +37,7 @@ "relu", "softmax", "layer_norm", + "gelu", ] diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 75b43ccafc5fc..9a352d74d4d3f 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -188,6 +188,35 @@ std::tuple layer_norm_decomp( return std::make_tuple(out, mean_, variance); } +template +Tensor gelu_decomp(const Tensor& x, bool approximate) { + const double PM_2_SQRTPI = 1.12837916709551257390; /* 2/sqrt(pi) */ + const double PM_SQRT1_2 = 0.70710678118654752440; /* 1/sqrt(2) */ + + auto org_dtype = x.dtype(); + auto half = full(phi::vectorize(x.dims()), 0.5, org_dtype); + auto one = full(phi::vectorize(x.dims()), 1.0, org_dtype); + if (approximate) { + // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) + auto kAlpha = + full(phi::vectorize(x.dims()), PM_2_SQRTPI * PM_SQRT1_2, org_dtype); + auto GELU_CONSTANT = full(phi::vectorize(x.dims()), 0.044715, org_dtype); + auto x_pow3 = + elementwise_pow(x, full(phi::vectorize(x.dims()), 3, org_dtype)); + auto tanh_out = tanh(kAlpha * (x + x_pow3 * GELU_CONSTANT)); + + auto res = x * half * (one + tanh_out); + return res; + } else { + // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + auto M_SQRT1_2T = full(phi::vectorize(x.dims()), PM_SQRT1_2, org_dtype); + auto erf_out = one + erf(x * M_SQRT1_2T); + + auto res = x * half * erf_out; + return res; + } +} + } // namespace details } // namespace primitive diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 12905d3101cac..bd8a58fc680a3 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _pir_ops from .primitives import * # noqa: F403 from .register import register_decomp @@ -37,31 +36,6 @@ def mean(x, axis, keepdim): return res -@register_decomp('pd_op.gelu') -def gelu(x, approximate): - """define composite rule of op gelu""" - M_SQRT1_2 = ( - 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc - ) - M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ - full_shape = x.shape if len(x.shape) == 0 else [1] - one = ones(full_shape, x.dtype) - half = full(full_shape, 0.5, x.dtype) - # Todo(cz): after symbol overload, add and multiply will be replaced by "+" and "*" - if approximate: - # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) - kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) - GELU_CONSTANT = full(full_shape, 0.044715, x.dtype) - tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) - out = x * half * (one + tanh_out) - return out - else: - # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) - out = x * cdf - return out - - @register_decomp('pd_op.sqrt') def sqrt(x): """ diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index d4d120dc2696e..40a11eec11ae7 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2691,6 +2691,8 @@ def setUp(self): self.public_python_api = paddle.nn.functional.gelu self.init_dtype() self.init_shape() + # Todo: Under float64, only this accuracy is currently supported, for further processing + self.fw_comp_rtol = 1e-7 approximate = False np.random.seed(2048) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) @@ -2713,7 +2715,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_pir=True, check_prim_pir=False) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index e9154eba60976..b55ac33c485bd 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -149,5 +149,50 @@ def test_relu_forward(self): np.testing.assert_equal(ref, actual) +class TestGeluSink(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.prog = None + + def base_net(self, approximate=True, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + x.stop_gradient = False + sum_out = F.gelu(x, approximate=approximate) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, x) + + exe = paddle.static.Executor() + [fwd, dx] = exe.run( + feed={'x': self.x}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert 'pd_op.gelu' not in whole_ops + else: + assert 'pd_op.gelu' in whole_ops + return fwd, dx + + def test_gelu_forward_true(self): + res_ref = self.base_net(approximate=True) + res = self.base_net(approximate=True, flag="forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_gelu_approximate_false(self): + res_ref = self.base_net(approximate=False) + res = self.base_net(approximate=False, flag="forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + if __name__ == "__main__": unittest.main()