Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] refine the build interface for while_op #59423

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT
argument.AddOutput(op.operand(i).type());
}
}
argument.AddRegion()->push_back(block.release());
argument.AddRegion().push_back(block.release());
}

pir::Block *GroupOp::block() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void GetInputIds(pir::Operation* op,
std::unordered_set<pir::Value> GetBlockInnerOutputs(pir::Block* block) {
std::unordered_set<pir::Value> inner_outputs;
for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) {
inner_outputs.insert(block->argument(arg_id));
inner_outputs.insert(block->arg(arg_id));
}
for (auto& op : (*block)) {
VLOG(8) << "GetBlockInnerOutputs of " << op.name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ WhileInstruction::WhileInstruction(size_t id,
parent_exe_info->GetValue2VarName().at(while_op.result(i))));
}

body_block_ = &while_op.body_block();
body_block_ = &while_op.body();

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *parent_exe_info, &inputs);
Expand Down Expand Up @@ -108,7 +108,7 @@ WhileInstruction::WhileInstruction(size_t id,
<< "body_block_arg_";
auto var_name = ss.str() + std::to_string(i);
body_scope->Var(var_name);
body_exe_info->Add(body_block_->argument(i), var_name);
body_exe_info->Add(body_block_->arg(i), var_name);
}
body_inter_ = std::unique_ptr<PirInterpreter>(new PirInterpreter(
place, {}, body_block_, body_scope, body_exe_info, {}));
Expand Down Expand Up @@ -150,7 +150,7 @@ void WhileInstruction::CopyInputsToOutputs() {

void WhileInstruction::PassArgsToBodyBlock() {
for (size_t i = 0; i < body_block_->args_size(); ++i) {
auto block_arg = body_block_->argument(i);
auto block_arg = body_block_->arg(i);
auto var_name = body_inter_->GetNameByValue(block_arg);
auto* inner_var = body_inter_->local_scope()->GetVar(var_name);
inner_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
"equal. but they are %u and 0, respectively",
argument.output_types.size()));
}
argument.AddRegion()->push_back(true_block.release());
argument.AddRegion()->push_back(false_block.release());
argument.AddRegion().push_back(true_block.release());
argument.AddRegion().push_back(false_block.release());
argument.AddInput(cond);
}

