Skip to content

Commit

Permalink
[PIR] refine the build interface for while_op (#59423)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Nov 28, 2023
1 parent e1ceed5 commit d2911ce
Show file tree
Hide file tree
Showing 18 changed files with 71 additions and 59 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT
argument.AddOutput(op.operand(i).type());
}
}
argument.AddRegion()->push_back(block.release());
argument.AddRegion().push_back(block.release());
}

pir::Block *GroupOp::block() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void GetInputIds(pir::Operation* op,
std::unordered_set<pir::Value> GetBlockInnerOutputs(pir::Block* block) {
std::unordered_set<pir::Value> inner_outputs;
for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) {
inner_outputs.insert(block->argument(arg_id));
inner_outputs.insert(block->arg(arg_id));
}
for (auto& op : (*block)) {
VLOG(8) << "GetBlockInnerOutputs of " << op.name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ WhileInstruction::WhileInstruction(size_t id,
parent_exe_info->GetValue2VarName().at(while_op.result(i))));
}

body_block_ = &while_op.body_block();
body_block_ = &while_op.body();

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *parent_exe_info, &inputs);
Expand Down Expand Up @@ -108,7 +108,7 @@ WhileInstruction::WhileInstruction(size_t id,
<< "body_block_arg_";
auto var_name = ss.str() + std::to_string(i);
body_scope->Var(var_name);
body_exe_info->Add(body_block_->argument(i), var_name);
body_exe_info->Add(body_block_->arg(i), var_name);
}
body_inter_ = std::unique_ptr<PirInterpreter>(new PirInterpreter(
place, {}, body_block_, body_scope, body_exe_info, {}));
Expand Down Expand Up @@ -150,7 +150,7 @@ void WhileInstruction::CopyInputsToOutputs() {

void WhileInstruction::PassArgsToBodyBlock() {
for (size_t i = 0; i < body_block_->args_size(); ++i) {
auto block_arg = body_block_->argument(i);
auto block_arg = body_block_->arg(i);
auto var_name = body_inter_->GetNameByValue(block_arg);
auto* inner_var = body_inter_->local_scope()->GetVar(var_name);
inner_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
"equal. but they are %u and 0, respectively",
argument.output_types.size()));
}
argument.AddRegion()->push_back(true_block.release());
argument.AddRegion()->push_back(false_block.release());
argument.AddRegion().push_back(true_block.release());
argument.AddRegion().push_back(false_block.release());
argument.AddInput(cond);
}

