-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor apply transformer #36899
Refactor apply transformer #36899
Conversation
Thanks for your contribution! |
@@ -216,6 +216,129 @@ void apply_device_guard(const OperatorBase* op_base, | |||
} | |||
} | |||
|
|||
std::vector<OpFuncNode> apply_data_transformer(VariableValueMap& ins_map_temp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply_data_transformer -> apply_data_transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数签名把const放前面,非const放后面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ins_map_temp 函数里修改了么?没有修改的话,这里应该是 const 修饰下?expected_kernel_key 也确认下?
const platform::Place& place) { | ||
auto& op_base = op_func_node.operator_base_; | ||
auto& op = op_base; | ||
auto inputs_names = op->Inputs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 op_base 和 op是同一个指针,为什么需要两个变量?另外,对于指针类型,这里要PADDLE_ENFORCE_NOT_NULL判断下
no_data_transform_index; // record the no need transform variable index. | ||
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode. | ||
|
||
for (auto& var_name_item : ins_map_temp) { /*{{{*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移除 /{{{/
@@ -272,123 +395,28 @@ void build_op_func_list(const platform::Place& place, | |||
Scope scope; | |||
auto expected_kernel_key = | |||
dynamic_cast<const framework::OperatorWithKernel*>(op_base) | |||
->GetExpectedKernelType( | |||
->GetExpectedKernelType( // for all variables's type and context's | |||
// place |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的注释不是很明确,完善或者移除掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
var_name_item.second[i] = var_scope->Var(new_var_name); | ||
} else if (need_dtype_transform_for_var(kernel_type_for_var, | ||
expected_kernel_key)) { | ||
// TODO(@xiongkun) add dtype judgement here) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO(@xiongkun) add dtype judgement here) | |
PADDLE_THROW("not implemented yet") |
@@ -216,6 +216,171 @@ void apply_device_guard(const OperatorBase* op_base, | |||
} | |||
} | |||
|
|||
// the return value is whether data transformer is needed for this var | |||
bool need_place_transform_for_var(const OpKernelType kernel_type_for_var, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool need_place_transform_for_var(const OpKernelType kernel_type_for_var, | |
bool need_place_transform_for_var(const OpKernelType& kernel_type_for_var, |
|
||
bool need_dtype_transform_for_var(const OpKernelType kernel_type_for_var, | ||
const OpKernelType expected_kernel_key) { | ||
return false; // TODO(@xiongkun) add dtype judgement here) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here) -> here, and I think throw error is better than return false
, since it is not correct yet.
} | ||
|
||
// NOTE(@xiongkun03) | ||
// the different of var_name and outer_name : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
different -> difference
// record no need data transformer input var_id | ||
VLOG(3) << op_base->Type() | ||
<< " found no data_transform var: " << var_name | ||
<< " with id: " << var_name; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var_name -> var_scope->VarId(var_name)
PR types
Others
PR changes
Others
Describe
Extract mem copy op as a function: apply_data_transformer