diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 0d66d8d96a9b4..b3657a9894f82 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include @@ -27,69 +26,21 @@ #include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/string/string_helper.h" -DEFINE_bool(generate_all, false, - "Generate all operators currently registered in Paddle"); +namespace paddle { +namespace framework { static std::unordered_map operators_with_attrs = {}; static std::unordered_set operators_to_skip = { - "pull_sparse", "pull_box_extended_sparse", "pull_sparse_v2", - "pull_box_sparse", "fused_attention", "diag_v2", - "c_split"}; + "chunk_eval", // Stupid tensor name + "minus", "pull_sparse", "pull_box_extended_sparse", + "pull_sparse_v2", "pull_box_sparse", "fused_attention", + "diag_v2", "c_split"}; static std::unordered_set operators_to_codegen = {}; static std::unordered_set skipped_operators = {}; -static void PrepareAttrMapForOps() { - // Handle "fused_elemwise_add_activation" - std::vector functor_list = {"a", "b"}; - operators_with_attrs["fused_elemwise_add_activation"] = {}; - operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] = - functor_list; - - // Handle "fused_elemwise_activation" - operators_with_attrs["fused_elemwise_activation"] = {}; - operators_with_attrs["fused_elemwise_activation"]["functor_list"] = - functor_list; - - // Handle "reverse" - std::vector axis = {0}; - operators_with_attrs["reverse"] = {}; - operators_with_attrs["reverse"]["axis"] = axis; - - // Handle "flip" - operators_with_attrs["flip"] = {}; - operators_with_attrs["flip"]["axis"] = axis; - - // Handle "cast" - operators_with_attrs["cast"] = {}; - operators_with_attrs["cast"]["out_dtype"] = 5; - operators_with_attrs["cast"]["in_dtype"] = 5; - - // Handle "transfer_dtype" - operators_with_attrs["transfer_dtype"] = {}; - operators_with_attrs["transfer_dtype"]["out_dtype"] = 5; - operators_with_attrs["transfer_dtype"]["in_dtype"] = 5; -} - -static void CollectOperatorsToCodeGen(const std::string& op_list_path) { - std::string line; - std::ifstream op_list_file(op_list_path); - if (op_list_file.is_open()) { - while (getline(op_list_file, line)) { - operators_to_codegen.insert(line); - } - op_list_file.close(); - } else { - PADDLE_THROW( - paddle::platform::errors::Fatal("Unable to open op_list.txt file")); - } -} - -namespace paddle { -namespace framework { - static std::string AttrTypeToString(const proto::AttrType& type) { std::string ret; switch (type) { @@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) { // Only handle matmul_v2 for now VLOG(1) << "------ Analyzing Op ------: " << op_type; - if (!FLAGS_generate_all) { - if (!operators_to_codegen.count(op_type)) return false; - } - + if (!operators_to_codegen.count(op_type)) return false; if (operators_to_skip.count(op_type)) return false; return true; @@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) { /* --------------------------------------- */ /* --------- Preprocess Ins/Outs --------- */ /* --------------------------------------- */ -static void PurifyOpProto( +static void PurifyForwardOpProto( const proto::OpProto& op_proto, std::unordered_map* fwd_inputs_name_pos_map, std::unordered_map* fwd_outputs_name_pos_map, - std::map* grad_outs_slotname_map, - std::map* grad_ins_fwd_slotname_map, - std::map* grad_ins_grad_slotname_map, std::vector* in_vars, - std::vector* out_vars, - std::map>>* - grad_ins, - std::map>>* - grad_outs) { + std::vector* out_vars) { // Op Name const std::string op_name = op_proto.type(); @@ -440,6 +379,72 @@ static void PurifyOpProto( } } in_vars->erase(iter); + } + } + } + + for (const proto::OpProto::Var& output : op_proto.outputs()) { + std::string output_name = output.name(); + + // Delete dispensable tensor unless specified in op_outs_map + if (output.dispensable()) { + if (!op_outs_map.count(op_name) || + !op_outs_map[op_name].count(output_name)) { + VLOG(6) << "Removing Dispensable Output: " << output_name; + + // out_vars + auto iter = out_vars->begin(); + for (iter = out_vars->begin(); iter != out_vars->end(); iter++) { + if (iter->name() == output_name) { + break; + } + } + out_vars->erase(iter); + } + } + } + + /* ------ Maping forward slot name to fwd position ------ */ + size_t in_pos = 0; + for (const auto& var : *in_vars) { + VLOG(6) << "Mapping input tensor: " << var.name() + << " To position: " << in_pos; + (*fwd_inputs_name_pos_map)[var.name()] = in_pos; + in_pos++; + } + + size_t out_pos = 0; + for (const auto& var : *out_vars) { + VLOG(6) << "Mapping output tensor: " << var.name() + << " To position: " << out_pos; + (*fwd_outputs_name_pos_map)[var.name()] = out_pos; + out_pos++; + } +} + +static void PurifyGradOpProto( + const proto::OpProto& op_proto, + std::map* grad_outs_slotname_map, + std::map* grad_ins_fwd_slotname_map, + std::map* grad_ins_grad_slotname_map, + std::map>>* + grad_ins, + std::map>>* + grad_outs) { + // Op Name + const std::string op_name = op_proto.type(); + + // Handle dispensable inputs + for (const proto::OpProto::Var& input : op_proto.inputs()) { + std::string input_name = input.name(); + + // Delete dispensable tensor unless specified in op_ins_map + if (input.dispensable()) { + if (!op_ins_map.count(op_name) || + !op_ins_map[op_name].count(input_name)) { + VLOG(6) << "Removing Dispensable Input: " << input_name; // grad_outs_slotname_map auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; @@ -478,15 +483,6 @@ static void PurifyOpProto( !op_outs_map[op_name].count(output_name)) { VLOG(6) << "Removing Dispensable Output: " << output_name; - // out_vars - auto iter = out_vars->begin(); - for (iter = out_vars->begin(); iter != out_vars->end(); iter++) { - if (iter->name() == output_name) { - break; - } - } - out_vars->erase(iter); - // grad_ins_grad_slotname_map auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map; for (const auto& iter : *grad_ins_grad_slotname_map) { @@ -514,52 +510,40 @@ static void PurifyOpProto( } } } - - /* ------ Maping forward slot name to fwd position ------ */ - size_t in_pos = 0; - for (const auto& var : *in_vars) { - VLOG(6) << "Mapping input tensor: " << var.name() - << " To position: " << in_pos; - (*fwd_inputs_name_pos_map)[var.name()] = in_pos; - in_pos++; - } - - size_t out_pos = 0; - for (const auto& var : *out_vars) { - VLOG(6) << "Mapping output tensor: " << var.name() - << " To position: " << out_pos; - (*fwd_outputs_name_pos_map)[var.name()] = out_pos; - out_pos++; - } } /* -------------------------------- */ /* --------- Collect Info --------- */ /* -------------------------------- */ -static bool CollectInformationFromOpInfo( +static void CollectForwardInformationFromOpInfo( const paddle::framework::OpInfo& op_info, - std::vector* grad_op_types, - std::map* grad_outs_slotname_map, - std::map* grad_ins_fwd_slotname_map, - std::map* grad_ins_grad_slotname_map, std::vector* in_vars, - std::vector* out_vars, - std::map>>* - grad_ins, - std::map>>* - grad_outs) { + std::vector* out_vars) { const proto::OpProto& op_proto = *op_info.proto_; - const std::string& op_type = op_proto.type(); - std::vector dims = {1, 1, 1, 1}; - for (const proto::OpProto::Var& input : op_proto.inputs()) { in_vars->push_back(input); } for (const proto::OpProto::Var& output : op_proto.outputs()) { out_vars->push_back(output); } +} + +static bool CollectGradInformationFromOpInfo( + const paddle::framework::OpInfo& op_info, bool* generate_forward_only, + std::vector* grad_op_types, // grad + std::map* grad_outs_slotname_map, // grad + std::map* grad_ins_fwd_slotname_map, // grad + std::map* grad_ins_grad_slotname_map, // grad + std::map>>* + grad_ins, // grad + std::map>>* + grad_outs // grad + ) { + const proto::OpProto& op_proto = *op_info.proto_; + const std::string& op_type = op_proto.type(); + std::vector dims = {1, 1, 1, 1}; /* ------ Prepare "ins" ------ */ std::mapsize() > 1) { // Backward attributes can be super complicated VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type; skipped_operators.insert(op_type); return false; } + */ VLOG(6) << "Prepared GradOpNode"; @@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent( /* --------- CodeGen: Forward ----- */ /* -------------------------------- */ static std::pair GenerateForwardFunctionContents( + bool generate_forward_only, const std::unordered_map& fwd_inputs_name_pos_map, const std::unordered_map& fwd_outputs_name_pos_map, const std::map& grad_ins_fwd_slotname_map, @@ -1044,7 +1029,6 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Attrs dygraph_function_args_str += ", const paddle::framework::AttributeMap& attr_map"; - generated_function_body += "\n"; // [Generation] Get TraceOp const char* FWD_TRACE_OP_TEMPLATE = @@ -1092,16 +1076,18 @@ static std::pair GenerateForwardFunctionContents( VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; // [Generation] ComputeRequireGrad -> GradNodeCreation - std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( - fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, - grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); - generated_function_body += grad_node_creation_body_str; - generated_function_body += "\n"; - VLOG(6) << "Generated GradNode Creation codes"; + if (!generate_forward_only) { + std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( + fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, + grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); + generated_function_body += grad_node_creation_body_str; + generated_function_body += "\n"; + VLOG(6) << "Generated GradNode Creation codes"; + } // [Generation] Handle return: Tuple/Vector/Tensor generated_function_body += "\n"; - std::string return_str; + std::string return_str = ""; std::string return_type_str = ""; std::string function_proto_return_type_str = ""; if (return_contents.size() > 1) { @@ -1124,14 +1110,20 @@ static std::pair GenerateForwardFunctionContents( const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>"; function_proto_return_type_str = paddle::string::Sprintf( FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str); - } else { + + } else if (return_contents.size() == 1) { // Return vector or Tensor return_type_str = return_types[0]; const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;"; return_str = paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]); function_proto_return_type_str = return_type_str; + + } else { + return_str = "return nullptr;"; + function_proto_return_type_str = "void*"; } + generated_function_body += return_str; generated_function_body += "\n"; VLOG(6) << "Generated return codes"; @@ -1139,6 +1131,11 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Full Function std::string function_name = op_type + "_dygraph_function"; + if (dygraph_function_args_str.size() > 0) { + auto iter = dygraph_function_args_str.begin(); + if ((*iter) == ',') dygraph_function_args_str.erase(iter); + } + const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n"; std::string fwd_function_str = paddle::string::Sprintf( FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name, @@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) { /* ---- Collect Information ---- */ /* ----------------------------- */ std::vector grad_op_types; + std::vector in_vars; + std::vector out_vars; std::map grad_outs_slotname_map; std::map grad_ins_fwd_slotname_map; std::map grad_ins_grad_slotname_map; - std::vector in_vars; - std::vector out_vars; std::map>> grad_ins; @@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) { grad_outs; VLOG(6) << "-------- CollectInformationFromOpInfo -------"; - bool is_available = CollectInformationFromOpInfo( - op_info, &grad_op_types, &grad_outs_slotname_map, - &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, &in_vars, - &out_vars, &grad_ins, &grad_outs); - if (!is_available) continue; + CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars); + + bool generate_forward_only = false; + bool is_available = CollectGradInformationFromOpInfo( + op_info, &generate_forward_only, &grad_op_types, + &grad_outs_slotname_map, &grad_ins_fwd_slotname_map, + &grad_ins_grad_slotname_map, &grad_ins, &grad_outs); + + if (!is_available && !generate_forward_only) { + VLOG(6) << "Skipped operator: " << op_type; + continue; + } VLOG(6) << "-------- PurifyOpProto -------"; std::unordered_map fwd_inputs_name_pos_map; std::unordered_map fwd_outputs_name_pos_map; - PurifyOpProto(*op_proto, &fwd_inputs_name_pos_map, - &fwd_outputs_name_pos_map, &grad_outs_slotname_map, - &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, - &in_vars, &out_vars, &grad_ins, &grad_outs); + PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map, + &fwd_outputs_name_pos_map, &in_vars, &out_vars); + + if (!generate_forward_only) { + PurifyGradOpProto(*op_proto, &grad_outs_slotname_map, + &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, + &grad_ins, &grad_outs); + } /* --------------------------- */ /* --------- CodeGen --------- */ @@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) { VLOG(6) << "-------- GenerateForwardFunctionContents -------"; std::pair body_and_declaration = GenerateForwardFunctionContents( - fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, - grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, - grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, - out_vars); + generate_forward_only, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, + grad_outs, op_type, in_vars, out_vars); + fwd_function_str += body_and_declaration.first + "\n"; /* ---- dygraph_forward_api.h ---- */ std::string fwd_function_declare_str = body_and_declaration.second; dygraph_forward_api_str += fwd_function_declare_str; + if (generate_forward_only) continue; + /* ---- nodes.h ---- */ VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; grad_node_h_str += @@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) { GenerateNodeCCFile(output_dir, grad_node_cc_str); } +static void PrepareAttrMapForOps() { + // Handle "fused_elemwise_add_activation" + std::vector functor_list = {"a", "b"}; + operators_with_attrs["fused_elemwise_add_activation"] = {}; + operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] = + functor_list; + + // Handle "fused_elemwise_activation" + operators_with_attrs["fused_elemwise_activation"] = {}; + operators_with_attrs["fused_elemwise_activation"]["functor_list"] = + functor_list; + + // Handle "reverse" + std::vector axis = {0}; + operators_with_attrs["reverse"] = {}; + operators_with_attrs["reverse"]["axis"] = axis; + + // Handle "flip" + operators_with_attrs["flip"] = {}; + operators_with_attrs["flip"]["axis"] = axis; + + // Handle "cast" + operators_with_attrs["cast"] = {}; + operators_with_attrs["cast"]["out_dtype"] = 5; + operators_with_attrs["cast"]["in_dtype"] = 5; + + // Handle "transfer_dtype" + operators_with_attrs["transfer_dtype"] = {}; + operators_with_attrs["transfer_dtype"]["out_dtype"] = 5; + operators_with_attrs["transfer_dtype"]["in_dtype"] = 5; +} + +static void CollectOperatorsToCodeGen(const std::string& op_list_path) { + std::string line; + std::ifstream op_list_file(op_list_path); + if (op_list_file.is_open()) { + while (getline(op_list_file, line)) { + operators_to_codegen.insert(line); + } + op_list_file.close(); + } else { + PADDLE_THROW( + paddle::platform::errors::Fatal("Unable to open op_list.txt file")); + } +} + } // namespace framework } // namespace paddle @@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) { std::string eager_root = argv[1]; std::string op_list_path = argv[2]; - CollectOperatorsToCodeGen(op_list_path); - PrepareAttrMapForOps(); + paddle::framework::CollectOperatorsToCodeGen(op_list_path); + paddle::framework::PrepareAttrMapForOps(); paddle::framework::DygraphCodeGeneration(eager_root); diff --git a/paddle/fluid/eager/auto_code_generator/op_list.txt b/paddle/fluid/eager/auto_code_generator/op_list.txt index 6bfba753633f3..2456a7a1846d1 100644 --- a/paddle/fluid/eager/auto_code_generator/op_list.txt +++ b/paddle/fluid/eager/auto_code_generator/op_list.txt @@ -215,7 +215,6 @@ spp floor gelu retinanet_detection_output -minus push_dense silu sequence_erase