From d2911ce67eaa1d2ff312ede10d9eb803111f495d Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Tue, 28 Nov 2023 19:48:38 +0800 Subject: [PATCH] [PIR] refine the build interface for while_op (#59423) --- .../hlir/dialect/operator/ir/manual_op.cc | 2 +- .../instruction/instruction_util.cc | 2 +- .../instruction/while_instruction.cc | 6 ++-- .../dialect/operator/ir/control_flow_op.cc | 15 ++++---- .../pir/dialect/operator/ir/control_flow_op.h | 2 +- .../pir/transforms/pd_op_to_kernel_pass.cc | 8 ++--- paddle/fluid/pybind/control_flow_api.cc | 36 ++++++++++++------- paddle/pir/core/block.h | 4 +-- paddle/pir/core/block_argument.cc | 4 +-- paddle/pir/core/block_argument.h | 4 +-- paddle/pir/core/operation_utils.cc | 4 +-- paddle/pir/core/operation_utils.h | 2 +- paddle/pir/core/region.cc | 4 +-- paddle/pir/core/region.h | 2 +- .../standalone_executor_pir_test.cc | 6 ++-- .../pir/control_flow_dialect/while_op_test.cc | 18 +++++----- test/cpp/pir/core/block_argument_test.cc | 6 ++-- test/cpp/pir/core/program_translator_test.cc | 5 ++- 18 files changed, 71 insertions(+), 59 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 79808c7db61f0..33c8bbe1b8624 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -48,7 +48,7 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT argument.AddOutput(op.operand(i).type()); } } - argument.AddRegion()->push_back(block.release()); + argument.AddRegion().push_back(block.release()); } pir::Block *GroupOp::block() { diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index 517a91e3d4bc3..71d70a4d56a22 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -226,7 +226,7 @@ void GetInputIds(pir::Operation* op, std::unordered_set GetBlockInnerOutputs(pir::Block* block) { std::unordered_set inner_outputs; for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) { - inner_outputs.insert(block->argument(arg_id)); + inner_outputs.insert(block->arg(arg_id)); } for (auto& op : (*block)) { VLOG(8) << "GetBlockInnerOutputs of " << op.name(); diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index 8d244c7692096..49b3878edf88b 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -74,7 +74,7 @@ WhileInstruction::WhileInstruction(size_t id, parent_exe_info->GetValue2VarName().at(while_op.result(i)))); } - body_block_ = &while_op.body_block(); + body_block_ = &while_op.body(); std::unordered_map> inputs; GetInputIds(op, *parent_exe_info, &inputs); @@ -108,7 +108,7 @@ WhileInstruction::WhileInstruction(size_t id, << "body_block_arg_"; auto var_name = ss.str() + std::to_string(i); body_scope->Var(var_name); - body_exe_info->Add(body_block_->argument(i), var_name); + body_exe_info->Add(body_block_->arg(i), var_name); } body_inter_ = std::unique_ptr(new PirInterpreter( place, {}, body_block_, body_scope, body_exe_info, {})); @@ -150,7 +150,7 @@ void WhileInstruction::CopyInputsToOutputs() { void WhileInstruction::PassArgsToBodyBlock() { for (size_t i = 0; i < body_block_->args_size(); ++i) { - auto block_arg = body_block_->argument(i); + auto block_arg = body_block_->arg(i); auto var_name = body_inter_->GetNameByValue(block_arg); auto* inner_var = body_inter_->local_scope()->GetVar(var_name); inner_var->GetMutable()->ShareDataWith( diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 17f009875d31d..807e355d9355c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -80,8 +80,8 @@ void IfOp::Build(pir::Builder &builder, // NOLINT "equal. but they are %u and 0, respectively", argument.output_types.size())); } - argument.AddRegion()->push_back(true_block.release()); - argument.AddRegion()->push_back(false_block.release()); + argument.AddRegion().push_back(true_block.release()); + argument.AddRegion().push_back(false_block.release()); argument.AddInput(cond); } @@ -232,12 +232,13 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT const std::vector &inputs) { argument.AddInput(cond); argument.AddInputs(inputs); + auto &body = argument.AddRegion().emplace_back(); for (auto val : inputs) { argument.AddOutput(val.type()); + body.AddArgument(val.type()); } - argument.AddRegion(nullptr); } -pir::Block &WhileOp::body_block() { +pir::Block &WhileOp::body() { pir::Region &body_region = (*this)->region(0); if (body_region.empty()) body_region.emplace_back(); return body_region.front(); @@ -259,11 +260,11 @@ void WhileOp::Print(pir::IrPrinter &printer) { [&]() { os << ", "; }); os << "] { \n ^"; pir::PrintInterleave( - body_block().args_begin(), - body_block().args_end(), + body().args_begin(), + body().args_end(), [&](pir::Value v) { printer.PrintValue(v); }, [&]() { os << ", "; }); - for (auto &item : body_block()) { + for (auto &item : body()) { os << "\n "; printer.PrintOperation(&item); } diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index addc5496e4868..0198b57cce82f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -76,7 +76,7 @@ class WhileOp : public pir::Op { pir::OperationArgument &argument, // NOLINT pir::Value cond, const std::vector &inputs); - pir::Block &body_block(); + pir::Block &body(); pir::Value cond(); void Print(pir::IrPrinter &printer); // NOLINT void VerifySig() {} diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 8dafa1161eadf..c0ffe3e19d4fe 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -926,15 +926,15 @@ void HandleForWhileOp( pir::Builder builder(ctx, block); auto base_while_op = op_item->dyn_cast(); auto new_while_op = builder.Build(cond_val, vec_in); - pir::Block& body_block = new_while_op.body_block(); + pir::Block& body_block = new_while_op.body(); for (size_t i = 0; i < vec_in.size(); ++i) { - auto block_arg = body_block.AddArgument(vec_in[i].type()); - (*map_value_pair)[base_while_op.body_block().argument(i)] = block_arg; + auto block_arg = body_block.arg(i); + (*map_value_pair)[base_while_op.body().arg(i)] = block_arg; } // process body block ProcessBlock(place, - &base_while_op.body_block(), + &base_while_op.body(), &body_block, ctx, map_op_pair, diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 1a7cfab5a6bf8..f0360322ef404 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -36,6 +36,7 @@ namespace py = pybind11; using paddle::dialect::ApiBuilder; using paddle::dialect::IfOp; +using paddle::dialect::WhileOp; using pir::Block; using pir::Builder; using pir::Operation; @@ -51,19 +52,10 @@ using pybind11::return_value_policy; using paddle::pybind::PyIfOp; namespace { -PyIfOp BuildPyIfOp(Value cond) { - return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build( - cond, std::vector{})); -} - void BindIfOp(py::module* m) { - m->def("build_if_op", BuildPyIfOp); - m->def("cf_yield", [](py::list inputs) { - std::vector input_values; - for (auto input : inputs) { - input_values.push_back(input.cast()); - } - ApiBuilder::Instance().GetBuilder()->Build(input_values); + m->def("build_if_op", [](Value cond) { + return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build( + cond, std::vector{})); }); py::class_ if_op(*m, "IfOp", R"DOC( The PyIfOp is a encapsulation of IfOp. Compared with ifOp, it provides an additional 'update_output' interface. @@ -83,6 +75,17 @@ void BindIfOp(py::module* m) { }); } +void BindWhileOp(py::module* m) { + m->def("build_while_op", [](Value cond, py::list loop_vars) { + std::vector loop_values; + for (auto var : loop_vars) { + loop_values.push_back(var.cast()); + } + return ApiBuilder::Instance().GetBuilder()->Build(cond, + loop_values); + }); +} + void GetUsedExternalValueImpl( std::unordered_set& defined_values, // NOLINT std::vector& used_values, // NOLINT @@ -185,7 +188,16 @@ void PyIfOp::UpdateOutput() { void BindControlFlowApi(py::module* m) { m->def("get_used_external_value", GetUsedExternalValue); m->def("build_pipe_for_block", BuildPipeForBlock); + m->def("cf_yield", [](py::list inputs) { + std::vector input_values; + for (auto input : inputs) { + input_values.push_back(input.cast()); + } + ApiBuilder::Instance().GetBuilder()->Build(input_values); + }); + BindIfOp(m); + BindWhileOp(m); } } // namespace pybind } // namespace paddle diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index dbee5f8b13544..12fc66294627c 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -106,8 +106,8 @@ class IR_API Block { bool args_empty() const { return arguments_.empty(); } uint32_t args_size() const { return arguments_.size(); } const BlockArgListType &args() const { return arguments_; } - BlockArgument argument(uint32_t index) { return arguments_[index]; } - Type argument_type(uint32_t index) const { return arguments_[index].type(); } + BlockArgument arg(uint32_t index) { return arguments_[index]; } + Type arg_type(uint32_t index) const { return arguments_[index].type(); } void ClearArguments(); BlockArgument AddArgument(Type type); template diff --git a/paddle/pir/core/block_argument.cc b/paddle/pir/core/block_argument.cc index 3b851c054b85e..a0da7fbc16b2a 100644 --- a/paddle/pir/core/block_argument.cc +++ b/paddle/pir/core/block_argument.cc @@ -62,7 +62,7 @@ Block *BlockArgument::owner() const { return IMPL_->owner_; } -uint32_t BlockArgument::arg_index() const { +uint32_t BlockArgument::index() const { CHECK_NULL_IMPL(arg_index); return IMPL_->index_; } @@ -79,7 +79,7 @@ void BlockArgument::Destroy() { } } -void BlockArgument::set_arg_index(uint32_t index) { +void BlockArgument::set_index(uint32_t index) { CHECK_NULL_IMPL(set_arg_number); IMPL_->index_ = index; } diff --git a/paddle/pir/core/block_argument.h b/paddle/pir/core/block_argument.h index 27f1779650ef1..890e37234b131 100644 --- a/paddle/pir/core/block_argument.h +++ b/paddle/pir/core/block_argument.h @@ -31,7 +31,7 @@ class IR_API BlockArgument : public Value { public: BlockArgument() = default; Block *owner() const; - uint32_t arg_index() const; + uint32_t index() const; private: /// constructor @@ -42,7 +42,7 @@ class IR_API BlockArgument : public Value { /// Destroy the argument. void Destroy(); /// set the position in the block argument list. - void set_arg_index(uint32_t index); + void set_index(uint32_t index); // Access create annd destroy. friend Block; diff --git a/paddle/pir/core/operation_utils.cc b/paddle/pir/core/operation_utils.cc index 9bca7c2756d28..bb7b4c04a17c2 100644 --- a/paddle/pir/core/operation_utils.cc +++ b/paddle/pir/core/operation_utils.cc @@ -22,9 +22,9 @@ OperationArgument::OperationArgument(IrContext* ir_context, info = ir_context->GetRegisteredOpInfo(name); } -Region* OperationArgument::AddRegion() { +Region& OperationArgument::AddRegion() { regions.emplace_back(new Region); - return regions.back().get(); + return *regions.back(); } /// Take a region that should be attached to the Operation. diff --git a/paddle/pir/core/operation_utils.h b/paddle/pir/core/operation_utils.h index 77a64a358365d..62e69354cd939 100644 --- a/paddle/pir/core/operation_utils.h +++ b/paddle/pir/core/operation_utils.h @@ -104,7 +104,7 @@ struct OperationArgument { /// Create a region that should be attached to the operation. These regions /// can be filled in immediately without waiting for Operation to be /// created. When it is, the region bodies will be transferred. - Region* AddRegion(); + Region& AddRegion(); /// Take a region that should be attached to the Operation. The body of the /// region will be transferred when the Operation is created. If the diff --git a/paddle/pir/core/region.cc b/paddle/pir/core/region.cc index ba53e638e3f66..dfb3b45aef3e9 100644 --- a/paddle/pir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -22,10 +22,10 @@ Region::~Region() { clear(); } void Region::push_back(Block *block) { insert(blocks_.end(), block); } -Block *Region::emplace_back() { +Block &Region::emplace_back() { auto block = new Block; insert(blocks_.end(), block); - return block; + return *block; } void Region::push_front(Block *block) { insert(blocks_.begin(), block); } diff --git a/paddle/pir/core/region.h b/paddle/pir/core/region.h index 7c9efd699291b..cd8107c2f5e25 100644 --- a/paddle/pir/core/region.h +++ b/paddle/pir/core/region.h @@ -59,7 +59,7 @@ class IR_API Region { const Block &back() const { return *blocks_.back(); } void push_back(Block *block); - Block *emplace_back(); + Block &emplace_back(); void push_front(Block *block); Iterator insert(ConstIterator position, Block *block); Iterator erase(ConstIterator position); diff --git a/test/cpp/new_executor/standalone_executor_pir_test.cc b/test/cpp/new_executor/standalone_executor_pir_test.cc index d5f36bd681648..b377eaee15e5e 100644 --- a/test/cpp/new_executor/standalone_executor_pir_test.cc +++ b/test/cpp/new_executor/standalone_executor_pir_test.cc @@ -303,9 +303,9 @@ TEST(StandaloneExecutor, while_op) { builder.Build(cond_value, std::vector{i, ten}); // { i = i + 1} - pir::Block& body_block = while_op.body_block(); - auto body_i_argument = body_block.AddArgument(i.type()); - auto body_ten_argument = body_block.AddArgument(ten.type()); + pir::Block& body_block = while_op.body(); + auto body_i_argument = body_block.arg(0); + auto body_ten_argument = body_block.arg(1); builder.SetInsertionPointToStart(&body_block); auto one = builder.Build(std::vector{1}, 1, phi::DataType::INT32) diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index d68b4fe7a5b4a..45416cf74732f 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -50,9 +50,9 @@ TEST(while_op_test, base) { builder.Build(cond_value, std::vector{i, ten}); // { i = i + 1} - pir::Block& body_block = while_op.body_block(); - auto body_i_argument = body_block.AddArgument(i.type()); - auto body_ten_argument = body_block.AddArgument(ten.type()); + pir::Block& body_block = while_op.body(); + auto body_i_argument = body_block.arg(0); + auto body_ten_argument = body_block.arg(1); builder.SetInsertionPointToStart(&body_block); auto one = builder.Build(std::vector{1}, 1, phi::DataType::INT32) @@ -104,11 +104,11 @@ TEST(while_op_test, network_with_backward) { builder.Build(cond_value, std::vector{i, x}); // { return i + 1, x + y} - auto& body_block = while_op.body_block(); + auto& body_block = while_op.body(); builder.SetInsertionPointToStart(&body_block); - auto body_i_argument = body_block.AddArgument(i.type()); - auto body_x_argument = body_block.AddArgument(x.type()); + auto body_i_argument = body_block.arg(0); + auto body_x_argument = body_block.arg(1); auto new_i = builder.Build(body_i_argument, one).out(); auto new_x = builder.Build(body_x_argument, y).out(); @@ -141,10 +141,10 @@ TEST(while_op_test, network_with_backward) { auto bwd_cond = builder.Build(stack).out(); auto while_grad = builder.Build( bwd_cond, std::vector{x_out_grad, zero}); - pir::Block& bwd_body_block = while_grad.body_block(); + pir::Block& bwd_body_block = while_grad.body(); builder.SetInsertionPointToStart(&bwd_body_block); - auto local_x_out_grad_arg = bwd_body_block.AddArgument(x.type()); - auto local_y_grad_arg = bwd_body_block.AddArgument(y.type()); + auto local_x_out_grad_arg = bwd_body_block.arg(0); + auto local_y_grad_arg = bwd_body_block.arg(1); auto pop_op = builder.Build(outlet); auto bwd_body_x_argument = pop_op.outlet_element(0); diff --git a/test/cpp/pir/core/block_argument_test.cc b/test/cpp/pir/core/block_argument_test.cc index 4bfba843aff80..6cc522976cf2e 100644 --- a/test/cpp/pir/core/block_argument_test.cc +++ b/test/cpp/pir/core/block_argument_test.cc @@ -37,14 +37,14 @@ TEST(block_argument_test, base) { uint32_t index = 0; for (auto iter = block->args_begin(); iter != block->args_end(); ++iter) { - EXPECT_EQ(iter->arg_index(), index++); + EXPECT_EQ(iter->index(), index++); } - pir::Value value = block->argument(0); + pir::Value value = block->arg(0); pir::BlockArgument argument = value.dyn_cast(); EXPECT_TRUE(argument); EXPECT_EQ(argument.owner(), block); - EXPECT_EQ(block->argument_type(0), types[0]); + EXPECT_EQ(block->arg_type(0), types[0]); pir::OpResult op_result = value.dyn_cast(); EXPECT_FALSE(op_result); diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index a79b4a6a8be45..a7dc7845bdca7 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -293,8 +293,7 @@ TEST(OperatorDialectTest, WhileOpProgram) { EXPECT_TRUE(op.isa()); EXPECT_EQ(op.num_regions(), 1u); // body block - pir::Block &body_block = - op.dyn_cast().body_block(); + pir::Block &body_block = op.dyn_cast().body(); size_t body_id = 0; for (auto &op1 : body_block) { if (body_id == 0) { @@ -308,7 +307,7 @@ TEST(OperatorDialectTest, WhileOpProgram) { } if (body_id == 3) { pir::Block &body_body_block = - op1.dyn_cast().body_block(); + op1.dyn_cast().body(); size_t body_body_id = 0; for (auto &op2 : body_body_block) { if (body_body_id == 0) {