From 8d46279afb8b7783b71965c8bf068c4e4044701d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 11 Oct 2022 12:42:58 +0000 Subject: [PATCH 01/11] support generating code of opmaker for backward op invoke forward op --- paddle/fluid/operators/flip_op.cc | 108 ------------------ paddle/fluid/operators/unity_build_rule.cmake | 2 - paddle/phi/api/lib/CMakeLists.txt | 6 +- paddle/phi/api/yaml/backward.yaml | 6 + paddle/phi/api/yaml/generator/generate_op.py | 75 +++++++++--- .../phi/api/yaml/generator/templates/op.c.j2 | 4 +- .../generator/templates/operator_utils.c.j2 | 42 +++++++ paddle/phi/api/yaml/legacy_backward.yaml | 6 - paddle/phi/api/yaml/legacy_ops.yaml | 9 -- paddle/phi/api/yaml/op_compat.yaml | 6 + paddle/phi/api/yaml/op_version.yaml | 10 ++ paddle/phi/api/yaml/ops.yaml | 9 ++ 12 files changed, 140 insertions(+), 143 deletions(-) delete mode 100644 paddle/fluid/operators/flip_op.cc diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc deleted file mode 100644 index 4c14418690a85..0000000000000 --- a/paddle/fluid/operators/flip_op.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -using framework::OpKernelType; - -class FlipOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); - } -}; - -class FlipOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of flip op."); - AddOutput("Out", "(Tensor), The output tensor of flip op."); - AddAttr>("axis", "The axes to flip on."); - AddComment(R"DOC( - Flip Operator. - Reverse the order of a n-D tensor along given axis in axes. - )DOC"); - } -}; - -class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -template -class FlipOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("flip"); - retv->SetInput("X", this->OutputGrad("Out")); - retv->SetOutput("Out", this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -DECLARE_INFER_SHAPE_FUNCTOR(flip, - FlipInferShapeFunctor, - PD_INFER_META(phi::FlipInferMeta)); -REGISTER_OPERATOR(flip, - ops::FlipOp, - ops::FlipOpMaker, - ops::FlipOpInferVarType, - ops::FlipOpGradMaker, - ops::FlipOpGradMaker, - FlipInferShapeFunctor); - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(flip).AddCheckpoint( - R"ROC(Upgrade flip, add new attr [axis] and delete attr [dims].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewAttr("axis", - "The added attr 'axis' doesn't set default value.", - paddle::none) - .DeleteAttr("dims", "The attr 'dims' is deleted.")); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 69e6539feaca3..7cde56121b00c 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -108,7 +108,6 @@ register_unity_group( register_unity_group( cc flatten_op.cc - flip_op.cc fsp_op.cc gather_nd_op.cc gather_op.cc @@ -423,7 +422,6 @@ register_unity_group(cu expand_v2_op.cu fake_dequantize_op.cu fill_any_like_op.cu) register_unity_group( cu - flip_op.cu fsp_op.cu gather_nd_op.cu gather_op.cu diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 3795060d24b98..b8680e9c91add 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -171,9 +171,9 @@ create or remove auto-geneated argument mappings: ${generated_argument_mapping_p execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND - ${PYTHON_EXECUTABLE} generator/generate_op.py --api_yaml_path - ./parsed_apis/api.parsed.yaml --backward_api_yaml_path - ./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path + ${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path + ./parsed_apis/api.parsed.yaml --backward_yaml_path + ./parsed_apis/backward_api.parsed.yaml --op_version_yaml_path op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path "${generated_op_path}.tmp" --output_arg_map_path "${generated_argument_mapping_path}.tmp" diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 9e37919d7b3da..7923edac0c8d4 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -147,6 +147,12 @@ data_type: out_grad no_need_buffer: x +- backward_op : flip_grad + forward : flip (Tensor x, int[] axis) -> Tensor(out) + args : (Tensor out_grad, int[] axis) + output : Tensor(x_grad) + invoke : flip(out_grad, axis) + - backward_op : graph_send_uv_grad forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out) args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD") diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index 5fa3be685e487..c8390ee4aa32a 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -145,6 +145,14 @@ def get_api_and_op_name(api_item): [:-5]] + '_grad' args_item['name'] = args_map[args_item['name']] + if 'invoke' in backward_api_item: + backward_api_item['invoke']['args'] = [ + args_map[param.strip()] + if param.strip() in args_map else param.strip() + for param in backward_api_item['invoke']['args'].split(',') + ] + continue + backward_api_item['infer_meta']['param'] = [ args_map[param] if param in args_map else param for param in backward_api_item['infer_meta']['param'] @@ -175,9 +183,9 @@ def get_api_and_op_name(api_item): ] -def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, - api_version_yaml_path, output_op_path, output_arg_map_path): - with open(api_yaml_path, "rt") as f: +def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path, + op_version_yaml_path, output_op_path, output_arg_map_path): + with open(ops_yaml_path, "rt") as f: apis = yaml.safe_load(f) apis = [restruct_io(api) for api in apis] forward_api_dict = to_named_dict(apis) @@ -187,7 +195,7 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, backward_apis = [restruct_io(api) for api in backward_apis] backward_api_dict = to_named_dict(backward_apis) - with open(api_version_yaml_path, "rt") as f: + with open(op_version_yaml_path, "rt") as f: api_versions = yaml.safe_load(f) # add api version info into api for api_version in api_versions: @@ -203,6 +211,45 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) + # prepare for invoke case + for bw_name, bw_api in backward_api_dict.items(): + if 'invoke' in bw_api: + invoke_op = bw_api['invoke']['func'] + args_list = bw_api['invoke']['args'] + args_index = 0 + if invoke_op in forward_api_dict.keys(): + reuse_op = forward_api_dict[invoke_op] + bw_api['invoke']['inputs'] = [] + bw_api['invoke']['attrs'] = [] + bw_api['invoke']['outputs'] = [] + for input_item in reuse_op['inputs']: + bw_api['invoke']['inputs'].append({ + 'name': + input_item['name'], + 'value': + args_list[args_index] + }) + args_index = args_index + 1 + for attr in reuse_op['attrs']: + if args_index < len(args_list): + attr_value = f"this->GetAttr(\"{args_list[args_index]}\")" if args_list[ + args_index] in bw_api['attr_dict'] else args_list[ + args_index] + bw_api['invoke']['attrs'].append({ + 'name': attr['name'], + 'value': attr_value + }) + args_index = args_index + 1 + else: + break + for idx, output_item in enumerate(reuse_op['outputs']): + bw_api['invoke']['outputs'].append({ + 'name': + output_item['name'], + 'value': + bw_api['outputs'][idx]['name'] + }) + # fill backward field for an api if another api claims it as forward for name, backward_api in backward_api_dict.items(): forward_name = backward_api["forward"]["name"] @@ -238,18 +285,18 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate operator file from api yaml.") - parser.add_argument('--api_yaml_path', + parser.add_argument('--ops_yaml_path', type=str, - help="parsed api yaml file.") - parser.add_argument('--backward_api_yaml_path', + help="parsed ops yaml file.") + parser.add_argument('--backward_yaml_path', type=str, - help="parsed backward api yaml file.") + help="parsed backward ops yaml file.") parser.add_argument('--op_compat_yaml_path', type=str, - help="api args compat yaml file.") - parser.add_argument('--api_version_yaml_path', + help="ops args compat yaml file.") + parser.add_argument('--op_version_yaml_path', type=str, - help="api version yaml file.") + help="ops version yaml file.") parser.add_argument("--output_op_path", type=str, help="path to save generated operators.") @@ -259,6 +306,6 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, help="path to save generated argument mapping functions.") args = parser.parse_args() - main(args.api_yaml_path, args.backward_api_yaml_path, - args.op_compat_yaml_path, args.api_version_yaml_path, - args.output_op_path, args.output_arg_map_path) + main(args.ops_yaml_path, args.backward_yaml_path, args.op_compat_yaml_path, + args.op_version_yaml_path, args.output_op_path, + args.output_arg_map_path) diff --git a/paddle/phi/api/yaml/generator/templates/op.c.j2 b/paddle/phi/api/yaml/generator/templates/op.c.j2 index 0c2708ce223c7..4799866f993cb 100644 --- a/paddle/phi/api/yaml/generator/templates/op.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/op.c.j2 @@ -1,4 +1,4 @@ -{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %} +{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %} // this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. #include #include "paddle/fluid/framework/infershape_utils.h" @@ -33,6 +33,8 @@ using paddle::framework::GradVarName; {{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{operator(api)}} + {% else %} +{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {% endif %} {% endfor %} } // namespace operators diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index 3910da99d8ae3..afef14ecaf508 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -352,6 +352,48 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker }; {% endmacro %} +{% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %} + {% set name = bw_op["op_name"] %} + {% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="name") | list %} + {% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="name") | list %} + {% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="name") | list %} + {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %} + {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %} + {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %} +template +class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("{{invoke_op["func"]}}"); + + {% for input in invoke_op["inputs"] %} + grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( + input["value"], + forward_input_names, + forward_output_names, + forward_input_orig_names, + forward_output_orig_names)}}); + {% endfor %} + + {% for output in invoke_op["outputs"] %} + grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( + output["value"], + forward_input_names, + forward_output_names, + forward_input_orig_names, + forward_output_orig_names)}}); + {% endfor %} + + {% for attr in invoke_op["attrs"] %} + grad_op->SetAttr("{{attr["name"]}}", {{attr["value"]}}); + {% endfor %} + } +}; +{% endmacro %} + {% macro extract_input_from_forward(name, input_names, output_names, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2f4eff9835604..2e7d240c5f586 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -839,12 +839,6 @@ layout: out_grad inplace : (out_grad -> x_grad) -- backward_op : flip_grad - forward : flip (Tensor x, int[] axis) -> Tensor(out) - args : (Tensor out_grad, int[] axis) - output : Tensor(x_grad) - invoke : flip(out_grad, axis) - - backward_op : floor_grad forward : floor(Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index bca2aeb58a507..53c908a26694f 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -962,15 +962,6 @@ intermediate : xshape backward : flatten_grad -- op : flip - args : (Tensor x, int[] axis) - output : Tensor - infer_meta : - func : FlipInferMeta - kernel : - func : flip - backward : flip_grad - - op : floor args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index bbaa942520107..bcb3563a040b4 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -324,6 +324,12 @@ inputs: {x: X} outputs: {out: Out} +- op : flip + inputs : + x : X + outputs : + out : Out + - op : floor backward : floor_grad extra : diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 5702884533a28..3028b927966a2 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -1,3 +1,13 @@ +- op : flip + version : + - checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims] + action : + - add_attr : axis + comment : The added attr 'axis' doesn't set default value + default : paddle::none + - delete_attr : dims + comment : The attr 'dims' is deleted. + - op : trace version : - checkpoint : Upgrade trace add a new attribute [axis2] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 02bcd1f0040a8..10e617bd91243 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -199,3 +199,12 @@ kernel : func : trunc backward : trunc_grad + +- op : flip + args : (Tensor x, int[] axis) + output : Tensor (out) + infer_meta : + func : FlipInferMeta + kernel : + func : flip + backward : flip_grad From 9cd25df41f193042dfa70f61b2ac4f1a34bc78ee Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 11 Oct 2022 13:40:31 +0000 Subject: [PATCH 02/11] gsupport code-gen of opmaker for sparse op --- .gitignore | 2 + paddle/phi/api/lib/CMakeLists.txt | 106 +++++++++++--- .../api/yaml/generator/generate_sparse_op.py | 131 ++++++++++++++++++ .../yaml/generator/templates/sparse_op.c.j2 | 44 ++++++ paddle/phi/api/yaml/sparse_backward.yaml | 10 +- paddle/phi/api/yaml/sparse_ops.yaml | 2 +- 6 files changed, 270 insertions(+), 25 deletions(-) create mode 100644 paddle/phi/api/yaml/generator/generate_sparse_op.py create mode 100644 paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 diff --git a/.gitignore b/.gitignore index 4ed5ca8bcd3de..ee8acc3186f38 100644 --- a/.gitignore +++ b/.gitignore @@ -72,7 +72,9 @@ tools/nvcc_lazy # these files (directories) are generated before build system generation paddle/fluid/operators/generated_op.cc +paddle/fluid/operators/generated_sparse_op.cc paddle/phi/ops/compat/generated_sig.cc +paddle/phi/ops/compat/generated_sparse_sig.cc paddle/phi/api/yaml/parsed_apis/ python/paddle/utils/code_gen/ paddle/fluid/pybind/tmp_eager_op_function_impl.h diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 3795060d24b98..4a04272e10e8d 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -118,8 +118,13 @@ endif() set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis) set(generated_op_path ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc) +set(generated_sparse_ops_path + ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc) set(generated_argument_mapping_path ${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc) +set(generated_sparse_argument_mapping_path + ${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc) + message( "parse api yamls: - ${api_yaml_file} @@ -130,16 +135,22 @@ execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir} COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml - --output_path ./parsed_apis/api.parsed.yaml + --output_path ./parsed_apis/ops.parsed.yaml COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path - ./legacy_ops.yaml --output_path ./parsed_apis/legacy_api.parsed.yaml + ./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml - --output_path ./parsed_apis/backward_api.parsed.yaml --backward + --output_path ./parsed_apis/backward_ops.parsed.yaml --backward COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./legacy_backward.yaml --output_path - ./parsed_apis/legacy_backward_api.parsed.yaml --backward RESULTS_VARIABLE + ./parsed_apis/legacy_backward_ops.parsed.yaml --backward + COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path + ./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml + COMMAND + ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path + ./sparse_backward.yaml --output_path + ./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE _results) foreach(_result in ${_results}) if(${_result}) @@ -149,19 +160,25 @@ endforeach() # validation of api yamls message("validate api yaml: -- ${parsed_api_dir}/api.parsed.yaml -- ${parsed_api_dir}/backward_api.parsed.yaml") +- ${parsed_api_dir}/ops.parsed.yaml +- ${parsed_api_dir}/backward_ops.parsed.yaml") execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND ${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths - ./parsed_apis/api.parsed.yaml ./parsed_apis/legacy_api.parsed.yaml - --backward_yaml_paths ./parsed_apis/backward_api.parsed.yaml - ./parsed_apis/legacy_backward_api.parsed.yaml - RESULT_VARIABLE _result) -if(${_result}) - message(FATAL_ERROR "api validation failed, exiting.") -endif() + ./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml + --backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml + ./parsed_apis/legacy_backward_ops.parsed.yaml + COMMAND + ${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths + ./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths + ./parsed_apis/sparse_backward.parsed.yaml + RESULT_VARIABLE _results) +foreach(_result in ${_results}) + if(${_result}) + message(FATAL_ERROR "ops validation failed, exiting.") + endif() +endforeach() # code generation for op, op makers, and argument mapping functions message( @@ -172,15 +189,23 @@ execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND ${PYTHON_EXECUTABLE} generator/generate_op.py --api_yaml_path - ./parsed_apis/api.parsed.yaml --backward_api_yaml_path - ./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path + ./parsed_apis/ops.parsed.yaml --backward_api_yaml_path + ./parsed_apis/backward_ops.parsed.yaml --api_version_yaml_path op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path "${generated_op_path}.tmp" --output_arg_map_path "${generated_argument_mapping_path}.tmp" - RESULT_VARIABLE _result) -if(${_result}) - message(FATAL_ERROR "operator codegen failed, exiting.") -endif() + COMMAND + ${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path + ./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path + ./parsed_apis/sparse_backward.parsed.yaml --output_op_path + "${generated_sparse_ops_path}.tmp" --output_arg_map_path + "${generated_sparse_argument_mapping_path}.tmp" + RESULT_VARIABLE _results) +foreach(_result in ${_results}) + if(${_result}) + message(FATAL_ERROR "operator codegen failed, exiting.") + endif() +endforeach() if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}") execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different @@ -195,6 +220,25 @@ else() message("remove ${generated_op_path}") endif() +if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS + "${generated_sparse_ops_path}") + execute_process( + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}") + message( + "copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}" + ) +elseif(EXISTS "${generated_sparse_ops_path}.tmp") + execute_process( + COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp" + "${generated_sparse_ops_path}") + message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}") +else() + execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f + "${generated_sparse_ops_path}") + message("remove ${generated_sparse_ops_path}") +endif() + if(EXISTS "${generated_argument_mapping_path}.tmp" AND EXISTS "${generated_argument_mapping_path}") execute_process( @@ -218,6 +262,30 @@ else() message("remove ${generated_argument_mapping_path}") endif() +if(EXISTS "${generated_sparse_argument_mapping_path}.tmp" + AND EXISTS "${generated_sparse_argument_mapping_path}") + execute_process( + COMMAND + ${CMAKE_COMMAND} -E copy_if_different + "${generated_sparse_argument_mapping_path}.tmp" + "${generated_sparse_argument_mapping_path}") + message( + "copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}" + ) +elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp") + execute_process( + COMMAND + ${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp" + "${generated_sparse_argument_mapping_path}") + message( + "copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}" + ) +else() + execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f + "${generated_sparse_argument_mapping_path}") + message("remove ${generated_sparse_argument_mapping_path}") +endif() + # generate ops extra info execute_process( COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py new file mode 100644 index 0000000000000..3e35e1094ff88 --- /dev/null +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +from itertools import chain +from pathlib import Path + +import yaml +from jinja2 import Environment, FileSystemLoader, StrictUndefined + +from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case +from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer +from filters import to_input_name, cartesian_prod_mapping +from parse_utils import to_named_dict + +file_loader = FileSystemLoader(Path(__file__).parent / "templates") +env = Environment(loader=file_loader, + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=StrictUndefined, + extensions=['jinja2.ext.do']) +env.filters["to_op_attr_type"] = to_op_attr_type +env.filters["to_opmaker_name"] = to_opmaker_name +env.filters["to_pascal_case"] = to_pascal_case +env.filters["to_input_name"] = to_input_name +env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr +env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping +env.tests["base_api"] = is_base_api +env.tests["vec"] = is_vec +env.tests["scalar"] = is_scalar +env.tests["initializer_list"] = is_initializer_list +env.tests["supports_inplace"] = supports_inplace +env.tests["supports_no_need_buffer"] = supports_no_need_buffer + + +def restruct_io(api): + api["input_dict"] = to_named_dict(api["inputs"]) + api["attr_dict"] = to_named_dict(api["attrs"]) + api["output_dict"] = to_named_dict(api["outputs"]) + return api + + +SPARSE_OP_PREFIX = 'sparse_' + + +def main(api_yaml_path, backward_yaml_path, output_op_path, + output_arg_map_path): + with open(api_yaml_path, "rt") as f: + apis = yaml.safe_load(f) + apis = [restruct_io(api) for api in apis] + forward_api_dict = to_named_dict(apis) + + with open(backward_yaml_path, "rt") as f: + backward_apis = yaml.safe_load(f) + backward_apis = [restruct_io(api) for api in backward_apis] + backward_api_dict = to_named_dict(backward_apis) + + for api in apis: + api['op_name'] = SPARSE_OP_PREFIX + api['name'] + if api["backward"] is not None: + api["backward"] = SPARSE_OP_PREFIX + api["backward"] + for bw_api in backward_apis: + bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name'] + + # fill backward field for an api if another api claims it as forward + for name, backward_api in backward_api_dict.items(): + forward_name = backward_api["forward"]["name"] + if forward_name in backward_api_dict: + forward_api = backward_api_dict[forward_name] + if forward_api["backward"] is None: + forward_api["backward"] = name + forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"] + + api_dict = {} + api_dict.update(forward_api_dict) + api_dict.update(backward_api_dict) + + if len(apis) == 0 and len(backward_apis) == 0: + if os.path.isfile(output_op_path): + os.remove(output_op_path) + if os.path.isfile(output_arg_map_path): + os.remove(output_arg_map_path) + return + + op_template = env.get_template('sparse_op.c.j2') + with open(output_op_path, "wt") as f: + msg = op_template.render(apis=apis, + backward_apis=backward_apis, + api_dict=api_dict) + f.write(msg) + + ks_template = env.get_template('ks.c.j2') + with open(output_arg_map_path, 'wt') as f: + msg = ks_template.render(apis=apis, backward_apis=backward_apis) + f.write(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate operator file from api yaml.") + parser.add_argument('--ops_yaml_path', + type=str, + help="parsed sparse ops yaml file.") + parser.add_argument('--backward_ops_yaml_path', + type=str, + help="parsed backward sparse ops yaml file.") + parser.add_argument("--output_op_path", + type=str, + help="path to save generated operators.") + parser.add_argument( + "--output_arg_map_path", + type=str, + help="path to save generated argument mapping functions.") + + args = parser.parse_args() + main(args.ops_yaml_path, args.backward_ops_yaml_path, args.output_op_path, + args.output_arg_map_path) diff --git a/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 new file mode 100644 index 0000000000000..434518446b574 --- /dev/null +++ b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 @@ -0,0 +1,44 @@ +{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %} +// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +using paddle::framework::GradVarName; + +{% for api in apis %} + {% if api is base_api %} + +{{op_maker(api)}} + +{{operator(api)}} + {% endif %} +{% endfor %} + +{% for api in backward_apis %} + {% if api is base_api %} + +{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} + +{{operator(api)}} + {% endif %} +{% endfor %} +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +{% for api in apis + backward_apis %} +{% if api is base_api %} +{{register_op_with_components(api)}} +{% endif %} +{% endfor %} diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 5850acb3c37fd..5b104122fccab 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -1,5 +1,5 @@ - backward_op : abs_grad - forward : tanh(Tensor x) -> Tensor(out) + forward : abs(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : @@ -124,8 +124,8 @@ cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} data_type : out_grad -- backward_op : conv3d_coo_grad - forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) +- backward_op : conv3d_grad + forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(x_grad), Tensor(kernel_grad) infer_meta : @@ -421,7 +421,7 @@ transpose_csr_grad {sparse_csr -> sparse_csr} - backward_op : values_grad - forward : values_coo(Tensor x) -> Tensor(out) + forward : values(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : @@ -431,7 +431,7 @@ func : values_coo_grad{sparse_coo, dense-> sparse_coo} - backward_op: fused_attention_grad - forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) + forward : fused_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 43f8688fb81a0..93d7e47617546 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -119,7 +119,7 @@ func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x intermediate: rulebook, counter - backward : conv3d_coo_grad + backward : conv3d_grad - op : divide args : (Tensor x, Tensor y) From 35ba8a931761e56a281191b7e4910cfaa257555b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 12 Oct 2022 04:12:05 +0000 Subject: [PATCH 03/11] refind logic of choose phi kernrel --- paddle/phi/core/compat/op_utils.h | 2 +- paddle/phi/core/kernel_factory.cc | 12 ++++++++++++ paddle/phi/core/kernel_factory.h | 4 +--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b578afa7c2b85..10b859fdac260 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -40,7 +40,7 @@ const std::unordered_set standard_kernel_suffixs({ * after 2.0, and can no longer be occupied by the previously abandoned ops. * They are marked here uniformly. */ -const std::unordered_set deprecated_op_names( +static const std::unordered_set deprecated_op_names( {"diag", "flatten", "flatten_grad", diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 71256bdabaa67..d97314e70a78b 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif +#include "paddle/phi/core/compat/op_utils.h" DECLARE_bool(enable_api_kernel_fallback); @@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } +bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { + if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) { + if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type)) { + return true; + } else if (kernels_.find(op_type) != kernels_.end()) { + return true; + } + } + return false; +} + const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8e98c276646d9..ed9280fa475bf 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -272,9 +272,7 @@ class KernelFactory { KernelNameMap& kernels() { return kernels_; } - bool HasCompatiblePhiKernel(const std::string& op_type) const { - return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end(); - } + bool HasCompatiblePhiKernel(const std::string& op_type) const; KernelResult SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key, From 55a0b719182cb3baa9fd15c982559a3e9428893f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 13 Oct 2022 06:49:34 +0000 Subject: [PATCH 04/11] fix complie budg --- paddle/phi/core/kernel_factory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index d97314e70a78b..480882550dbca 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -48,7 +48,7 @@ KernelFactory& KernelFactory::Instance() { bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) { - if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type)) { + if (phi::OpUtilsMap::Instance().Contains(op_type)) { return true; } else if (kernels_.find(op_type) != kernels_.end()) { return true; From 74a5abcc7d4dd083f32c680b1525879d84c02117 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 13 Oct 2022 08:42:38 +0000 Subject: [PATCH 05/11] fix code_gen bug --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/sparse_manual_op.cc | 273 ------------------ paddle/phi/api/yaml/generator/generate_op.py | 66 +++-- .../api/yaml/generator/generate_sparse_op.py | 10 + .../generator/templates/operator_utils.c.j2 | 8 +- .../yaml/generator/templates/sparse_op.c.j2 | 9 +- paddle/phi/ops/compat/sparse_manual_op_sig.cc | 104 ------- 7 files changed, 59 insertions(+), 413 deletions(-) delete mode 100644 paddle/fluid/operators/sparse_manual_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d985baf8c9088..2aab6e9d50cb7 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -101,7 +101,7 @@ else() cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) endif() -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta) +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/sparse_manual_op.cc b/paddle/fluid/operators/sparse_manual_op.cc deleted file mode 100644 index 327e03af80506..0000000000000 --- a/paddle/fluid/operators/sparse_manual_op.cc +++ /dev/null @@ -1,273 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/binary.h" -#include "paddle/phi/infermeta/multiary.h" -#include "paddle/phi/infermeta/sparse/binary.h" -#include "paddle/phi/infermeta/sparse/unary.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class SparseSparseCooTensorOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("values", "(Tensor), input 0 of sparse_coo_tensor op."); - AddInput("indices", "(Tensor), input 1 of sparse_coo_tensor op."); - AddOutput("out", "(Tensor), output 0 of sparse_coo_tensor op."); - AddAttr>( - "dense_shape", "(vector), attribute 0 for sparse_coo_tensor op."); - AddComment(R"DOC( -TODO: Documentation of sparse_coo_tensor op. -)DOC"); - } -}; - -class SparseSparseCooTensorOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR( - sparse_sparse_coo_tensor, - SparseSparseCooTensorInferShapeFunctor, - PD_INFER_META(phi::sparse::SparseCooTensorInferMeta)); - -class SparseValuesOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_values op."); - AddOutput("out", "(Tensor), output 0 of sparse_values op."); - AddComment(R"DOC( -TODO: Documentation of sparse_values op. -)DOC"); - } -}; - -class SparseValuesOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_values, - SparseValuesInferShapeFunctor, - PD_INFER_META(phi::sparse::ValuesInferMeta)); - -class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_indices op."); - AddOutput("out", "(Tensor), output 0 of sparse_indices op."); - AddComment(R"DOC( -TODO: Documentation of sparse_indices op. -)DOC"); - } -}; - -class SparseIndicesOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices, - SparseIndicesInferShapeFunctor, - PD_INFER_META(phi::sparse::IndicesInferMeta)); - -class SparseToDenseOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_to_dense op."); - AddOutput("out", "(Tensor), output 0 of sparse_to_dense op."); - AddComment(R"DOC( -TODO: Documentation of sparse_to_dense op. -)DOC"); - } -}; - -class SparseToDenseOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_to_dense, - SparseToDenseInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_relu op."); - AddOutput("out", "(Tensor), output 0 of sparse_relu op."); - AddComment(R"DOC( -TODO: Documentation of sparse_relu op. -)DOC"); - } -}; - -class SparseReluOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_relu, - SparseReluInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseConv3dOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_conv3d op."); - AddInput("kernel", "(Tensor), input 1 of sparse_conv3d op."); - AddOutput("out", "(Tensor), output 0 of sparse_conv3d op."); - AddOutput("rulebook", "(Tensor), output 1 of sparse_conv3d op."); - AddOutput("counter", "(Tensor), output 2 of sparse_conv3d op."); - AddAttr>( - "paddings", "(vector), attribute 0 for sparse_conv3d op."); - AddAttr>( - "dilations", "(vector), attribute 1 for sparse_conv3d op."); - AddAttr>( - "strides", "(vector), attribute 2 for sparse_conv3d op."); - AddAttr("groups", "(int), attribute 3 for sparse_conv3d op."); - AddAttr("subm", "(bool), attribute 4 for conv3d_coo op."); - AddAttr("key", "(string), attribute 5 for sparse_conv3d op.") - .SetDefault(""); - AddComment(R"DOC( -TODO: Documentation of sparse_conv3d op. -)DOC"); - } -}; - -class SparseConv3dOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_conv3d, - SparseConv3dInferShapeFunctor, - PD_INFER_META(phi::sparse::Conv3dInferMeta)); - -class SparseAddOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_add op."); - AddInput("y", "(Tensor), input 1 of sparse_add op."); - AddOutput("out", "(Tensor), output 0 of sparse_add op."); - AddComment(R"DOC( -TODO: Documentation of sparse_add op. -)DOC"); - } -}; - -class SparseAddOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_add, - SparseAddInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_batch_norm op."); - AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op."); - AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op."); - AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op."); - AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op."); - AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op."); - AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op."); - AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op."); - AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op."); - AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op."); - AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op."); - AddAttr("momentum", - "(float), attribute 0 for sparse_batch_norm op."); - AddAttr("epsilon", "(float), attribute 1 for sparse_batch_norm op."); - AddAttr("data_layout", - "(string), attribute 2 for sparse_batch_norm op."); - AddAttr("is_test", "(bool), attribute 3 for sparse_batch_norm op."); - AddAttr("use_global_stats", - "(bool), attribute 4 for sparse_batch_norm op."); - AddAttr("trainable_statistics", - "(bool), attribute 4 for sparse_batch_norm op."); - AddAttr("fuse_with_relu", - "(bool), attribute 4 for sparse_batch_norm op."); - AddComment(R"DOC( -TODO: Documentation of sparse_batch_norm op. -)DOC"); - } -}; - -class SparseBatchNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm, - SparseBatchNormInferShapeFunctor, - PD_INFER_META(phi::BatchNormInferMeta)); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(sparse_sparse_coo_tensor, - ops::SparseSparseCooTensorOp, - ops::SparseSparseCooTensorOpMaker, - ops::SparseSparseCooTensorInferShapeFunctor); - -REGISTER_OPERATOR(sparse_values, - ops::SparseValuesOp, - ops::SparseValuesOpMaker, - ops::SparseValuesInferShapeFunctor); - -REGISTER_OPERATOR(sparse_indices, - ops::SparseIndicesOp, - ops::SparseIndicesOpMaker, - ops::SparseIndicesInferShapeFunctor); - -REGISTER_OPERATOR(sparse_to_dense, - ops::SparseToDenseOp, - ops::SparseToDenseOpMaker, - ops::SparseToDenseInferShapeFunctor); - -REGISTER_OPERATOR(sparse_relu, - ops::SparseReluOp, - ops::SparseReluOpMaker, - ops::SparseReluInferShapeFunctor); - -REGISTER_OPERATOR(sparse_conv3d, - ops::SparseConv3dOp, - ops::SparseConv3dOpMaker, - ops::SparseConv3dInferShapeFunctor); - -REGISTER_OPERATOR(sparse_add, - ops::SparseAddOp, - ops::SparseAddOpMaker, - ops::SparseAddInferShapeFunctor); - -REGISTER_OPERATOR(sparse_batch_norm, - ops::SparseBatchNormOp, - ops::SparseBatchNormOpMaker, - ops::SparseBatchNormInferShapeFunctor); diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index c8390ee4aa32a..65e7b4958136d 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -183,41 +183,13 @@ def get_api_and_op_name(api_item): ] -def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path, - op_version_yaml_path, output_op_path, output_arg_map_path): - with open(ops_yaml_path, "rt") as f: - apis = yaml.safe_load(f) - apis = [restruct_io(api) for api in apis] - forward_api_dict = to_named_dict(apis) - - with open(backward_yaml_path, "rt") as f: - backward_apis = yaml.safe_load(f) - backward_apis = [restruct_io(api) for api in backward_apis] - backward_api_dict = to_named_dict(backward_apis) - - with open(op_version_yaml_path, "rt") as f: - api_versions = yaml.safe_load(f) - # add api version info into api - for api_version in api_versions: - forward_api_dict[api_version['op']]['version'] = api_version['version'] - - with open(op_compat_yaml_path, "rt") as f: - api_op_map = yaml.safe_load(f) - - for api in apis: - api['op_name'] = api['name'] - for bw_api in backward_apis: - bw_api['op_name'] = bw_api['name'] - - replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) - - # prepare for invoke case - for bw_name, bw_api in backward_api_dict.items(): +def process_invoke_op(forward_api_dict, backward_api_dict): + for bw_api in backward_api_dict.values(): if 'invoke' in bw_api: invoke_op = bw_api['invoke']['func'] args_list = bw_api['invoke']['args'] args_index = 0 - if invoke_op in forward_api_dict.keys(): + if invoke_op in forward_api_dict: reuse_op = forward_api_dict[invoke_op] bw_api['invoke']['inputs'] = [] bw_api['invoke']['attrs'] = [] @@ -250,6 +222,38 @@ def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path, bw_api['outputs'][idx]['name'] }) + +def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path, + op_version_yaml_path, output_op_path, output_arg_map_path): + with open(ops_yaml_path, "rt") as f: + apis = yaml.safe_load(f) + apis = [restruct_io(api) for api in apis] + forward_api_dict = to_named_dict(apis) + + with open(backward_yaml_path, "rt") as f: + backward_apis = yaml.safe_load(f) + backward_apis = [restruct_io(api) for api in backward_apis] + backward_api_dict = to_named_dict(backward_apis) + + with open(op_version_yaml_path, "rt") as f: + api_versions = yaml.safe_load(f) + # add api version info into api + for api_version in api_versions: + forward_api_dict[api_version['op']]['version'] = api_version['version'] + + with open(op_compat_yaml_path, "rt") as f: + api_op_map = yaml.safe_load(f) + + for api in apis: + api['op_name'] = api['name'] + for bw_api in backward_apis: + bw_api['op_name'] = bw_api['name'] + + replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) + + # prepare for invoke case + process_invoke_op(forward_api_dict, backward_api_dict) + # fill backward field for an api if another api claims it as forward for name, backward_api in backward_api_dict.items(): forward_name = backward_api["forward"]["name"] diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py index 3e35e1094ff88..4d9e2c1d54d01 100644 --- a/paddle/phi/api/yaml/generator/generate_sparse_op.py +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -25,6 +25,7 @@ from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer from filters import to_input_name, cartesian_prod_mapping from parse_utils import to_named_dict +from generate_op import process_invoke_op file_loader = FileSystemLoader(Path(__file__).parent / "templates") env = Environment(loader=file_loader, @@ -71,10 +72,19 @@ def main(api_yaml_path, backward_yaml_path, output_op_path, for api in apis: api['op_name'] = SPARSE_OP_PREFIX + api['name'] + api['name'] = api['op_name'] if api["backward"] is not None: api["backward"] = SPARSE_OP_PREFIX + api["backward"] for bw_api in backward_apis: bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name'] + bw_api['name'] = bw_api['op_name'] + if 'invoke' in bw_api: + bw_api['invoke']['args'] = [ + param.strip() for param in bw_api['invoke']['args'].split(',') + ] + + # prepare for invoke case + process_invoke_op(forward_api_dict, backward_api_dict) # fill backward field for an api if another api claims it as forward for name, backward_api in backward_api_dict.items(): diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index afef14ecaf508..f51b91413e578 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -81,7 +81,11 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty {% set default_value = attr["default_value"] %} {% set typename = attr["typename"] %} {% if typename == "DataType" %}{# convert back to VarType #} + {% if default_value == "DataType::UNDEFINED" %} +-1 + {%- else %} static_cast(framework::TransToProtoVarType(experimental::{{default_value}})) + {%- endif %} {%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#} static_cast(experimental::{{default_value}}) {%- elif typename == "Place" %}{# construct a Place to get the type #} @@ -94,7 +98,7 @@ static_cast(phi::Place({{"phi::" if not default_value is initializer_list}} {# --------------------------------------- name mapping ---------------------------------------------- #} {% macro name_map(api) %} -KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { +KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { {% set kernel_args = api["kernel"]["param"] %} {{get_input_list(api["inputs"], kernel_args)}}; paddle::small_vector attrs; @@ -129,7 +133,7 @@ PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}}); {%- endmacro %} {% macro register_name_map(api) %} -PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping); {%- endmacro %} {% macro get_input_list(inputs, kernel_args) %}{# inline #} diff --git a/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 index 434518446b574..15d887e589e70 100644 --- a/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 @@ -1,14 +1,17 @@ -{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %} +{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %} // this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. #include #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/sparse/backward.h" #include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/multiary.h" #include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -31,6 +34,8 @@ using paddle::framework::GradVarName; {{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{operator(api)}} + {% else %} +{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {% endif %} {% endfor %} } // namespace operators diff --git a/paddle/phi/ops/compat/sparse_manual_op_sig.cc b/paddle/phi/ops/compat/sparse_manual_op_sig.cc index 6c2a2bc9f451f..94985bc3dfbff 100644 --- a/paddle/phi/ops/compat/sparse_manual_op_sig.cc +++ b/paddle/phi/ops/compat/sparse_manual_op_sig.cc @@ -14,122 +14,18 @@ #include "paddle/phi/core/compat/op_utils.h" -namespace phi { - -// TODO(zhangkaihuo): add csr op - -KernelSignature SparseSparseCooTensorOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "sparse_coo_tensor", {"values", "indices"}, {"dense_shape"}, {"out"}); -} - -KernelSignature SparseValuesOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("values_coo", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseIndicesOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("indices_coo", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseToDenseOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("coo_to_dense", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("relu_coo", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseConv3dOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature( - "conv3d_coo", - {"x", "kernel"}, - {"paddings", "dilations", "strides", "groups", "subm", "key"}, - {"out", "rulebook", "counter"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x") && ctx.IsSparseCooTensorInput("y")) { - return KernelSignature("add_coo_coo", {"x", "y"}, {}, {"out"}); - } else if (ctx.IsSparseCooTensorInput("x") && ctx.IsDenseTensorInput("y")) { - return KernelSignature("add_coo_dense", {"x", "y"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseBatchNormOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("batch_norm_coo", - {"x", "scale", "bias", "mean", "variance"}, - {"momentum", - "epsilon", - "data_layout", - "is_test", - "use_global_stats", - "trainable_statistics", - "fuse_with_relu"}, - {"y", - "mean_out", - "variance_out", - "saved_mean", - "saved_variance", - "reserve_space"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -} // namespace phi - PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor); -PD_REGISTER_ARG_MAPPING_FN(sparse_sparse_coo_tensor, - phi::SparseSparseCooTensorOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_values, phi::SparseValuesOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense); -PD_REGISTER_ARG_MAPPING_FN(sparse_to_dense, - phi::SparseToDenseOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_relu, phi::SparseReluOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm, - phi::SparseBatchNormOpArgumentMapping); From 4aea85ed069b39b485ad7e54395abfd142153003 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 13 Oct 2022 17:14:09 +0000 Subject: [PATCH 06/11] fix bug --- .../auto_code_generator/eager_generator.cc | 10 +++++- .../eager_legacy_op_function_generator.cc | 5 +++ .../api/yaml/generator/generate_sparse_op.py | 7 +++-- paddle/phi/ops/compat/sparse_manual_op_sig.cc | 31 ------------------- 4 files changed, 19 insertions(+), 34 deletions(-) delete mode 100644 paddle/phi/ops/compat/sparse_manual_op_sig.cc diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 38ad5be09d8fa..daade88581c71 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -55,7 +55,9 @@ static std::unordered_set black_ops_list = {"run_program", "fused_gate_attention", "fused_feedforward", "fused_attention", - "fused_gemm_epilogue"}; + "fused_gemm_epilogue", + "sparse_divide_scalar", + "sparse_scale"}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; @@ -3161,6 +3163,12 @@ static void DygraphCodeGeneration(const std::string& output_dir, continue; } + // Skip the sparse op + if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" && + op_type != "sparse_attention") { + continue; + } + GradNodeGenerationInfo bwd_info; bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info); diff --git a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc index 1d27d45beb736..fff811e84ba6f 100644 --- a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc @@ -416,6 +416,11 @@ GenerateOpFunctions() { if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) { continue; } + // Skip the sparse op + if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" && + op_type != "sparse_attention") { + continue; + } // Skip operator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. // if the phi lib contains op kernel, we still generate ops method diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py index 4d9e2c1d54d01..b0c67f1d9df0a 100644 --- a/paddle/phi/api/yaml/generator/generate_sparse_op.py +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -14,8 +14,6 @@ import argparse import os -import re -from itertools import chain from pathlib import Path import yaml @@ -85,6 +83,11 @@ def main(api_yaml_path, backward_yaml_path, output_op_path, # prepare for invoke case process_invoke_op(forward_api_dict, backward_api_dict) + for bw_api in backward_apis: + if 'invoke' in bw_api: + if bw_api['invoke']['func'] in forward_api_dict: + bw_api['invoke'][ + 'func'] = SPARSE_OP_PREFIX + bw_api['invoke']['func'] # fill backward field for an api if another api claims it as forward for name, backward_api in backward_api_dict.items(): diff --git a/paddle/phi/ops/compat/sparse_manual_op_sig.cc b/paddle/phi/ops/compat/sparse_manual_op_sig.cc deleted file mode 100644 index 94985bc3dfbff..0000000000000 --- a/paddle/phi/ops/compat/sparse_manual_op_sig.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo); From 3d3f0f7ff36d9b664f052eff8e84d18aff25de77 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 14 Oct 2022 07:11:52 +0000 Subject: [PATCH 07/11] fix kernel signature code-gen --- paddle/fluid/framework/infershape_utils.cc | 5 ++ paddle/fluid/framework/operator.h | 5 ++ paddle/fluid/framework/tensor.h | 1 + paddle/fluid/framework/var_type_traits.h | 2 + .../tensorrt/plugin_arg_mapping_context.cc | 4 ++ .../tensorrt/plugin_arg_mapping_context.h | 2 + .../api/yaml/generator/generate_sparse_op.py | 2 +- paddle/phi/api/yaml/generator/parse_utils.py | 35 +++++++++++-- .../generator/templates/operator_utils.c.j2 | 52 +++++++++++++++++++ .../yaml/generator/templates/sparse_ks.c.j2 | 24 +++++++++ paddle/phi/core/compat/arg_map_context.h | 1 + paddle/phi/tests/ops/test_op_signature.h | 4 ++ 12 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index de9f6a4745fd0..f8400e89b2b99 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -117,6 +117,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { return var_type == proto::VarType::SPARSE_COO; } + bool IsSparseCsrTensorInput(const std::string& name) const override { + auto var_type = ctx_.GetInputVarType(name); + return var_type == proto::VarType::SPARSE_CSR; + } + bool IsDenseTensorOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); return std::all_of(var_types.begin(), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e649bf2fc7e95..7f99f7dbf7643 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -537,6 +537,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { return var->IsType(); } + bool IsSparseCsrTensorInput(const std::string& name) const override { + const auto* var = ctx_.InputVar(name); + return var->IsType(); + } + bool IsDenseTensorOutput(const std::string& name) const override { auto vars = ctx_.MultiOutputVar(name); return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 3d3334b2c06ee..451a2309892f8 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index f5ae14bbf6109..6b234db2a3647 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -55,6 +55,7 @@ namespace phi { class DenseTensor; class SelectedRows; class SparseCooTensor; +class SparseCsrTensor; } // namespace phi // Users should add forward declarations here @@ -182,6 +183,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< phi::DenseTensor, phi::SelectedRows, phi::SparseCooTensor, + phi::SparseCsrTensor, std::vector, LoDRankTable, Strings, diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc index 9b5ab945d77fa..f7667c6df9eda 100644 --- a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc @@ -112,6 +112,10 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput( const std::string& name) const { return false; } +bool PluginArgumentMappingContext::IsSparseCsrTensorInput( + const std::string& name) const { + return false; +} bool PluginArgumentMappingContext::IsDenseTensorVectorInput( const std::string& name) const { return false; diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h index f004040af3b1a..088e966a0cca7 100644 --- a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h @@ -50,6 +50,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext { bool IsSparseCooTensorInput(const std::string& name) const override; + bool IsSparseCsrTensorInput(const std::string& name) const override; + bool IsDenseTensorVectorInput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override; diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py index b0c67f1d9df0a..e4d7f44856bfd 100644 --- a/paddle/phi/api/yaml/generator/generate_sparse_op.py +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -116,7 +116,7 @@ def main(api_yaml_path, backward_yaml_path, output_op_path, api_dict=api_dict) f.write(msg) - ks_template = env.get_template('ks.c.j2') + ks_template = env.get_template('sparse_ks.c.j2') with open(output_arg_map_path, 'wt') as f: msg = ks_template.render(apis=apis, backward_apis=backward_apis) f.write(msg) diff --git a/paddle/phi/api/yaml/generator/parse_utils.py b/paddle/phi/api/yaml/generator/parse_utils.py index fed5683b3383b..ed8068b40e827 100644 --- a/paddle/phi/api/yaml/generator/parse_utils.py +++ b/paddle/phi/api/yaml/generator/parse_utils.py @@ -156,14 +156,15 @@ def parse_kernel(api_name: str, kernel_config: Dict[str, # backend : str, the names of param to choose the kernel backend, default is None # layout : str, the names of param to choose the kernel layout, default is None # data_type : str, the names of param to choose the kernel data_type, default is None + # dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)}) kernel = { - 'func': None, # up to 2 function names + 'func': [], # up to 2 function names 'param': None, 'backend': None, 'layout': None, - 'data_type': None + 'data_type': None, + 'dispatch': {} } - kernel['func'] = parse_plain_list(kernel_config['func']) if 'param' in kernel_config: kernel['param'] = kernel_config['param'] @@ -175,6 +176,34 @@ def parse_kernel(api_name: str, kernel_config: Dict[str, if 'data_type' in kernel_config: kernel['data_type'] = parse_candidates(kernel_config["data_type"]) + + kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( + kernel_config['func']) + + def parse_kernel_in_out_type(in_out_str): + if len(in_out_str) == 0: + return None + tmp_in_out_list = in_out_str[1:-1].split('->') + inputs = [item.strip() for item in tmp_in_out_list[0].split(',')] + outputs = [item.strip() for item in tmp_in_out_list[1].split(',')] + + # check the tensor type + for item in inputs: + assert item in [ + 'dense', 'selected_rows', 'sparse_coo', 'sparse_csr' + ], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." + for item in outputs: + assert item in [ + 'dense', 'selected_rows', 'sparse_coo', 'sparse_csr' + ], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." + + return (inputs, outputs) + + for func_item in kernel_funcs: + kernel['func'].append(func_item[0]) + kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type( + func_item[1]) + return kernel diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index f51b91413e578..78d206b8d0bcb 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -128,6 +128,58 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg */ {% endmacro %} +{% macro get_kernel_dispatch(inputs, kernel_config) %}{# inline #} +{%- for kernel_func in kernel_config["func"] %} + {% set input_idx = 0 %} + {% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %} + + if ( {%- for input in inputs %} + {%- if input["name"] in kernel_config["param"] %} + {%- if kernel_in_type_list[input_idx] == "dense" %} +ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx] == "selected_rows" %} +ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx] == "sparse_coo" %} +ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx] == "sparse_csr" %} +ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- endif %} + {% set input_idx = input_idx + 1 %} + {%- endif %} + {%- endfor %}) { + kernel_name = "{{kernel_func}}"; + } +{%- endfor %} +{%- endmacro %} + +{% macro sparse_op_name_map(api) %} +KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { + {% set kernel_args = api["kernel"]["param"] %} + {{get_input_list(api["inputs"], kernel_args)}}; + paddle::small_vector attrs; + {% for attr in api["attrs"]%} + {% filter indent(2)%} + {{get_an_attr(attr)}}; + {% endfilter %} + {% endfor %} + {{get_output_list(api["outputs"], kernel_args)}}; + + const char* kernel_name = "unregistered"; +{{get_kernel_dispatch(api["inputs"], api["kernel"])}} + KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); + return sig; +} + +/* +****************************************************************** +NOTE: The following codes are for 'get_compat_kernel_signature.py' +All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: + +{{api | cartesian_prod_mapping}} +****************************************************************** +*/ +{% endmacro %} + {% macro register_base_kernel_name(api) %} PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}}); {%- endmacro %} diff --git a/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 b/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 new file mode 100644 index 0000000000000..1af54ca866083 --- /dev/null +++ b/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 @@ -0,0 +1,24 @@ +{% from "operator_utils.c.j2" import sparse_op_name_map, register_name_map, register_base_kernel_name %} +// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +{% for api in apis %} + {% if api is base_api %} +{{sparse_op_name_map(api)}} + {% endif %} +{% endfor %} +{% for api in backward_apis %} + {% if api is base_api %} +{{sparse_op_name_map(api)}} + {% endif %} +{% endfor %} +} // namespace phi + +{% for api in apis + backward_apis %} + {% if api is base_api %} +{{register_name_map(api)}} + {% endif %} +{% endfor %} diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 3ec35c40c8e5b..4e3447fe9eb22 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -110,6 +110,7 @@ class ArgumentMappingContext { virtual bool IsSelectedRowsInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInputs(const std::string& name) const = 0; virtual bool IsSparseCooTensorInput(const std::string& name) const = 0; + virtual bool IsSparseCsrTensorInput(const std::string& name) const = 0; // For compatibility with LoDTensorArray virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0; diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index 1b067c0aa17e5..7f89fb34994fc 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -90,6 +90,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { return false; } + bool IsSparseCsrTensorInput(const std::string& name) const override { + return false; + } + bool IsDenseTensorOutput(const std::string& name) const override { return dense_tensor_outputs.count(name) > 0; } From f714ed7ab8c997c0baf9f077a9aa1fd859c51672 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 14 Oct 2022 08:09:44 +0000 Subject: [PATCH 08/11] fix complie bug of VarType --- paddle/fluid/framework/framework.proto | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 344c0ac0ccb6f..83ba52f4ec655 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -156,6 +156,8 @@ message VarType { PSTRING = 29; // the data type of phi::SparseCooTensor SPARSE_COO = 30; + // the data type of phi::SparseCsrTensor + SPARSE_CSR = 31; } required Type type = 1; @@ -189,6 +191,7 @@ message VarType { optional TensorDesc strings = 9; optional TensorDesc vocab = 10; optional TensorDesc sparse_coo = 11; + optional TensorDesc sparse_csr = 12; } message VarDesc { From 103348abf61c1f0277224ac26f1e38f2463c20e6 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 14 Oct 2022 08:58:03 +0000 Subject: [PATCH 09/11] fix complie bug of VarType --- paddle/fluid/framework/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2507919e6fc7f..c009cfcd92ac0 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -237,7 +237,7 @@ cc_test( cc_library( var_type_traits SRCS var_type_traits.cc - DEPS framework_proto scope tensor_array sparse_coo_tensor) + DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor) if(WITH_GPU) target_link_libraries(var_type_traits dynload_cuda) endif() From 514476ec912e35c2a7d543073f2ec46efa4394e7 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 14 Oct 2022 12:40:45 +0000 Subject: [PATCH 10/11] fix test_sparse_conv_op --- paddle/fluid/operators/sparse_manual_op.cc | 59 +++++++++++++++++++ .../generator/templates/operator_utils.c.j2 | 12 ++-- paddle/phi/api/yaml/sparse_ops.yaml | 2 +- paddle/phi/ops/compat/sparse_manual_op_sig.cc | 30 ++++++++++ 4 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 paddle/fluid/operators/sparse_manual_op.cc create mode 100644 paddle/phi/ops/compat/sparse_manual_op_sig.cc diff --git a/paddle/fluid/operators/sparse_manual_op.cc b/paddle/fluid/operators/sparse_manual_op.cc new file mode 100644 index 0000000000000..f95d5250c62f6 --- /dev/null +++ b/paddle/fluid/operators/sparse_manual_op.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("x", "(Tensor), input 0 of sparse_indices op."); + AddOutput("out", "(Tensor), output 0 of sparse_indices op."); + AddComment(R"DOC( +TODO: Documentation of sparse_indices op. +)DOC"); + } +}; + +class SparseIndicesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices, + SparseIndicesInferShapeFunctor, + PD_INFER_META(phi::sparse::IndicesInferMeta)); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sparse_indices, + ops::SparseIndicesOp, + ops::SparseIndicesOpMaker, + ops::SparseIndicesInferShapeFunctor); diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index 78d206b8d0bcb..da497e2b3bd00 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -130,21 +130,21 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg {% macro get_kernel_dispatch(inputs, kernel_config) %}{# inline #} {%- for kernel_func in kernel_config["func"] %} - {% set input_idx = 0 %} + {% set input_idx = namespace(idx=0) %} {% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %} if ( {%- for input in inputs %} {%- if input["name"] in kernel_config["param"] %} - {%- if kernel_in_type_list[input_idx] == "dense" %} + {%- if kernel_in_type_list[input_idx.idx] == "dense" %} ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}} - {%- elif kernel_in_type_list[input_idx] == "selected_rows" %} + {%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %} ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}} - {%- elif kernel_in_type_list[input_idx] == "sparse_coo" %} + {%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %} ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}} - {%- elif kernel_in_type_list[input_idx] == "sparse_csr" %} + {%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %} ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} {%- endif %} - {% set input_idx = input_idx + 1 %} + {% set input_idx.idx = input_idx.idx + 1 %} {%- endif %} {%- endfor %}) { kernel_name = "{{kernel_func}}"; diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index e1083ae3a65e9..4e4874cb96736 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -111,7 +111,7 @@ backward : cast_grad - op : conv3d - args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key="") output : Tensor(out), Tensor(rulebook), Tensor(counter) infer_meta : func : sparse::Conv3dInferMeta diff --git a/paddle/phi/ops/compat/sparse_manual_op_sig.cc b/paddle/phi/ops/compat/sparse_manual_op_sig.cc new file mode 100644 index 0000000000000..6e520cbdd96cd --- /dev/null +++ b/paddle/phi/ops/compat/sparse_manual_op_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SparseIndicesOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsSparseCooTensorInput("x")) { + return KernelSignature("indices_coo", {"x"}, {}, {"out"}); + } else { + return KernelSignature("unregistered", {}, {}, {}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping); From dc296ddbd0438282b292ae130671458431b75597 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 14 Oct 2022 14:52:48 +0000 Subject: [PATCH 11/11] fix test_sparse_norm_op --- python/paddle/incubate/sparse/nn/layer/norm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/sparse/nn/layer/norm.py b/python/paddle/incubate/sparse/nn/layer/norm.py index 51dedaf5d9e6e..440a08953951f 100644 --- a/python/paddle/incubate/sparse/nn/layer/norm.py +++ b/python/paddle/incubate/sparse/nn/layer/norm.py @@ -170,9 +170,9 @@ def forward(self, input): dtype=dtype, stop_gradient=True) reserve_space = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) - y = helper.create_sparse_variable_for_type_inference(dtype) + out = helper.create_sparse_variable_for_type_inference(dtype) outputs = { - "y": y, + "out": out, "mean_out": mean_out, "variance_out": variance_out, "saved_mean": saved_mean, @@ -183,7 +183,7 @@ def forward(self, input): inputs=inputs, outputs=outputs, attrs=attrs) - return y + return out class SyncBatchNorm(paddle.nn.SyncBatchNorm):