Skip to content

Commit

Permalink
[Refactor] refactored eager_gen.py PR #2 (PaddlePaddle#40907)
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 authored Mar 25, 2022
1 parent 5f6038f commit f027b2a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
#############################
### File Reader Helpers ###
#############################
def AssertMessage(lhs_str, rhs_str):
return f"lhs: {lhs_str}, rhs: {rhs_str}"


def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
Expand All @@ -62,10 +66,10 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
assert 'backward_api' in content.keys(), AssertMessage('backward_api',
content.keys())
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False

ret[api_name] = content
f.close()
Expand Down Expand Up @@ -225,7 +229,7 @@ def ParseYamlReturns(string):
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]

assert "Tensor" in ret_type
assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type)
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import argparse
import os
import logging
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile
Expand All @@ -30,6 +31,7 @@
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage


###########
Expand Down Expand Up @@ -398,14 +400,21 @@ def DygraphYamlValidationCheck(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents

assert 'api' in forward_api_contents.keys()
assert 'args' in forward_api_contents.keys()
assert 'output' in forward_api_contents.keys()
assert 'backward' in forward_api_contents.keys()

assert 'args' in grad_api_contents.keys()
assert 'output' in grad_api_contents.keys()
assert 'forward' in grad_api_contents.keys()
assert 'api' in forward_api_contents.keys(
), "Unable to find \"api\" in api.yaml"
assert 'args' in forward_api_contents.keys(
), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in api.yaml"
assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml"

assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"

def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list
Expand All @@ -424,8 +433,10 @@ def ForwardsValidationCheck(self):
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]

assert forward_input_type == orig_input_type
assert forward_input_pos == orig_input_pos
assert forward_input_type == orig_input_type, AssertMessage(
forward_input_type, orig_input_type)
assert forward_input_pos == orig_input_pos, AssertMessage(
forward_input_pos, orig_input_pos)

for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
Expand All @@ -436,18 +447,23 @@ def ForwardsValidationCheck(self):
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
assert orig_attr_type == forward_attr_type, AssertMessage(
orig_attr_type, forward_attr_type)
assert orig_attr_default == forward_attr_default, AssertMessage(
orig_attr_default, forward_attr_default)
assert orig_attr_pos == forward_attr_pos, AssertMessage(
orig_attr_pos, forward_attr_pos)

for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]

assert orig_return_type == forward_return_type
assert orig_return_pos == forward_return_pos
assert orig_return_type == forward_return_type, AssertMessage(
orig_return_type, forward_return_type)
assert orig_return_pos == forward_return_pos, AssertMessage(
orig_return_pos, forward_return_pos)

# Check Order: Inputs, Attributes
max_input_position = -1
Expand All @@ -456,7 +472,8 @@ def ForwardsValidationCheck(self):

max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
max_attr_position = max(max_attr_position, pos)

def BackwardValidationCheck(self):
Expand All @@ -471,12 +488,14 @@ def BackwardValidationCheck(self):

max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_inputs_map.items():
assert pos > max_fwd_input_position
assert pos > max_fwd_input_position, AssertMessage(
pos, max_grad_tensor_position)
max_grad_tensor_position = max(max_grad_tensor_position, pos)

max_attr_position = -1
for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position
assert pos > max_grad_tensor_position, AssertMessage(
pos, max_grad_tensor_position)
max_attr_position = max(max_attr_position, pos)

def IntermediateValidationCheck(self):
Expand All @@ -491,7 +510,8 @@ def IntermediateValidationCheck(self):
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
assert pos in intermediate_positions, AssertMessage(
pos, intermediate_positions)

def CollectBackwardInfo(self):
forward_api_contents = self.forward_api_contents
Expand All @@ -505,9 +525,12 @@ def CollectBackwardInfo(self):

self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
backward_args_str, backward_returns_str)
print("Parsed Backward Inputs List: ", self.backward_inputs_list)
print("Prased Backward Attrs List: ", self.backward_attrs_list)
print("Parsed Backward Returns List: ", self.backward_returns_list)

logging.info(
f"Parsed Backward Inputs List: {self.backward_inputs_list}")
logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}")
logging.info(
f"Parsed Backward Returns List: {self.backward_returns_list}")

def CollectForwardInfoFromBackwardContents(self):

Expand All @@ -530,7 +553,9 @@ def SlotNameMatching(self):
backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name:
# Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys()
assert backward_fwd_name in forward_outputs_position_map.keys(
), AssertMessage(backward_fwd_name,
forward_outputs_position_map.keys())
matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[
Expand All @@ -556,17 +581,18 @@ def SlotNameMatching(self):
backward_input_type, False, backward_input_pos
]
else:
assert False, backward_input_name
assert False, f"Cannot find {backward_input_name} in forward position map"

for backward_output in backward_returns_list:
backward_output_name = backward_output[0]
backward_output_type = backward_output[1]
backward_output_pos = backward_output[2]

backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None
assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None"
assert backward_fwd_name in forward_inputs_position_map.keys(
), f"Unable to find {backward_fwd_name} in forward inputs"
), AssertMessage(backward_fwd_name,
forward_inputs_position_map.keys())

matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0]
Expand All @@ -577,12 +603,15 @@ def SlotNameMatching(self):
backward_output_type, matched_forward_input_pos,
backward_output_pos
]
print("Generated Backward Fwd Input Map: ",
self.backward_forward_inputs_map)
print("Generated Backward Grad Input Map: ",
self.backward_grad_inputs_map)
print("Generated Backward Grad Output Map: ",
self.backward_grad_outputs_map)
logging.info(
f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}"
)
logging.info(
f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}"
)
logging.info(
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)

def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
Expand Down Expand Up @@ -642,7 +671,7 @@ def GenerateNodeDeclaration(self):
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)

print("Generated Node Declaration: ", self.node_declaration_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")

def GenerateNodeDefinition(self):
namespace = self.namespace
Expand Down Expand Up @@ -710,7 +739,7 @@ def GenerateNodeDefinition(self):
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)

print("Generated Node Definition: ", self.node_definition_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}")

def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace
Expand Down Expand Up @@ -813,8 +842,10 @@ def GenerateForwardDefinition(self, is_inplaced):
dygraph_event_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"

print("Generated Forward Definition: ", self.forward_definition_str)
print("Generated Forward Declaration: ", self.forward_declaration_str)
logging.info(
f"Generated Forward Definition: {self.forward_definition_str}")
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")

def GenerateNodeCreationCodes(self, forward_call_str):
forward_api_name = self.forward_api_name
Expand Down Expand Up @@ -921,7 +952,8 @@ def GenerateNodeCreationCodes(self, forward_call_str):
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
Expand Down Expand Up @@ -1114,7 +1146,8 @@ def GetBackwardAPIContents(self, forward_api_contents):
if 'backward' not in forward_api_contents.keys(): return None

backward_api_name = forward_api_contents['backward']
assert backward_api_name in grad_api_dict.keys()
assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name]

return backward_api_contents
Expand Down

0 comments on commit f027b2a

Please sign in to comment.