Skip to content

Commit

Permalink
【PIR API adaptor No.156、159、180、189】Migrate some ops into pir (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored and SecretXV committed Nov 28, 2023
1 parent 87de81f commit db91d66
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 17 deletions.
8 changes: 4 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def multiplex(inputs, index, name=None):
[3., 4.]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.multiplex(inputs, index)
else:
helper = LayerHelper('multiplex', **locals())
Expand Down Expand Up @@ -2406,7 +2406,7 @@ def renorm(x, p, axis, max_norm):
)
)
axis = axis + len(input_shape)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.renorm(x, p, axis, max_norm)
return out
else:
Expand Down Expand Up @@ -5420,7 +5420,7 @@ def rad2deg(x, name=None):
57.29578018)
"""
rad2deg_scale = 180 / np.pi
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if convert_dtype(x.dtype) in ['int32', 'int64']:
x = cast(x, dtype="float32")
return _C_ops.scale(x, rad2deg_scale, 0.0, True)
Expand Down Expand Up @@ -6630,7 +6630,7 @@ def nextafter(x, y, name=None):
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.00000012, 1.99999988])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.nextafter(x, y)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter')
Expand Down
16 changes: 11 additions & 5 deletions test/legacy_test/test_multiplex_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,25 @@ def setUp(self):
self.outputs = {'Out': output}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out')
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out', check_pir=True)

def test_check_grad_ignore_x1(self):
self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'))
self.check_grad(
['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'), check_pir=True
)

def test_check_grad_ignore_x1_x2(self):
self.check_grad(['x3', 'x4'], 'Out', no_grad_set={'x1', 'x2'})
self.check_grad(
['x3', 'x4'], 'Out', no_grad_set={'x1', 'x2'}, check_pir=True
)

def test_check_grad_ignore_x3(self):
self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'))
self.check_grad(
['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'), check_pir=True
)


class TestMultiplexOpError(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_nextafter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def ref_nextafter(x, y):
Expand All @@ -39,6 +40,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -103,7 +105,7 @@ def setUp(self):
self.outputs = {'out': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def init_dtype(self):
self.dtype = np.float64
Expand Down
15 changes: 10 additions & 5 deletions test/legacy_test/test_rad2deg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand All @@ -32,10 +33,11 @@ def setUp(self):
self.x_shape = [6]
self.out_np = np.rad2deg(self.x_np)

@test_with_pir_api
def test_static_graph(self):
startup_program = base.Program()
train_program = base.Program()
with base.program_guard(startup_program, train_program):
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(startup_program, train_program):
x = paddle.static.data(
name='input', dtype=self.x_dtype, shape=self.x_shape
)
Expand All @@ -48,11 +50,10 @@ def test_static_graph(self):
)
exe = base.Executor(place)
res = exe.run(
base.default_main_program(),
feed={'input': self.x_np},
fetch_list=[out],
)
self.assertTrue((np.array(out[0]) == self.out_np).all())
np.testing.assert_allclose(self.out_np, res[0], rtol=1e-05)

def test_dygraph(self):
paddle.disable_static()
Expand Down Expand Up @@ -96,3 +97,7 @@ def test_dygraph(self):
np.testing.assert_allclose(180 / np.pi, result2.numpy(), rtol=1e-05)

paddle.enable_static()


if __name__ == "__main__":
unittest.main()
7 changes: 5 additions & 2 deletions test/legacy_test/test_renorm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
from paddle import base
from paddle.base import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.set_device('cpu')

Expand All @@ -32,12 +32,15 @@ def input_data(self):
self.dim = 2
self.max_norm = 2.05

@test_with_pir_api
def test_renorm_api(self):
paddle.enable_static()
self.input_data()

# case 1:
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name="x", shape=[-1, 2, 3], dtype='float64')
z = paddle.renorm(x, self.p, self.dim, self.max_norm)
exe = base.Executor(base.CPUPlace())
Expand Down

0 comments on commit db91d66

Please sign in to comment.