From 9da63111bbf41628b108afa6549963371daa08fc Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Sat, 15 Jul 2023 16:35:23 +0800 Subject: [PATCH 01/12] [NewIR]Remove compatible logic of ProgramTranslator --- .../ir_adaptor/translator/op_compat_info.h | 6 +-- paddle/fluid/ir_adaptor/translator/utils.h | 42 ------------------- test/legacy_test/eager_op_test.py | 4 -- 3 files changed, 2 insertions(+), 50 deletions(-) delete mode 100644 paddle/fluid/ir_adaptor/translator/utils.h diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index 5feb2c6c76b07..02276d0ee26bc 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -19,8 +19,6 @@ #include "glog/logging.h" -#include "paddle/fluid/ir_adaptor/translator/utils.h" - #pragma once namespace paddle { @@ -106,11 +104,11 @@ class OpNameNormalizer { return legacy_name; } if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { - return UnderscoreToCamelCase(arg_name); + return arg_name; } auto& arg_mappings = op_arg_name_mappings[op_type]; if (arg_mappings.find(arg_name) == arg_mappings.end()) { - return UnderscoreToCamelCase(arg_name); + return arg_name; } return arg_mappings.at(arg_name); } diff --git a/paddle/fluid/ir_adaptor/translator/utils.h b/paddle/fluid/ir_adaptor/translator/utils.h deleted file mode 100644 index 7065f46992c6a..0000000000000 --- a/paddle/fluid/ir_adaptor/translator/utils.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -namespace paddle { -namespace translator { - -static std::string UnderscoreToCamelCase(std::string str) { - std::string camel_case; - bool next_upper = true; - for (char c : str) { - if (c == '_') { - next_upper = true; - } else { - if (next_upper) { - camel_case += toupper(c); - next_upper = false; - } else { - camel_case += c; - } - } - } - return camel_case; -} - -} // namespace translator -} // namespace paddle diff --git a/test/legacy_test/eager_op_test.py b/test/legacy_test/eager_op_test.py index e4d1ff08d99af..f80786a522283 100644 --- a/test/legacy_test/eager_op_test.py +++ b/test/legacy_test/eager_op_test.py @@ -1203,8 +1203,6 @@ def _calc_dygraph_output( def _check_ir_output(self, place, program, feed_map, fetch_list, outs): if os.getenv("FLAGS_NEW_IR_OPTEST") is None: return - if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: - return if self.check_prim: return if self._check_cinn: @@ -2877,8 +2875,6 @@ def _check_ir_grad_output( ): if os.getenv("FLAGS_NEW_IR_OPTEST") is None: return - if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: - return if self.check_prim: return if self._check_cinn: From 326414b397dd0b75c27c84042db769504f1d3be2 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 17 Jul 2023 10:54:08 +0800 Subject: [PATCH 02/12] enable flag new ir opset white list --- test/legacy_test/eager_op_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/legacy_test/eager_op_test.py b/test/legacy_test/eager_op_test.py index f80786a522283..e4d1ff08d99af 100644 --- a/test/legacy_test/eager_op_test.py +++ b/test/legacy_test/eager_op_test.py @@ -1203,6 +1203,8 @@ def _calc_dygraph_output( def _check_ir_output(self, place, program, feed_map, fetch_list, outs): if os.getenv("FLAGS_NEW_IR_OPTEST") is None: return + if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: + return if self.check_prim: return if self._check_cinn: @@ -2875,6 +2877,8 @@ def _check_ir_grad_output( ): if os.getenv("FLAGS_NEW_IR_OPTEST") is None: return + if os.getenv("FLAGS_NEW_IR_OPTEST_WHITE_LIST") is None: + return if self.check_prim: return if self._check_cinn: From 7813103c92e6076c9435289fd1540cb3add4d8a1 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Tue, 18 Jul 2023 11:01:16 +0800 Subject: [PATCH 03/12] add fetch_v2 in op_compat.yaml --- paddle/fluid/ir_adaptor/translator/op_compat_gen.py | 3 --- paddle/phi/api/yaml/op_compat.yaml | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 913d8f26f9cd1..b69c49771033c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -126,9 +126,6 @@ def insert_new_mutable_attributes( backward_op, op_compat_item["scalar"] ) - # special op mappings - op_name_mappings["fetch_v2"] = "fetch" - op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index aa19a4a1027c0..05cd78452b93c 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3009,7 +3009,7 @@ yolo_loss : GetYoloLossExpectedKernelType yolo_loss_grad : GetYoloLossExpectedKernelType -- op: fetch +- op: fetch (fetch_v2) inputs: {x: X} outputs: {out: Out} From f87e94423c71827d5ec6bb28ca0a46fb5e66e93d Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 20 Jul 2023 17:26:05 +0800 Subject: [PATCH 04/12] try to fix --- .../ir_adaptor/translator/op_compat_gen.py | 7 +++++- paddle/phi/api/yaml/op_compat.yaml | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index b69c49771033c..c3cc70bfb188d 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -28,6 +28,7 @@ undefined=StrictUndefined, extensions=['jinja2.ext.do'], ) +grad_suffix = '_grad' def OpNameNormalizerInitialization( @@ -86,7 +87,7 @@ def insert_new_mutable_attributes( attribute_name ].append(v) - _, legacy_name = insert_new_mappings(op_compat_item["op"]) + phi_name, legacy_name = insert_new_mappings(op_compat_item["op"]) legacy_backward_op_names = [] if "backward" in op_compat_item: backward_op_name_mapping_paris = op_compat_item["backward"].split( @@ -95,6 +96,10 @@ def insert_new_mutable_attributes( for pair in backward_op_name_mapping_paris: _, legacy_backward_op_name = insert_new_mappings(pair) legacy_backward_op_names.append(legacy_backward_op_name) + elif phi_name != legacy_name: + legacy_backward_op_name = legacy_name + grad_suffix + legacy_backward_op_names.append(legacy_backward_op_name) + op_name_mappings[legacy_backward_op_name] = phi_name + grad_suffix if "inputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index cd3527500ed2e..b040c02469044 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3027,11 +3027,27 @@ yolo_loss : GetYoloLossExpectedKernelType yolo_loss_grad : GetYoloLossExpectedKernelType +- op: channel_shuffle + inputs: + {x: X} + outputs: + {out: Out} + - op: fetch (fetch_v2) inputs: {x: X} outputs: {out: Out} - op: full_batch_size_like (fill_constant_batch_size_like) + inputs: + {input: Input} + outputs: + {out: Out} + +- op: logspace + inputs: + {start: Start, stop: Stop, num: Num, base: Base} + outputs: + {out: Out} - op: lu backward: lu_grad @@ -3048,6 +3064,12 @@ outputs : {reindex_src : Reindex_Src, reindex_dst : Reindex_Dst, out_nodes : Out_Nodes} +- op: rrelu + inputs: + {x: X} + outputs: + {out: Out, Noise: noise} + - op: sigmoid_cross_entropy_with_logits backward: sigmoid_cross_entropy_with_logits_grad inputs : From cbe321b15ed3c3c67d80fa165e1d45fe6ea10175 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 20 Jul 2023 18:37:04 +0800 Subject: [PATCH 05/12] fix --- paddle/phi/api/yaml/op_compat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index b040c02469044..a0c39f4acc3f4 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3068,7 +3068,7 @@ inputs: {x: X} outputs: - {out: Out, Noise: noise} + {out: Out, noise: Noise} - op: sigmoid_cross_entropy_with_logits backward: sigmoid_cross_entropy_with_logits_grad From d8ad9ee6d17f6e9dfcf6f62d95f8f2f236cbcf51 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 21 Jul 2023 15:52:38 +0800 Subject: [PATCH 06/12] fix bilinear_tensor_product --- paddle/phi/api/yaml/op_compat.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a0c39f4acc3f4..e101bb42b90c7 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -350,6 +350,7 @@ attrs : [bool use_mkldnn = false] - op : bilinear (bilinear_tensor_product) + backward: bilinear_grad (bilinear_tensor_product_grad) inputs : {x : X, y : Y,weight: Weight, bias: Bias} outputs : From 05cd0cf101e9cf822a2d1a695c0d327cb06cd067 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 21 Jul 2023 16:27:52 +0800 Subject: [PATCH 07/12] remove --- paddle/fluid/ir_adaptor/translator/op_compat_gen.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index c3cc70bfb188d..b69c49771033c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -28,7 +28,6 @@ undefined=StrictUndefined, extensions=['jinja2.ext.do'], ) -grad_suffix = '_grad' def OpNameNormalizerInitialization( @@ -87,7 +86,7 @@ def insert_new_mutable_attributes( attribute_name ].append(v) - phi_name, legacy_name = insert_new_mappings(op_compat_item["op"]) + _, legacy_name = insert_new_mappings(op_compat_item["op"]) legacy_backward_op_names = [] if "backward" in op_compat_item: backward_op_name_mapping_paris = op_compat_item["backward"].split( @@ -96,10 +95,6 @@ def insert_new_mutable_attributes( for pair in backward_op_name_mapping_paris: _, legacy_backward_op_name = insert_new_mappings(pair) legacy_backward_op_names.append(legacy_backward_op_name) - elif phi_name != legacy_name: - legacy_backward_op_name = legacy_name + grad_suffix - legacy_backward_op_names.append(legacy_backward_op_name) - op_name_mappings[legacy_backward_op_name] = phi_name + grad_suffix if "inputs" in op_compat_item: insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) From 8455c75f138874c0746f97750084e520282fba20 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 24 Jul 2023 14:26:56 +0800 Subject: [PATCH 08/12] fix merged_momentum --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 2 +- paddle/phi/api/yaml/op_compat.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index dd59141d2d319..6cb90f09052d0 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -300,7 +300,7 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { std::string target_op_name = kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); - if (IsInplace(op_desc)) { + if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } VLOG(6) << "[op name normalizing]: " << op_desc.Type() << " to " diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index e101bb42b90c7..4e1feb685b6ed 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1832,7 +1832,7 @@ data_type : float support_tensor : true -- op : merged_momentum_ +- op : merged_momentum_ (merged_momentum) inputs : {param : Param, grad : Grad, velocity : Velocity, learning_rate : LearningRate, master_param : MasterParam} outputs : From 0920620eef5aa2621e310a6ae0e6e390c3048573 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 24 Jul 2023 21:50:54 +0800 Subject: [PATCH 09/12] fix ci coverage --- test/ir/new_ir/test_special_op_translator.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 2ab4819d88a10..1f74141a92af3 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -159,5 +159,19 @@ def test_normal_attribute(self): _ = paddle.fluid.core.translate_newirprogram(main_program.desc) +class TestArgmaxOpTranscriber(unittest.TestCase): + def test_op(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x = paddle.to_tensor([2, 3, 4], 'float64') + y = paddle.argmax(x) + + _ = paddle.fluid.core.translate_newirprogram(main_program.desc) + + if __name__ == "__main__": unittest.main() From bf9e49c53df500d67e093d8f2f9b2a11596d9ea8 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Tue, 25 Jul 2023 10:00:45 +0800 Subject: [PATCH 10/12] fix ci coverage --- test/ir/new_ir/test_special_op_translator.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 1f74141a92af3..a13dedb5bd430 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -159,7 +159,7 @@ def test_normal_attribute(self): _ = paddle.fluid.core.translate_newirprogram(main_program.desc) -class TestArgmaxOpTranscriber(unittest.TestCase): +class TestRMSNormOpTranscriber(unittest.TestCase): def test_op(self): place = core.Place() place.set_place(paddle.CPUPlace()) @@ -167,8 +167,12 @@ def test_op(self): main_program = paddle.static.Program() with paddle.static.scope_guard(new_scope): with paddle.static.program_guard(main_program): - x = paddle.to_tensor([2, 3, 4], 'float64') - y = paddle.argmax(x) + x = paddle.randn([2, 8]) + gamma = paddle.randn([8]) + beta = paddle.randn([8]) + y = paddle.incubate.nn.functional.rms_norm( + x, gamma, beta, 1e-6, begin_norm_axis=1 + ) _ = paddle.fluid.core.translate_newirprogram(main_program.desc) From b3f11731f9e4bc04113c057ae746e8d670704dc6 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Tue, 25 Jul 2023 14:28:23 +0800 Subject: [PATCH 11/12] fix ci coverage --- test/ir/new_ir/test_special_op_translator.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index a13dedb5bd430..e2fe582d6c6a9 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -159,7 +159,7 @@ def test_normal_attribute(self): _ = paddle.fluid.core.translate_newirprogram(main_program.desc) -class TestRMSNormOpTranscriber(unittest.TestCase): +class TestIndexPutOpTranscriber(unittest.TestCase): def test_op(self): place = core.Place() place.set_place(paddle.CPUPlace()) @@ -167,12 +167,10 @@ def test_op(self): main_program = paddle.static.Program() with paddle.static.scope_guard(new_scope): with paddle.static.program_guard(main_program): - x = paddle.randn([2, 8]) - gamma = paddle.randn([8]) - beta = paddle.randn([8]) - y = paddle.incubate.nn.functional.rms_norm( - x, gamma, beta, 1e-6, begin_norm_axis=1 - ) + x = paddle.randn([2, 3]) + indices = [paddle.randint(0, 2, [2]), paddle.randint(0, 1, [2])] + value = paddle.randn([2]) + y = paddle.index_put(x, indices, value, False) _ = paddle.fluid.core.translate_newirprogram(main_program.desc) From c9513e4e9d48bc331686de406e4ea9064a4d69ed Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 26 Jul 2023 17:43:07 +0800 Subject: [PATCH 12/12] change to ir.translate_to_new_ir --- test/ir/new_ir/test_special_op_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 3a50101d144a8..22b0a82c2ce57 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -207,7 +207,7 @@ def test_op(self): value = paddle.randn([2]) y = paddle.index_put(x, indices, value, False) - _ = paddle.fluid.core.translate_newirprogram(main_program.desc) + _ = ir.translate_to_new_ir(main_program.desc) if __name__ == "__main__":