-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR]support new ir load combine #56101
Changes from all commits
db06fab
2fb4ec2
0823f7b
353f0b6
be5a2b1
e36a260
7c88193
c585da7
57f8493
9b4fe88
c2b9a41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -598,17 +598,39 @@ void BuildRuntimeContext( | |||||||||
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), | ||||||||||
phi::errors::PreconditionNotMet( | ||||||||||
"can not find var[%s] in scope", in_var_name)); | ||||||||||
|
||||||||||
auto var = inner_scope->FindVar(in_var_name); | ||||||||||
std::vector<paddle::framework::Variable*> vec_tmp = {var}; | ||||||||||
auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); | ||||||||||
runtime_ctx->outputs[legacy_attr_name] = vec_tmp; | ||||||||||
|
||||||||||
auto type = ptr.type(); | ||||||||||
auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); | ||||||||||
if (type.isa<paddle::dialect::AllocatedDenseTensorType>() || | ||||||||||
type.isa<paddle::dialect::AllocatedSelectedRowsType>()) { | ||||||||||
std::vector<paddle::framework::Variable*> vec_tmp = {var}; | ||||||||||
|
||||||||||
runtime_ctx->outputs[legacy_arg_name] = vec_tmp; | ||||||||||
Comment on lines
+608
to
+610
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
} else if (type.isa<ir::VectorType>()) { | ||||||||||
auto var_ref = var->Get<paddle::framework::VariableRefArray>(); | ||||||||||
std::vector<paddle::framework::Variable*> vec_tmp; | ||||||||||
vec_tmp.reserve(var_ref.size()); | ||||||||||
for (size_t k = 0; k < var_ref.size(); ++k) { | ||||||||||
vec_tmp.push_back(const_cast<paddle::framework::Variable*>(var_ref[k])); | ||||||||||
} | ||||||||||
runtime_ctx->outputs[legacy_arg_name] = vec_tmp; | ||||||||||
} else { | ||||||||||
PADDLE_THROW(phi::errors::Unimplemented( | ||||||||||
"only support AllocatedDenseTensor, AllocatedSelectedRowsType and " | ||||||||||
"ir::vector type")); | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase( | ||||||||||
ir::Operation* op, | ||||||||||
const std::unordered_map<ir::Value, std::string>& name_map, | ||||||||||
const paddle::dialect::OpYamlInfoParser& op_yaml_info) { | ||||||||||
const paddle::dialect::OpYamlInfoParser& op_yaml_info, | ||||||||||
const std::unordered_map<const paddle::framework::Variable*, std::string>& | ||||||||||
variable_2_var_name, | ||||||||||
const paddle::framework::Scope* scope) { | ||||||||||
paddle::framework::VariableNameMap in_name_map; | ||||||||||
paddle::framework::VariableNameMap out_name_map; | ||||||||||
paddle::framework::AttributeMap attr_map; | ||||||||||
|
@@ -637,15 +659,57 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase( | |||||||||
} | ||||||||||
|
||||||||||
// build attribute | ||||||||||
auto& op_attr_map = op->attributes(); | ||||||||||
auto attr_name_list = op_yaml_info.AttrParams(true); | ||||||||||
for (auto& name : attr_name_list) { | ||||||||||
auto& val = op_attr_map.at(name); | ||||||||||
|
||||||||||
if (val.isa<ir::StrAttribute>()) { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的Attribute解析数据的逻辑在 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我确认了下 GetAttributeData 返回的是 phi的attribute, 这里需要的是framework::attribute, 定义还是有些差别,不能直接用同一个 |
||||||||||
attr_map[name] = val.dyn_cast<ir::StrAttribute>().AsString(); | ||||||||||
} else if (val.isa<ir::Int32Attribute>()) { | ||||||||||
attr_map[name] = val.dyn_cast<ir::Int32Attribute>().data(); | ||||||||||
} else if (val.isa<ir::BoolAttribute>()) { | ||||||||||
attr_map[name] = val.dyn_cast<ir::BoolAttribute>().data(); | ||||||||||
} else if (val.isa<ir::FloatAttribute>()) { | ||||||||||
attr_map[name] = val.dyn_cast<ir::FloatAttribute>().data(); | ||||||||||
} else if (val.isa<ir::DoubleAttribute>()) { | ||||||||||
attr_map[name] = val.dyn_cast<ir::DoubleAttribute>().data(); | ||||||||||
} else if (val.isa<ir::Int64Attribute>()) { | ||||||||||
attr_map[name] = val.dyn_cast<ir::Int64Attribute>().data(); | ||||||||||
} else { | ||||||||||
std::stringstream ss; | ||||||||||
val.Print(ss); | ||||||||||
VLOG(1) << "type not support " << ss.str() << std::endl; | ||||||||||
PADDLE_THROW("Type[%s] in attribute map not support yet", ss.str()); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
auto& output_name_list = op_yaml_info.OutputNames(); | ||||||||||
for (size_t i = 0; i < output_name_list.size(); ++i) { | ||||||||||
auto name = output_name_list[i]; | ||||||||||
ir::Value ptr = op->result(i); | ||||||||||
|
||||||||||
auto out_var_name = name_map.at(ptr); | ||||||||||
auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); | ||||||||||
out_name_map[legacy_attr_name].push_back(out_var_name); | ||||||||||
|
||||||||||
auto type = ptr.type(); | ||||||||||
auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); | ||||||||||
if (type.isa<paddle::dialect::AllocatedDenseTensorType>() || | ||||||||||
type.isa<paddle::dialect::AllocatedSelectedRowsType>()) { | ||||||||||
out_name_map[legacy_arg_name].push_back(out_var_name); | ||||||||||
} else if (type.isa<ir::VectorType>()) { | ||||||||||
auto var = scope->FindVar(out_var_name); | ||||||||||
auto var_ref = var->Get<paddle::framework::VariableRefArray>(); | ||||||||||
for (size_t k = 0; k < var_ref.size(); ++k) { | ||||||||||
PADDLE_ENFORCE(variable_2_var_name.count(var_ref[k]), | ||||||||||
"Variable MUST in variable_2_var_name map"); | ||||||||||
out_name_map[legacy_arg_name].push_back( | ||||||||||
variable_2_var_name.at(var_ref[k])); | ||||||||||
} | ||||||||||
} else { | ||||||||||
PADDLE_THROW(phi::errors::Unimplemented( | ||||||||||
"only support AllocatedDenseTensor, AllocatedSelectedRowsType and " | ||||||||||
"ir::vector type")); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
auto& op_info = paddle::framework::OpInfoMap::Instance().Get(fluid_op_name); | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,10 +59,6 @@ const std::unordered_set<std::string> UnchangeOutputOps = { | |
"builtin.get_parameter", | ||
"pd.shadow_output"}; | ||
|
||
const std::unordered_set<std::string> LegacyOpList = { | ||
"pd.fused_softmax_mask_upper_triangle", | ||
"pd.fused_softmax_mask_upper_triangle_grad"}; | ||
|
||
bool NeedFallBackCpu(const ir::Operation* op, | ||
const std::string& kernel_fn_name, | ||
const phi::KernelKey& kernel_key) { | ||
|
@@ -553,6 +549,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, | |
GetKernelKey(op_item, place, map_value_pair, op_info_parser.get()); | ||
VLOG(6) << "kernel type " << kernel_key; | ||
|
||
if (op_item->name() == "pd.load_combine") { | ||
kernel_key.set_dtype(phi::DataType::FLOAT32); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么直接设置 kernel_key 为FP32了?我看load_combile 是支持很多其他数据类型的?是先验证了FP32,还是后面逻辑会自适应的修改? 若是前者,这里是否需要加一下 TODO 标记下? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. load_combine 是一个没有输入的op,无法从输入的参数表里面推导输出的类型,但是kernel key不是有一个类型 |
||
} | ||
if (NeedFallBackCpu((op_item), kernel_fn_str, kernel_key)) { | ||
kernel_key.set_backend(phi::Backend::CPU); | ||
} | ||
|
@@ -571,7 +570,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, | |
auto args_def = phi_kernel.args_def(); | ||
auto output_defs = args_def.output_defs(); | ||
if (!UnchangeOutputOps.count(op_item->name()) && | ||
!LegacyOpList.count(op_item->name())) { | ||
!IsLegacyOp(op_item->name())) { | ||
PADDLE_ENFORCE_EQ( | ||
op_item->num_results(), | ||
output_defs.size(), | ||
|
@@ -583,7 +582,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, | |
for (size_t i = 0; i < op_item->num_results(); ++i) { | ||
phi::Place out_place; | ||
if ((!UnchangeOutputOps.count(op_item->name())) && | ||
(!LegacyOpList.count(op_item->name())) && phi_kernel.IsValid()) { | ||
(!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { | ||
out_place = phi::TransToPhiPlace(output_defs[i].backend); | ||
} else { | ||
out_place = phi::TransToPhiPlace(kernel_key.backend()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
随着执行器接入的完善,BuildOpFuncList函数后续应该没有什么用处了。TODO:后续随执行器代码清理 pr 一并清理