Skip to content

Commit

Permalink
[Dy2St] enable_to_static_guard 推全 6-15 (#59691)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang authored Dec 6, 2023
1 parent 2d78a6b commit ac0fe2c
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 70 deletions.
11 changes: 6 additions & 5 deletions test/dygraph_to_static/test_cache_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
)
Expand Down Expand Up @@ -99,14 +100,14 @@ def setUp(self):
self.batch_num = 5

def train_static(self):
return self.train(to_static=True)
with enable_to_static_guard(True):
return self.train()

def train_dygraph(self):
return self.train(to_static=False)

def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
with enable_to_static_guard(False):
return self.train()

def train(self):
static_net = paddle.jit.to_static(self.dygraph_class())
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=static_net.parameters()
Expand Down
17 changes: 9 additions & 8 deletions test/dygraph_to_static/test_convert_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pir,
)
Expand Down Expand Up @@ -89,13 +90,13 @@ def init_test_func(self):
self.dyfunc = nested_func

def get_dygraph_output(self):
paddle.jit.enable_to_static(False)
res = self.dyfunc(self.input).numpy()
with enable_to_static_guard(False):
res = self.dyfunc(self.input).numpy()
return res

def get_static_output(self):
paddle.jit.enable_to_static(True)
res = self.dyfunc(self.input).numpy()
with enable_to_static_guard(True):
res = self.dyfunc(self.input).numpy()
return res

@test_legacy_and_pir
Expand Down Expand Up @@ -180,12 +181,12 @@ def _run(self):
return res.numpy()

def get_dygraph_output(self):
paddle.jit.enable_to_static(False)
return self._run()
with enable_to_static_guard(False):
return self._run()

def get_static_output(self):
paddle.jit.enable_to_static(True)
return self._run()
with enable_to_static_guard(True):
return self._run()

def test_transformed_static_result(self):
self.set_func()
Expand Down
7 changes: 4 additions & 3 deletions test/dygraph_to_static/test_cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -537,12 +538,11 @@ def optimizer_setting(parameters):
return optimizer


def train(args, to_static):
def train(args):
place = (
base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
)

paddle.jit.enable_to_static(to_static)
with base.dygraph.guard(place):
max_images_num = args.max_images_num
data_shape = [-1] + args.image_shape
Expand Down Expand Up @@ -687,7 +687,8 @@ def setUp(self):
self.args = Args()

def train(self, to_static):
out = train(self.args, to_static)
with enable_to_static_guard(to_static):
out = train(self.args)
return out

@test_legacy_and_pt_and_pir
Expand Down
13 changes: 6 additions & 7 deletions test/dygraph_to_static/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
)
from test_basic_api_transformation import dyfunc_to_variable
Expand Down Expand Up @@ -371,18 +372,16 @@ def test_error(self):
with self.assertRaises(RuntimeError):
func(np.ones(5).astype("int32"))

paddle.jit.enable_to_static(False)
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use base.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))
paddle.jit.enable_to_static(True)
with enable_to_static_guard(False):
with self.assertRaises(AssertionError):
# AssertionError: We Only support to_variable in imperative mode,
# please use base.dygraph.guard() as context to run it in imperative Mode
func(np.ones(5).astype("int32"))


class TestDecorateModelDirectly(Dy2StTestBase):
def setUp(self):
paddle.disable_static()
paddle.jit.enable_to_static(True)
self.x = to_variable(np.ones([4, 10]).astype('float32'))

@test_ast_only
Expand Down
21 changes: 10 additions & 11 deletions test/dygraph_to_static/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -128,13 +129,14 @@ def setUp(self):
self.batch_size = self.x.shape[0]

def _run_static(self):
return self.train(to_static=True)
with enable_to_static_guard(True):
return self.train()

def _run_dygraph(self):
return self.train(to_static=False)
with enable_to_static_guard(False):
return self.train()

