Skip to content

Commit

Permalink
[Dy2St] remove compare_legacy_with_pt in dygraph_to_static (#59427)
Browse files Browse the repository at this point in the history
  • Loading branch information
ooooo-create authored Nov 28, 2023
1 parent a386366 commit 797c800
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 14 deletions.
4 changes: 1 addition & 3 deletions test/dygraph_to_static/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt
from dygraph_to_static_utils import Dy2StTestBase

import paddle
from paddle import base
Expand Down Expand Up @@ -126,7 +126,6 @@ def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.batch_size = self.x.shape[0]

@compare_legacy_with_pt
def _run_static(self):
return self.train(to_static=True)

Expand Down Expand Up @@ -182,7 +181,6 @@ def setUp(self):
def _set_test_func(self):
self.dygraph_func = test_dic_pop

@compare_legacy_with_pt
def _run_static(self):
return self._run(to_static=True)

Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_layer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt
from dygraph_to_static_utils import Dy2StTestBase

import paddle

Expand Down Expand Up @@ -66,7 +66,6 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@compare_legacy_with_pt
def train_net(self, to_static=False):
paddle.seed(2022)
net = SimpleNet()
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_pir_selectedrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import unittest

from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt
from dygraph_to_static_utils import Dy2StTestBase

import paddle
from paddle.jit.api import to_static
Expand Down Expand Up @@ -77,7 +77,6 @@ def train_dygraph():
return train(net, adam, x)


@compare_legacy_with_pt
def train_static():
paddle.seed(100)
net = IRSelectedRowsTestNet()
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_ptb_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt
from dygraph_to_static_utils import Dy2StTestBase

import paddle
from paddle import base
Expand Down Expand Up @@ -315,7 +315,6 @@ def train_dygraph(place):
return train(place)


@compare_legacy_with_pt
def train_static(place):
paddle.jit.enable_to_static(True)
return train(place)
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
compare_legacy_with_pt,
test_ast_only,
)

Expand Down Expand Up @@ -266,7 +265,6 @@ def _run(self, to_static):
def get_dygraph_output(self):
return self._run(to_static=False)

@compare_legacy_with_pt
def get_static_output(self):
return self._run(to_static=True)

Expand Down
4 changes: 1 addition & 3 deletions test/dygraph_to_static/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np
import transformer_util as util
from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt
from dygraph_to_static_utils import Dy2StTestBase
from transformer_dygraph_model import (
CrossEntropyCriterion,
Transformer,
Expand All @@ -36,7 +36,6 @@
STEP_NUM = 10


@compare_legacy_with_pt
def train_static(args, batch_generator):
paddle.enable_static()
paddle.seed(SEED)
Expand Down Expand Up @@ -419,7 +418,6 @@ def predict_dygraph(args, batch_generator):
return seq_ids, seq_scores


@compare_legacy_with_pt
def predict_static(args, batch_generator):
test_prog = base.Program()
with base.program_guard(test_prog):
Expand Down

0 comments on commit 797c800

Please sign in to comment.