From 210ca0fa841de2c980b37b3dbaccef3ea977979c Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Thu, 2 Nov 2023 14:14:27 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.5-6=E3=80=91M?= =?UTF-8?q?igrate=20paddle.amax/amin=20into=20pir=20(#58546)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/tensor/math.py | 4 ++-- test/legacy_test/test_max_min_amax_amin_op.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ae3d6121061a6..86c7ec13db055 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3079,7 +3079,7 @@ def amax(x, axis=None, keepdim=False, name=None): [[0.50000000, 0.33333333], [0. , 0. ]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.amax(x, axis, keepdim) else: @@ -3227,7 +3227,7 @@ def amin(x, axis=None, keepdim=False, name=None): [[0.50000000, 0.33333333], [0. , 0. ]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.amin(x, axis, keepdim) else: diff --git a/test/legacy_test/test_max_min_amax_amin_op.py b/test/legacy_test/test_max_min_amax_amin_op.py index b5184bd3acd20..4c07869f6f988 100644 --- a/test/legacy_test/test_max_min_amax_amin_op.py +++ b/test/legacy_test/test_max_min_amax_amin_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -90,6 +91,7 @@ def _choose_paddle_func(self, func, x): return out # We check the output between paddle API and numpy in static graph. + @test_with_pir_api def test_static_graph(self): def _test_static_graph(func): startup_program = base.Program() @@ -103,7 +105,6 @@ def _test_static_graph(func): exe = base.Executor(self.place) res = exe.run( - base.default_main_program(), feed={'input': self.x_np}, fetch_list=[out], )