def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
def train(self):
with base.dygraph.guard(PLACE):
net = paddle.jit.to_static(
MainNetWithDict(batch_size=self.batch_size)
Expand Down Expand Up @@ -190,11 +192,9 @@ def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
paddle.jit.enable_to_static(to_static)

result = self.dygraph_func(self.input)

return result.numpy()
with enable_to_static_guard(to_static):
result = self.dygraph_func(self.input)
return result.numpy()

@test_legacy_and_pt_and_pir
def test_transformed_result(self):
Expand Down Expand Up @@ -232,8 +232,7 @@ class TestDictPop3(TestNetWithDict):
def setUp(self):
self.x = np.array([2, 2]).astype('float32')

def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
def train(self):
with base.dygraph.guard(PLACE):
net = paddle.jit.to_static(NetWithDictPop())
ret = net(z=0, x=self.x, y=True)
Expand Down
10 changes: 6 additions & 4 deletions test/dygraph_to_static/test_fetch_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -66,8 +67,7 @@ def setUp(self):
self.dygraph_class = Pool2D
self.data = np.random.random((1, 2, 4, 4)).astype('float32')

def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
def train(self):
dy_layer = paddle.jit.to_static(self.dygraph_class())
x = paddle.to_tensor(self.data)
prediction = dy_layer(x)
Expand All @@ -77,10 +77,12 @@ def train(self, to_static=False):
return prediction.numpy()

def train_static(self):
return self.train(to_static=True)
with enable_to_static_guard(True):
return self.train()

def train_dygraph(self):
return self.train(to_static=False)
with enable_to_static_guard(False):
return self.train()

@test_legacy_and_pt_and_pir
def test_to_static(self):
Expand Down
29 changes: 16 additions & 13 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
test_sot_only,
)
Expand Down Expand Up @@ -345,36 +346,38 @@ def set_test_func(self):
"For Enumerate test should implement set_test_func"
)

def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
def _run(self):
self.dygraph_func = paddle.jit.to_static(self.dygraph_func)
return self.dygraph_func(self.input)

def get_dygraph_output(self):
return self._run(to_static=False)
with enable_to_static_guard(False):
return self._run()

def get_static_output(self):
return self._run(to_static=True)
with enable_to_static_guard(True):
return self._run()


class TestTransform(TestTransformBase):
def transformed_result_compare(self):
dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs,)
with enable_to_static_guard(False):
dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs,)

self.dygraph_func.eval()
st_outs = self.get_static_output()
if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs,)
with enable_to_static_guard(True):
self.dygraph_func.eval()
st_outs = self.get_static_output()
if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs,)

for x, y in zip(dy_outs, st_outs):
np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05)


class TestTransformForOriginalList(TestTransform):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
def _run(self):
self.dygraph_func = paddle.jit.to_static(self.dygraph_func)
return self.dygraph_func()

Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, enable_to_static_guard

import paddle

Expand Down Expand Up @@ -72,9 +72,8 @@ def setUp(self):
self.x.stop_gradient = False

def _run(self, func, to_static):
paddle.jit.enable_to_static(to_static)
ret = func(self.x).numpy()
paddle.jit.enable_to_static(True)
with enable_to_static_guard(to_static):
ret = func(self.x).numpy()
return ret

def test_forward(self):
Expand Down
22 changes: 11 additions & 11 deletions test/dygraph_to_static/test_grid_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -134,17 +135,16 @@ def setUp(self):
self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32')

def _run(self, to_static):
paddle.jit.enable_to_static(to_static)

net = paddle.jit.to_static(
GridGenerator(40, 20),
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 32, 100], dtype='float32'
),
],
)
ret = net(self.x, [32, 100])
with enable_to_static_guard(to_static):
net = paddle.jit.to_static(
GridGenerator(40, 20),
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 32, 100], dtype='float32'
),
],
)
ret = net(self.x, [32, 100])
return ret.numpy()

@test_legacy_and_pt_and_pir
Expand Down
8 changes: 4 additions & 4 deletions test/dygraph_to_static/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -77,10 +78,9 @@ def forward(self, x):


def train(model, to_static):
paddle.jit.enable_to_static(to_static)

x = paddle.ones(shape=[2, 3], dtype='int32')
out = model(x)
with enable_to_static_guard(to_static):
x = paddle.ones(shape=[2, 3], dtype='int32')
out = model(x)

return out.numpy()

Expand Down

0 comments on commit ac0fe2c

Please sign in to comment.