Skip to content

Commit

Permalink
【PIR API adaptor No.42、43】 Migrate paddle.diff/conj into pir (PaddleP…
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored and SecretXV committed Nov 28, 2023
1 parent d9a1cbc commit e10d964
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
6 changes: 3 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4958,7 +4958,7 @@ def conj(x, name=None):
[(4-4j), (5-5j), (6-6j)]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.conj(x)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -5863,7 +5863,7 @@ def _diff_handler(x, n=1, axis=-1, prepend=None, append=None, name=None):
dtype = x.dtype
axes = [axis]
infer_flags = [1 for i in range(len(axes))]
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
has_pend = False
input_list = []
if prepend is not None and append is not None:
Expand Down Expand Up @@ -5900,7 +5900,7 @@ def _diff_handler(x, n=1, axis=-1, prepend=None, append=None, name=None):
new_input, axes, starts_2, ends_2, infer_flags, []
)

if x.dtype == paddle.bool:
if x.dtype == paddle.bool or x.dtype == core.DataType.BOOL:
return _C_ops.logical_xor(input_back, input_front)
else:
return _C_ops.subtract(input_back, input_front)
Expand Down
12 changes: 8 additions & 4 deletions test/legacy_test/test_conj_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import paddle.base.dygraph as dg
from paddle import static
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand All @@ -50,12 +51,13 @@ def init_input_output(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(
['X'],
'Out',
check_pir=True,
)


Expand Down Expand Up @@ -90,6 +92,7 @@ def test_conj_operator(self):
target = np.conj(input)
np.testing.assert_array_equal(result, target)

@test_with_pir_api
def test_conj_static_mode(self):
def init_input_output(dtype):
input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand(
Expand All @@ -110,7 +113,7 @@ def init_input_output(dtype):
out = paddle.conj(x)

exe = static.Executor(place)
out_value = exe.run(feed=input_dict, fetch_list=[out.name])
out_value = exe.run(feed=input_dict, fetch_list=[out])
np.testing.assert_array_equal(np_res, out_value[0])

def test_conj_api_real_number(self):
Expand All @@ -125,6 +128,7 @@ def test_conj_api_real_number(self):


class Testfp16ConjOp(unittest.TestCase):
@test_with_pir_api
def testfp16(self):
input_x = (
np.random.random((12, 14)) + 1j * np.random.random((12, 14))
Expand Down Expand Up @@ -170,11 +174,11 @@ def init_input_output(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions test/legacy_test/test_diff_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np

import paddle
from paddle import base
from paddle import base, static
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestDiffOp(unittest.TestCase):
Expand Down Expand Up @@ -77,13 +78,16 @@ def test_dygraph(self):
self.setUp()
self.func_dygraph()

@test_with_pir_api
def test_static(self):
paddle.enable_static()
places = [base.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for place in places:
with base.program_guard(base.Program(), base.Program()):
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
name="input", shape=self.input.shape, dtype=self.input.dtype
)
Expand All @@ -105,12 +109,12 @@ def test_static(self):
dtype=self.append.dtype,
)

exe = base.Executor(place)
exe = static.Executor(place)
out = paddle.diff(
x, n=self.n, axis=self.axis, prepend=prepend, append=append
)

fetches = exe.run(
base.default_main_program(),
feed={
"input": self.input,
"prepend": self.prepend,
Expand Down Expand Up @@ -238,6 +242,7 @@ def set_args(self):


class TestDiffOpFp16(TestDiffOp):
@test_with_pir_api
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand All @@ -258,7 +263,6 @@ def test_fp16_with_gpu(self):
append=self.append,
)
fetches = exe.run(
paddle.static.default_main_program(),
feed={
"input": input,
},
Expand Down

0 comments on commit e10d964

Please sign in to comment.