Skip to content

Commit

Permalink
【PIR API adaptor No.5-6】Migrate paddle.amax/amin into pir (#58546)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored Nov 2, 2023
1 parent 038d4b4 commit 210ca0f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,7 +3079,7 @@ def amax(x, axis=None, keepdim=False, name=None):
[[0.50000000, 0.33333333],
[0. , 0. ]]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.amax(x, axis, keepdim)

else:
Expand Down Expand Up @@ -3227,7 +3227,7 @@ def amin(x, axis=None, keepdim=False, name=None):
[[0.50000000, 0.33333333],
[0. , 0. ]]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.amin(x, axis, keepdim)

else:
Expand Down
3 changes: 2 additions & 1 deletion test/legacy_test/test_max_min_amax_amin_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -90,6 +91,7 @@ def _choose_paddle_func(self, func, x):
return out

# We check the output between paddle API and numpy in static graph.
@test_with_pir_api
def test_static_graph(self):
def _test_static_graph(func):
startup_program = base.Program()
Expand All @@ -103,7 +105,6 @@ def _test_static_graph(func):

exe = base.Executor(self.place)
res = exe.run(
base.default_main_program(),
feed={'input': self.x_np},
fetch_list=[out],
)
Expand Down

0 comments on commit 210ca0f

Please sign in to comment.