Skip to content

Commit

Permalink
[PIR] Refine IfOp translate (PaddlePaddle#58088)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored and Frida-a committed Oct 14, 2023
1 parent daf947c commit 90a3361
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 41 deletions.
123 changes: 86 additions & 37 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,36 @@ const std::unordered_set<std::string> ProgramTranslator::unsupported_ops = {
static std::vector<uint64_t> GetCondOpIds(const BlockDesc& src_block,
uint64_t first_id) {
std::vector<uint64_t> op_list = {first_id};
if (src_block.Op(static_cast<int>(first_id + 1))->Type() == "logical_not") {
if (((first_id + 1) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 1))->Type() == "logical_not")) {
op_list.emplace_back(first_id + 1);
}
if (src_block.Op(static_cast<int>(first_id + 2))->Type() ==
"conditional_block") {
if (((first_id + 2) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 2))->Type() ==
"conditional_block")) {
op_list.emplace_back(first_id + 2);
}
if (src_block.Op(static_cast<int>(first_id + 3))->Type() == "cast") {
if (((first_id + 3) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 3))->Type() == "cast")) {
op_list.emplace_back(first_id + 3);
}
size_t output_size =
src_block.Op(static_cast<int>(first_id))->Output("Out").size();
// Note(zhangbo): Some output variables are input, without select_input op.
std::vector<std::string> output_names =
src_block.Op(static_cast<int>(first_id))->Output("Out");
std::vector<std::string> input_names =
src_block.Op(static_cast<int>(first_id))->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
size_t output_size = diffs.size();
for (size_t i = 0; i < output_size; i++) {
if (src_block.Op(static_cast<int>(first_id + 4 + i))->Type() ==
"select_input") {
if (((first_id + 4 + i) < src_block.OpSize()) &&
(src_block.Op(static_cast<int>(first_id + 4 + i))->Type() ==
"select_input")) {
op_list.emplace_back(first_id + 4 + i);
}
}
Expand All @@ -101,7 +116,16 @@ const std::string& ConditionBlockCombination::CondVarName() const {
}

size_t ConditionBlockCombination::OutputSize() const {
return op_list_[0]->Output("Out").size();
std::vector<std::string> output_names = op_list_[0]->Output("Out");
std::vector<std::string> input_names = op_list_[0]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs.size();
}

std::vector<::paddle::framework::VarDesc*>
Expand All @@ -116,23 +140,41 @@ ConditionBlockCombination::OutputVars() const {
return outputs;
}

const std::vector<std::string>&
ConditionBlockCombination::TrueBlockOutputVarNames() const {
return op_list_[0]->Output("Out");
}

int ConditionBlockCombination::TrueBlockId() const {
return op_list_[0]->GetBlockAttrId("sub_block");
std::vector<std::string> ConditionBlockCombination::TrueBlockOutputVarNames()
const {
std::vector<std::string> output_names = op_list_[0]->Output("Out");
std::vector<std::string> input_names = op_list_[0]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs;
}

std::vector<std::string> ConditionBlockCombination::FalseBlockOutputVarNames()
const {
if (op_list_.size() > 1) {
return op_list_[2]->Output("Out");
std::vector<std::string> output_names = op_list_[2]->Output("Out");
std::vector<std::string> input_names = op_list_[2]->Input("Input");
std::vector<std::string> diffs(output_names.size());
auto iter = std::set_difference(output_names.begin(),
output_names.end(),
input_names.begin(),
input_names.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
return diffs;
}
return {""};
}

int ConditionBlockCombination::TrueBlockId() const {
return op_list_[0]->GetBlockAttrId("sub_block");
}

int ConditionBlockCombination::FalseBlockId() const {
if (op_list_.size() > 1) {
return op_list_[2]->GetBlockAttrId("sub_block");
Expand All @@ -147,9 +189,6 @@ bool ConditionBlockCombination::Verify(
if (op_list[id]->Type() != "conditional_block") {
return false;
}
if (op_list.size() == 1 && op_list[id]->Output("Out").size() != 0) {
return false;
}
} else if (id == 1) {
if (op_list[id]->Type() != "logical_not") {
return false;
Expand Down Expand Up @@ -248,11 +287,13 @@ void ProgramTranslator::Translate() {
}
}

void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block) {
void ProgramTranslator::TranslateBlock(
const BlockDesc& src_block,
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block,
std::vector<std::string> skip_cond_assign) {
VLOG(8) << "=============>start to translate a block";
PADDLE_ENFORCE(
(src_block.OpSize() >= end_id) && (start_id <= end_id),
Expand All @@ -264,10 +305,12 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
src_block.OpSize()));

std::unordered_map<uint64_t, bool> translate_completed;
std::vector<std::string> assign_inputs;
for (uint64_t op_id = start_id; op_id < end_id; op_id++) {
if (translate_completed.count(op_id) && translate_completed.at(op_id)) {
continue;
}

auto op = src_block.Op(static_cast<int>(op_id));
VLOG(8) << "=============>start to translate a op: " << op->Type();

Expand All @@ -287,20 +330,24 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block,
}
VLOG(10) << "[op translated][conditional_block]" << if_op;
} else {
TranslateGeneralOperation(op, dest_block);
translate_completed[op_id] = true;
if (for_cond_block && op->Type() == "assign" &&
std::count(skip_cond_assign.begin(),
skip_cond_assign.end(),
op->Output("Out")[0])) {
assign_inputs.push_back(op->Input("X")[0]);
translate_completed[op_id] = true;
} else {
TranslateGeneralOperation(op, dest_block);
translate_completed[op_id] = true;
}
}
}
// NOTE(zhangbo): If conditional_block operator has output, the cf.yeild
// operator needs to be inserted
if (for_cond_block) {
std::vector<pir::Value> yeild_inputs;
for (size_t id = end_id; id < src_block.OpSize(); id++) {
PADDLE_ENFORCE(
src_block.Op(id)->Type() == "assign",
"The operator at the end of the sub block needs to be assign");
yeild_inputs.emplace_back(
param_map_[src_block.Op(static_cast<int>(id))->Input("X")[0]].value);
for (size_t id = 0; id < assign_inputs.size(); id++) {
yeild_inputs.emplace_back(param_map_[assign_inputs[id]].value);
}
pir::AttributeMap attribute_map;
auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name());
Expand Down Expand Up @@ -349,9 +396,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
if (true_region.empty()) true_region.emplace_back();
TranslateBlock(true_sub_block,
0,
true_sub_block.OpSize() - cond_ops.OutputSize(),
true_sub_block.OpSize(),
true_region.front(),
true);
true,
cond_ops.TrueBlockOutputVarNames());
}
VLOG(4) << "[general op][conditional_block] IfOp true block translate end.";

Expand All @@ -362,9 +410,10 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
if (false_region.empty()) false_region.emplace_back();
TranslateBlock(false_sub_block,
0,
false_sub_block.OpSize() - cond_ops.OutputSize(),
false_sub_block.OpSize(),
false_region.front(),
true);
true,
cond_ops.FalseBlockOutputVarNames());
}
VLOG(4) << "[general op][conditional_block] IfOp false block translate end.";

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ class ConditionBlockCombination {
ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block,
const std::vector<uint64_t>& op_ids);
const std::string& CondVarName() const;
int TrueBlockId() const;
int FalseBlockId() const;
size_t OutputSize() const;
std::vector<::paddle::framework::VarDesc*> OutputVars() const;
const std::vector<std::string>& TrueBlockOutputVarNames() const;
int TrueBlockId() const;
std::vector<std::string> TrueBlockOutputVarNames() const;
std::vector<std::string> FalseBlockOutputVarNames() const;
int FalseBlockId() const;

private:
bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list);
Expand Down Expand Up @@ -124,7 +124,8 @@ class ProgramTranslator {
uint64_t start_id,
uint64_t end_id,
pir::Block* dest_block,
bool for_cond_block = false);
bool for_cond_block = false,
std::vector<std::string> skip_cond_assign = {});
void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block);
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
Expand Down

0 comments on commit 90a3361

Please sign in to comment.