From 125a671e51dc977cdb7a8bfde26af8bc7ee815c7 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 12 Jan 2024 19:21:28 +0800 Subject: [PATCH] [Prim][PIR] Recover pir bn (#60689) * reopen bn prim pir * fix atol * decomp support batch_norm_ * fix test case * fix bug * fix code --- .../decomp_interface_gen_op_list.py | 1 + .../fluid/pir/dialect/op_generator/op_gen.py | 13 ++- .../dialect/operator/ir/manual_op_decomp.cc | 90 ++++++++++++++++++- paddle/fluid/primitive/base/decomp_trans.cc | 12 ++- paddle/fluid/primitive/composite/composite.h | 6 +- paddle/phi/infermeta/multiary.cc | 10 +++ .../test_batch_norm_op_prim_nchw.py | 1 + .../test_batch_norm_op_prim_nhwc.py | 3 +- 8 files changed, 122 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 5a1e8f361bd62..364a2a8de7724 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -21,6 +21,7 @@ decomp_interface_declare_gen_op_list = [ "add_n", "batch_norm", + "batch_norm_", "dropout", "full_like", "gelu", diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index e6801ead235a5..5c9c4c97e0e78 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1177,6 +1177,8 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ) op_interfaces_tmp = op_interfaces exclusive_interface_str_tmp = exclusive_interface_str + decomp_interface_str = "paddle::dialect::DecompInterface" + decomp_interface_declare_str = "\n static std::vector> Decomp(pir::Operation* op);" # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: @@ -1212,10 +1214,13 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): and kernel_func_name in decomp_interface_declare_gen_op_list and dialect_name != "onednn_op" ): - op_interfaces = op_interfaces + [ - "paddle::dialect::DecompInterface" - ] - exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" + if decomp_interface_str not in op_interfaces: + op_interfaces = op_interfaces + [decomp_interface_str] + if ( + decomp_interface_declare_str + not in exclusive_interface_str + ): + exclusive_interface_str += decomp_interface_declare_str else: op_interfaces = op_interfaces_tmp exclusive_interface_str = exclusive_interface_str_tmp diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc index beca2f6f9b640..cc8af56c2f481 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc @@ -60,7 +60,7 @@ std::vector> BatchNormOp::Decomp( float epsilon = op->attribute("epsilon").dyn_cast().data(); const std::string& data_layout = - op->attribute("data_layout").dyn_cast().AsString(); + op->attribute("data_format").dyn_cast().AsString(); bool use_global_stats = op->attribute("use_global_stats").dyn_cast().data(); bool trainable_statistics = op->attribute("trainable_statistics") @@ -117,5 +117,93 @@ std::vector> BatchNormOp::Decomp( return res; } +std::vector> BatchNorm_Op::Decomp( + pir::Operation* op) { + VLOG(4) << "Decomp call batch_norm_'s decomp interface begin"; + BatchNorm_Op op_obj = op->dyn_cast(); + (void)op_obj; + + FLAGS_tensor_operants_mode = "static"; + + VLOG(6) << "Decomp Prepare inputs of batch_norm_"; + + Tensor x(std::make_shared(op_obj.x())); + Tensor mean(std::make_shared(op_obj.mean())); + Tensor variance(std::make_shared(op_obj.variance())); + paddle::optional scale; + if (!IsEmptyValue(op_obj.scale())) { + scale = paddle::make_optional( + Tensor(std::make_shared(op_obj.scale()))); + } + paddle::optional bias; + if (!IsEmptyValue(op_obj.bias())) { + bias = paddle::make_optional( + Tensor(std::make_shared(op_obj.bias()))); + } + + VLOG(6) << "Decomp prepare attributes of batch_norm_"; + bool is_test = op->attribute("is_test").dyn_cast().data(); + float momentum = + op->attribute("momentum").dyn_cast().data(); + float epsilon = + op->attribute("epsilon").dyn_cast().data(); + const std::string& data_layout = + op->attribute("data_format").dyn_cast().AsString(); + bool use_global_stats = + op->attribute("use_global_stats").dyn_cast().data(); + bool trainable_statistics = op->attribute("trainable_statistics") + .dyn_cast() + .data(); + + VLOG(6) << "Decomp call batch_norm_'s forward composite rule prepare"; + + auto org_res = op->results(); + std::vector> res(org_res.size()); + + VLOG(6) << "Decomp call batch_norm_'s forward composite rule begin"; + + std::tuple op_res = + paddle::primitive::details::batch_norm_decomp( + x, + mean, + variance, + scale, + bias, + is_test, + momentum, + epsilon, + data_layout, + use_global_stats, + trainable_statistics); + + VLOG(6) << "Decomp call batch_norm_'s forward composite rule end"; + + res[0].push_back(std::static_pointer_cast( + std::get<0>(op_res).impl()) + ->value() + .dyn_cast()); + res[1].push_back(std::static_pointer_cast( + std::get<1>(op_res).impl()) + ->value() + .dyn_cast()); + res[2].push_back(std::static_pointer_cast( + std::get<2>(op_res).impl()) + ->value() + .dyn_cast()); + res[3].push_back(std::static_pointer_cast( + std::get<3>(op_res).impl()) + ->value() + .dyn_cast()); + res[4].push_back(std::static_pointer_cast( + std::get<4>(op_res).impl()) + ->value() + .dyn_cast()); + pir::OpResult reserve_space; + res[5].push_back(reserve_space); + + VLOG(4) << "Decomp call batch_norm_'s decomp interface end"; + return res; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/base/decomp_trans.cc b/paddle/fluid/primitive/base/decomp_trans.cc index df0111d56f8af..c5d3cd104ef6e 100644 --- a/paddle/fluid/primitive/base/decomp_trans.cc +++ b/paddle/fluid/primitive/base/decomp_trans.cc @@ -32,8 +32,11 @@ using Program = pir::Program; // some outputs like xshape will no longer used after decomp, and those outputs // will skip checking. -std::unordered_set decomp_op_contain_none = { - "pd_op.squeeze", "pd_op.unsqueeze", "pd_op.flatten", "pd_op.batch_norm"}; +std::unordered_set decomp_op_contain_none = {"pd_op.squeeze", + "pd_op.unsqueeze", + "pd_op.flatten", + "pd_op.batch_norm", + "pd_op.batch_norm_"}; static bool find_value(const std::vector& vec, int64_t value) { if (std::find(vec.begin(), vec.end(), value) != vec.end()) { @@ -275,14 +278,15 @@ std::vector> call_decomp_rule(pir::Operation* op) { } void DecompProgram::decomp_program() { - std::ostringstream orig_prog_stream; std::unordered_map orig_vars_dict; for (size_t i = 0; i < src_vars_.size(); i++) { orig_vars_dict[src_vars_[i]] = static_cast(i); } + std::ostringstream orig_prog_stream; program_->Print(orig_prog_stream); - VLOG(4) << "[Prim] Origin program bofore decomp :\n" + VLOG(4) << "[Prim] Origin program before decomp :\n" << orig_prog_stream.str(); + if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { return; } diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index cb3366229a92f..e460029205432 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -206,13 +206,13 @@ std::tuple batch_norm_decomp( y = x_hat * reshape(new_scale, stats_shape) + reshape(new_bias, stats_shape); } - if (need_cast) { - y = cast(y, org_dtype); - } Tensor reserve_space; auto batch_mean_ = assign(batch_mean); auto inv_std_ = assign(inv_std); + if (need_cast) { + y = cast(y, org_dtype); + } if (!use_run_stat) { return std::make_tuple( y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 77af8e5c19f94..36faa875fbe19 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -727,14 +727,24 @@ void BatchNormInferMeta(const MetaTensor& x, C, bias.dims()[0])); } + auto dtype = x.dtype(); + if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16 || + dtype == phi::DataType::UINT16) { + dtype = phi::DataType::FLOAT32; + } + y->set_dims(x_dims); mean_out->set_dims({C}); + mean_out->set_dtype(mean.dtype()); variance_out->set_dims({C}); + variance_out->set_dtype(variance.dtype()); if (saved_mean) { saved_mean->set_dims({C}); + saved_mean->set_dtype(dtype); } if (saved_variance) { saved_variance->set_dims({C}); + saved_variance->set_dtype(dtype); } if (reserve_space) { reserve_space->set_dims({-1}); diff --git a/test/legacy_test/test_batch_norm_op_prim_nchw.py b/test/legacy_test/test_batch_norm_op_prim_nchw.py index 49a67011f2cf6..639afba1c6585 100644 --- a/test/legacy_test/test_batch_norm_op_prim_nchw.py +++ b/test/legacy_test/test_batch_norm_op_prim_nchw.py @@ -278,6 +278,7 @@ def initConfig(self): self.epsilon = 1e-05 self.data_format = "NCHW" self.use_global_stats = None + self.check_prim_pir = True class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp): diff --git a/test/legacy_test/test_batch_norm_op_prim_nhwc.py b/test/legacy_test/test_batch_norm_op_prim_nhwc.py index 57041857c0042..0ca4812c70540 100644 --- a/test/legacy_test/test_batch_norm_op_prim_nhwc.py +++ b/test/legacy_test/test_batch_norm_op_prim_nhwc.py @@ -124,6 +124,7 @@ def initConfig(self): self.epsilon = 1e-05 self.data_format = "NHWC" self.use_global_stats = None + self.check_prim_pir = True class TestBatchNormOpNHWCFp16(TestBatchNormOp): @@ -161,8 +162,6 @@ def initConfig(self): self.epsilon = 1e-05 self.data_format = "NHWC" self.use_global_stats = None - # Todo(CZ): open this - self.check_prim_pir = False class TestBatchNormOpNHWCShape2(TestBatchNormOp):