diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index bce89dba8ef44..34744a6567cf0 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -18,6 +18,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_ast_only, test_legacy_and_pt_and_pir, ) @@ -99,14 +100,14 @@ def setUp(self): self.batch_num = 5 def train_static(self): - return self.train(to_static=True) + with enable_to_static_guard(True): + return self.train() def train_dygraph(self): - return self.train(to_static=False) - - def train(self, to_static=False): - paddle.jit.enable_to_static(to_static) + with enable_to_static_guard(False): + return self.train() + def train(self): static_net = paddle.jit.to_static(self.dygraph_class()) adam = paddle.optimizer.Adam( learning_rate=0.001, parameters=static_net.parameters() diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 1d64447fdda33..dc451f949b35c 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -18,6 +18,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_ast_only, test_legacy_and_pir, ) @@ -89,13 +90,13 @@ def init_test_func(self): self.dyfunc = nested_func def get_dygraph_output(self): - paddle.jit.enable_to_static(False) - res = self.dyfunc(self.input).numpy() + with enable_to_static_guard(False): + res = self.dyfunc(self.input).numpy() return res def get_static_output(self): - paddle.jit.enable_to_static(True) - res = self.dyfunc(self.input).numpy() + with enable_to_static_guard(True): + res = self.dyfunc(self.input).numpy() return res @test_legacy_and_pir @@ -180,12 +181,12 @@ def _run(self): return res.numpy() def get_dygraph_output(self): - paddle.jit.enable_to_static(False) - return self._run() + with enable_to_static_guard(False): + return self._run() def get_static_output(self): - paddle.jit.enable_to_static(True) - return self._run() + with enable_to_static_guard(True): + return self._run() def test_transformed_static_result(self): self.set_func() diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index d03c1cc5cc759..8465d5f93577d 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -28,6 +28,7 @@ from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -537,12 +538,11 @@ def optimizer_setting(parameters): return optimizer -def train(args, to_static): +def train(args): place = ( base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace() ) - paddle.jit.enable_to_static(to_static) with base.dygraph.guard(place): max_images_num = args.max_images_num data_shape = [-1] + args.image_shape @@ -687,7 +687,8 @@ def setUp(self): self.args = Args() def train(self, to_static): - out = train(self.args, to_static) + with enable_to_static_guard(to_static): + out = train(self.args) return out @test_legacy_and_pt_and_pir diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index 07c7c91df1b4d..5e10e51bc354d 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -19,6 +19,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_ast_only, ) from test_basic_api_transformation import dyfunc_to_variable @@ -371,18 +372,16 @@ def test_error(self): with self.assertRaises(RuntimeError): func(np.ones(5).astype("int32")) - paddle.jit.enable_to_static(False) - with self.assertRaises(AssertionError): - # AssertionError: We Only support to_variable in imperative mode, - # please use base.dygraph.guard() as context to run it in imperative Mode - func(np.ones(5).astype("int32")) - paddle.jit.enable_to_static(True) + with enable_to_static_guard(False): + with self.assertRaises(AssertionError): + # AssertionError: We Only support to_variable in imperative mode, + # please use base.dygraph.guard() as context to run it in imperative Mode + func(np.ones(5).astype("int32")) class TestDecorateModelDirectly(Dy2StTestBase): def setUp(self): paddle.disable_static() - paddle.jit.enable_to_static(True) self.x = to_variable(np.ones([4, 10]).astype('float32')) @test_ast_only diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index f69b112ba9afd..457d0995677db 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -17,6 +17,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -128,13 +129,14 @@ def setUp(self): self.batch_size = self.x.shape[0] def _run_static(self): - return self.train(to_static=True) + with enable_to_static_guard(True): + return self.train() def _run_dygraph(self): - return self.train(to_static=False) + with enable_to_static_guard(False): + return self.train() - def train(self, to_static=False): - paddle.jit.enable_to_static(to_static) + def train(self): with base.dygraph.guard(PLACE): net = paddle.jit.to_static( MainNetWithDict(batch_size=self.batch_size) @@ -190,11 +192,9 @@ def _run_dygraph(self): return self._run(to_static=False) def _run(self, to_static): - paddle.jit.enable_to_static(to_static) - - result = self.dygraph_func(self.input) - - return result.numpy() + with enable_to_static_guard(to_static): + result = self.dygraph_func(self.input) + return result.numpy() @test_legacy_and_pt_and_pir def test_transformed_result(self): @@ -232,8 +232,7 @@ class TestDictPop3(TestNetWithDict): def setUp(self): self.x = np.array([2, 2]).astype('float32') - def train(self, to_static=False): - paddle.jit.enable_to_static(to_static) + def train(self): with base.dygraph.guard(PLACE): net = paddle.jit.to_static(NetWithDictPop()) ret = net(z=0, x=self.x, y=True) diff --git a/test/dygraph_to_static/test_fetch_feed.py b/test/dygraph_to_static/test_fetch_feed.py index 6ee8f295b5696..ab1e83f63843d 100644 --- a/test/dygraph_to_static/test_fetch_feed.py +++ b/test/dygraph_to_static/test_fetch_feed.py @@ -17,6 +17,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -66,8 +67,7 @@ def setUp(self): self.dygraph_class = Pool2D self.data = np.random.random((1, 2, 4, 4)).astype('float32') - def train(self, to_static=False): - paddle.jit.enable_to_static(to_static) + def train(self): dy_layer = paddle.jit.to_static(self.dygraph_class()) x = paddle.to_tensor(self.data) prediction = dy_layer(x) @@ -77,10 +77,12 @@ def train(self, to_static=False): return prediction.numpy() def train_static(self): - return self.train(to_static=True) + with enable_to_static_guard(True): + return self.train() def train_dygraph(self): - return self.train(to_static=False) + with enable_to_static_guard(False): + return self.train() @test_legacy_and_pt_and_pir def test_to_static(self): diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index f4c9ab1ce7774..964c0871303af 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -19,6 +19,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, test_sot_only, ) @@ -345,36 +346,38 @@ def set_test_func(self): "For Enumerate test should implement set_test_func" ) - def _run(self, to_static): - paddle.jit.enable_to_static(to_static) + def _run(self): self.dygraph_func = paddle.jit.to_static(self.dygraph_func) return self.dygraph_func(self.input) def get_dygraph_output(self): - return self._run(to_static=False) + with enable_to_static_guard(False): + return self._run() def get_static_output(self): - return self._run(to_static=True) + with enable_to_static_guard(True): + return self._run() class TestTransform(TestTransformBase): def transformed_result_compare(self): - dy_outs = self.get_dygraph_output() - if not isinstance(dy_outs, (tuple, list)): - dy_outs = (dy_outs,) + with enable_to_static_guard(False): + dy_outs = self.get_dygraph_output() + if not isinstance(dy_outs, (tuple, list)): + dy_outs = (dy_outs,) - self.dygraph_func.eval() - st_outs = self.get_static_output() - if not isinstance(st_outs, (tuple, list)): - st_outs = (st_outs,) + with enable_to_static_guard(True): + self.dygraph_func.eval() + st_outs = self.get_static_output() + if not isinstance(st_outs, (tuple, list)): + st_outs = (st_outs,) for x, y in zip(dy_outs, st_outs): np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05) class TestTransformForOriginalList(TestTransform): - def _run(self, to_static): - paddle.jit.enable_to_static(to_static) + def _run(self): self.dygraph_func = paddle.jit.to_static(self.dygraph_func) return self.dygraph_func() diff --git a/test/dygraph_to_static/test_grad.py b/test/dygraph_to_static/test_grad.py index 6afed83e64c8d..15876ddb3f6a4 100644 --- a/test/dygraph_to_static/test_grad.py +++ b/test/dygraph_to_static/test_grad.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, enable_to_static_guard import paddle @@ -72,9 +72,8 @@ def setUp(self): self.x.stop_gradient = False def _run(self, func, to_static): - paddle.jit.enable_to_static(to_static) - ret = func(self.x).numpy() - paddle.jit.enable_to_static(True) + with enable_to_static_guard(to_static): + ret = func(self.x).numpy() return ret def test_forward(self): diff --git a/test/dygraph_to_static/test_grid_generator.py b/test/dygraph_to_static/test_grid_generator.py index 75a14bfb89fd4..580d4e710891b 100644 --- a/test/dygraph_to_static/test_grid_generator.py +++ b/test/dygraph_to_static/test_grid_generator.py @@ -17,6 +17,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -134,17 +135,16 @@ def setUp(self): self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32') def _run(self, to_static): - paddle.jit.enable_to_static(to_static) - - net = paddle.jit.to_static( - GridGenerator(40, 20), - input_spec=[ - paddle.static.InputSpec( - shape=[None, 3, 32, 100], dtype='float32' - ), - ], - ) - ret = net(self.x, [32, 100]) + with enable_to_static_guard(to_static): + net = paddle.jit.to_static( + GridGenerator(40, 20), + input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 32, 100], dtype='float32' + ), + ], + ) + ret = net(self.x, [32, 100]) return ret.numpy() @test_legacy_and_pt_and_pir diff --git a/test/dygraph_to_static/test_isinstance.py b/test/dygraph_to_static/test_isinstance.py index 9cac83cdbbe2f..ff8634e4c88c6 100644 --- a/test/dygraph_to_static/test_isinstance.py +++ b/test/dygraph_to_static/test_isinstance.py @@ -28,6 +28,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -77,10 +78,9 @@ def forward(self, x): def train(model, to_static): - paddle.jit.enable_to_static(to_static) - - x = paddle.ones(shape=[2, 3], dtype='int32') - out = model(x) + with enable_to_static_guard(to_static): + x = paddle.ones(shape=[2, 3], dtype='int32') + out = model(x) return out.numpy()