Skip to content

Commit

Permalink
[DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul (#41387)
Browse files Browse the repository at this point in the history
* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues

* [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run

* Fixed minor issues

* Fixed issue with backward graph construction logic

* Fixed implementation issues with backward graph reconstruction

* Fixed unittest issue

* Fixed issues

* [DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul

* Fixed issues with phi kernel

* Added triple grad test case

* Fixed minor issue
  • Loading branch information
jim19930609 authored Apr 5, 2022
1 parent 84e8ae7 commit d8a1097
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
########################
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
ops_to_fill_zero_for_empty_grads = set([
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
"sigmoid_triple_grad"
])

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
Expand Down Expand Up @@ -171,12 +173,6 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"


def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string


def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage, GetIndent


Expand Down Expand Up @@ -483,10 +482,8 @@ def ForwardsValidationCheck(self):
orig_forward_returns_list = self.orig_forward_returns_list

for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]

Expand All @@ -496,11 +493,9 @@ def ForwardsValidationCheck(self):
forward_input_pos, orig_input_pos)

for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
Expand Down Expand Up @@ -1133,11 +1128,20 @@ def __init__(self,
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)

# Record name mapping from forward_api_name to grad_api_names
self.to_next_grad_name_mapping = {} # {name : name}

# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents

def TransformToNextGradName(self, string):
name_mapping = self.to_next_grad_name_mapping
if string in name_mapping.keys():
return name_mapping[string]
return string

def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1147,6 +1151,22 @@ def ResetOptionalInputs(self):

self.optional_inputs = base_generator.optional_inputs

def RecordGrad2NextGradNameMapping(self, next_node_generator):
next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
next_orig_returns_list = next_node_generator.orig_forward_returns_list

next_forward_inputs_list = next_node_generator.forward_inputs_list
next_forward_returns_list = next_node_generator.forward_returns_list
for i in range(len(next_orig_inputs_list)):
grad_name = next_orig_inputs_list[i][0]
next_forward_name = next_forward_inputs_list[i][0]
self.to_next_grad_name_mapping[grad_name] = next_forward_name

for i in range(len(next_orig_returns_list)):
grad_ret_name = next_orig_returns_list[i][0]
next_ret_name = next_forward_returns_list[i][0]
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1164,6 +1184,8 @@ def GenerateHigherOrderNodeCreationCode(self):
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str

self.RecordGrad2NextGradNameMapping(next_node_generator)

return grad_node_creation_str

def GenerateNodeDeclaration(self):
Expand Down Expand Up @@ -1253,8 +1275,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
Expand All @@ -1274,8 +1295,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype):
Expand Down Expand Up @@ -1316,8 +1336,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

if num_outputs == 1:
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
Expand All @@ -1339,8 +1358,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1358,8 +1376,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):

# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1382,8 +1399,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand Down Expand Up @@ -1417,8 +1433,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

# Infer Grad API Return Type
if num_bwd_outputs == 1:
Expand All @@ -1441,6 +1456,9 @@ def GenerateNodeDefinition(self, grad_node_creation_str):

grad_node_name = GetGradNodeName(forward_api_name)

if len(grad_node_creation_str) == 0:
grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"

self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
Expand All @@ -1457,11 +1475,11 @@ def run(self):
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()

# Higher-order GradNode generation
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()

self.GenerateNodeDeclaration()

self.GenerateNodeDefinition(grad_node_creation_str)


Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
dz->share_meta(z);
}
}
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
}

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
if (dl) {
dl->share_meta(l);
}
}

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
if (dx) {
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
MetaTensor* dy,
MetaTensor* dz);

void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk);

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl);

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);

void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,18 @@ void EluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout);

template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout) {
if (dout_new) {
Expand All @@ -262,10 +262,10 @@ void SigmoidDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/activation_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ KernelSignature TanhTripleGradOpArgumentMapping(
KernelSignature SigmoidDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sigmoid_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"});
"sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
}

KernelSignature SigmoidTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sigmoid_triple_grad",
{"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"},
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
{},
{"D_OutNew", "D_DOut", "D_DDx"});
}
Expand Down
Loading

0 comments on commit d8a1097

Please sign in to comment.