Skip to content

Commit

Permalink
【PIR api adaptor No.233、234】 Migrate paddle.trunc/frac into pir (#58675)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored Nov 6, 2023
1 parent 38e314e commit 3fc1550
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
9 changes: 7 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc, dygraph_utils
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

Expand Down Expand Up @@ -2026,7 +2027,7 @@ def trunc(input, name=None):
[[ 0., 1.],
[-0., -2.]])
'''
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.trunc(input)
else:
inputs = {"X": input}
Expand Down Expand Up @@ -6061,11 +6062,15 @@ def frac(x, name=None):
paddle.int64,
paddle.float32,
paddle.float64,
DataType.INT32,
DataType.INT64,
DataType.FLOAT32,
DataType.FLOAT64,
]:
raise TypeError(
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}"
)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
y = _C_ops.trunc(x)
return _C_ops.subtract(x, y)
else:
Expand Down
12 changes: 6 additions & 6 deletions test/legacy_test/test_frac_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def ref_frac(x):
Expand All @@ -40,15 +41,13 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_api_static(self):
paddle.enable_static()
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
input = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out = paddle.frac(input)
place = base.CPUPlace()
if base.core.is_compiled_with_cuda():
place = base.CUDAPlace(0)
exe = base.Executor(place)
exe = base.Executor(self.place)
(res,) = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_frac(self.x_np)
np.testing.assert_allclose(out_ref, res, rtol=1e-05)
Expand Down Expand Up @@ -101,6 +100,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_error(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down
12 changes: 8 additions & 4 deletions test/legacy_test/test_trunc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand All @@ -36,10 +37,10 @@ def init_dtype_type(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5)
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True)


class TestFloatTruncOp(TestTruncOp):
Expand All @@ -66,6 +67,7 @@ def setUp(self):
self.x = np.random.random((20, 20)).astype(np.float32)
self.place = paddle.CPUPlace()

@test_with_pir_api
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -114,11 +116,13 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5)
self.check_grad_with_place(
place, ['X'], 'Out', numeric_grad_delta=1e-5, check_pir=True
)


if __name__ == "__main__":
Expand Down

0 comments on commit 3fc1550

Please sign in to comment.