diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index 360d7b22569df..c196b4ebcc96f 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -19,6 +19,7 @@ Dy2StTestBase, enable_to_static_guard, test_ast_only, + test_legacy_and_pt_and_pir, test_legacy_only, ) from ifelse_simple_func import dyfunc_with_if_else @@ -26,7 +27,6 @@ import paddle from paddle.base import core from paddle.jit.dy2static.utils import Dygraph2StaticException -from paddle.pir_utils import test_with_pir_api SEED = 2020 np.random.seed(SEED) @@ -279,7 +279,7 @@ def _test_value_impl(self): self.assertEqual(dygraph_res, static_res) @test_ast_only - @test_with_pir_api + @test_legacy_and_pt_and_pir def test_transformed_static_result(self): self.init_dygraph_func() if hasattr(self, "error"): @@ -326,8 +326,7 @@ def _test_value_impl(self): else: self.assertEqual(dygraph_res, static_res) - # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype - @test_legacy_only + @test_legacy_and_pt_and_pir @test_ast_only def test_transformed_static_result(self): self.init_dygraph_func() @@ -429,8 +428,7 @@ def _test_value_impl(self): else: self.assertEqual(dygraph_res, static_res) - # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype - @test_legacy_only + @test_legacy_and_pt_and_pir @test_ast_only def test_transformed_static_result(self): self.init_dygraph_func()