From d73930502e4e7f7d28697df6946b84cfa81c4dcf Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Wed, 29 Nov 2023 19:41:32 +0800 Subject: [PATCH 1/3] add `test_legacy_and_pir_exe_and_pir_api` --- .../test_convert_call_generator.py | 5 +- .../dygraph_to_static/test_full_name_usage.py | 42 ++--- test/dygraph_to_static/test_slice.py | 48 ++++-- test/dygraph_to_static/test_spec_names.py | 2 + test/dygraph_to_static/test_tsm.py | 156 ++++++++---------- 5 files changed, 133 insertions(+), 120 deletions(-) diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index bdd9c6364c241..16e03d616d0ad 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -17,10 +17,10 @@ from dygraph_to_static_utils import ( Dy2StTestBase, test_ast_only, + test_legacy_and_pt_and_pir, ) import paddle -from paddle.jit import to_static from paddle.jit.dy2static.convert_call_func import translator_logger @@ -38,12 +38,13 @@ def main_func(): class TestConvertGenerator(Dy2StTestBase): # fallback will ok. @test_ast_only + @test_legacy_and_pt_and_pir def test_raise_error(self): translator_logger.verbosity_level = 1 with self.assertLogs( translator_logger.logger_name, level='WARNING' ) as cm: - to_static(main_func)() + paddle.jit.to_static(main_func)() self.assertRegex( cm.output[0], "Your function:`dyfunc_generator` doesn't support " diff --git a/test/dygraph_to_static/test_full_name_usage.py b/test/dygraph_to_static/test_full_name_usage.py index ed48bb457fece..339a87727fb5e 100644 --- a/test/dygraph_to_static/test_full_name_usage.py +++ b/test/dygraph_to_static/test_full_name_usage.py @@ -15,15 +15,16 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, +) import paddle -from paddle import base -@paddle.jit.to_static(full_graph=True) def dygraph_decorated_func(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) if paddle.mean(x) > 0: x_v = x - 1 else: @@ -33,7 +34,7 @@ def dygraph_decorated_func(x): @paddle.jit.to_static(full_graph=True) def jit_decorated_func(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) if paddle.mean(x) > 0: x_v = x - 1 else: @@ -50,7 +51,7 @@ class DoubleDecorated: @classmethod @paddle.jit.to_static(full_graph=True) def double_decorated_func1(self, x): - return dygraph_decorated_func(x) + return paddle.jit.to_static(dygraph_decorated_func)(x) @classmethod @paddle.jit.to_static(full_graph=True) @@ -63,20 +64,21 @@ class TestFullNameDecorator(Dy2StTestBase): def test_run_success(self): x = np.ones([1, 2]).astype("float32") answer = np.zeros([1, 2]).astype("float32") - with base.dygraph.guard(): - np.testing.assert_allclose( - dygraph_decorated_func(x).numpy(), answer, rtol=1e-05 - ) - np.testing.assert_allclose( - jit_decorated_func(x).numpy(), answer, rtol=1e-05 - ) - np.testing.assert_allclose( - decorated_call_decorated(x).numpy(), answer, rtol=1e-05 - ) - with self.assertRaises((NotImplementedError, TypeError)): - DoubleDecorated().double_decorated_func1(x) - with self.assertRaises((NotImplementedError, TypeError)): - DoubleDecorated().double_decorated_func2(x) + np.testing.assert_allclose( + paddle.jit.to_static(dygraph_decorated_func)(x).numpy(), + answer, + rtol=1e-05, + ) + np.testing.assert_allclose( + jit_decorated_func(x).numpy(), answer, rtol=1e-05 + ) + np.testing.assert_allclose( + decorated_call_decorated(x).numpy(), answer, rtol=1e-05 + ) + with self.assertRaises((NotImplementedError, TypeError)): + DoubleDecorated().double_decorated_func1(x) + with self.assertRaises((NotImplementedError, TypeError)): + DoubleDecorated().double_decorated_func2(x) if __name__ == '__main__': diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index d50e288d3dfd1..e0e64776d778c 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pt_and_pir, +) import paddle from paddle.static import InputSpec @@ -26,7 +30,6 @@ np.random.seed(SEED) -@paddle.jit.to_static def test_slice_without_control_flow(x): # Python slice will not be transformed. x = paddle.to_tensor(x) @@ -35,7 +38,6 @@ def test_slice_without_control_flow(x): return a[0] -@paddle.jit.to_static def test_slice_in_if(x): x = paddle.to_tensor(x) a = [] @@ -70,7 +72,6 @@ def test_slice_in_while_loop(x, iter_num=3): return out[0] -@paddle.jit.to_static def test_slice_in_for_loop(x, iter_num=3): x = paddle.to_tensor(x) a = [] @@ -88,7 +89,6 @@ def test_slice_in_for_loop(x, iter_num=3): return out -@paddle.jit.to_static def test_set_value(x): x = paddle.to_tensor(x) x[0] = paddle.full(shape=[1], fill_value=2, dtype="float32") @@ -101,14 +101,13 @@ def __init__(self, input_dim, hidden): super().__init__() self.linear = paddle.nn.Linear(input_dim, hidden) - @paddle.jit.to_static def forward(self, x): x = self.linear(x) x[0] = 1 return x -class TestSliceWithoutControlFlow(Dy2StTestBase): +class TestSliceBase(Dy2StTestBase): def setUp(self): self.init_input() self.place = ( @@ -116,14 +115,16 @@ def setUp(self): if paddle.is_compiled_with_cuda() else paddle.CPUPlace() ) - self.init_dygraph_func() + self.dygraph_func = None paddle.disable_static() def init_input(self): self.input = np.random.random(3).astype('int32') def init_dygraph_func(self): - self.dygraph_func = test_slice_without_control_flow + raise NotImplementedError( + "For Enumerate test should implement set_test_func" + ) def run_dygraph_mode(self): return self._run(to_static=False) @@ -140,28 +141,41 @@ def _run(self, to_static): def run_static_mode(self): return self._run(to_static=True) + +class TestSliceWithoutControlFlow(TestSliceBase): + def init_dygraph_func(self): + self.dygraph_func = test_slice_without_control_flow + + @test_legacy_and_pt_and_pir def test_transformed_static_result(self): + self.init_dygraph_func() static_res = self.run_static_mode() dygraph_res = self.run_dygraph_mode() np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) -class TestSliceInIf(TestSliceWithoutControlFlow): +class TestSliceInIf(TestSliceBase): def init_dygraph_func(self): self.dygraph_func = test_slice_in_if + def test_transformed_static_result(self): + self.init_dygraph_func() + static_res = self.run_static_mode() + dygraph_res = self.run_dygraph_mode() + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + -class TestSliceInWhileLoop(TestSliceWithoutControlFlow): +class TestSliceInWhileLoop(TestSliceInIf): def init_dygraph_func(self): - self.dygraph_func = paddle.jit.to_static(test_slice_in_while_loop) + self.dygraph_func = test_slice_in_while_loop -class TestSliceInForLoop(TestSliceWithoutControlFlow): +class TestSliceInForLoop(TestSliceInIf): def init_dygraph_func(self): self.dygraph_func = test_slice_in_for_loop -class TestSetValue(TestSliceWithoutControlFlow): +class TestSetValue(TestSliceInIf): def init_input(self): self.input = np.full([3, 4, 5], 5).astype('float32') @@ -182,7 +196,7 @@ def tearDown(self): @test_ast_only def test_set_value_with_save(self): paddle.jit.enable_to_static(True) - model = LayerWithSetValue(input_dim=10, hidden=1) + model = paddle.jit.to_static(LayerWithSetValue(input_dim=10, hidden=1)) x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32") paddle.jit.save( layer=model, path=self.model_path, input_spec=[x], output_spec=None @@ -191,6 +205,7 @@ def test_set_value_with_save(self): class TestSliceSupplementSpecialCase(Dy2StTestBase): # unittest for slice index which abs(step)>0. eg: x[::2] + @test_legacy_and_pt_and_pir def test_static_slice_step(self): paddle.enable_static() array = np.arange(4**3).reshape((4, 4, 4)).astype('int64') @@ -209,6 +224,7 @@ def test_static_slice_step(self): np.testing.assert_array_equal(out[0], array[::2]) np.testing.assert_array_equal(out[1], array[::-2]) + @test_legacy_and_pt_and_pir def test_static_slice_step_dygraph2static(self): paddle.disable_static() @@ -233,6 +249,7 @@ def func(inps): class TestPaddleStridedSlice(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_compare_paddle_strided_slice_with_numpy(self): paddle.disable_static() array = np.arange(5) @@ -294,6 +311,7 @@ def slice_zero_shape_tensor(x): class TestSliceZeroShapeTensor(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_slice(self): paddle.disable_static() x = paddle.ones([0, 0, 0, 0]) diff --git a/test/dygraph_to_static/test_spec_names.py b/test/dygraph_to_static/test_spec_names.py index 7225f42b5941c..e367881110923 100644 --- a/test/dygraph_to_static/test_spec_names.py +++ b/test/dygraph_to_static/test_spec_names.py @@ -17,6 +17,7 @@ from dygraph_to_static_utils import ( Dy2StTestBase, test_ast_only, + test_legacy_and_pt_and_pir, ) import paddle @@ -48,6 +49,7 @@ def read_from_dataset(self): self.n = paddle.randn([4, 2, 8]) @test_ast_only + @test_legacy_and_pt_and_pir def test_spec_name_hash(self): net = Net() net = paddle.jit.to_static(net) diff --git a/test/dygraph_to_static/test_tsm.py b/test/dygraph_to_static/test_tsm.py index 7601345a296d9..31e98bab3141f 100644 --- a/test/dygraph_to_static/test_tsm.py +++ b/test/dygraph_to_static/test_tsm.py @@ -19,12 +19,13 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_default_mode_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_default_mode_only, +) from tsm_config_utils import merge_configs, parse_config, print_configs import paddle -from paddle import base -from paddle.base.dygraph import to_variable from paddle.nn import BatchNorm, Linear random.seed(0) @@ -42,7 +43,7 @@ def parse_args(): parser.add_argument( '--use_gpu', type=bool, - default=base.is_compiled_with_cuda(), + default=paddle.is_compiled_with_cuda(), help='default use gpu.', ) args = parser.parse_args( @@ -70,15 +71,15 @@ def __init__( stride=stride, padding=(filter_size - 1) // 2, groups=1, - weight_attr=base.param_attr.ParamAttr(), + weight_attr=paddle.ParamAttr(), bias_attr=False, ) self._batch_norm = BatchNorm( num_filters, act=act, - param_attr=base.param_attr.ParamAttr(), - bias_attr=base.param_attr.ParamAttr(), + param_attr=paddle.ParamAttr(), + bias_attr=paddle.ParamAttr(), ) def forward(self, inputs): @@ -299,96 +300,85 @@ def train(args, fake_data_reader, to_static): valid_config = merge_configs(config, 'valid', vars(args)) print_configs(train_config, 'Train') - place = base.CUDAPlace(0) if args.use_gpu else base.CPUPlace() - random.seed(0) np.random.seed(0) - with base.dygraph.guard(place): - paddle.seed(1000) - paddle.framework.random._manual_program_seed(1000) - - video_model = paddle.jit.to_static( - TSM_ResNet("TSM", train_config, 'Train') - ) - - optimizer = create_optimizer( - train_config.TRAIN, video_model.parameters() - ) + paddle.seed(1000) + paddle.framework.random._manual_program_seed(1000) + + video_model = paddle.jit.to_static(TSM_ResNet("TSM", train_config, 'Train')) + + optimizer = create_optimizer(train_config.TRAIN, video_model.parameters()) + + train_reader = fake_data_reader.create_reader() + + ret = [] + for epoch in range(train_config.TRAIN.epoch): + video_model.train() + total_loss = 0.0 + total_acc1 = 0.0 + total_acc5 = 0.0 + total_sample = 0 + for batch_id, data in enumerate(train_reader()): + x_data = np.array([item[0] for item in data]) + y_data = np.array([item[1] for item in data]).reshape([-1, 1]) + + imgs = paddle.to_tensor(x_data) + labels = paddle.to_tensor(y_data) + labels.stop_gradient = True + outputs = video_model(imgs) + loss = paddle.nn.functional.cross_entropy( + input=outputs, + label=labels, + ignore_index=-1, + reduction='none', + use_softmax=False, + ) + avg_loss = paddle.mean(loss) + acc_top1 = paddle.static.accuracy(input=outputs, label=labels, k=1) + acc_top5 = paddle.static.accuracy(input=outputs, label=labels, k=5) - train_reader = fake_data_reader.create_reader() - - ret = [] - for epoch in range(train_config.TRAIN.epoch): - video_model.train() - total_loss = 0.0 - total_acc1 = 0.0 - total_acc5 = 0.0 - total_sample = 0 - for batch_id, data in enumerate(train_reader()): - x_data = np.array([item[0] for item in data]) - y_data = np.array([item[1] for item in data]).reshape([-1, 1]) - - imgs = to_variable(x_data) - labels = to_variable(y_data) - labels.stop_gradient = True - outputs = video_model(imgs) - loss = paddle.nn.functional.cross_entropy( - input=outputs, - label=labels, - ignore_index=-1, - reduction='none', - use_softmax=False, - ) - avg_loss = paddle.mean(loss) - acc_top1 = paddle.static.accuracy( - input=outputs, label=labels, k=1 - ) - acc_top5 = paddle.static.accuracy( - input=outputs, label=labels, k=5 - ) + avg_loss.backward() + optimizer.minimize(avg_loss) + video_model.clear_gradients() - avg_loss.backward() - optimizer.minimize(avg_loss) - video_model.clear_gradients() - - total_loss += float(avg_loss) - total_acc1 += float(acc_top1) - total_acc5 += float(acc_top5) - total_sample += 1 - - print( - 'TRAIN Epoch {}, iter {}, loss = {}, acc1 {}, acc5 {}'.format( - epoch, - batch_id, - float(avg_loss), - float(acc_top1), - float(acc_top5), - ) - ) - ret.extend( - [ - float(avg_loss), - float(acc_top1), - float(acc_top5), - ] - ) + total_loss += float(avg_loss) + total_acc1 += float(acc_top1) + total_acc5 += float(acc_top5) + total_sample += 1 print( - 'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}'.format( + 'TRAIN Epoch {}, iter {}, loss = {}, acc1 {}, acc5 {}'.format( epoch, - total_loss / total_sample, - total_acc1 / total_sample, - total_acc5 / total_sample, + batch_id, + float(avg_loss), + float(acc_top1), + float(acc_top5), ) ) - return ret + ret.extend( + [ + float(avg_loss), + float(acc_top1), + float(acc_top5), + ] + ) + + print( + 'TRAIN End, Epoch {}, avg_loss= {}, avg_acc1= {}, avg_acc5= {}'.format( + epoch, + total_loss / total_sample, + total_acc1 / total_sample, + total_acc5 / total_sample, + ) + ) + return ret class TestTsm(Dy2StTestBase): @test_default_mode_only def test_dygraph_static_same_loss(self): - if base.is_compiled_with_cuda(): - base.set_flags({"FLAGS_cudnn_deterministic": True}) + if paddle.is_compiled_with_cuda(): + paddle.set_flags({"FLAGS_cudnn_deterministic": True}) args = parse_args() fake_data_reader = FakeDataReader("train", parse_config(args.config)) dygraph_loss = train(args, fake_data_reader, to_static=False) From 4bcc4ccb0783b5607ddebeeb0d97d4f623dfc1e7 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Thu, 30 Nov 2023 22:31:25 +0800 Subject: [PATCH 2/3] clean code --- test/dygraph_to_static/test_slice.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index e0e64776d778c..7a221647af22d 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -110,11 +110,6 @@ def forward(self, x): class TestSliceBase(Dy2StTestBase): def setUp(self): self.init_input() - self.place = ( - paddle.CUDAPlace(0) - if paddle.is_compiled_with_cuda() - else paddle.CPUPlace() - ) self.dygraph_func = None paddle.disable_static() From a041bec387636e01b93000ea89bd0b9f4e499af6 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 4 Dec 2023 07:08:13 +0000 Subject: [PATCH 3/3] update tsm --- python/paddle/optimizer/optimizer.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 771cf337f58e1..311bbfd5abc95 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -31,6 +31,7 @@ in_dynamic_or_pir_mode, in_pir_mode, name_scope, + use_pir_api, ) from paddle.regularizer import L2Decay @@ -788,12 +789,19 @@ def _create_param_lr(self, param_and_grad): if param_lr == 1.0: return self._global_learning_rate() else: - with paddle.static.default_main_program()._lr_schedule_guard( - is_with_opt=True - ), framework.name_scope( - 'scale_with_param_lr' - ): - return self._global_learning_rate() * param_lr + if not use_pir_api(): + with paddle.static.default_main_program()._lr_schedule_guard( + is_with_opt=True + ), framework.name_scope( + 'scale_with_param_lr' + ): + return self._global_learning_rate() * param_lr + else: + # TODO(dev): Currently there has not equivalent of op_role in PIR + # mode, so we simply remove _lr_schedule_guard here, this should + # be fixed in the future. + with framework.name_scope('scale_with_param_lr'): + return self._global_learning_rate() * param_lr else: return self._global_learning_rate()