Expand Down Expand Up @@ -232,12 +232,13 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT
const std::vector<pir::Value> &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
auto &body = argument.AddRegion().emplace_back();
for (auto val : inputs) {
argument.AddOutput(val.type());
body.AddArgument(val.type());
}
argument.AddRegion(nullptr);
}
pir::Block &WhileOp::body_block() {
pir::Block &WhileOp::body() {
pir::Region &body_region = (*this)->region(0);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
Expand All @@ -259,11 +260,11 @@ void WhileOp::Print(pir::IrPrinter &printer) {
[&]() { os << ", "; });
os << "] { \n ^";
pir::PrintInterleave(
body_block().args_begin(),
body_block().args_end(),
body().args_begin(),
body().args_end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
for (auto &item : body_block()) {
for (auto &item : body()) {
os << "\n ";
printer.PrintOperation(&item);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class WhileOp : public pir::Op<WhileOp> {
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs);
pir::Block &body_block();
pir::Block &body();
pir::Value cond();
void Print(pir::IrPrinter &printer); // NOLINT
void VerifySig() {}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,15 +926,15 @@ void HandleForWhileOp(
pir::Builder builder(ctx, block);
auto base_while_op = op_item->dyn_cast<WhileOp>();
auto new_while_op = builder.Build<WhileOp>(cond_val, vec_in);
pir::Block& body_block = new_while_op.body_block();
pir::Block& body_block = new_while_op.body();
for (size_t i = 0; i < vec_in.size(); ++i) {
auto block_arg = body_block.AddArgument(vec_in[i].type());
(*map_value_pair)[base_while_op.body_block().argument(i)] = block_arg;
auto block_arg = body_block.arg(i);
(*map_value_pair)[base_while_op.body().arg(i)] = block_arg;
}

// process body block
ProcessBlock(place,
&base_while_op.body_block(),
&base_while_op.body(),
&body_block,
ctx,
map_op_pair,
Expand Down
36 changes: 24 additions & 12 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
namespace py = pybind11;
using paddle::dialect::ApiBuilder;
using paddle::dialect::IfOp;
using paddle::dialect::WhileOp;
using pir::Block;
using pir::Builder;
using pir::Operation;
Expand All @@ -51,19 +52,10 @@ using pybind11::return_value_policy;
using paddle::pybind::PyIfOp;
namespace {

PyIfOp BuildPyIfOp(Value cond) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
}

void BindIfOp(py::module* m) {
m->def("build_if_op", BuildPyIfOp);
m->def("cf_yield", [](py::list inputs) {
std::vector<Value> input_values;
for (auto input : inputs) {
input_values.push_back(input.cast<Value>());
}
ApiBuilder::Instance().GetBuilder()->Build<YieldOp>(input_values);
m->def("build_if_op", [](Value cond) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
});
py::class_<PyIfOp> if_op(*m, "IfOp", R"DOC(
The PyIfOp is a encapsulation of IfOp. Compared with ifOp, it provides an additional 'update_output' interface.
Expand All @@ -83,6 +75,17 @@ void BindIfOp(py::module* m) {
});
}

void BindWhileOp(py::module* m) {
m->def("build_while_op", [](Value cond, py::list loop_vars) {
std::vector<Value> loop_values;
for (auto var : loop_vars) {
loop_values.push_back(var.cast<Value>());
}
return ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond,
loop_values);
});
}

void GetUsedExternalValueImpl(
std::unordered_set<Value>& defined_values, // NOLINT
std::vector<Value>& used_values, // NOLINT
Expand Down Expand Up @@ -185,7 +188,16 @@ void PyIfOp::UpdateOutput() {
void BindControlFlowApi(py::module* m) {
m->def("get_used_external_value", GetUsedExternalValue);
m->def("build_pipe_for_block", BuildPipeForBlock);
m->def("cf_yield", [](py::list inputs) {
std::vector<Value> input_values;
for (auto input : inputs) {
input_values.push_back(input.cast<Value>());
}
ApiBuilder::Instance().GetBuilder()->Build<YieldOp>(input_values);
});

BindIfOp(m);
BindWhileOp(m);
}
} // namespace pybind
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class IR_API Block {
bool args_empty() const { return arguments_.empty(); }
uint32_t args_size() const { return arguments_.size(); }
const BlockArgListType &args() const { return arguments_; }
BlockArgument argument(uint32_t index) { return arguments_[index]; }
Type argument_type(uint32_t index) const { return arguments_[index].type(); }
BlockArgument arg(uint32_t index) { return arguments_[index]; }
Type arg_type(uint32_t index) const { return arguments_[index].type(); }
void ClearArguments();
BlockArgument AddArgument(Type type);
template <class TypeIter>
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Block *BlockArgument::owner() const {
return IMPL_->owner_;
}

uint32_t BlockArgument::arg_index() const {
uint32_t BlockArgument::index() const {
CHECK_NULL_IMPL(arg_index);
return IMPL_->index_;
}
Expand All @@ -79,7 +79,7 @@ void BlockArgument::Destroy() {
}
}

void BlockArgument::set_arg_index(uint32_t index) {
void BlockArgument::set_index(uint32_t index) {
CHECK_NULL_IMPL(set_arg_number);
IMPL_->index_ = index;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class IR_API BlockArgument : public Value {
public:
BlockArgument() = default;
Block *owner() const;
uint32_t arg_index() const;
uint32_t index() const;

private:
/// constructor
Expand All @@ -42,7 +42,7 @@ class IR_API BlockArgument : public Value {
/// Destroy the argument.
void Destroy();
/// set the position in the block argument list.
void set_arg_index(uint32_t index);
void set_index(uint32_t index);
// Access create annd destroy.
friend Block;

Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/operation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ OperationArgument::OperationArgument(IrContext* ir_context,
info = ir_context->GetRegisteredOpInfo(name);
}

Region* OperationArgument::AddRegion() {
Region& OperationArgument::AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
return *regions.back();
}

/// Take a region that should be attached to the Operation.
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct OperationArgument {
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
Region* AddRegion();
Region& AddRegion();

/// Take a region that should be attached to the Operation. The body of the
/// region will be transferred when the Operation is created. If the
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Region::~Region() { clear(); }

void Region::push_back(Block *block) { insert(blocks_.end(), block); }

Block *Region::emplace_back() {
Block &Region::emplace_back() {
auto block = new Block;
insert(blocks_.end(), block);
return block;
return *block;
}

void Region::push_front(Block *block) { insert(blocks_.begin(), block); }
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class IR_API Region {
const Block &back() const { return *blocks_.back(); }

void push_back(Block *block);
Block *emplace_back();
Block &emplace_back();
void push_front(Block *block);
Iterator insert(ConstIterator position, Block *block);
Iterator erase(ConstIterator position);
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/new_executor/standalone_executor_pir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ TEST(StandaloneExecutor, while_op) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
pir::Block& body_block = while_op.body();
auto body_i_argument = body_block.arg(0);
auto body_ten_argument = body_block.arg(1);
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
Expand Down
18 changes: 9 additions & 9 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ TEST(while_op_test, base) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
pir::Block& body_block = while_op.body();
auto body_i_argument = body_block.arg(0);
auto body_ten_argument = body_block.arg(1);
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
Expand Down Expand Up @@ -104,11 +104,11 @@ TEST(while_op_test, network_with_backward) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, x});

// { return i + 1, x + y}
auto& body_block = while_op.body_block();
auto& body_block = while_op.body();
builder.SetInsertionPointToStart(&body_block);

auto body_i_argument = body_block.AddArgument(i.type());
auto body_x_argument = body_block.AddArgument(x.type());
auto body_i_argument = body_block.arg(0);
auto body_x_argument = body_block.arg(1);

auto new_i = builder.Build<AddOp>(body_i_argument, one).out();
auto new_x = builder.Build<AddOp>(body_x_argument, y).out();
Expand Down Expand Up @@ -141,10 +141,10 @@ TEST(while_op_test, network_with_backward) {
auto bwd_cond = builder.Build<pir::HasElementsOp>(stack).out();
auto while_grad = builder.Build<WhileOp>(
bwd_cond, std::vector<pir::Value>{x_out_grad, zero});
pir::Block& bwd_body_block = while_grad.body_block();
pir::Block& bwd_body_block = while_grad.body();
builder.SetInsertionPointToStart(&bwd_body_block);
auto local_x_out_grad_arg = bwd_body_block.AddArgument(x.type());
auto local_y_grad_arg = bwd_body_block.AddArgument(y.type());
auto local_x_out_grad_arg = bwd_body_block.arg(0);
auto local_y_grad_arg = bwd_body_block.arg(1);

auto pop_op = builder.Build<pir::TuplePopOp>(outlet);
auto bwd_body_x_argument = pop_op.outlet_element(0);
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/pir/core/block_argument_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ TEST(block_argument_test, base) {

uint32_t index = 0;
for (auto iter = block->args_begin(); iter != block->args_end(); ++iter) {
EXPECT_EQ(iter->arg_index(), index++);
EXPECT_EQ(iter->index(), index++);
}

pir::Value value = block->argument(0);
pir::Value value = block->arg(0);
pir::BlockArgument argument = value.dyn_cast<pir::BlockArgument>();
EXPECT_TRUE(argument);
EXPECT_EQ(argument.owner(), block);
EXPECT_EQ(block->argument_type(0), types[0]);
EXPECT_EQ(block->arg_type(0), types[0]);
pir::OpResult op_result = value.dyn_cast<pir::OpResult>();
EXPECT_FALSE(op_result);

Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ TEST(OperatorDialectTest, WhileOpProgram) {
EXPECT_TRUE(op.isa<paddle::dialect::WhileOp>());
EXPECT_EQ(op.num_regions(), 1u);
// body block
pir::Block &body_block =
op.dyn_cast<paddle::dialect::WhileOp>().body_block();
pir::Block &body_block = op.dyn_cast<paddle::dialect::WhileOp>().body();
size_t body_id = 0;
for (auto &op1 : body_block) {
if (body_id == 0) {
Expand All @@ -308,7 +307,7 @@ TEST(OperatorDialectTest, WhileOpProgram) {
}
if (body_id == 3) {
pir::Block &body_body_block =
op1.dyn_cast<paddle::dialect::WhileOp>().body_block();
op1.dyn_cast<paddle::dialect::WhileOp>().body();
size_t body_body_id = 0;
for (auto &op2 : body_body_block) {
if (body_body_id == 0) {
Expand Down

0 comments on commit d2911ce

Please sign in to comment.