Expand Down Expand Up @@ -232,12 +232,13 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT
const std::vector<pir::Value> &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
auto &body = argument.AddRegion().emplace_back();
for (auto val : inputs) {
argument.AddOutput(val.type());
body.AddArgument(val.type());
}
argument.AddRegion(nullptr);
}
pir::Block &WhileOp::body_block() {
pir::Block &WhileOp::body() {
pir::Region &body_region = (*this)->region(0);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
Expand All @@ -259,11 +260,11 @@ void WhileOp::Print(pir::IrPrinter &printer) {
[&]() { os << ", "; });
os << "] { \n ^";
pir::PrintInterleave(
body_block().args_begin(),
body_block().args_end(),
body().args_begin(),
body().args_end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
for (auto &item : body_block()) {
for (auto &item : body()) {
os << "\n ";
printer.PrintOperation(&item);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class WhileOp : public pir::Op<WhileOp> {
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs);
pir::Block &body_block();
pir::Block &body();
pir::Value cond();
void Print(pir::IrPrinter &printer); // NOLINT
void VerifySig() {}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,15 +926,15 @@ void HandleForWhileOp(
pir::Builder builder(ctx, block);
auto base_while_op = op_item->dyn_cast<WhileOp>();
auto new_while_op = builder.Build<WhileOp>(cond_val, vec_in);
pir::Block& body_block = new_while_op.body_block();
pir::Block& body_block = new_while_op.body();
for (size_t i = 0; i < vec_in.size(); ++i) {
auto block_arg = body_block.AddArgument(vec_in[i].type());
(*map_value_pair)[base_while_op.body_block().argument(i)] = block_arg;
auto block_arg = body_block.arg(i);
(*map_value_pair)[base_while_op.body().arg(i)] = block_arg;
}

// process body block
ProcessBlock(place,
&base_while_op.body_block(),
&base_while_op.body(),
&body_block,
ctx,
map_op_pair,
Expand Down
36 changes: 24 additions & 12 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
namespace py = pybind11;
using paddle::dialect::ApiBuilder;
using paddle::dialect::IfOp;
using paddle::dialect::WhileOp;
using pir::Block;
using pir::Builder;
using pir::Operation;
Expand All @@ -51,19 +52,10 @@ using pybind11::return_value_policy;
using paddle::pybind::PyIfOp;
namespace {

PyIfOp BuildPyIfOp(Value cond) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
}

void BindIfOp(py::module* m) {
m->def("build_if_op", BuildPyIfOp);
m->def("cf_yield", [](py::list inputs) {
std::vector<Value> input_values;
for (auto input : inputs) {
input_values.push_back(input.cast<Value>());
}
ApiBuilder::Instance().GetBuilder()->Build<YieldOp>(input_values);
m->def("build_if_op", [](Value cond) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
});
py::class_<PyIfOp> if_op(*m, "IfOp", R"DOC(
The PyIfOp is a encapsulation of IfOp. Compared with ifOp, it provides an additional 'update_output' interface.
Expand All @@ -83,6 +75,17 @@ void BindIfOp(py::module* m) {
});
}

void BindWhileOp(py::module* m) {
m->def("build_while_op", [](Value cond, py::list loop_vars) {
std::vector<Value> loop_values;
for (auto var : loop_vars) {
loop_values.push_back(var.cast<Value>());
}
return ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond,
loop_values);
});
}

void GetUsedExternalValueImpl(
std::unordered_set<Value>& defined_values, // NOLINT
std::vector<Value>& used_values, // NOLINT
Expand Down Expand Up @@ -185,7 +188,16 @@ void PyIfOp::UpdateOutput() {
void BindControlFlowApi(py::module* m) {
m->def("get_used_external_value", GetUsedExternalValue);
m->def("build_pipe_for_block", BuildPipeForBlock);
m->def("cf_yield", [](py::list inputs) {
std::vector<Value> input_values;
for (auto input : inputs) {
input_values.push_back(input.cast<Value>());
}
ApiBuilder::Instance().GetBuilder()->Build<YieldOp>(input_values);
});

BindIfOp(m);
BindWhileOp(m);
}
} // namespace pybind
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class IR_API Block {
bool args_empty() const { return arguments_.empty(); }
uint32_t args_size() const { return arguments_.size(); }
const BlockArgListType &args() const { return arguments_; }
BlockArgument argument(uint32_t index) { return arguments_[index]; }
Type argument_type(uint32_t index) const { return arguments_[index].type(); }
BlockArgument arg(uint32_t index) { return arguments_[index]; }
Type arg_type(uint32_t index) const { return arguments_[index].type(); }
void ClearArguments();
BlockArgument AddArgument(Type type);
template <class TypeIter>
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Block *BlockArgument::owner() const {
return IMPL_->owner_;
}

uint32_t BlockArgument::arg_index() const {
uint32_t BlockArgument::index() const {
CHECK_NULL_IMPL(arg_index);
return IMPL_->index_;
}
Expand All @@ -79,7 +79,7 @@ void BlockArgument::Destroy() {
}
}

void BlockArgument::set_arg_index(uint32_t index) {
void BlockArgument::set_index(uint32_t index) {
CHECK_NULL_IMPL(set_arg_number);
IMPL_->index_ = index;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class IR_API BlockArgument : public Value {
public:
BlockArgument() = default;
Block *owner() const;
uint32_t arg_index() const;
uint32_t index() const;

private:
/// constructor
Expand All @@ -42,7 +42,7 @@ class IR_API BlockArgument : public Value {
/// Destroy the argument.
void Destroy();
/// set the position in the block argument list.
void set_arg_index(uint32_t index);
void set_index(uint32_t index);
// Access create annd destroy.
friend Block;

Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/operation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ OperationArgument::OperationArgument(IrContext* ir_context,
info = ir_context->GetRegisteredOpInfo(name);
}

Region* OperationArgument::AddRegion() {
Region& OperationArgument::AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
return *regions.back();
}

/// Take a region that should be attached to the Operation.
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct OperationArgument {
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
Region* AddRegion();
Region& AddRegion();

/// Take a region that should be attached to the Operation. The body of the
/// region will be transferred when the Operation is created. If the
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Region::~Region() { clear(); }

void Region::push_back(Block *block) { insert(blocks_.end(), block); }

Block *Region::emplace_back() {
Block &Region::emplace_back() {
auto block = new Block;
insert(blocks_.end(), block);
return block;
return *block;
}

void Region::push_front(Block *block) { insert(blocks_.begin(), block); }
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class IR_API Region {
const Block &back() const { return *blocks_.back(); }

void push_back(Block *block);
Block *emplace_back();
Block &emplace_back();
void push_front(Block *block);
Iterator insert(ConstIterator position, Block *block);
Iterator erase(ConstIterator position);
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/new_executor/standalone_executor_pir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ TEST(StandaloneExecutor, while_op) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
pir::Block& body_block = while_op.body();
auto body_i_argument = body_block.arg(0);
auto body_ten_argument = body_block.arg(1);
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
Expand Down
18 changes: 9 additions & 9 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ TEST(while_op_test, base) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
pir::Block& body_block = while_op.body();
auto body_i_argument = body_block.arg(0);
auto body_ten_argument = body_block.arg(1);
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
Expand Down Expand Up @@ -104,11 +104,11 @@ TEST(while_op_test, network_with_backward) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, x});

// { return i + 1, x + y}
auto& body_block = while_op.body_block();
auto& body_block = while_op.body();
builder.SetInsertionPointToStart(&body_block);

auto body_i_argument = body_block.AddArgument(i.type());
auto body_x_argument = body_block.AddArgument(x.type());
auto body_i_argument = body_block.arg(0);
auto body_x_argument = body_block.arg(1);

auto new_i = builder.Build<AddOp>(body_i_argument, one).out();
auto new_x = builder.Build<AddOp>(body_x_argument, y).out();
Expand Down Expand Up @@ -141,10 +141,10 @@ TEST(while_op_test, network_with_backward) {
auto bwd_cond = builder.Build<pir::HasElementsOp>(stack).out();
auto while_grad = builder.Build<WhileOp>(
bwd_cond, std::vector<pir::Value>{x_out_grad, zero});
pir::Block& bwd_body_block = while_grad.body_block();
pir::Block& bwd_body_block = while_grad.body();
builder.SetInsertionPointToStart(&bwd_body_block);
auto local_x_out_grad_arg = bwd_body_block.AddArgument(x.type());
auto local_y_grad_arg = bwd_body_block.AddArgument(y.type());
auto local_x_out_grad_arg = bwd_body_block.arg(0);
auto local_y_grad_arg = bwd_body_block.arg(1);

auto pop_op = builder.Build<pir::TuplePopOp>(outlet);
auto bwd_body_x_argument = pop_op.outlet_element(0);
Expand Down
6 changes: 3 additions & 3 deletions test/cpp/pir/core/block_argument_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ TEST(block_argument_test, base) {

uint32_t index = 0;
for (auto iter = block->args_begin(); iter != block->args_end(); ++iter) {
EXPECT_EQ(iter->arg_index(), index++);
EXPECT_EQ(iter->index(), index++);
}

pir::Value value = block->argument(0);
pir::Value value = block->arg(0);
pir::BlockArgument argument = value.dyn_cast<pir::BlockArgument>();
EXPECT_TRUE(argument);
EXPECT_EQ(argument.owner(), block);
EXPECT_EQ(block->argument_type(0), types[0]);
EXPECT_EQ(block->arg_type(0), types[0]);
pir::OpResult op_result = value.dyn_cast<pir::OpResult>();
EXPECT_FALSE(op_result);

Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ TEST(OperatorDialectTest, WhileOpProgram) {
EXPECT_TRUE(op.isa<paddle::dialect::WhileOp>());
EXPECT_EQ(op.num_regions(), 1u);
// body block
pir::Block &body_block =
op.dyn_cast<paddle::dialect::WhileOp>().body_block();
pir::Block &body_block = op.dyn_cast<paddle::dialect::WhileOp>().body();
size_t body_id = 0;
for (auto &op1 : body_block) {
if (body_id == 0) {
Expand All @@ -308,7 +307,7 @@ TEST(OperatorDialectTest, WhileOpProgram) {
}
if (body_id == 3) {
pir::Block &body_body_block =
op1.dyn_cast<paddle::dialect::WhileOp>().body_block();
op1.dyn_cast<paddle::dialect::WhileOp>().body();
size_t body_body_id = 0;
for (auto &op2 : body_body_block) {
if (body_body_id == 0) {
Expand Down