From d5fae112355e0607f6146d3d4a24f857110106ba Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Sun, 29 Oct 2023 10:34:54 +0000 Subject: [PATCH 1/2] forward decomp interface code gen --- .../decomp_interface_gen_op_list.py | 3 + .../fluid/pir/dialect/op_generator/op_gen.py | 106 +++--- .../pir/dialect/operator/ir/CMakeLists.txt | 4 +- .../dialect/operator/ir/manual_op_decomp.cc | 32 -- paddle/fluid/primitive/codegen/CMakeLists.txt | 18 + paddle/fluid/primitive/codegen/decomp_gen.py | 335 ++++++++++++++++++ .../templates/decomp/generated_decomp.j2 | 113 ++++++ 7 files changed, 530 insertions(+), 81 deletions(-) create mode 100644 paddle/fluid/primitive/codegen/decomp_gen.py create mode 100644 paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 2c559330eec99..334d410e7dab3 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -17,6 +17,9 @@ # ===================================== +# come into effect in generated file pd_op.h decomp_interface_declare_gen_op_list = ['mean'] +# come into effect in generated file op_decomp.cc +# manual decomp interface implementation are located in manual_op_decomp.cc decomp_interface_implementation_gen_op_list = ["mean"] diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 185a874615d39..70ab49e17a509 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -206,6 +206,55 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ } +attr_types_map = { + 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], + 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], + 'Scalar(int)': ['pir::Int32Attribute', 'int'], + 'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'], + 'Scalar(float)': ['pir::FloatAttribute', 'float'], + 'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'], + 'Scalar[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'int': ['pir::Int32Attribute', 'int'], + 'int32_t': ['pir::Int32Attribute', 'int32_t'], + 'int64_t': ['pir::Int64Attribute', 'int64_t'], + 'long': ['pir::LongAttribute', 'long'], + 'size_t': ['pir::Size_tAttribute', 'size_t'], + 'float': ['pir::FloatAttribute', 'float'], + 'float[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'double': ['pir::DoubleAttribute', 'double'], + 'bool': ['pir::BoolAttribute', 'bool'], + 'bool[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'str': ['pir::StrAttribute', 'const std::string&'], + 'str[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], + 'DataLayout': [ + 'paddle::dialect::DataLayoutAttribute', + 'DataLayout', + ], + 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], + 'int64_t[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], + 'int[]': [ + 'pir::ArrayAttribute', + 'const std::vector&', + ], +} + + def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) names = op_item.split('(') @@ -287,53 +336,7 @@ def __init__(self, op_yaml_item, op_compat_item): ) # parse attributes - self.attr_types_map = { - 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], - 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], - 'Scalar(int)': ['pir::Int32Attribute', 'int'], - 'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'], - 'Scalar(float)': ['pir::FloatAttribute', 'float'], - 'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'], - 'Scalar[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'int': ['pir::Int32Attribute', 'int'], - 'int32_t': ['pir::Int32Attribute', 'int32_t'], - 'int64_t': ['pir::Int64Attribute', 'int64_t'], - 'long': ['pir::LongAttribute', 'long'], - 'size_t': ['pir::Size_tAttribute', 'size_t'], - 'float': ['pir::FloatAttribute', 'float'], - 'float[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'double': ['pir::DoubleAttribute', 'double'], - 'bool': ['pir::BoolAttribute', 'bool'], - 'bool[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'str': ['pir::StrAttribute', 'const std::string&'], - 'str[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], - 'DataLayout': [ - 'paddle::dialect::DataLayoutAttribute', - 'DataLayout', - ], - 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], - 'int64_t[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - 'int[]': [ - 'pir::ArrayAttribute', - 'const std::vector&', - ], - } + self.attr_types_map = attr_types_map self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() self.attribute_build_arg_type_list = ( @@ -1051,12 +1054,19 @@ def OpGenerator( mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic( op_info, op_info_items ) + op_interfaces_tmp = op_interfaces + exclusive_interface_str_tmp = exclusive_interface_str # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: if op_name in decomp_interface_declare_gen_op_list: - op_interfaces += ["paddle::dialect::DecompInterface"] + op_interfaces = op_interfaces + [ + "paddle::dialect::DecompInterface" + ] exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" + else: + op_interfaces = op_interfaces_tmp + exclusive_interface_str = exclusive_interface_str_tmp if op_name in PD_MANUAL_OP_LIST: continue if op_kernel_map is None: diff --git a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt index 26343fa724968..fac2c40dc1ce4 100644 --- a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt @@ -50,6 +50,7 @@ set(op_header_file_tmp ${op_header_file}.tmp) set(op_source_file_tmp ${op_source_file}.tmp) set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc) +set(op_decomp_source_file ${PD_DIALECT_BINARY_DIR}/op_decomp.cc) set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp) execute_process( @@ -202,6 +203,7 @@ target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR}) cc_library( pd_op_dialect - SRCS op_dialect.cc manual_op_decomp.cc manual_op_vjp.cc ${op_vjp_source_file} + SRCS op_dialect.cc manual_op_decomp.cc ${op_decomp_source_file} + manual_op_vjp.cc ${op_vjp_source_file} DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental pd_op_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc index e6c84ca293477..43ffd4657c1fb 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc @@ -29,37 +29,5 @@ namespace paddle { namespace dialect { using IntArray = paddle::experimental::IntArray; -std::vector> MeanOp::Decomp(pir::Operation* op) { - MeanOp op_obj = op->dyn_cast(); - (void)op_obj; - - VLOG(4) << "Decomp Prepare inputs of mean"; - - Tensor x(std::make_shared(op_obj.x())); - - VLOG(4) << "Decomp prepare attributes of mean"; - - IntArray axis = op->attribute("axis") - .dyn_cast() - .data(); - - bool keepdim = op->attribute("keepdim").dyn_cast().data(); - VLOG(4) << "Decomp mean keep_dim " << keepdim; - - VLOG(4) << "Decomp prepare call mean's decomp interface"; - - Tensor op_res = - paddle::primitive::details::mean_decomp( - x, axis, keepdim); - - auto org_res = op->results(); - std::vector> res(org_res.size()); - res[0].push_back( - std::static_pointer_cast(op_res.impl()) - ->value() - .dyn_cast()); - return res; -} - } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/codegen/CMakeLists.txt b/paddle/fluid/primitive/codegen/CMakeLists.txt index e081da5b5dfe0..d4d52c8299d63 100644 --- a/paddle/fluid/primitive/codegen/CMakeLists.txt +++ b/paddle/fluid/primitive/codegen/CMakeLists.txt @@ -33,3 +33,21 @@ if(${_result}) "Automatic code generation for paddle/fluid/primitive failed, exiting.") endif() message("Automatic code generation for paddle/fluid/primitive succeed.") + +execute_process( + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/primitive/codegen + COMMAND + ${PYTHON_EXECUTABLE} + ${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/decomp_gen.py --fwd_path + ${fwd_path} --fwd_legacy_path ${fwd_legacy_path} --fwd_pd_op_path + ${fwd_pd_op_path} --templates_dir ${templates_dir} --compat_path + ${compat_path} --destination_dir + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc + RESULT_VARIABLE _result) +if(${_result}) + message( + FATAL_ERROR + "Automatic code generation for build/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc failed." + ) +endif() +message("Automatic code generation for decomp interface succeed.") diff --git a/paddle/fluid/primitive/codegen/decomp_gen.py b/paddle/fluid/primitive/codegen/decomp_gen.py new file mode 100644 index 0000000000000..4375092cad2cc --- /dev/null +++ b/paddle/fluid/primitive/codegen/decomp_gen.py @@ -0,0 +1,335 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import pathlib +import sys + +import jinja2 +import yaml + +# fmt: off +# import from paddle/fluid/operators/generator +sys.path.append( + str(pathlib.Path(__file__).resolve().parents[2] / 'operators/generator') +) +import filters as op_gen_filters +import tests_utils as op_gen_tests +from parse_utils import to_named_dict +from type_mapping import output_type_map + +# import from paddle/fluid/pir/dialect/op_generator/api_gen.py +sys.path.append( + str(pathlib.Path(__file__).resolve().parents[2] / 'pir/dialect/op_generator') +) + +from decomp_interface_gen_op_list import ( + decomp_interface_implementation_gen_op_list, +) +from op_gen import attr_types_map, to_pascal_case + +# fmt: on + + +def load(path: pathlib.Path): + """Load config from yaml file. + + Args: + path (pathlib.Path): The path of yaml config. + + Returns: + dict: The config info. + + """ + with open(path, 'rt') as f: + return yaml.safe_load(f) + + +def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): + """Render and save Jinja2 templates to the destination directory. + + Args: + src_dir (pathlib.Path): The source directory containing Jinja2 templates. + dst_dir (pathlib.Path): The destination directory to save rendered files. + *args: Additional positional arguments passed to the `render` function. + **kwargs: Additional keyword arguments passed to the `render` function. + + Returns: + None + """ + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(src_dir), + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + extensions=['jinja2.ext.do'], + ) + env.filters.update( + { + 'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type, + 'to_paddle_input_type': op_gen_filters.to_paddle_input_type, + 'to_paddle_output_type': op_gen_filters.to_paddle_output_type, + 'trip_intermediate': op_gen_filters.filter_intermediate, + } + ) + env.tests.update( + { + 'scalar': op_gen_tests.is_scalar, + 'intarray': op_gen_tests.is_intarray, + 'datatype': op_gen_tests.is_datatype, + 'exist_mutable_attribute': op_gen_tests.exist_mutable_attribute, + 'mutable_attribute': op_gen_tests.is_mutable_attribute, + 'only_composite_op': op_gen_tests.is_only_composite_op, + } + ) + + decomp_temp = "decomp/generated_decomp.j2" + save( + env.get_template(decomp_temp).render(*args, **kwargs), + pathlib.Path(dst_dir), + ) + + +def save(content: str, path: pathlib.Path): + """Saves the given string contents to a file in the specified path. + + Args: + content (str): The string content that needs to be saved. + path (pathlib.Path): The path to save the file, a Pathlib path object + + Returns: + None + """ + path.parent.mkdir(parents=True, exist_ok=True) + + dst_content = '' + if path.is_file(): + with open(path, 'r') as f: + dst_content = f.read() + + if ( + hashlib.md5(content.encode("UTF-8")).hexdigest() + != hashlib.md5(dst_content.encode("UTF-8")).hexdigest() + ): + with open(path, 'w') as f: + f.write(content) + print(f"Generate or cover source file {path}") + + +def filter_compat_info(items): + for item in items: + item['op'] = item['op'].split('(')[0].strip() + if 'backward' in item: + item_backwards = item['backward'].split(',') + for idx, item_backward in enumerate(item_backwards): + item_backward = item_backward.split('(')[0].strip() + item_backwards[idx] = item_backward + item['backward'] = ( + ','.join(item_backwards) + if len(item_backwards) > 0 + else item_backwards[0] + ) + + +def extend_compat_info(apis, compats): + for api in apis: + attrs = api["attrs"] + for attr in attrs: + if op_gen_tests.is_scalar( + attr['typename'] + ) or op_gen_tests.is_intarray(attr['typename']): + attr["support_tensor"] = False + apis_dict = to_named_dict(apis) + for compat_item in compats: + fwd_op_name = compat_item["op"] + if fwd_op_name not in apis_dict: + continue + fwd_api = apis_dict[fwd_op_name] + backward_op_names = [] + while fwd_op_name is not None and fwd_op_name in apis_dict: + backward_op_names.append(apis_dict[fwd_op_name]['backward']) + fwd_op_name = apis_dict[fwd_op_name]['backward'] + backward_apis = [] + for backward_op_name in backward_op_names: + if backward_op_name in apis_dict: + backward_apis.append(apis_dict[backward_op_name]) + support_tensor_attrs_names = [] + compat_attrs_data_type = {} + if 'scalar' in compat_item and compat_item['op'] != "pow": + for attr_name, attr_info in compat_item['scalar'].items(): + if ( + 'support_tensor' in attr_info + and attr_info['support_tensor'] is True + or 'tensor_name' in attr_info + ): + support_tensor_attrs_names.append(attr_name) + if 'data_type' in attr_info: + compat_attrs_data_type.update( + {attr_name: attr_info['data_type']} + ) + if 'int_array' in compat_item: + for attr_name, attr_info in compat_item['int_array'].items(): + if ( + 'support_tensor' in attr_info + and attr_info['support_tensor'] is True + or 'tensor_name' in attr_info + or 'tensors_name' in attr_info + ): + support_tensor_attrs_names.append(attr_name) + if len(support_tensor_attrs_names) > 0: + for api in [fwd_api] + backward_apis: + attrs = api["attrs"] + for attr in attrs: + if attr['name'] in support_tensor_attrs_names: + attr['support_tensor'] = True + for api in [fwd_api] + backward_apis: + attrs = api["attrs"] + for attr in attrs: + if attr['name'] in compat_attrs_data_type: + attr['data_type'] = compat_attrs_data_type[attr['name']] + return apis + + +def process_optional_output_info(apis): + for api in apis: + inputs_dict = to_named_dict(api['inputs']) + for output in api['outputs']: + if ( + api.get("inplace", None) + and output['name'] in api['inplace'] + and inputs_dict[api['inplace'][output['name']]]['optional'] + ): + output['optional'] = True + else: + output['optional'] = False + + +def gen( + fwd_path: pathlib.Path, + fwd_legacy_path: pathlib.Path, + compat_path: pathlib.Path, + fwd_pd_op_path: pathlib.Path, + templates_dir: pathlib.Path, + destination_dir: pathlib.Path, +): + """The `gen` load jinja2 templates and relative config info, use jinja2 + templating engine to generate c++ code, and save the code into destination. + + Args: + prim_path (pathlib.Path): The YAML file path of the primitive API. + fwd_path (pathlib.Path): The YAML file path of the forwad API. + fwd_legacy_path (pathlib.Path): The YAML file path of the legacy + forwad API. + rev_path (pathlib.Path): The YAML file path of the backward API. + rev_legacy_path (pathlib.Path): The YAML file path of the legacy + backward API. + compat_path: (pathlib.Path): The YAML file path of the ops compat. + fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API. + rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API. + templates_dir (pathlib.Path): The directory of the templates. + destination_dir (pathlib.Path): The Directory of the generated file. + + Returns: + None + """ + ( + fwds, + legacy_fwds, + compats, + ir_fwds, + ) = ( + load(fwd_path), + load(fwd_legacy_path), + load(compat_path), + load(fwd_pd_op_path), + ) + filter_compat_info(compats) + apis = [ + {**api, **{'class_name': to_pascal_case(api["name"]) + "Op"}} + for api in fwds + legacy_fwds + ir_fwds + ] + + apis = extend_compat_info(apis, compats) + + process_optional_output_info(apis) + + for item in apis: + for attr_item in item["attrs"]: + if attr_item["typename"] not in attr_types_map.keys(): + raise TypeError + attr_item["mapped_type"] = attr_types_map[attr_item["typename"]][0] + for out_item in item["outputs"]: + if out_item["typename"] not in output_type_map.keys(): + name = out_item["typename"] + raise TypeError(f"err type {name}") + if out_item["optional"]: + out_item["mapped_type"] = ( + "paddle::optional<" + + output_type_map[out_item["typename"]] + + ">" + ) + else: + out_item["mapped_type"] = output_type_map[out_item["typename"]] + render( + templates_dir, + destination_dir, + apis=apis, + decomp_white_list=decomp_interface_implementation_gen_op_list, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Generate Static Primitive API' + ) + parser.add_argument( + '--fwd_path', type=str, help='The parsed ops yaml file.' + ) + parser.add_argument( + '--fwd_legacy_path', + type=str, + help='The parsed ops yaml file.', + ) + parser.add_argument( + '--compat_path', + type=str, + help='The parsed ops compat yaml file.', + ) + parser.add_argument( + '--fwd_pd_op_path', + type=str, + help='The ir forward ops parsed yaml file.', + ) + parser.add_argument( + '--templates_dir', + type=str, + help='JinJa2 templates base directory.', + ) + parser.add_argument( + '--destination_dir', + type=str, + help='Destination base directory for generated file.', + ) + args = parser.parse_args() + + gen( + pathlib.Path(args.fwd_path), + pathlib.Path(args.fwd_legacy_path), + pathlib.Path(args.compat_path), + pathlib.Path(args.fwd_pd_op_path), + pathlib.Path(args.templates_dir), + pathlib.Path(args.destination_dir), + ) diff --git a/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 new file mode 100644 index 0000000000000..601a59802d431 --- /dev/null +++ b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 @@ -0,0 +1,113 @@ +{% import "common.j2" as common %} +// Auto Generated by decomp_gen.py, DO NOT EDIT! + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/primitive/composite/composite.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/op_base.h" + + +namespace paddle { +namespace dialect { +using IntArray = paddle::experimental::IntArray; + +{% macro sig(fwd_name, class_name, inputs, attrs, outputs) %} +{%- set input_names=[] -%} +{%- set attr_names=[] -%} +{%- set output_names=[] -%} +{%- set output_types=[] -%} +std::vector> {{class_name}}::Decomp(pir::Operation* op) { + {{class_name}} op_obj = op->dyn_cast<{{class_name}}>(); + (void)op_obj; + + VLOG(4) << "Decomp Prepare inputs of {{fwd_name}}"; + + {% for item in inputs %} + {% do input_names.append(item.name) %} + {% if item.typename == "Tensor" %} {#- Tensor or Tensor[] #} + {% if item.optional %} + paddle::optional {{item.name}}; + if (!IsEmptyValue(op_obj.{{item.name}}())){ + {{item.name}} = paddle::make_optional(Tensor(std::make_shared(op_obj.{{item.name}}()))); + } + {% else %} + {{item.typename}} {{item.name}}(std::make_shared(op_obj.{{item.name}}())); + {% endif %} + {% elif item.typename == "Tensor[]" %} + {% if item.optional %} + + paddle::optional> {{item.name}}; + if (!IsEmptyValue(op_obj.{{item.name}}())){ + pir::CombineOp combine_op_obj = + op_obj.{{item.name}}().dyn_cast().owner()->dyn_cast(); + std::vector optional_{{item.name}}; + for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + optional_{{item.name}}.emplace_back( + std::make_shared(combine_op_obj.inputs()[idx])); + } + {{item.name}} = paddle::make_optional>(optional_{{item.name}}); + } + + {% else %} + pir::CombineOp combine_op_obj_{{item.name}} = + op_obj.{{item.name}}().dyn_cast().owner()->dyn_cast(); + std::vector {{item.name}}; + for (size_t idx = 0; idx < combine_op_obj_{{item.name}}.inputs().size(); idx++) { + {{item.name}}.emplace_back( + std::make_shared(combine_op_obj_{{item.name}}.inputs()[idx])); + {% endif %} + {% endif %} + {% endfor %} + + VLOG(4) << "Decomp prepare attributes of {{fwd_name}}"; + + {% if attrs %} + {% for item in attrs %} + {% do attr_names.append(item.name) %} + {{item.typename}} {{item.name}} = op->attribute("{{item.name}}").dyn_cast<{{item.mapped_type}}>().data(); + {% endfor %} + {% endif %} + + VLOG(4) << "Decomp prepare call {{fwd_name}}'s decomp interface"; + + auto org_res = op->results(); + std::vector> res(org_res.size()); + + {% if outputs|length == 1 %} + Tensor op_res = paddle::primitive::details::{{fwd_name}}_decomp({{common.args(input_names, attr_names)}}); + res[0].push_back( + std::static_pointer_cast(op_res.impl()) + ->value() + .dyn_cast()); + + {%- else -%} + {% for item in outputs %} + {% do output_names.append(item.name) %} + {% do output_types.append(item.mapped_type) %} + {% endfor %} + std::tuple<{{common.sequence('', '', ', ', output_types)}}> op_res = paddle::primitive::details::{{fwd_name}}_decomp( + {{common.args(input_names, attr_names)}}); + for (size_t i = 0; i < org_res.size(); ++i) { + res[i].push_back(std::static_pointer_cast(std::get(op_res).impl())->value().dyn_cast()); + } + {% endif %} + + return res; + +} + +{% endmacro %} + +{% for api in apis %} + {%- if api.name in decomp_white_list -%} + {{sig(api.name, api.class_name, api.inputs, api.attrs, api.outputs)}} + {%- else -%} {#- render nothing -#} + {% endif %} +{% endfor %} +} // namespace dialect +} // namespace paddle From 514869fc684f8e933d4b4b247273fb2878e142ed Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Sun, 29 Oct 2023 10:36:46 +0000 Subject: [PATCH 2/2] polish code --- paddle/fluid/primitive/codegen/decomp_gen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/primitive/codegen/decomp_gen.py b/paddle/fluid/primitive/codegen/decomp_gen.py index 4375092cad2cc..ee183cbdacee6 100644 --- a/paddle/fluid/primitive/codegen/decomp_gen.py +++ b/paddle/fluid/primitive/codegen/decomp_gen.py @@ -126,7 +126,6 @@ def save(content: str, path: pathlib.Path): ): with open(path, 'w') as f: f.write(content) - print(f"Generate or cover source file {path}") def filter_compat_info(items):