-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dygraph API] Fix merged_momentum, provide actual inplace operations … #59161
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
…after falling back to CPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for const_cast
{code_indent} kernel_out_{output_idx}[i] = const_cast<phi::DenseTensor*>({PREFIX_TENSOR_NAME}{self.inplace_map[self.outputs['names'][output_idx]]}->at(i)); | ||
{code_indent} }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本身就是inplace,这里为何要再次把input的指针赋值给out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为如果发生了fallback到cpu的情况的话,输入的tensor是会被从xpu copy到cpu上;因为算子是inplace,输入输出的指针应该是同一个,所以需要把kernel_out指向copy到cpu上的数据而不是指向xpu上的数据~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里等号左右的tensor不是同一个吗?
{code_indent} auto target_ptr = static_cast<phi::DenseTensor*>({target_input}->at(i).impl().get()); | ||
{code_indent} *target_ptr = *{kernel_out}.at(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不是inplace吗,为何又要把输出给了输入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为fallback到cpu之后,out是在cpu上的,所以需要重新把out的值写回xpu上的input~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…after falling back to CPU (PaddlePaddle#59161)
PR types
Bug fixes
PR changes
APIs
Description
在维持#58204功能的前提下修复遗留的问题,使fallback到CPU后,merged_momentum kernel里输入输出算子在CPU上的地址是一样的
以merged_momentum为例:
output部分:
修复前:
修复后:
return前的copy back部分:
修复前:
修复后: