From 3fc15504c2422f0ccda28da9b1ab40815eaf3e17 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:38:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20api=20adaptor=20No.233=E3=80=812?= =?UTF-8?q?34=E3=80=91=20Migrate=20paddle.trunc/frac=20into=20pir=20(#5867?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/tensor/math.py | 9 +++++++-- test/legacy_test/test_frac_api.py | 12 ++++++------ test/legacy_test/test_trunc_op.py | 12 ++++++++---- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9bb9c3e1fae23..5fb1a8bfc4528 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -20,6 +20,7 @@ import paddle from paddle import _C_ops, _legacy_C_ops +from paddle.base.libpaddle import DataType from paddle.common_ops_import import VarDesc, dygraph_utils from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only @@ -2026,7 +2027,7 @@ def trunc(input, name=None): [[ 0., 1.], [-0., -2.]]) ''' - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.trunc(input) else: inputs = {"X": input} @@ -6061,11 +6062,15 @@ def frac(x, name=None): paddle.int64, paddle.float32, paddle.float64, + DataType.INT32, + DataType.INT64, + DataType.FLOAT32, + DataType.FLOAT64, ]: raise TypeError( f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): y = _C_ops.trunc(x) return _C_ops.subtract(x, y) else: diff --git a/test/legacy_test/test_frac_api.py b/test/legacy_test/test_frac_api.py index 26bc74225e54b..1d401066cee2f 100644 --- a/test/legacy_test/test_frac_api.py +++ b/test/legacy_test/test_frac_api.py @@ -18,7 +18,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def ref_frac(x): @@ -40,15 +41,13 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_api_static(self): paddle.enable_static() - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): input = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) out = paddle.frac(input) - place = base.CPUPlace() - if base.core.is_compiled_with_cuda(): - place = base.CUDAPlace(0) - exe = base.Executor(place) + exe = base.Executor(self.place) (res,) = exe.run(feed={'X': self.x_np}, fetch_list=[out]) out_ref = ref_frac(self.x_np) np.testing.assert_allclose(out_ref, res, rtol=1e-05) @@ -101,6 +100,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_error(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_trunc_op.py b/test/legacy_test/test_trunc_op.py index e67c0d94b78bc..3f157fe879b05 100644 --- a/test/legacy_test/test_trunc_op.py +++ b/test/legacy_test/test_trunc_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -36,10 +37,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True) class TestFloatTruncOp(TestTruncOp): @@ -66,6 +67,7 @@ def setUp(self): self.x = np.random.random((20, 20)).astype(np.float32) self.place = paddle.CPUPlace() + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -114,11 +116,13 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5) + self.check_grad_with_place( + place, ['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True + ) if __name__ == "__main__":