diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index ad7ecbc4a1cd2..6f0acfaedbbbb 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -368,6 +368,28 @@ def __impl__(self, other_var): '__rtruediv__', _binary_creator_('__rtruediv__', paddle.tensor.divide, True, None), ), + ( + '__pow__', + _binary_creator_('__pow__', paddle.tensor.pow, False, None), + ), + ( + '__rpow__', + _binary_creator_('__rpow__', paddle.tensor.pow, True, None), + ), + ( + '__floordiv__', + _binary_creator_( + '__floordiv__', paddle.tensor.floor_divide, False, None + ), + ), + ( + '__mod__', + _binary_creator_('__mod__', paddle.tensor.remainder, False, None), + ), + ( + '__matmul__', + _binary_creator_('__matmul__', paddle.tensor.matmul, False, None), + ), ] global _already_patch_opresult diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index acdd3a35b57d0..36e92aeabb50d 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -941,7 +941,7 @@ def floor_divide(x, y, name=None): [2, 0, 2, 2]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.floor_divide(x, y) else: return _elementwise_op(LayerHelper('elementwise_floordiv', **locals())) diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 1a0254b66df52..9a7ab29ea4451 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -16,12 +16,153 @@ import unittest import warnings +import numpy as np + import paddle +from paddle import base paddle.enable_static() +paddle.device.set_device("cpu") + + +def new_program(): + # TODO(gouzil): Optimize program code + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + place = base.CPUPlace() + exe = base.Executor(place) + return ( + main_program, + exe, + paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ), + ) class TestMathOpPatchesPir(unittest.TestCase): + def test_pow(self): + # Calculate results in dynamic graphs + paddle.disable_static() + x_np = np.random.random([10, 1024]).astype('float32') + y_np = np.random.random([10, 1024]).astype('float32') + res_np_b = x_np**y_np + res_np_c = paddle.pow(paddle.to_tensor(x_np), 2) + # TODO(gouzil): solve paddle.fill_constant problem + # res_np_d = x_np.__pow__(2) + # res_np_e = x_np.__rpow__(2) + paddle.enable_static() + # Calculate results under pir + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='float32' + ) + b = x**y + c = x.pow(2) + # d = x.__pow__(2) + # e = x.__rpow__(2) + # TODO(gouzil): Why not use `paddle.static.default_main_program()`? + # Because different case do not isolate parameters (This is a known problem) + (b_np, c_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c], + ) + np.testing.assert_allclose(res_np_b, b_np, rtol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, rtol=1e-05) + # np.testing.assert_allclose(res_np_d, d_np, rtol=1e-05) + # np.testing.assert_allclose(res_np_e, e_np, rtol=1e-05) + + def test_mod(self): + paddle.disable_static() + x_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + y_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + res_np_b = x_np % y_np + res_np_c = paddle.mod(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__mod__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x % y + c = x.mod(y) + d = x.__mod__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_matmul(self): + paddle.disable_static() + x_np = np.random.uniform(-1, 1, [2, 3]).astype('float32') + y_np = np.random.uniform(-1, 1, [3, 5]).astype('float32') + res_np_b = x_np @ y_np # __matmul__ + res_np_c = paddle.matmul(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__matmul__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name='x', shape=[2, 3], dtype='float32') + y = paddle.static.data(name='y', shape=[3, 5], dtype='float32') + b = x @ y + c = x.matmul(y) + d = x.__matmul__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_floordiv(self): + paddle.disable_static() + x_np = np.full([10, 1024], 10, np.int64) + y_np = np.full([10, 1024], 2, np.int64) + res_np_b = x_np // y_np + res_np_c = paddle.floor_divide( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__floordiv__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x // y + c = x.floor_divide(y) + d = x.__floordiv__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + def test_item(self): with paddle.pir_utils.IrGuard(): x = paddle.static.data(name='x', shape=[3, 2, 1])