From 12ea7dde800161dd210ad329abba72102f6c39dc Mon Sep 17 00:00:00 2001 From: luq Date: Thu, 14 Dec 2023 08:20:14 +0000 Subject: [PATCH 1/2] add searchsorted --- python/paddle/tensor/search.py | 2 +- test/legacy_test/test_searchsorted_op.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 8b2700615862b..25042b6085d49 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1136,7 +1136,7 @@ def searchsorted( [1, 3, 4, 5]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.searchsorted(sorted_sequence, values, out_int32, right) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_searchsorted_op.py b/test/legacy_test/test_searchsorted_op.py index c3537fc4a47f4..bf1f371f56024 100644 --- a/test/legacy_test/test_searchsorted_op.py +++ b/test/legacy_test/test_searchsorted_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -42,7 +43,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def init_test_case(self): self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float32") @@ -102,6 +103,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api def test_static_api(self): paddle.enable_static() @@ -154,6 +156,7 @@ def test_out_int32(self): class TestSearchSortedError(unittest.TestCase): + @test_with_pir_api def test_error_api(self): paddle.enable_static() @@ -201,6 +204,7 @@ def test_searchsorted_sortedsequence_size_error(): RuntimeError, test_searchsorted_sortedsequence_size_error ) + def test_check_type_error(self): def test_sortedsequence_values_type_error(): with paddle.static.program_guard(paddle.static.Program()): sorted_sequence = paddle.static.data( From 04971fbf519b389ad8ecc55ab3a9468bf0319cb6 Mon Sep 17 00:00:00 2001 From: luq Date: Thu, 14 Dec 2023 08:31:10 +0000 Subject: [PATCH 2/2] add tensordot --- python/paddle/tensor/manipulation.py | 4 +++- test/legacy_test/test_tensordot.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 91a3887d66387..6f1fcbbcc4ce0 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4723,7 +4723,9 @@ def tensordot(x, y, axes=2, name=None): check_variable_and_dtype(x, 'x', input_dtype, op_type) check_variable_and_dtype(y, 'y', input_dtype, op_type) - check_type(axes, 'axes', (int, tuple, list, Variable), op_type) + check_type( + axes, 'axes', (int, tuple, list, Variable, paddle.pir.Value), op_type + ) def _var_to_list(var): if in_dynamic_mode(): diff --git a/test/legacy_test/test_tensordot.py b/test/legacy_test/test_tensordot.py index 0e41772abd6cb..072a38415a52d 100644 --- a/test/legacy_test/test_tensordot.py +++ b/test/legacy_test/test_tensordot.py @@ -18,6 +18,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(2021) @@ -205,6 +206,7 @@ def test_dygraph(self): np_res = tensordot_np(self.x, self.y, axes) np.testing.assert_allclose(paddle_res, np_res, rtol=1e-6) + @test_with_pir_api def test_static(self): paddle.enable_static() for axes in self.all_axes: @@ -226,6 +228,7 @@ def test_static(self): np_res = tensordot_np(self.x, self.y, axes) np.testing.assert_allclose(paddle_res[0], np_res, rtol=1e-6) + @test_with_pir_api def test_fp16_with_gpu(self): paddle.enable_static() if paddle.base.core.is_compiled_with_cuda():