Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.42、43】 Migrate paddle.diff/conj into pir #58676

Merged
merged 9 commits into from
Nov 21, 2023
6 changes: 3 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4945,7 +4945,7 @@ def conj(x, name=None):
[(4-4j), (5-5j), (6-6j)]])

"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.conj(x)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -5799,7 +5799,7 @@ def _diff_handler(x, n=1, axis=-1, prepend=None, append=None, name=None):
dtype = x.dtype
axes = [axis]
infer_flags = [1 for i in range(len(axes))]
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
has_pend = False
input_list = []
if prepend is not None and append is not None:
Expand Down Expand Up @@ -5836,7 +5836,7 @@ def _diff_handler(x, n=1, axis=-1, prepend=None, append=None, name=None):
new_input, axes, starts_2, ends_2, infer_flags, []
)

if x.dtype == paddle.bool:
if x.dtype == paddle.bool or x.dtype == core.DataType.BOOL:
return _C_ops.logical_xor(input_back, input_front)
else:
return _C_ops.subtract(input_back, input_front)
Expand Down
12 changes: 8 additions & 4 deletions test/legacy_test/test_conj_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import paddle.base.dygraph as dg
from paddle import static
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand All @@ -50,12 +51,13 @@ def init_input_output(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(
['X'],
'Out',
check_pir=True,
)


Expand Down Expand Up @@ -90,6 +92,7 @@ def test_conj_operator(self):
target = np.conj(input)
np.testing.assert_array_equal(result, target)

@test_with_pir_api
def test_conj_static_mode(self):
def init_input_output(dtype):
input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand(
Expand All @@ -110,7 +113,7 @@ def init_input_output(dtype):
out = paddle.conj(x)

exe = static.Executor(place)
out_value = exe.run(feed=input_dict, fetch_list=[out.name])
out_value = exe.run(feed=input_dict, fetch_list=[out])
np.testing.assert_array_equal(np_res, out_value[0])

def test_conj_api_real_number(self):
Expand All @@ -125,6 +128,7 @@ def test_conj_api_real_number(self):


class Testfp16ConjOp(unittest.TestCase):
@test_with_pir_api
def testfp16(self):
input_x = (
np.random.random((12, 14)) + 1j * np.random.random((12, 14))
Expand Down Expand Up @@ -170,11 +174,11 @@ def init_input_output(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions test/legacy_test/test_diff_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np

import paddle
from paddle import base
from paddle import base, static
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestDiffOp(unittest.TestCase):
Expand Down Expand Up @@ -77,13 +78,16 @@ def test_dygraph(self):
self.setUp()
self.func_dygraph()

@test_with_pir_api
def test_static(self):
paddle.enable_static()
places = [base.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for place in places:
with base.program_guard(base.Program(), base.Program()):
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
name="input", shape=self.input.shape, dtype=self.input.dtype
)
Expand All @@ -105,12 +109,12 @@ def test_static(self):
dtype=self.append.dtype,
)

exe = base.Executor(place)
exe = static.Executor(place)
out = paddle.diff(
x, n=self.n, axis=self.axis, prepend=prepend, append=append
)

fetches = exe.run(
base.default_main_program(),
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
feed={
"input": self.input,
"prepend": self.prepend,
Expand Down Expand Up @@ -238,6 +242,7 @@ def set_args(self):


class TestDiffOpFp16(TestDiffOp):
@test_with_pir_api
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand All @@ -258,7 +263,6 @@ def test_fp16_with_gpu(self):
append=self.append,
)
fetches = exe.run(
paddle.static.default_main_program(),
feed={
"input": input,
},
Expand Down