From fa990565a77ea368a914eff16ecdb67ce46d246d Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Thu, 9 Dec 2021 11:19:25 +0800 Subject: [PATCH] Revert "Adjusted Eager AutoCodeGen to Support Operators with Multiple OpBases & Enable Passing Output Tensor as Input Argument (#37943)" This reverts commit 9aed9ea072d69cb6843e0150136171564684ef3b. --- .../auto_code_generator/eager_generator.cc | 1081 +++++++---------- .../eager/auto_code_generator/op_list.txt | 1 - 2 files changed, 417 insertions(+), 665 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 2a5b158d315c3..fe29792b6e75c 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -29,11 +29,15 @@ namespace paddle { namespace framework { -/* --- Static maps to handle corner cases --- */ static std::unordered_map operators_with_attrs = {}; +static std::unordered_set operators_to_skip = { + "minus", +}; + static std::unordered_set operators_to_codegen = {}; +static std::unordered_set skipped_operators = {}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; @@ -41,132 +45,6 @@ static std::string LegalizeVariableName(const std::string& var_name) { return ret; } -/* --- Helper Objects --- */ -class ForwardGenerationInfo { - public: - const std::string& GetOpType() const { return op_type_; } - void SetOpType(const std::string& op_type) { op_type_ = op_type; } - - const std::unordered_map& GetFwdInputsNamePosMap() - const { - return fwd_inputs_name_pos_map_; - } - std::unordered_map* GetMutableFwdInputsNamePosMap() { - return &fwd_inputs_name_pos_map_; - } - - const std::unordered_map& GetFwdOutputsNamePosMap() - const { - return fwd_outputs_name_pos_map_; - } - std::unordered_map* GetMutableFwdOutputsNamePosMap() { - return &fwd_outputs_name_pos_map_; - } - - const std::vector& GetInVars() const { return in_vars_; } - std::vector* GetMutableInVars() { return &in_vars_; } - - const std::vector& GetOutVars() const { - return out_vars_; - } - std::vector* GetMutableOutVars() { return &out_vars_; } - - private: - std::string op_type_; - std::unordered_map fwd_inputs_name_pos_map_; - std::unordered_map fwd_outputs_name_pos_map_; - std::vector in_vars_; - std::vector out_vars_; -}; - -class GradNodeGenerationInfo { - class OpBaseGenerationInfo { - public: - const std::string& GetOpBaseType() const { return op_base_type_; } - void SetOpBaseType(const std::string& op_type) { op_base_type_ = op_type; } - - const std::map& GetGradOutsSlotnameMap() const { - return grad_outs_slotname_map_; - } - std::map* GetMutableGradOutsSlotnameMap() { - return &grad_outs_slotname_map_; - } - - const std::map& GetGradInsFwdSlotnameMap() const { - return grad_ins_fwd_slotname_map_; - } - std::map* GetMutableGradInsFwdSlotnameMap() { - return &grad_ins_fwd_slotname_map_; - } - - const std::map& GetGradInsGradSlotnameMap() - const { - return grad_ins_grad_slotname_map_; - } - std::map* GetMutableGradInsGradSlotnameMap() { - return &grad_ins_grad_slotname_map_; - } - - const std::map< - std::string, - std::vector>>& - GetGradIns() const { - return grad_ins_; - } - std::map>>* - GetMutableGradIns() { - return &grad_ins_; - } - - const std::map< - std::string, - std::vector>>& - GetGradOuts() const { - return grad_outs_; - } - std::map>>* - GetMutableGradOuts() { - return &grad_outs_; - } - - private: - std::string op_base_type_; - std::map grad_outs_slotname_map_; - std::map grad_ins_fwd_slotname_map_; - std::map grad_ins_grad_slotname_map_; - std::map>> - grad_ins_; - std::map>> - grad_outs_; - }; - - public: - const std::string& GetFwdOpType() const { return fwd_op_type_; } - void SetFwdOpType(const std::string& op_type) { fwd_op_type_ = op_type; } - - bool GenerateForwardOnly() const { return generate_forward_only_; } - void SetGenerateForwardOnly(bool generate_forward_only) { - generate_forward_only_ = generate_forward_only; - } - - const std::vector& GetOpBaseInfos() const { - return op_base_infos_; - } - std::vector* GetMutableOpBaseInfos() { - return &op_base_infos_; - } - - private: - std::string fwd_op_type_; - bool generate_forward_only_ = false; - std::vector op_base_infos_; -}; - -/* --- Helper Functions --- */ static std::string AttrTypeToString(const proto::AttrType& type) { std::string ret; switch (type) { @@ -470,6 +348,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) { VLOG(1) << "------ Analyzing Op ------: " << op_type; if (!operators_to_codegen.count(op_type)) return false; + if (operators_to_skip.count(op_type)) return false; return true; } @@ -477,16 +356,15 @@ static bool CheckOpProto(proto::OpProto* op_proto) { /* --------------------------------------- */ /* --------- Preprocess Ins/Outs --------- */ /* --------------------------------------- */ -static void PurifyForwardOpProto(const proto::OpProto& op_proto, - ForwardGenerationInfo* fwd_info) { +static void PurifyForwardOpProto( + const proto::OpProto& op_proto, + std::unordered_map* fwd_inputs_name_pos_map, + std::unordered_map* fwd_outputs_name_pos_map, + std::vector* in_vars, + std::vector* out_vars) { // Op Name const std::string op_name = op_proto.type(); - auto* in_vars = fwd_info->GetMutableInVars(); - auto* out_vars = fwd_info->GetMutableOutVars(); - auto* fwd_inputs_name_pos_map = fwd_info->GetMutableFwdInputsNamePosMap(); - auto* fwd_outputs_name_pos_map = fwd_info->GetMutableFwdOutputsNamePosMap(); - // Handle dispensable inputs for (const proto::OpProto::Var& input : op_proto.inputs()) { std::string input_name = input.name(); @@ -548,104 +426,6 @@ static void PurifyForwardOpProto(const proto::OpProto& op_proto, } } -static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto, - GradNodeGenerationInfo* bwd_info) { - auto* op_base_infos = bwd_info->GetMutableOpBaseInfos(); - for (auto& iter : *op_base_infos) { - std::map* grad_outs_slotname_map = - iter.GetMutableGradOutsSlotnameMap(); - std::map* grad_ins_fwd_slotname_map = - iter.GetMutableGradInsFwdSlotnameMap(); - std::map* grad_ins_grad_slotname_map = - iter.GetMutableGradInsGradSlotnameMap(); - std::map>>* - grad_ins = iter.GetMutableGradIns(); - std::map>>* - grad_outs = iter.GetMutableGradOuts(); - - // Op Name - const std::string op_name = op_proto.type(); - - // Handle dispensable inputs - for (const proto::OpProto::Var& input : op_proto.inputs()) { - std::string input_name = input.name(); - - // Delete dispensable tensor unless specified in op_ins_map - if (input.dispensable()) { - if (!op_ins_map.count(op_name) || - !op_ins_map[op_name].count(input_name)) { - VLOG(6) << "Removing Dispensable Input: " << input_name; - - // grad_outs_slotname_map - auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; - for (const auto& iter : *grad_outs_slotname_map) { - const std::string& grad_output_name = iter.first; - const std::string& matched_input_name = iter.second; - if (matched_input_name == input_name) { - grad_outs_slotname_map_purified.erase(grad_output_name); - - PADDLE_ENFORCE( - grad_outs->count(grad_output_name) > 0, - paddle::platform::errors::Fatal( - "Unable to find gradient output name in grad_outs.")); - // grad_outs - grad_outs->erase(grad_output_name); - } - } - *grad_outs_slotname_map = grad_outs_slotname_map_purified; - - // grad_ins_fwd_slotname_map: output as tensorwrapper - if (grad_ins_fwd_slotname_map->count(input_name)) - grad_ins_fwd_slotname_map->erase(input_name); - - // grad_ins: output as tensorwrapper - if (grad_ins->count(input_name)) grad_ins->erase(input_name); - } - } - } - - for (const proto::OpProto::Var& output : op_proto.outputs()) { - std::string output_name = output.name(); - - // Delete dispensable tensor unless specified in op_outs_map - if (output.dispensable()) { - if (!op_outs_map.count(op_name) || - !op_outs_map[op_name].count(output_name)) { - VLOG(6) << "Removing Dispensable Output: " << output_name; - - // grad_ins_grad_slotname_map - auto grad_ins_grad_slotname_map_purified = - *grad_ins_grad_slotname_map; - for (const auto& iter : *grad_ins_grad_slotname_map) { - const std::string& grad_input_name = iter.first; - const std::string& matched_output_name = iter.second; - if (matched_output_name == output_name) { - grad_ins_grad_slotname_map_purified.erase(grad_input_name); - - PADDLE_ENFORCE( - grad_ins->count(grad_input_name) > 0, - paddle::platform::errors::Fatal( - "Unable to find gradient input name in grad_ins.")); - // grad_ins - grad_ins->erase(grad_input_name); - } - } - *grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified; - - // grad_ins_fwd_slotname_map: output as tensorwrapper - if (grad_ins_fwd_slotname_map->count(output_name)) - grad_ins_fwd_slotname_map->erase(output_name); - - // grad_ins: output as tensorwrapper - if (grad_ins->count(output_name)) grad_ins->erase(output_name); - } - } - } - } -} - static void PurifyGradOpProto( const proto::OpProto& op_proto, std::map* grad_outs_slotname_map, @@ -740,22 +520,31 @@ static void PurifyGradOpProto( /* --------- Collect Info --------- */ /* -------------------------------- */ static void CollectForwardInformationFromOpInfo( - const paddle::framework::OpInfo& op_info, ForwardGenerationInfo* fwd_info) { + const paddle::framework::OpInfo& op_info, + std::vector* in_vars, + std::vector* out_vars) { const proto::OpProto& op_proto = *op_info.proto_; - - fwd_info->SetOpType(op_proto.type()); - for (const proto::OpProto::Var& input : op_proto.inputs()) { - fwd_info->GetMutableInVars()->push_back(input); + in_vars->push_back(input); } for (const proto::OpProto::Var& output : op_proto.outputs()) { - fwd_info->GetMutableOutVars()->push_back(output); + out_vars->push_back(output); } } static bool CollectGradInformationFromOpInfo( - const paddle::framework::OpInfo& op_info, - GradNodeGenerationInfo* bwd_info) { + const paddle::framework::OpInfo& op_info, bool* generate_forward_only, + std::vector* grad_op_types, // grad + std::map* grad_outs_slotname_map, // grad + std::map* grad_ins_fwd_slotname_map, // grad + std::map* grad_ins_grad_slotname_map, // grad + std::map>>* + grad_ins, // grad + std::map>>* + grad_outs // grad + ) { const proto::OpProto& op_proto = *op_info.proto_; const std::string& op_type = op_proto.type(); std::vector dims = {1, 1, 1, 1}; @@ -856,7 +645,7 @@ static bool CollectGradInformationFromOpInfo( /* ------ Run GradOpMaker ------ */ if (!op_info.dygraph_grad_op_maker_) { VLOG(6) << op_type << " has no GradOpMaker"; - bwd_info->SetGenerateForwardOnly(true); + *generate_forward_only = true; return false; } @@ -867,31 +656,32 @@ static bool CollectGradInformationFromOpInfo( if (!grad_node) { VLOG(6) << "Got nullptr GradOpNode for " << op_type << " likely registered EmptyGradOpMaker"; - bwd_info->SetGenerateForwardOnly(true); + *generate_forward_only = true; return false; } + /* + if (grad_node->size() > 1) { + // Backward attributes can be super complicated + VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type; + skipped_operators.insert(op_type); + return false; + } + */ + VLOG(6) << "Prepared GradOpNode"; - /* ---- Collect OpBase's op_types ---- */ - bwd_info->SetFwdOpType(op_type); - auto* op_base_infos = bwd_info->GetMutableOpBaseInfos(); - op_base_infos->resize(grad_node->size()); + /* ---- Collect Default Attr Map ---- */ for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { // Each OpBase - int index = std::distance(grad_node->begin(), iter); paddle::imperative::OpBase& op_base = *iter; - (*op_base_infos)[index].SetOpBaseType(op_base.Type()); + grad_op_types->push_back(op_base.Type()); } /* ------ Get Grad ins/outs ---- */ // In case of multiple OpBase, stitch all the respective ins/outs into one VLOG(6) << "In function size: " << grad_node->size(); for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { - int index = std::distance(grad_node->begin(), iter); - auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns(); - auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts(); - const paddle::imperative::OpBase& op_base = *iter; const std::map& g_ins = op_base.GetInsMap(); @@ -899,47 +689,34 @@ static bool CollectGradInformationFromOpInfo( g_outs = op_base.GetOutsMap(); for (const auto& it : g_ins) { - if (!op_base_grad_ins->count(it.first)) - (*op_base_grad_ins)[it.first] = {}; - + if (!grad_ins->count(it.first)) (*grad_ins)[it.first] = {}; for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); vw_iter++) { std::shared_ptr vw = *vw_iter; - - (*op_base_grad_ins)[it.first].push_back(vw); - - VLOG(6) << "GradIns Name: " << it.first; + (*grad_ins)[it.first].push_back(vw); } } for (const auto& it : g_outs) { - if (!op_base_grad_outs->count(it.first)) - (*op_base_grad_outs)[it.first] = {}; - + if (!grad_outs->count(it.first)) (*grad_outs)[it.first] = {}; for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); vw_iter++) { std::shared_ptr vw = *vw_iter; - - (*op_base_grad_outs)[it.first].push_back(vw); - - VLOG(6) << "GradOuts Name: " << it.first; + (*grad_outs)[it.first].push_back(vw); } } } /* ------ Slot Name Matching ---- */ - for (auto& iter : *op_base_infos) { - // grad_ins -> fwd_ins, fwd_outs - SlotNameMatching(iter.GetGradIns(), fwd_ins, fwd_outs, - iter.GetMutableGradInsFwdSlotnameMap(), - iter.GetMutableGradInsGradSlotnameMap()); - - // grad_outs -> fwd_ins, fwd_outs - SlotNameMatching(iter.GetGradOuts(), fwd_ins, fwd_outs, - iter.GetMutableGradOutsSlotnameMap(), - iter.GetMutableGradOutsSlotnameMap()); - } - VLOG(6) << "Finished Slotname Matching"; + // grad_ins -> fwd_ins, fwd_outs + SlotNameMatching(*grad_ins, fwd_ins, fwd_outs, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map); + VLOG(6) << "Finished Slotname Matching for Grad_Ins"; + + // grad_outs -> fwd_ins, fwd_outs + SlotNameMatching(*grad_outs, fwd_ins, fwd_outs, grad_outs_slotname_map, + grad_outs_slotname_map); + VLOG(6) << "Finished Slotname Matching for Grad_Outs"; return true; } @@ -948,20 +725,13 @@ static bool CollectGradInformationFromOpInfo( /* --------- CodeGen: Forward GradNode Creation ------ */ /* --------------------------------------------------- */ static std::string GenerateGradNodeCreationContent( - const ForwardGenerationInfo& fwd_info, - const GradNodeGenerationInfo& bwd_info) { + const std::unordered_map& fwd_inputs_name_pos_map, + const std::unordered_map& fwd_outputs_name_pos_map, + const std::map& grad_ins_fwd_slotname_map, + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating GradNode Creation codes"; - const std::string& op_type = fwd_info.GetOpType(); - const std::unordered_map& fwd_inputs_name_pos_map = - fwd_info.GetFwdInputsNamePosMap(); - const std::unordered_map& fwd_outputs_name_pos_map = - fwd_info.GetFwdOutputsNamePosMap(); - const std::vector& in_vars = fwd_info.GetInVars(); - const std::vector& out_vars = fwd_info.GetOutVars(); - - const auto& op_base_infos = bwd_info.GetOpBaseInfos(); - // [Generation] Construct GradOpNode // Run ComputeRequiredGrad @@ -1047,17 +817,12 @@ static std::string GenerateGradNodeCreationContent( // [GradOpNode] Set TensorWrappers grad_node_creation_str += " // Set Tensor Wrappers\n"; - for (const auto& iter : op_base_infos) { - const std::map& grad_ins_fwd_slotname_map = - iter.GetGradInsFwdSlotnameMap(); - for (auto& kv : grad_ins_fwd_slotname_map) { - const std::string& tensor_wrapper_name = kv.second; - const char* SET_TENSOR_WRAPPER_TEMPLATE = - " grad_node->SetTensorWrapper%s(%s);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE, - tensor_wrapper_name, tensor_wrapper_name); - } + for (auto& kv : grad_ins_fwd_slotname_map) { + const std::string& tensor_wrapper_name = kv.second; + const char* SET_TENSOR_WRAPPER_TEMPLATE = + " grad_node->SetTensorWrapper%s(%s);\n"; + grad_node_creation_str += paddle::string::Sprintf( + SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, tensor_wrapper_name); } grad_node_creation_str += "\n"; VLOG(6) << "Generated SetTensorWrapper"; @@ -1127,17 +892,22 @@ static std::string GenerateGradNodeCreationContent( /* --------- CodeGen: Forward ----- */ /* -------------------------------- */ static std::pair GenerateForwardFunctionContents( - const ForwardGenerationInfo& fwd_info, - const GradNodeGenerationInfo& bwd_info) { - /* --- Process Forward Info ---*/ - const std::string& op_type = fwd_info.GetOpType(); - const std::unordered_map& fwd_inputs_name_pos_map = - fwd_info.GetFwdInputsNamePosMap(); - const std::unordered_map& fwd_outputs_name_pos_map = - fwd_info.GetFwdOutputsNamePosMap(); - const std::vector& in_vars = fwd_info.GetInVars(); - const std::vector& out_vars = fwd_info.GetOutVars(); - + bool generate_forward_only, + const std::unordered_map& fwd_inputs_name_pos_map, + const std::unordered_map& fwd_outputs_name_pos_map, + const std::map& grad_ins_fwd_slotname_map, + const std::map& grad_ins_grad_slotname_map, + const std::map& grad_outs_slotname_map, + const std::map< + std::string, + std::vector>>& + grad_ins, + const std::map< + std::string, + std::vector>>& + grad_outs, + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { /* // Forward Function Example: std::tuple, Tensor, vector> @@ -1229,53 +999,24 @@ static std::pair GenerateForwardFunctionContents( for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); std::string outnum = "1"; - if (op_passing_outs_map[op_type].count(output_name)) { - const std::string output_var_name = output_name + "Var"; - - // Pass Output from function argument, - // in form of shared_ptr/vector> - if (output.duplicable()) { - const char* FWD_NUM_ARG_TEMPLATE = - ", std::vector>& %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); - dygraph_function_args_str += arg_str; - - const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; - outs_contents_str += paddle::string::Sprintf( - FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); - } else { - const char* FWD_NUM_ARG_TEMPLATE = - ", std::shared_ptr& %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); - dygraph_function_args_str += arg_str; - - const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", {%s} },"; - outs_contents_str += paddle::string::Sprintf( - FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); - } - + if (output.duplicable()) { + outnum = output_name + "Num"; + + const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s"; + std::string arg_str = + paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); + dygraph_function_args_str += arg_str; + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },"; + outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, + output_name, outnum); } else { - if (output.duplicable()) { - outnum = output_name + "Num"; - - const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); - dygraph_function_args_str += arg_str; - const char* FWD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },"; - outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, - output_name, outnum); - } else { - const char* FWD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " - "{std::make_shared(egr::Controller::Instance()." - "GenerateUniqueName())}},"; - outs_contents_str += - paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); - } + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += + paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); } } if (outs_contents_str.size() > 0) @@ -1343,9 +1084,10 @@ static std::pair GenerateForwardFunctionContents( VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; // [Generation] ComputeRequireGrad -> GradNodeCreation - if (!bwd_info.GenerateForwardOnly()) { - std::string grad_node_creation_body_str = - GenerateGradNodeCreationContent(fwd_info, bwd_info); + if (!generate_forward_only) { + std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( + fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, + grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); generated_function_body += grad_node_creation_body_str; generated_function_body += "\n"; VLOG(6) << "Generated GradNode Creation codes"; @@ -1420,16 +1162,22 @@ static std::pair GenerateForwardFunctionContents( /* --------- CodeGen: GradNode::operator() ------ */ /* ---------------------------------------------- */ static std::string GenerateGradNodeCCContents( - const ForwardGenerationInfo& fwd_info, - const GradNodeGenerationInfo& bwd_info) { - /* --- Process Forward Info --- */ - const std::string& fwd_op_type = fwd_info.GetOpType(); - const std::unordered_map& fwd_inputs_name_pos_map = - fwd_info.GetFwdInputsNamePosMap(); - const std::unordered_map& fwd_outputs_name_pos_map = - fwd_info.GetFwdOutputsNamePosMap(); - const std::vector& in_vars = fwd_info.GetInVars(); - + const std::vector& grad_op_types, + const std::unordered_map& fwd_inputs_name_pos_map, + const std::unordered_map& fwd_outputs_name_pos_map, + const std::map& grad_ins_fwd_slotname_map, + const std::map& grad_ins_grad_slotname_map, + const std::map& grad_outs_slotname_map, + const std::map< + std::string, + std::vector>>& + grad_ins, + const std::map< + std::string, + std::vector>>& + grad_outs, + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating Grad Node CC"; /* [Outline] @@ -1476,247 +1224,227 @@ static std::string GenerateGradNodeCCContents( */ std::string generated_grad_function_body = ""; - size_t outs_size = 0; - const auto& op_base_infos = bwd_info.GetOpBaseInfos(); - for (size_t i = 0; i < op_base_infos.size(); i++) { - const auto& op_base_info = op_base_infos[i]; - - const auto& grad_ins_fwd_slotname_map = - op_base_info.GetGradInsFwdSlotnameMap(); - const auto& grad_ins_grad_slotname_map = - op_base_info.GetGradInsGradSlotnameMap(); - const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap(); - const auto& grad_ins = op_base_info.GetGradIns(); - const auto& grad_outs = op_base_info.GetGradOuts(); - - const std::string& op_base_type = op_base_info.GetOpBaseType(); - const std::string& ins_name = "ins" + std::to_string(i); - const std::string& outs_name = "outs" + std::to_string(i); - - outs_size += grad_outs.size(); - - // [Generation] Get Ins Map - std::string ins_contents_str = ""; - for (auto iter : grad_ins) { - const std::string& grad_input_name = iter.first; - - if (grad_ins_fwd_slotname_map.count(grad_input_name)) { - // Fwd Tensor - std::string struct_fwd_input_name = - grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; - const char* GRAD_INS_FWD_CONTENT_TEMPLATE = - "{ \"%s\", " - "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(" - "&" - "this->%s, " - "nullptr)) },"; - ins_contents_str += - paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, - grad_input_name, struct_fwd_input_name); - - } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { - // Fwd Tensor's Grad - size_t fwd_output_position = fwd_outputs_name_pos_map.at( - grad_ins_grad_slotname_map.at(grad_input_name)); - const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; - ins_contents_str += - paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE, - grad_input_name, fwd_output_position); - } else { - PADDLE_THROW(platform::errors::Fatal( - "Detected mismatched slot names." - "Unable to find forward slot name that matches %s", - grad_input_name)); - } - } - if (ins_contents_str.size() > 0) - ins_contents_str.pop_back(); // // Remove trailing "," - - const char* BWD_INS_MAP_TEMPLATE = - " std::map>> %s = { " - "%s };\n"; - std::string ins_map_str = paddle::string::Sprintf( - BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); - generated_grad_function_body += ins_map_str; - - VLOG(6) << "Generated Ins Map"; - - // [Generation] Get Outs Map - std::unordered_set duplicable_input_name_set; - for (const auto& in : in_vars) { - if (in.duplicable()) duplicable_input_name_set.insert(in.name()); + // [Generation] Get Tracer + generated_grad_function_body += "\n"; + generated_grad_function_body += "\n"; + + // [Generation] Get Ins Map + std::string ins_contents_str = ""; + for (auto iter : grad_ins) { + const std::string& grad_input_name = iter.first; + + if (grad_ins_fwd_slotname_map.count(grad_input_name)) { + // Fwd Tensor + std::string struct_fwd_input_name = + grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; + const char* GRAD_INS_FWD_CONTENT_TEMPLATE = + "{ \"%s\", " + "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&" + "this->%s, " + "nullptr)) },"; + ins_contents_str += + paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, + grad_input_name, struct_fwd_input_name); + + } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { + // Fwd Tensor's Grad + size_t fwd_output_position = fwd_outputs_name_pos_map.at( + grad_ins_grad_slotname_map.at(grad_input_name)); + const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; + ins_contents_str += paddle::string::Sprintf( + GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); + + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_input_name)); } + } + if (ins_contents_str.size() > 0) + ins_contents_str.pop_back(); // // Remove trailing "," - std::string outs_contents_str = ""; - for (auto iter : grad_outs) { - const std::string& grad_output_name = iter.first; - - if (grad_outs_slotname_map.count(grad_output_name)) { - // Fwd Tensor - const std::string& fwd_name = - grad_outs_slotname_map.at(grad_output_name); - - /* Handle Special Case: "PullSparseOp", etc - - Forward: - - Ids W - | | - PullSparseOp - | - Out - - Backward: - - Ids GradOut W - | | | - PullSparseGradOp - | - GradOut - - Its grad output "GradOut" corresponds to forward output "Out", - where there is a hiden inplace involved. So we find "GradOut"'s - index - in - grads, and perform the inplace operation by constructing outs = - {{"Out", grads[i]}} - - GradOut -> Out -> fwd_output_pos -> grads position -> grads[i] - outs = {{"Out", grads[i]}} - - For returns, append "GradOut" to the very end of return list. - */ - if (!fwd_inputs_name_pos_map.count(fwd_name)) { - PADDLE_ENFORCE( - fwd_outputs_name_pos_map.count(fwd_name), - paddle::platform::errors::Fatal( - "fwd_name not found in fwd_inputs_name_pos_map nor " - "fwd_outputs_name_pos_map")); - - size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); - std::string grad_ptr_name = fwd_name + "_ptrs"; - const char* GET_GRADS_PTR_TEMPLATE = - " std::vector> %s;\n" - " for(const auto& t : grads[%d]) {\n " - "%s.emplace_back(std::move(std::make_shared(t))" - ");" - "\n }\n"; - std::string grads_ptr_str = - paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, - grads_position, grad_ptr_name); - generated_grad_function_body += grads_ptr_str; - generated_grad_function_body += "\n"; - - const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; - outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name); + const char* BWD_INS_MAP_TEMPLATE = + " std::map>> ins = { " + "%s };\n"; + std::string ins_map_str = + paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str); + generated_grad_function_body += ins_map_str; + + VLOG(6) << "Generated Ins Map"; + + // [Generation] Get Outs Map + std::unordered_set duplicable_input_name_set; + for (const auto& in : in_vars) { + if (in.duplicable()) duplicable_input_name_set.insert(in.name()); + } + + std::string outs_contents_str = ""; + for (auto iter : grad_outs) { + const std::string& grad_output_name = iter.first; + + if (grad_outs_slotname_map.count(grad_output_name)) { + // Fwd Tensor + const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name); + + /* Handle Special Case: "PullSparseOp", etc + + Forward: + + Ids W + | | + PullSparseOp + | + Out + + Backward: + + Ids GradOut W + | | | + PullSparseGradOp + | + GradOut + + Its grad output "GradOut" corresponds to forward output "Out", + where there is a hiden inplace involved. So we find "GradOut"'s index + in + grads, and perform the inplace operation by constructing outs = + {{"Out", grads[i]}} + + GradOut -> Out -> fwd_output_pos -> grads position -> grads[i] + outs = {{"Out", grads[i]}} + For returns, append "GradOut" to the very end of return list. + */ + if (!fwd_inputs_name_pos_map.count(fwd_name)) { + PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); + + size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); + std::string grad_ptr_name = fwd_name + "_ptrs"; + const char* GET_GRADS_PTR_TEMPLATE = + " std::vector> %s;\n" + " for(const auto& t : grads[%d]) {\n " + "%s.emplace_back(std::move(std::make_shared(t)));" + "\n }\n"; + std::string grads_ptr_str = + paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, + grads_position, grad_ptr_name); + generated_grad_function_body += grads_ptr_str; + generated_grad_function_body += "\n"; + + const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name); + + } else { + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + if (duplicable_input_name_set.count(fwd_name)) { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " + "this->OutputMeta()[%d].Size() ) },"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); } else { - size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); - if (duplicable_input_name_set.count(fwd_name)) { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " - "this->OutputMeta()[%d].Size() ) },"; - outs_contents_str += - paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, - grad_output_name, fwd_input_position); - } else { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " - "{std::make_shared(egr::Controller::Instance(" - ")." - "GenerateUniqueName())}},"; - outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); - } + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); } - } else { - PADDLE_THROW(platform::errors::Fatal( - "Detected mismatched slot names." - "Unable to find forward slot name that matches %s", - grad_output_name)); } + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_output_name)); } - if (outs_contents_str.size() > 0) - outs_contents_str.pop_back(); // // Remove trailing "," + } + if (outs_contents_str.size() > 0) + outs_contents_str.pop_back(); // // Remove trailing "," - const char* BWD_OUTS_MAP_TEMPLATE = - " std::map>> %s = { " - "%s };\n"; - std::string outs_map_str = paddle::string::Sprintf( - BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str); - generated_grad_function_body += outs_map_str; - generated_grad_function_body += "\n"; + const char* BWD_OUTS_MAP_TEMPLATE = + " std::map>> outs = { " + "%s };\n"; + std::string outs_map_str = + paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str); + generated_grad_function_body += outs_map_str; + generated_grad_function_body += "\n"; + + VLOG(6) << "Generated Outs Map"; - VLOG(6) << "Generated Outs Map"; + // [Generation] Get Attrs Map + std::string trace_opbase_str = ""; + for (size_t i = 0; i < grad_op_types.size(); i++) { + const std::string& op_base_type = grad_op_types[i]; - // [Generation] Get Attrs Map const char* TRACE_OP_TEMPLATE = " // Pass the entire attribute map to TraceOp\n" " // The underlying kernel will pickup whatever attribute they need " "at runtime\n" - " egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n" + " egr::legacy::RunOp(\"%s\", ins, outs, this->attr_map_,\n" " egr::Controller::Instance().GetExpectedPlace(),\n" " &this->default_attr_map_, false, {});\n"; - std::string trace_opbase_str = paddle::string::Sprintf( - TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name); + trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type); + } - generated_grad_function_body += trace_opbase_str; + generated_grad_function_body += trace_opbase_str; - VLOG(6) << "Generated Attrs Map"; + VLOG(6) << "Generated Attrs Map"; - // [Generation] Get Return - std::string outputs_str = ""; - size_t num_appended_outputs = 0; - for (auto iter : grad_outs) { - const std::string& grad_out_name = iter.first; - const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); + // [Generation] Get Return + std::string outputs_str = ""; + size_t num_appended_outputs = 0; + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); - if (fwd_inputs_name_pos_map.count(fwd_name)) { - size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); - const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; - outputs_str += paddle::string::Sprintf( - BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name); - num_appended_outputs++; - } else { - PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), - paddle::platform::errors::Fatal( - "fwd_name not found in fwd_inputs_name_pos_map nor " - "fwd_outputs_name_pos_map")); - } + if (fwd_inputs_name_pos_map.count(fwd_name)) { + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, + fwd_input_position, grad_out_name); + num_appended_outputs++; + } else { + PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); } + } - /* Handle Special Case: "PullSparseOp", etc - For returns, append "GradOut" to the very end of return list. */ - for (auto iter : grad_outs) { - const std::string& grad_out_name = iter.first; - const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); - - if (fwd_outputs_name_pos_map.count(fwd_name)) { - const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; - outputs_str += - paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs, - outs_name, grad_out_name); - num_appended_outputs++; - } - } + /* Handle Special Case: "PullSparseOp", etc + For returns, append "GradOut" to the very end of return list. */ + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); - generated_grad_function_body += outputs_str; - generated_grad_function_body += "\n"; + if (fwd_outputs_name_pos_map.count(fwd_name)) { + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf( + BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name); + num_appended_outputs++; + } } const char* BWD_RETURN_TEMPLATE = - " std::vector> outputs(%d);\n" - " %s\n" - " return outputs;\n"; - generated_grad_function_body = paddle::string::Sprintf( - BWD_RETURN_TEMPLATE, outs_size, generated_grad_function_body); + " std::vector> " + "outputs(outs.size());\n%s\n " + "return outputs;"; + std::string return_str = + paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str); + + generated_grad_function_body += "\n"; + generated_grad_function_body += return_str; // [Generation] Get Full Grad Function const char* GRAD_FUNCTION_TEMPLATE = @@ -1724,7 +1452,7 @@ static std::string GenerateGradNodeCCContents( "GradNode%s::operator()(const " "std::vector>& grads) {\n%s\n}"; std::string grad_function_str = paddle::string::Sprintf( - GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body); + GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body); VLOG(6) << "Generated returns"; @@ -1735,14 +1463,9 @@ static std::string GenerateGradNodeCCContents( /* --------- CodeGen: GradNode Header ------ */ /* ----------------------------------------- */ static std::string GenerateGradNodeHeaderContents( - const ForwardGenerationInfo& fwd_info, - const GradNodeGenerationInfo& bwd_info) { - const std::string& op_type = fwd_info.GetOpType(); - const std::vector& in_vars = fwd_info.GetInVars(); - const std::vector& out_vars = fwd_info.GetOutVars(); - - const auto& op_base_infos = bwd_info.GetOpBaseInfos(); - + const std::map& grad_ins_fwd_slotname_map, + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating Grad Node Header"; const char* GRAD_NODE_TEMPLATE = @@ -1799,60 +1522,55 @@ static std::string GenerateGradNodeHeaderContents( std::string set_tensor_wrappers_str = ""; std::string tensor_wrapper_members_str = ""; - for (const auto& iter : op_base_infos) { - const std::map& grad_ins_fwd_slotname_map = - iter.GetGradInsFwdSlotnameMap(); - - for (const auto& kv : grad_ins_fwd_slotname_map) { - const std::string& tensor_wrapper_name = kv.second; - const std::string& struct_tensor_wrapper_name = kv.second + "_"; - - std::string tensor_wrapper_arg_str; - std::string tensor_wrapper_body_str; - if (duplicable_tensors.count(tensor_wrapper_name)) { - const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = - "const std::vector& %s"; - tensor_wrapper_arg_str = paddle::string::Sprintf( - ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); - - const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = - " std::vector %s;\n"; - tensor_wrapper_members_str += paddle::string::Sprintf( - TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); - - const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "for(const auto& eager_tensor : %s) {\n" - " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " - "/*full_reserved*/) );\n" - " }\n"; - tensor_wrapper_body_str = paddle::string::Sprintf( - SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, - struct_tensor_wrapper_name); - - } else { - const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = - "const egr::EagerTensor& %s"; - tensor_wrapper_arg_str = paddle::string::Sprintf( - ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); - - const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = - " egr::TensorWrapper %s;\n"; - tensor_wrapper_members_str += paddle::string::Sprintf( - TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); - - const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "%s = egr::TensorWrapper(%s, true /*full_reserved*/);"; - tensor_wrapper_body_str = paddle::string::Sprintf( - SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, - tensor_wrapper_name); - } + for (const auto& kv : grad_ins_fwd_slotname_map) { + const std::string& tensor_wrapper_name = kv.second; + const std::string& struct_tensor_wrapper_name = kv.second + "_"; + + std::string tensor_wrapper_arg_str; + std::string tensor_wrapper_body_str; + if (duplicable_tensors.count(tensor_wrapper_name)) { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const std::vector& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " std::vector %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "for(const auto& eager_tensor : %s) {\n" + " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " + "/*full_reserved*/) );\n" + " }\n"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, + struct_tensor_wrapper_name); - const char* SET_TENSOR_WRAPPER_TEMPLATE = - " void SetTensorWrapper%s(%s) {\n %s\n }\n"; - set_tensor_wrappers_str += paddle::string::Sprintf( - SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, - tensor_wrapper_arg_str, tensor_wrapper_body_str); - } + } else { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const egr::EagerTensor& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " egr::TensorWrapper %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "%s = egr::TensorWrapper(%s, true /*full_reserved*/);"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, + tensor_wrapper_name); + } + + const char* SET_TENSOR_WRAPPER_TEMPLATE = + " void SetTensorWrapper%s(%s) {\n %s\n }\n"; + set_tensor_wrappers_str += paddle::string::Sprintf( + SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, + tensor_wrapper_arg_str, tensor_wrapper_body_str); } VLOG(6) << "Generated TensorWrapper"; @@ -1964,62 +1682,97 @@ static void DygraphCodeGeneration(const std::string& output_dir) { /* ----------------------------- */ /* ---- Collect Information ---- */ /* ----------------------------- */ - - ForwardGenerationInfo fwd_info; - GradNodeGenerationInfo bwd_info; + std::vector grad_op_types; + std::vector in_vars; + std::vector out_vars; + std::map grad_outs_slotname_map; + std::map grad_ins_fwd_slotname_map; + std::map grad_ins_grad_slotname_map; + std::map>> + grad_ins; + std::map>> + grad_outs; VLOG(6) << "-------- CollectInformationFromOpInfo -------"; - CollectForwardInformationFromOpInfo(op_info, &fwd_info); + CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars); - bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info); + bool generate_forward_only = false; + bool is_available = CollectGradInformationFromOpInfo( + op_info, &generate_forward_only, &grad_op_types, + &grad_outs_slotname_map, &grad_ins_fwd_slotname_map, + &grad_ins_grad_slotname_map, &grad_ins, &grad_outs); - if (!is_available && !bwd_info.GenerateForwardOnly()) { + if (!is_available && !generate_forward_only) { VLOG(6) << "Skipped operator: " << op_type; continue; } VLOG(6) << "-------- PurifyOpProto -------"; - PurifyForwardOpProto(*op_proto, &fwd_info); - if (!bwd_info.GenerateForwardOnly()) { - PurifyGradNodeGenerationInfo(*op_proto, &bwd_info); + std::unordered_map fwd_inputs_name_pos_map; + std::unordered_map fwd_outputs_name_pos_map; + PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map, + &fwd_outputs_name_pos_map, &in_vars, &out_vars); + + if (!generate_forward_only) { + PurifyGradOpProto(*op_proto, &grad_outs_slotname_map, + &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, + &grad_ins, &grad_outs); } /* --------------------------- */ /* --------- CodeGen --------- */ /* --------------------------- */ + /* ---- forward_dygraph_functions.cc ---- */ VLOG(6) << "-------- GenerateForwardFunctionContents -------"; std::pair body_and_declaration = - GenerateForwardFunctionContents(fwd_info, bwd_info); + GenerateForwardFunctionContents( + generate_forward_only, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, + grad_outs, op_type, in_vars, out_vars); fwd_function_str += body_and_declaration.first + "\n"; - VLOG(6) << "-------- GenerateDygraphForwardAPIContents -------"; + /* ---- dygraph_forward_api.h ---- */ std::string fwd_function_declare_str = body_and_declaration.second; dygraph_forward_api_str += fwd_function_declare_str; - if (bwd_info.GenerateForwardOnly()) continue; + if (generate_forward_only) continue; + /* ---- nodes.h ---- */ VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; - grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info); - grad_node_h_str += "\n"; + grad_node_h_str += + GenerateGradNodeHeaderContents(grad_ins_fwd_slotname_map, op_type, + in_vars, out_vars) + + "\n"; + /* ---- nodes.cc ---- */ VLOG(6) << "-------- GenerateGradNodeCCContents -------"; - grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info); - grad_node_cc_str += "\n"; + grad_node_cc_str += GenerateGradNodeCCContents( + grad_op_types, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map, grad_outs_slotname_map, + grad_ins, grad_outs, op_type, in_vars, out_vars) + + "\n"; VLOG(6) << op_type << ": Finished Generating Op: " << op_type; } - + /* ---- dygraph_forward_function.cc ---- */ VLOG(6) << "-------- GenerateDygraphForwardCCFile -------"; GenerateForwardDygraphFile(output_dir, fwd_function_str); + /* ---- dygraph_forward_api.h ---- */ VLOG(6) << "-------- GenerateForwardHFile -------"; GenerateForwardHFile(output_dir, dygraph_forward_api_str); + /* ---- nodes.h ---- */ VLOG(6) << "-------- GenerateNodeHFile -------"; GenerateNodeHFile(output_dir, grad_node_h_str); + /* ---- nodes.cc ---- */ VLOG(6) << "-------- GenerateNodeCCFile -------"; GenerateNodeCCFile(output_dir, grad_node_cc_str); } diff --git a/paddle/fluid/eager/auto_code_generator/op_list.txt b/paddle/fluid/eager/auto_code_generator/op_list.txt index d3e835a1d0355..699a84169d700 100644 --- a/paddle/fluid/eager/auto_code_generator/op_list.txt +++ b/paddle/fluid/eager/auto_code_generator/op_list.txt @@ -237,7 +237,6 @@ spp floor gelu retinanet_detection_output -minus push_dense silu sequence_erase