-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support inplace in dygraph eager_fluid state #40400
support inplace in dygraph eager_fluid state #40400
Conversation
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
… support_partial_grad
… support_partial_grad
… support_partial_grad
… support_partial_grad
…e/Paddle into inplace_in_eager_fluid_state
@@ -94,15 +105,52 @@ class TensorWrapper { | |||
intermidiate_tensor_.set_autograd_meta( | |||
std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>( | |||
p_ab_autograd_meta)); | |||
check_inplace_version(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we're gonna check inplace version anyway, let's move this function "check_inplace_version" out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in PR #41118
@@ -716,6 +716,15 @@ static PyObject* set_grad_type(TensorObject* self, PyObject* args, | |||
EAGER_CATCH_AND_THROW_RETURN_NULL | |||
} | |||
|
|||
static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
single underscore "_" in function name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its ok if this method indicate _inplace_version
api in python
std::string ins_initializer_with_null = ""; | ||
std::string py_arg = ""; | ||
int arg_idx = 0; | ||
int input_args_num = 0; | ||
std::string ins_cast_str = ""; | ||
std::string view_strategy_str = ""; | ||
if (!inplace_map.empty()) { | ||
// change call_api_str for inplace op | ||
call_api_str = "auto out = " + op_type + "__dygraph_function("; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better add "" at the very end of the function name, like "scale_dygraph_function" for inplaced scale
std::map<std::string, std::string> inplace_map; | ||
// `sum` op has duplicate input. Don't consider adding inplace strategy | ||
// for `sum` in temporary. | ||
if (op_type != "sum" && infer_inplace) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better store hard-coded op name in a static set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in PR #41118
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_inplace_eager_fluid PROPERTIES TIMEOUT 120)
PR types
New features
PR changes
Others
Describe
动态图中间态添加inplace策略。
python-c
实现要点:
动态图层
实现要点:
check_inplace
的检查,check_inplace
检查在执行TraceOp之前比较好。另一方面,check_inplace
需要使用输入auto_grad meta
信息。所以流程变为:先创建输入的auto_grad meta
信息,然后check_inplace
,然后执行TraceOp
生成输出,然后创建输出的auto_grad meta
信息,然后构反向。EagerVariable
替换输出。后续不重新创建输出Tensor,直接使用输入Tensor代替输出。EagerVariable
内不会改变输入Tensor的meta
信息(导致inplacereshape
无法改变ddim
信息),因此新增了ModifyInplaceInput
修改inplace tensor的meta
信息。反向检测
TensorWrapper
中加入snapshot_inplace_version_
快照信息。TensorWrapper
recover
出Tensor
时,进行inplace_version反向检测。比较snapshot_inplace_version_
与Tensor的current_inplace_version_
是否一致。示例