From dbf84afb16be0db07ef79c8159e8a1b6d5769778 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 3 Sep 2020 13:09:34 +0000 Subject: [PATCH 1/4] Add naming rule if not specific InputSpec.name --- .../dygraph_to_static/function_spec.py | 63 +++++++++++++++++++ .../dygraph_to_static/test_declarative.py | 36 ++++++++--- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 90e38bd98863f..9a8b8f7114145 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -135,6 +135,11 @@ def args_to_input_spec(self, args, kwargs): input_with_spec = pack_sequence_as(args, input_with_spec) + # If without specificing name in input_spec, add default name + # according by argument name from decorated function. + input_with_spec = replace_spec_empty_name(self._arg_names, + input_with_spec) + return input_with_spec @switch_to_static_graph @@ -309,3 +314,61 @@ def check_type_and_len(input, spec, check_length=False): raise TypeError( "The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.". type_name(input_spec)) + + +def replace_spec_empty_name(args_name, input_with_spec): + """ + Adds default name according to argument name from decorated function + if without specificing inputSpec.name + + The naming rule are as followed: + 1. If InputSpec.name is not None, do nothing. + 2. If each argument `x` corresponds to an InputSpec, using the argument name like `x` + 3. If the arguments `inputs` corresponds to a list(InputSpec), using name like `inputs_0`, `inputs_1` + 4. If the arguments `input_dic` corresponds to a dict(InputSpec), using key as name. + + For example: + + # case 1: foo(x, y) + foo = to_static(foo, input_spec=[InputSpec([None, 10]), InputSpec([None])]) + print([in_var.name for in_var in foo.inputs]) # [x, y] + + # case 2: foo(inputs) where inputs is a list + foo = to_static(foo, input_spec=[[InputSpec([None, 10]), InputSpec([None])]]) + print([in_var.name for in_var in foo.inputs]) # [inputs_0, inputs_1] + + # case 3: foo(inputs) where inputs is a dict + foo = to_static(foo, input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None])}]) + print([in_var.name for in_var in foo.inputs]) # [x, y] + """ + input_with_spec = list(input_with_spec) + candidate_arg_names = args_name[:len(input_with_spec)] + + for i, arg_name in enumerate(candidate_arg_names): + input_spec = input_with_spec[i] + input_with_spec[i] = _replace_spec_name(arg_name, input_spec) + + return input_with_spec + + +def _replace_spec_name(name, input_spec): + """ + Replaces InputSpec.name with given `name` while not specificing it. + """ + if isinstance(input_spec, paddle.static.InputSpec): + if input_spec.name is None: + input_spec.name = name + return input_spec + elif isinstance(input_spec, (list, tuple)): + processed_specs = [] + for i, spec in enumerate(input_spec): + new_name = "{}_{}".format(name, i) + processed_specs.append(fill_name(new_name, spec)) + return processed_specs + elif isinstance(input_spec, dict): + processed_specs = {} + for key, spec in six.iteritems(input_spec): + processed_specs[key] = fill_name(key, spec) + return processed_specs + else: + return input_spec diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index 949286f63efb3..333e367c10602 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -47,8 +47,8 @@ def add_func(self, x, y): return z @declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]]) - def func_with_list(self, l): - x, y, int_val = l + def func_with_list(self, l, int_val=1): + x, y = l z = x + y z = z + int_val return z @@ -60,10 +60,7 @@ def func_with_list(self, l): def func_with_dict(self, d): x = d['x'] y = d['y'] - int_val = d['int_val'] - z = x + y - z = z + int_val return z @@ -114,10 +111,10 @@ def test_with_input_spec(self): self.assertTrue(len(net.add_func.program_cache) == 1) # 5. test input with list - out = net.func_with_list([x, y, int_val]) + out = net.func_with_list([x, y], int_val) # 6. test input with dict - out = net.func_with_dict({'x': x, 'y': y, 'int_val': int_val}) + out = net.func_with_dict({'x': x, 'y': y}) # 7. test input with lits contains dict int_np = np.ones([1]).astype('float32') @@ -277,6 +274,31 @@ def test_concrete_program(self): foo_3.concrete_program +class TestInputDefaultName(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.net = SimpleNet() + + def assert_default_name(self, func_name, input_names): + decorated_func = getattr(self.net, func_name) + + spec_names = [x.name for x in decorated_func.inputs] + print(spec_names) + self.assertListEqual(spec_names, input_names) + + def test_common_input(self): + self.assert_default_name('forward', ['x']) + + def test_list_input(self): + self.assert_default_name('func_with_list', ['l_0', 'l_1']) + + def test_dict_input(self): + self.assert_default_name('func_with_dict', ['x', 'y']) + + def test_nest_input(self): + self.assert_default_name('func_with_list_dict', ['dl_0', 'x', 'y']) + + class TestDeclarativeAPI(unittest.TestCase): def test_error(self): func = declarative(dyfunc_to_variable) From 2aef62a250d406f77eb247f1bca78d45fd322c14 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 3 Sep 2020 13:55:54 +0000 Subject: [PATCH 2/4] fix function name typo --- .../paddle/fluid/dygraph/dygraph_to_static/function_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 9a8b8f7114145..c9594a55f0bc0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -363,12 +363,12 @@ def _replace_spec_name(name, input_spec): processed_specs = [] for i, spec in enumerate(input_spec): new_name = "{}_{}".format(name, i) - processed_specs.append(fill_name(new_name, spec)) + processed_specs.append(_replace_spec_name(new_name, spec)) return processed_specs elif isinstance(input_spec, dict): processed_specs = {} for key, spec in six.iteritems(input_spec): - processed_specs[key] = fill_name(key, spec) + processed_specs[key] = _replace_spec_name(key, spec) return processed_specs else: return input_spec From 7dcf7be55ac339616d516e0fa31f1c1fd671a32c Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 3 Sep 2020 13:58:42 +0000 Subject: [PATCH 3/4] refine comment --- .../paddle/fluid/dygraph/dygraph_to_static/function_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index c9594a55f0bc0..37ce8b0a152ff 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -136,7 +136,7 @@ def args_to_input_spec(self, args, kwargs): input_with_spec = pack_sequence_as(args, input_with_spec) # If without specificing name in input_spec, add default name - # according by argument name from decorated function. + # according to argument name from decorated function. input_with_spec = replace_spec_empty_name(self._arg_names, input_with_spec) @@ -319,7 +319,7 @@ def check_type_and_len(input, spec, check_length=False): def replace_spec_empty_name(args_name, input_with_spec): """ Adds default name according to argument name from decorated function - if without specificing inputSpec.name + if without specificing InputSpec.name The naming rule are as followed: 1. If InputSpec.name is not None, do nothing. From c840651f095aad83d24d5cd6fa8927060fdbca7f Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 7 Sep 2020 11:20:20 +0000 Subject: [PATCH 4/4] remove print statement --- .../fluid/tests/unittests/dygraph_to_static/test_declarative.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index 333e367c10602..3c3728507453e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -283,7 +283,6 @@ def assert_default_name(self, func_name, input_names): decorated_func = getattr(self.net, func_name) spec_names = [x.name for x in decorated_func.inputs] - print(spec_names) self.assertListEqual(spec_names, input_names) def test_common_input(self):