Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Prim】support higher order autodiff for dy2static+composite #53171

Merged
merged 12 commits into from
May 12, 2023
7 changes: 2 additions & 5 deletions paddle/fluid/prim/utils/static/composite_grad_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,7 @@ class CompositeGradOpMakerBase {
framework::VarDesc* SingleOutputGrad(const std::string& name) const {
auto* var = this->SingleForwardOutput(name);
if (!var) {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetSingleOutputGrad for %s_grad faild, if it is Optional input,"
"please use GetOptionalSingleOutputGrad replaced. ",
name));
return nullptr;
}
auto var_name = var->Name();
auto grad_var_name = framework::GradVarName(var_name);
Expand All @@ -371,7 +368,7 @@ class CompositeGradOpMakerBase {
return StaticCompositeContext::Instance().GetBlock()->FindVar(
grad_var_name);
} else {
return StaticCompositeContext::Instance().GetBlock()->Var(grad_var_name);
return nullptr;
}
}

Expand Down
73 changes: 46 additions & 27 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from collections.abc import Sequence

import re

__all__ = [
'append_backward',
'gradients',
Expand Down Expand Up @@ -459,10 +461,14 @@ def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given variable name
e.g. x@GRAD ==> x
x@GRAD@GRAD ==> x
y@GRAD@RENAME@1 ==> y
z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0
"""
pos = name.find(core.grad_var_suffix())
new_name = name[:pos] if pos != -1 else name
pos = re.search(f'{core.grad_var_suffix()}$', name) or re.search(
f'{core.grad_var_suffix()}@', name
)
new_name = name[: pos.start()] if pos is not None else name
new_pos = name.rfind('grad/')
return new_name[new_pos + 5 :] if new_pos != -1 else new_name

Expand Down Expand Up @@ -1343,15 +1349,17 @@ def update_distop_context(

if core._is_bwd_prim_enabled():
composite_block = program.clone().current_block()
# Infer shape for operators whose output haven't been created.
# Create output and infer shape for operators whose output haven't
# been created.
for op in composite_block.ops:
if not all(
tuple(
composite_block._find_var_recursive(arg)
for arg in op.output_arg_names
)
):
infershape_for_composite(composite_block, op.desc)
for name in op.output_arg_names:
if not (
composite_block.desc.has_var_recursive(name.encode())
or name == core.empty_var_name()
):
composite_block.create_var(name=name)
op.desc.infer_var_type(composite_block.desc)
op.desc.infer_shape(composite_block.desc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这样修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前Block只包含前向OP列表,只需要对于没有输出的OP创建输出并推断形状,不需要将OP重复添加到block中。infershape_for_composite 会重复创建OP,比如fill_any_like


# add grad_op_desc by reversed ops
for op in reversed(ops):
Expand Down Expand Up @@ -1492,27 +1500,36 @@ def find_op_index(block_desc, cur_op_desc):
or name in input_grad_names_set
)
is_append_grad = False

# NOTE: In primitive mode, the intermediate variable generated by
# decompositing raw grad op are not satisfied the rule of 'XX@GRAD',
# which will cause it be pruned according to current pruning logic.
# For simplicity, we treate all prmitive operators as one raw
# operator, and keep the pruning logic consistent with currently
# logic. The drawback of this solution is may lead to some primitive
# operators are not pruned, which is needed to fixed.
# FIXME: Optimize pruning logic from the perspective of whole graph.
input_grad_names = []
for op_desc in grad_op_desc:
input_grad_names += [
name
for name in op_desc.input_arg_names()
if is_grad_name(name)
]

# some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs.
if len(input_grad_names) == 0:
is_append_grad = True
break

for op_desc in grad_op_desc:

# some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs.
continue

if _some_in_set_(input_grad_names, input_grad_names_set):
if _some_in_set_(input_grad_names, input_grad_names_set):
is_append_grad = True
for op_desc in grad_op_desc:
grad_op_descs.append(op_desc)
is_append_grad = True
for name in op_desc.output_arg_names():
input_grad_names_set.add(name)

if is_append_grad:
grad_to_var.update(op_grad_to_var)
else:
Expand Down Expand Up @@ -1774,17 +1791,19 @@ def infershape_for_composite(block, grad_op_desc):
op_desc.check_attrs()
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)

grad_op_desc.copy_from(op_desc)

# NOTE: Some operator doesn't infer dtype correctly, this patch set the
# grad_var dtype same with corresponding forward variable.
for arg in grad_op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()):
# NOTE: Some raw fluid grad operators which hadn't been decomposed may not
# implement InferVarType method, such as elementwise_xx_grad, and it will
# cause the dtype or shape of corresponding cotangent incorrect. This
# patch set the cotangent dtype and shape same with corresponding
# forward variable. For primitive operators, we have ensure all
# InferVarType method to be executed correctly in PR#52818, we skip
# this patch for primitive operators.
for arg in grad_op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分修改是否可以合并至1784 else 分支中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,只不过这样代码是一个非常Hack的补丁(Paddle 1.x版本已存在),单独放一段逻辑更容易引起注意



def _rename_grad_(
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_calc_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from paddle import fluid
from paddle.fluid.backward import calc_gradient

paddle.enable_static()


class TestCalcGradient(unittest.TestCase):
def test_calc_gradient(self):
Expand Down
23 changes: 23 additions & 0 deletions test/cpp/prim/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ TEST(StaticPrim, TanhBackwardComposite) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
Tensor out_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out_grad_desc =
static_cast<prim::DescTensor*>(out_grad.impl().get())->get_ptr();
target_block->RenameVar(out_grad_desc->Name(), "b@GRAD");
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops =
std::move(framework::OpInfoMap::Instance()
.Get(forward_opdesc->Type())
Expand Down Expand Up @@ -288,6 +293,11 @@ TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
Tensor out_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out_grad_desc =
static_cast<prim::DescTensor*>(out_grad.impl().get())->get_ptr();
target_block->RenameVar(out_grad_desc->Name(), "out@GRAD");
auto test = TestCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
Expand Down Expand Up @@ -353,6 +363,19 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;

Tensor out1_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out1_grad_desc =
static_cast<prim::DescTensor*>(out1_grad.impl().get())->get_ptr();
target_block->RenameVar(out1_grad_desc->Name(), "out1@GRAD");

Tensor out2_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out2_grad_desc =
static_cast<prim::DescTensor*>(out2_grad.impl().get())->get_ptr();
target_block->RenameVar(out2_grad_desc->Name(), "out2@GRAD");

auto test = TestCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
Expand Down
5 changes: 5 additions & 0 deletions test/prim/test_comp_get_grad_op_desc_prim_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def setUpClass(cls):
for n, vs in cls.outputs.items()
},
)

for _, outs in cls.outputs.items():
for out in outs:
block.create_var(name=out + core.grad_var_suffix())

cls.fwd = block.ops[0].desc

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions test/prim/test_comp_skip_op_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def setUp(self):
for n, vs in self.outputs.items()
},
)

for _, outs in self.outputs.items():
for out in outs:
block.create_var(name=out + core.grad_var_suffix())

self.fwd = block.ops[0].desc

def tearDown(self):
Expand Down