-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Prim】support higher order autodiff for dy2static+composite #53171
Merged
cxxly
merged 12 commits into
PaddlePaddle:develop
from
cxxly:prim_higher_order_autodiff
May 12, 2023
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
d950162
[Dy2St]Fix x grad names when high order gradient
0x45f cbad143
Polish error msg
0x45f 8bd76d7
Add inputs var to backward in dy2st
0x45f 573c65e
Fix error
0x45f 65f5cf9
Get grad names for backward API
0x45f ad69937
Fix save load
0x45f 8620903
Polish code
0x45f 56ba3c3
Add ut
0x45f 98e36fc
[prim] fix not support optional grad bugs in higher order autodiff
cxxly 5ab7fdd
[prim] remove duplicate fill_any_like caused by infershape_for_composite
cxxly 1167429
fix _strip_grad_suffix_ bugs in higher-order autodiff
cxxly aa74cfe
[prim] create output for test_static_prim.cc
cxxly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,8 @@ | |
|
||
from collections.abc import Sequence | ||
|
||
import re | ||
|
||
__all__ = [ | ||
'append_backward', | ||
'gradients', | ||
|
@@ -459,10 +461,14 @@ def _strip_grad_suffix_(name): | |
""" | ||
Strip the grad suffix from the given variable name | ||
e.g. x@GRAD ==> x | ||
x@GRAD@GRAD ==> x | ||
y@GRAD@RENAME@1 ==> y | ||
z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0 | ||
""" | ||
pos = name.find(core.grad_var_suffix()) | ||
new_name = name[:pos] if pos != -1 else name | ||
pos = re.search(f'{core.grad_var_suffix()}$', name) or re.search( | ||
f'{core.grad_var_suffix()}@', name | ||
) | ||
new_name = name[: pos.start()] if pos is not None else name | ||
new_pos = name.rfind('grad/') | ||
return new_name[new_pos + 5 :] if new_pos != -1 else new_name | ||
|
||
|
@@ -1343,15 +1349,17 @@ def update_distop_context( | |
|
||
if core._is_bwd_prim_enabled(): | ||
composite_block = program.clone().current_block() | ||
# Infer shape for operators whose output haven't been created. | ||
# Create output and infer shape for operators whose output haven't | ||
# been created. | ||
for op in composite_block.ops: | ||
if not all( | ||
tuple( | ||
composite_block._find_var_recursive(arg) | ||
for arg in op.output_arg_names | ||
) | ||
): | ||
infershape_for_composite(composite_block, op.desc) | ||
for name in op.output_arg_names: | ||
if not ( | ||
composite_block.desc.has_var_recursive(name.encode()) | ||
or name == core.empty_var_name() | ||
): | ||
composite_block.create_var(name=name) | ||
op.desc.infer_var_type(composite_block.desc) | ||
op.desc.infer_shape(composite_block.desc) | ||
|
||
# add grad_op_desc by reversed ops | ||
for op in reversed(ops): | ||
|
@@ -1492,27 +1500,36 @@ def find_op_index(block_desc, cur_op_desc): | |
or name in input_grad_names_set | ||
) | ||
is_append_grad = False | ||
|
||
# NOTE: In primitive mode, the intermediate variable generated by | ||
# decompositing raw grad op are not satisfied the rule of 'XX@GRAD', | ||
# which will cause it be pruned according to current pruning logic. | ||
# For simplicity, we treate all prmitive operators as one raw | ||
# operator, and keep the pruning logic consistent with currently | ||
# logic. The drawback of this solution is may lead to some primitive | ||
# operators are not pruned, which is needed to fixed. | ||
# FIXME: Optimize pruning logic from the perspective of whole graph. | ||
input_grad_names = [] | ||
for op_desc in grad_op_desc: | ||
input_grad_names += [ | ||
name | ||
for name in op_desc.input_arg_names() | ||
if is_grad_name(name) | ||
] | ||
|
||
# some code of gradient ops, like increment, are not very | ||
# standard, there is no @GRAD in these ops' inputs. | ||
if len(input_grad_names) == 0: | ||
is_append_grad = True | ||
break | ||
|
||
for op_desc in grad_op_desc: | ||
|
||
# some code of gradient ops, like increment, are not very | ||
# standard, there is no @GRAD in these ops' inputs. | ||
continue | ||
|
||
if _some_in_set_(input_grad_names, input_grad_names_set): | ||
if _some_in_set_(input_grad_names, input_grad_names_set): | ||
is_append_grad = True | ||
for op_desc in grad_op_desc: | ||
grad_op_descs.append(op_desc) | ||
is_append_grad = True | ||
for name in op_desc.output_arg_names(): | ||
input_grad_names_set.add(name) | ||
|
||
if is_append_grad: | ||
grad_to_var.update(op_grad_to_var) | ||
else: | ||
|
@@ -1774,17 +1791,19 @@ def infershape_for_composite(block, grad_op_desc): | |
op_desc.check_attrs() | ||
op_desc.infer_var_type(block.desc) | ||
op_desc.infer_shape(block.desc) | ||
for arg in op_desc.output_arg_names(): | ||
if arg in new_vars: | ||
_infer_var_data_type_shape_(arg, block) | ||
|
||
grad_op_desc.copy_from(op_desc) | ||
|
||
# NOTE: Some operator doesn't infer dtype correctly, this patch set the | ||
# grad_var dtype same with corresponding forward variable. | ||
for arg in grad_op_desc.output_arg_names(): | ||
if arg in new_vars: | ||
_infer_var_data_type_shape_(arg, block) | ||
if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): | ||
# NOTE: Some raw fluid grad operators which hadn't been decomposed may not | ||
# implement InferVarType method, such as elementwise_xx_grad, and it will | ||
# cause the dtype or shape of corresponding cotangent incorrect. This | ||
# patch set the cotangent dtype and shape same with corresponding | ||
# forward variable. For primitive operators, we have ensure all | ||
# InferVarType method to be executed correctly in PR#52818, we skip | ||
# this patch for primitive operators. | ||
for arg in grad_op_desc.output_arg_names(): | ||
if arg in new_vars: | ||
_infer_var_data_type_shape_(arg, block) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分修改是否可以合并至1784 else 分支中 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以的,只不过这样代码是一个非常Hack的补丁(Paddle 1.x版本已存在),单独放一段逻辑更容易引起注意 |
||
|
||
|
||
def _rename_grad_( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这样修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前Block只包含前向OP列表,只需要对于没有输出的OP创建输出并推断形状,不需要将OP重复添加到block中。
infershape_for_composite
会重复创建OP,比如fill_any_like