From 6d5e15f6d0456a36305ebaeb62fdfa4bef6f7995 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 18 Oct 2023 11:32:52 +0800 Subject: [PATCH] [DRR] C++ DRR (Declarative Rewrite Rule) of Paddle (#55859) * fix cudnn 8.7+ bug on cudnnConvolutionBiasActivationForward * add drr_rewrite_pattern.h and remove_redundent_reshape demo * add drr_context and pattern_graph class * add test case * fix cmake file * fix compile bug * fix runtime bug and refine code * add MatchContext * update code * add impl of tensor_interface * fix compile bug * change smart ptr to pointor * change smart to pointor * change smart to pointor * Replace 'weak_ptr' with pointer * modify weak_ptr use count==0 judgment logic * change smart to pointor change smart to pointor Replace 'weak_ptr' with pointer modify weak_ptr use count==0 judgment logic Replace the declaration and call of weakptr with pointer * add match * add match * remove OperationInterface * update * Add Rewrite impl of DrrRewritePattern * refine code * rename ir_value to get in IrValue * fix header include * add CreateOperation template demo * Add GraphTopo class in pattern_graph * Reimplementing the GraphTopo class using queue * Reimplementing the GraphTopo class using queue * Optimize the access method of visited tensor * Considering that the inputs of opcall may be empty * Overloading the operator() method of Op, supporting dual tensor inputs * support attr * 1. Add Op class support for multi input and multi output function. 2. Add DRR duplicate TransposeOp merge testing code * 1. Add transferOP in createOption func * fix bug * fix NotifyOperationRemoved * refine code * Fix axis bug in perm * mupdate share_ptr * update * refine drr_test ut * Modify according to review * modify reshape_op * format code * support vector for attr * fix drr test * refine code * Resolve compilation loop dependencies * add RequireNativeCall * support native_call in drr api * temp tensor prefix fix * refine code * suport Tensor Assgin API in ResultPattern * refine test code * refactor ther drr_pattern class * refine test case * rename DrrPatternBuilder to DrrPatternBase * fix compile bug * adjust include * Add log info in DrrRewritePattern * use ir::get_type_name * use ir::get_type_name * support compute attrbute in drr pattern * refine code * Add fusion testing code for fullOp and expandOp * Standardize code format * Replace IR_THROW() with PADDLE_THROW() * refine code * add attention fuse demo * update * fix compile error * add multihead_matmul fuse pattern * fix multihead_matmul * Update drr_attention_fuse_test.cc add buildprogram * fix drr_attention_fuse_test compile * add fused_gemm_epilogue in drr * attr support std::vector * add debug log * update * fix some bug * fix confilct * support subgraph replace in source pattern graph for drr * Improve the implementation of Drr and multihead_matmul_fuse_pass * add ReorderBlockOpsPass * fix drr_attention_fuse_pass * update * update reorder_block_ops_pass * revert fusedgemm * update * Add Bottom2UpMatch() func * merge code * fix bug * add log & fix bug * refine cpp type trait * using oprand() & num_oprand() replace oprands() * fix conflict * fix compile * fix pd.xxx to pd_op.xxx * fix bug of delete op in drr * add PatternGraphMatchV2 & FindOutputOp func * refactor ir operation creator * fix include pir * fix ir * merging * Split out dfsvisitor func from FindOutputOp func * fix bug * fix output op in source pattern bug * Debugging drr_test drr_attention_fuse_test passed! * Debugging drr_fuse_linear_test passed! * Optimize the PatternGraphMatchV2 function interface and overload the operator= method in MatchContextImpl * Modify comments and function names * auto code-gen for creating ir operation in drr * delete debug log * optimize the interface of MatchFromOutputToInput() * Optimize SourcePatternGraph::OutputNodes judgment logic * polish code * using default operator=() in MatchContextImpl * fix merge conflict * create test case: drr_same_name_test * fix duplicate binding of ir op bug * Rename drr_same_name_test to drr_same_type_binding_test & Add graphical notes * refactor logic of insert point for creating new operation in drr * update * fix compile error * fix some bug * fix codestyle * fix bug * Update anchor node judgment logic * fix bug of link pir * fix codestyle * self review v1 * refine code format * set thread_local for count in op class * fix compile on mac * remove unused .h in value.cc * fix compile --------- Co-authored-by: zyfncg Co-authored-by: gongshaotian Co-authored-by: gongshaotian <> Co-authored-by: gongshaotian <141618702+gongshaotian@users.noreply.github.com> --- paddle/fluid/pir/CMakeLists.txt | 1 + .../op_generator/op_creator_drr_gen.py | 166 +++++ paddle/fluid/pir/drr/CMakeLists.txt | 65 ++ paddle/fluid/pir/drr/api/drr_pattern_base.h | 41 ++ .../fluid/pir/drr/api/drr_pattern_context.cc | 154 +++++ .../fluid/pir/drr/api/drr_pattern_context.h | 334 ++++++++++ paddle/fluid/pir/drr/api/match_context.cc | 49 ++ paddle/fluid/pir/drr/api/match_context.h | 43 ++ paddle/fluid/pir/drr/api/tensor_interface.cc | 34 ++ paddle/fluid/pir/drr/api/tensor_interface.h | 61 ++ paddle/fluid/pir/drr/attr_type_uilts.h | 116 ++++ paddle/fluid/pir/drr/drr_rewrite_pattern.h | 568 ++++++++++++++++++ paddle/fluid/pir/drr/ir_operation.h | 33 + paddle/fluid/pir/drr/ir_operation_factory.cc | 166 +++++ paddle/fluid/pir/drr/ir_operation_factory.h | 73 +++ paddle/fluid/pir/drr/ir_value.h | 82 +++ paddle/fluid/pir/drr/match_context_impl.h | 124 ++++ paddle/fluid/pir/drr/pattern_graph.cc | 223 +++++++ paddle/fluid/pir/drr/pattern_graph.h | 108 ++++ paddle/pir/pass/ir_printing.cc | 8 +- .../pattern_rewrite/pattern_rewrite_driver.cc | 1 + test/cpp/pir/pattern_rewrite/CMakeLists.txt | 41 ++ .../drr_attention_fuse_test.cc | 380 ++++++++++++ .../pattern_rewrite/drr_fuse_linear_test.cc | 399 ++++++++++++ .../drr_same_type_binding_test.cc | 332 ++++++++++ test/cpp/pir/pattern_rewrite/drr_test.cc | 232 +++++++ .../pattern_rewrite/pattern_rewrite_test.cc | 7 +- 27 files changed, 3833 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py create mode 100644 paddle/fluid/pir/drr/CMakeLists.txt create mode 100644 paddle/fluid/pir/drr/api/drr_pattern_base.h create mode 100644 paddle/fluid/pir/drr/api/drr_pattern_context.cc create mode 100644 paddle/fluid/pir/drr/api/drr_pattern_context.h create mode 100644 paddle/fluid/pir/drr/api/match_context.cc create mode 100644 paddle/fluid/pir/drr/api/match_context.h create mode 100644 paddle/fluid/pir/drr/api/tensor_interface.cc create mode 100644 paddle/fluid/pir/drr/api/tensor_interface.h create mode 100644 paddle/fluid/pir/drr/attr_type_uilts.h create mode 100644 paddle/fluid/pir/drr/drr_rewrite_pattern.h create mode 100644 paddle/fluid/pir/drr/ir_operation.h create mode 100644 paddle/fluid/pir/drr/ir_operation_factory.cc create mode 100644 paddle/fluid/pir/drr/ir_operation_factory.h create mode 100644 paddle/fluid/pir/drr/ir_value.h create mode 100644 paddle/fluid/pir/drr/match_context_impl.h create mode 100644 paddle/fluid/pir/drr/pattern_graph.cc create mode 100644 paddle/fluid/pir/drr/pattern_graph.h create mode 100644 test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc create mode 100644 test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc create mode 100644 test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc create mode 100644 test/cpp/pir/pattern_rewrite/drr_test.cc diff --git a/paddle/fluid/pir/CMakeLists.txt b/paddle/fluid/pir/CMakeLists.txt index 1ff77c6d7187e0..24f5e2892de8e2 100644 --- a/paddle/fluid/pir/CMakeLists.txt +++ b/paddle/fluid/pir/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(dialect) add_subdirectory(transforms) +add_subdirectory(drr) diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py new file mode 100644 index 00000000000000..c760d7fb85b84e --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import yaml +from op_gen import OpCompatParser, OpInfoParser, to_pascal_case + +CPP_FILE_TEMPLATE = """ +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" + +namespace pir {{ +namespace drr {{ + +void OperationFactory::RegisterGeneratedOpCreator() {{ +{body} +}} + +}} // namespace drr +}} // namespace pir + +""" + +NORMAL_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + return rewriter.Build( + {params_code}); + }}); +""" + +MUTABLE_ATTR_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + // mutable_attr is tensor + if (inputs.size() > {inputs_num}) {{ + return rewriter.Build( + {params_code_with_mutable_attr}); + }} else {{ + return rewriter.Build( + {params_code_no_mutable_attr}); + }} + }}); +""" + + +class OpCreatorCodeGen: + def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name): + self.op_info_items = self.parse_yaml(op_yaml_files, op_compat_yaml_file) + self.dialect_name = dialect_name + + def parse_yaml(self, op_yaml_files, op_compat_yaml_file): + op_compat_parser = OpCompatParser(op_compat_yaml_file) + + op_yaml_items = [] + for yaml_file in op_yaml_files: + with open(yaml_file, "r") as f: + ops = yaml.safe_load(f) + op_yaml_items = op_yaml_items + ops + op_info_items = [] + for op in op_yaml_items: + op_compat_item = op_compat_parser.get_compat(op['name']) + if ( + op_compat_item is not None + and op_compat_item['op'] == "pow" + and 'scalar' in op_compat_item + ): + op_compat_item = op_compat_item.pop('scalar') + op_info_items.append(OpInfoParser(op, op_compat_item)) + return op_info_items + + def gen_cpp_file_code(self, cpp_file_path): + body_code = "" + for op_info_item in self.op_info_items: + if op_info_item.infer_meta_map is None: + continue + for phi_op_name in op_info_item.op_phi_name: + ir_op_name = self.dialect_name + "." + phi_op_name + params_no_mutable_attr = [] + for i in range(len(op_info_item.input_name_list)): + params_no_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > 0: + params_no_mutable_attr.append("attrs") + + if len(op_info_item.mutable_attribute_name_list) == 0: + body_code += NORMAL_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code=", ".join(params_no_mutable_attr), + ) + else: + params_with_mutable_attr = [] + for i in range( + len(op_info_item.input_name_list) + + len(op_info_item.mutable_attribute_name_list) + ): + params_with_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > len( + op_info_item.mutable_attribute_name_list + ): + # TODO(zyfncg): Currently Op::Build Interface doesn't support this case. + continue + # params_with_mutable_attr.append("attrs") + + body_code += MUTABLE_ATTR_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + inputs_num=len(op_info_item.input_name_list), + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code_with_mutable_attr=",".join( + params_with_mutable_attr + ), + params_code_no_mutable_attr=", ".join( + params_no_mutable_attr + ), + ) + + with open(cpp_file_path, 'w') as f: + f.write(CPP_FILE_TEMPLATE.format(body=body_code)) + + +def ParseArguments(): + parser = argparse.ArgumentParser( + description='Generate Op Creator Files By Yaml' + ) + parser.add_argument('--op_yaml_files', type=str) + parser.add_argument('--op_compat_yaml_file', type=str) + parser.add_argument('--dialect_name', type=str) + parser.add_argument('--op_creator_file', type=str) + return parser.parse_args() + + +if __name__ == '__main__': + args = ParseArguments() + op_yaml_files = args.op_yaml_files.split(",") + op_compat_yaml_file = args.op_compat_yaml_file + op_creator_file = args.op_creator_file + dialect_name = args.dialect_name + + code_gen = OpCreatorCodeGen( + op_yaml_files, op_compat_yaml_file, dialect_name + ) + code_gen.gen_cpp_file_code(op_creator_file) diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt new file mode 100644 index 00000000000000..c1b524dda69a6a --- /dev/null +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -0,0 +1,65 @@ +file(GLOB DRR_SRCS "*.cc" "api/*.cc") + +set(op_creator_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +) +set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(op_forward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml +) +set(op_forward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml +) +set(op_backward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml +) +set(op_backward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml +) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) + +set(parsed_op_dir + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) + +set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) +set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) + +set(op_yaml_files + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4} +) + +set(op_creator_file + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/ir_op_factory_generated.cc) +set(op_creator_file_tmp ${op_creator_file}.tmp) + +set(dialect_name pd_op) + +add_custom_command( + OUTPUT ${op_creator_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --dialect_name ${dialect_name} + --op_creator_file ${op_creator_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_creator_file_tmp} + ${op_creator_file} + COMMENT "copy_if_different ${op_creator_file}" + DEPENDS ${op_creator_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + pd_op_dialect_op + VERBATIM) + +cc_library( + drr + SRCS ${DRR_SRCS} ${op_creator_file} + DEPS pd_op_dialect pir) diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h new file mode 100644 index 00000000000000..d5f19ff3e6e9be --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" + +namespace pir { +namespace drr { + +template +class DrrPatternBase { + public: + virtual ~DrrPatternBase() = default; + + // Define the Drr Pattern. + virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; + + std::unique_ptr> Build( + pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { + DrrPatternContext drr_context; + this->operator()(&drr_context); + return std::make_unique>( + drr_context, ir_context, benefit); + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/api/drr_pattern_context.cc new file mode 100644 index 00000000000000..5f74b986f1a5e7 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" + +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/pir/core/enforce.h" + +namespace pir { +namespace drr { + +DrrPatternContext::DrrPatternContext() { + source_pattern_graph_ = std::make_shared(); + result_pattern_graph_ = std::make_shared(); +} + +drr::SourcePattern DrrPatternContext::SourcePattern() { + return drr::SourcePattern(this); +} +const Op& DrrPatternContext::SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, source_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +const drr::Tensor& DrrPatternContext::SourceTensorPattern( + const std::string& name) { + return source_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, source_pattern_graph_.get()))); +} + +const Op& DrrPatternContext::ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, result_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +drr::Tensor& DrrPatternContext::ResultTensorPattern(const std::string& name) { + return result_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, result_pattern_graph_.get()))); +} + +std::vector DrrPatternContext::constraints() const { + return constraints_; +} + +// void DrrPatternContext::RequireEqual(const Attribute& first, const Attribute& +// second) { +// auto constrain_fn = [&](const MatchContext& match_context) { +// return match_context.Attr(first.id()) == match_context.Attr(second.id()); +// }; +// constraints_.emplace_back(constrain_fn); +// } + +void DrrPatternContext::RequireEqual(const TensorShape& first, + const TensorShape& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Shape() == + match_context.Tensor(second.tensor_name()).Shape(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireEqual(const TensorDataType& first, + const TensorDataType& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Dtype() == + match_context.Tensor(second.tensor_name()).Dtype(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireNativeCall( + const std::function& custom_fn) { + constraints_.emplace_back(custom_fn); +} + +void Op::operator()(const Tensor& arg, const Tensor* out) const { + std::vector inputs{&arg}; + std::vector outputs{out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); +} + +void Op::operator()(const std::vector& args, + const std::vector& outputs) const { + pattern_graph_->AddOpCall(std::make_shared(this, args, outputs)); +} + +Tensor& Op::operator()(const Tensor& arg) const { + std::vector inputs{&arg}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const { + std::vector inputs{&arg1, &arg2}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()() const { + std::vector inputs{}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +thread_local int64_t Op::count = 0; +const char* Op::prefix = "@drr_temp@_"; + +const char Tensor::NONE_TENSOR_NAME[] = "__@none_tensor@__"; + +void Tensor::Assign(const Tensor& other) { + dynamic_cast(pattern_graph_)->AssignTensor(*this, other); +} + +void Tensor::operator=(const Tensor& other) const { // NOLINT + // The two tensor must be in the same pattern graph. + IR_ENFORCE(this->pattern_graph_ == other.pattern_graph_); + if (other.name_.find(Op::prefix) == 0 && + name_.find(Op::prefix) == std::string::npos) { + other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_); + } +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h new file mode 100644 index 00000000000000..b4156bd54bf414 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -0,0 +1,334 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/match_context.h" + +namespace pir { +namespace drr { + +class Op; +class Tensor; +class OpCall; +class SourcePattern; +class ResultPattern; +class PatternGraph; +class SourcePatternGraph; +class ResultPatternGraph; + +class NormalAttribute { + public: + explicit NormalAttribute(const std::string& name) : attr_name_(name) {} + + const std::string& name() const { return attr_name_; } + + private: + std::string attr_name_; +}; + +using AttrComputeFunc = std::function; + +class ComputeAttribute { + public: + explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func) + : attr_compute_func_(attr_compute_func) {} + + const AttrComputeFunc& attr_compute_func() const { + return attr_compute_func_; + } + + private: + AttrComputeFunc attr_compute_func_; +}; + +using Attribute = std::variant; + +class TensorShape { + public: + explicit TensorShape(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class TensorDataType { + public: + explicit TensorDataType(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class Constraint { + public: + explicit Constraint( + const std::function& constrain_fn) + : IsContextMatchConstraint_(constrain_fn) {} + bool operator()(const MatchContext& match_context) const { + return IsContextMatchConstraint_(match_context); + } + + private: + std::function IsContextMatchConstraint_; +}; + +class DrrPatternContext { + public: + DrrPatternContext(); + ~DrrPatternContext() = default; + + drr::SourcePattern SourcePattern(); + + std::shared_ptr source_pattern_graph() const { + return source_pattern_graph_; + } + + std::vector constraints() const; + + std::shared_ptr result_pattern_graph() const { + return result_pattern_graph_; + } + + private: + friend class drr::SourcePattern; + friend class drr::ResultPattern; + + const Op& SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + const drr::Tensor& SourceTensorPattern(const std::string& name); + + const Op& ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + drr::Tensor& ResultTensorPattern(const std::string& name); + + // void RequireEqual(const Attribute& first, const Attribute& second); + void RequireEqual(const TensorShape& first, const TensorShape& second); + void RequireEqual(const TensorDataType& first, const TensorDataType& second); + void RequireNativeCall( + const std::function& custom_fn); + + std::shared_ptr source_pattern_graph_; + std::vector constraints_; + std::shared_ptr result_pattern_graph_; + + std::vector> owned_ops_; +}; + +class Op { + public: + const std::string& name() const { return op_type_name_; } + + void operator()(const Tensor& arg, const Tensor* out) const; + + Tensor& operator()() const; + + Tensor& operator()(const Tensor& arg) const; + Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; + void operator()(const std::vector& args, + const std::vector& outputs) const; + // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const + // Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const + // Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& + // operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2, + // const Tensor& arg3, const Tensor& arg4) const; + + static const char* prefix; + + private: + friend class DrrPatternContext; + friend class OpCall; + + Op(const std::string& op_type_name, + const std::unordered_map& attributes, + PatternGraph* pattern_graph) + : op_type_name_(op_type_name), + attributes_(attributes), + pattern_graph_(pattern_graph) {} + + const std::unordered_map& attributes() const { + return attributes_; + } + + thread_local static int64_t count; + + std::string op_type_name_; + std::unordered_map attributes_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class Tensor { + public: + static const char NONE_TENSOR_NAME[]; + + const std::string& DebugName() const; + + TensorShape shape() const { return TensorShape(name()); } + + TensorDataType dtype() const { return TensorDataType(name()); } + + bool is_none() const { return name_ == NONE_TENSOR_NAME; } + + void Assign(const Tensor& other); + + void operator=(const Tensor& other) const; // NOLINT + + const std::string& name() const { return name_; } + + void set_name(const std::string& name) { name_ = name; } + + OpCall* producer() const { return producer_; } + + void set_producer(OpCall* producer) { producer_ = producer; } + + const std::vector& consumers() const { return consumers_; } + + void set_consumables(const std::vector& consumers) { + consumers_ = consumers; + } + + void AddConsumer(const OpCall* consumer) { consumers_.push_back(consumer); } + + private: + friend class DrrPatternContext; + friend class Op; + + Tensor(const std::string& name, PatternGraph* pattern_graph) + : name_(name), pattern_graph_(pattern_graph) {} + + std::string name_; + OpCall* producer_{nullptr}; + std::vector consumers_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class OpCall { + public: + OpCall(const Op* op, + const std::vector& inputs, + const std::vector& outputs) + : op_name_(op->op_type_name_), + inputs_(inputs), + outputs_(outputs), + attributes_(op->attributes_) {} + + const std::string& name() const { return op_name_; } + + const std::vector& inputs() const { return inputs_; } + + const std::vector& outputs() const { return outputs_; } + + const std::unordered_map& attributes() const { + return attributes_; + } + + private: + std::string op_name_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map attributes_; +}; + +class ResultPattern { + public: + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->ResultOpPattern(op_type, attributes); + } + + drr::Tensor& Tensor(const std::string& name) { + return ctx_->ResultTensorPattern(name); + } + + // Represent the input tensor which is none. + // Example: + // instance_norm has follow input tensor : (x, scale, bias), scale and + // bias are optional(means it may be none). + // When scale is onoe, we can write a instance_norm op in drr as follow: + // res.Op("instance_norm")(res.Tensor("x"), res.NoneTensor, + // res.Tensor("bias")); + drr::Tensor& NoneTensor() { + return ctx_->ResultTensorPattern(Tensor::NONE_TENSOR_NAME); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + Attribute Attr(const AttrComputeFunc& attr_compute_func) const { + return ComputeAttribute(attr_compute_func); + } + + private: + friend class SourcePattern; + + explicit ResultPattern(DrrPatternContext* ctx) : ctx_(ctx) {} + + DrrPatternContext* ctx_{nullptr}; +}; + +class SourcePattern { + public: + drr::ResultPattern ResultPattern() const { return drr::ResultPattern(ctx_); } + + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->SourceOpPattern(op_type, attributes); + } + + const drr::Tensor& Tensor(const std::string& name) { + return ctx_->SourceTensorPattern(name); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + + void RequireEqual(const TensorShape& first, const TensorShape& second) { + ctx_->RequireEqual(first, second); + } + void RequireEqual(const TensorDataType& first, const TensorDataType& second) { + ctx_->RequireEqual(first, second); + } + + void RequireNativeCall( + const std::function& custom_fn) { + ctx_->RequireNativeCall(custom_fn); + } + + private: + friend class DrrPatternContext; + explicit SourcePattern(DrrPatternContext* ctx) : ctx_(ctx) {} + DrrPatternContext* ctx_{nullptr}; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/api/match_context.cc new file mode 100644 index 00000000000000..35b28db13254ed --- /dev/null +++ b/paddle/fluid/pir/drr/api/match_context.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/match_context.h" + +#include + +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" + +namespace pir { +namespace drr { + +MatchContext::MatchContext(std::shared_ptr impl) + : impl_(impl) {} + +const TensorInterface& MatchContext::Tensor( + const std::string& tensor_name) const { + return impl_->Tensor(tensor_name); +} + +template +T MatchContext::Attr(const std::string& attr_name) const { + return impl_->Attr(attr_name); +} + +template bool MatchContext::Attr(const std::string&) const; +template int32_t MatchContext::Attr(const std::string&) const; +template int64_t MatchContext::Attr(const std::string&) const; +template float MatchContext::Attr(const std::string&) const; +template std::string MatchContext::Attr(const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/api/match_context.h new file mode 100644 index 00000000000000..a1699ccb5bddf6 --- /dev/null +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -0,0 +1,43 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_operation.h" + +namespace pir { +namespace drr { + +class TensorInterface; +class MatchContextImpl; + +class MatchContext final { + public: + MatchContext(std::shared_ptr impl); + + const TensorInterface& Tensor(const std::string& tensor_name) const; + + template + T Attr(const std::string& attr_name) const; + + private: + std::shared_ptr impl_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc new file mode 100644 index 00000000000000..1b81b3a5672117 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_value.h" + +namespace pir { +namespace drr { + +bool ShapeInterface::operator==(const ShapeInterface& other) const { + return *shape_ == *other.shape_; +} + +int ShapeInterface::size() const { return shape_->size(); } + +int64_t ShapeInterface::at(int idx) const { return shape_->at(idx); } + +bool DtypeInterface::operator==(const DtypeInterface& other) const { + return *dtype_ == *other.dtype_; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h new file mode 100644 index 00000000000000..7629857591bf33 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace pir { +namespace drr { + +class IrValue; +class IrShape; +class IrDtype; + +class ShapeInterface final { + public: + bool operator==(const ShapeInterface& other) const; + + int size() const; + + int64_t at(int idx) const; + + private: + explicit ShapeInterface(const IrShape* shape) : shape_(shape) {} + + friend class IrValue; + + const IrShape* shape_; +}; + +class DtypeInterface final { + public: + bool operator==(const DtypeInterface& other) const; + + private: + explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} + + friend class IrValue; + + const IrDtype* dtype_; +}; + +class TensorInterface { + public: + virtual ShapeInterface Shape() const = 0; + virtual DtypeInterface Dtype() const = 0; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h new file mode 100644 index 00000000000000..fb989fe063b771 --- /dev/null +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -0,0 +1,116 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +template +struct CppTypeToIrAttribute; + +#define PD_SPECIALIZE_CppTypeToIrAttribute(cpp_type, ir_attr_type) \ + template <> \ + struct CppTypeToIrAttribute< \ + std::remove_const_t>> { \ + using type = ir_attr_type; \ + }; + +PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, + paddle::dialect::DataTypeAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, + paddle::dialect::IntArrayAttribute); + +template +struct IrAttrbuteCreator { + typename CppTypeToIrAttribute::type operator()(T obj) const { + return CppTypeToIrAttribute::type::template get( + pir::IrContext::Instance(), obj); + } +}; + +template <> +struct IrAttrbuteCreator> { + pir::ArrayAttribute operator()(std::vector obj) const { + std::vector attr_vec; + attr_vec.reserve(obj.size()); + for (int32_t x : obj) { + attr_vec.push_back(Int32Attribute::get(pir::IrContext::Instance(), x)); + } + return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); + } +}; + +template +struct IrAttrTypeCast { + static T To(const pir::Attribute& attr) { + return attr.dyn_cast::type>().data(); + } +}; + +template <> +struct IrAttrTypeCast { + static std::string To(const pir::Attribute& attr) { + return attr.dyn_cast::type>() + .AsString(); + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back(array_attr.at(i).dyn_cast().data()); + } + return result; + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + if (attr.dyn_cast()) { + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back( + array_attr.at(i).dyn_cast().data()); + } + } else if (attr.dyn_cast()) { + result = + attr.dyn_cast().data().GetData(); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "Dynamic cast failed for IR attribute vector")); + } + return result; + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h new file mode 100644 index 00000000000000..c17feb0eaad052 --- /dev/null +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -0,0 +1,568 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_operation_factory.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/type_name.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +template +class DrrRewritePattern : public pir::RewritePattern { + public: + explicit DrrRewritePattern(const DrrPatternContext& drr_context, + pir::IrContext* context, + pir::PatternBenefit benefit = 1) + : pir::RewritePattern( + drr_context.source_pattern_graph()->AnchorNode()->name(), + benefit, + context, + {}), + source_pattern_graph_(drr_context.source_pattern_graph()), + constraints_(drr_context.constraints()), + result_pattern_graph_(drr_context.result_pattern_graph()) { + IR_ENFORCE(!source_pattern_graph_->owned_op_call().empty(), + "source_pattern_graph is empty, please check the drr pattern " + "define code."); + } + + bool MatchAndRewrite(pir::Operation* op, + PatternRewriter& rewriter) const override { // NOLINT + std::shared_ptr src_match_ctx = + std::make_shared(); + if (PatternGraphMatch(op, src_match_ctx.get())) { + VLOG(4) << "DRR pattern (" << pir::get_type_name() + << ") is matched in program."; + PatternGraphRewrite(*src_match_ctx, rewriter); + return true; + } + return false; + } + + private: + bool PatternGraphMatch(pir::Operation* op, + MatchContextImpl* source_pattern_match_ctx) const { + VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; + const OpCall* anchor = source_pattern_graph_->AnchorNode(); + std::unordered_map> + bind_map = + FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); + if (bind_map.empty()) { + return false; + } + std::vector drr_output_sequence; + std::vector ir_output_sequence; + std::unordered_map output_op_map; + for (auto pair : bind_map) { + drr_output_sequence.push_back(pair.first); + } + // using dfs to obtain the arrangement of all candidate ir ops + auto permute = [&](auto&& permute, size_t index) -> bool { + if (index == drr_output_sequence.size()) { + // avoiding duplicate binding of ir op + std::unordered_set ir_output_set; + for (Operation* op : ir_output_sequence) { + auto pr = ir_output_set.insert(op); + if (pr.second == false) { + return false; + } + } + // new match_ctx + std::shared_ptr match_ctx = + std::make_shared(); + std::transform(drr_output_sequence.begin(), + drr_output_sequence.end(), + ir_output_sequence.begin(), + std::inserter(output_op_map, output_op_map.end()), + [](const OpCall* drr_op, Operation* ir_op) { + return std::make_pair(drr_op, ir_op); + }); + if (MatchFromOutputToInput( + output_op_map, *(source_pattern_graph_.get()), match_ctx)) { + *source_pattern_match_ctx = *match_ctx; + return true; + } + return false; + } + for (auto* ir_op : bind_map[drr_output_sequence[index]]) { + ir_output_sequence.push_back(ir_op); + if (permute(permute, index + 1)) { + return true; + } + ir_output_sequence.pop_back(); + } + return false; + }; + + return permute(permute, 0); + } + + std::unordered_map> + FindCandidateIrOutputOp( + pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const { + // get source pattern output op + std::unordered_set drr_output_op_set = + source_pattern_graph.OutputNodes(); + std::unordered_map> + output_op_bind_map{{anchor, {op}}}; + if (drr_output_op_set.size() == 1) { + return output_op_bind_map; + } + std::unordered_set drr_visited_ops{anchor}; + DfsVisitor( + anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map); + if (output_op_bind_map.size() != drr_output_op_set.size()) { + return {}; + } + return output_op_bind_map; + } + + void DfsVisitor( + const OpCall* drr_op, + pir::Operation* ir_op, + const std::unordered_set& drr_output_op_set, + std::unordered_set* drr_visited_ops, + std::unordered_map>* + output_op_bind_map) const { + VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")" + << "ir op(" << ir_op->name() << ")"; + if (drr_op->name() != ir_op->name()) { + return; + } + // check op input's size + const auto& drr_op_input_tensors = drr_op->inputs(); + auto ir_op_input_value_size = ir_op->num_operands(); + if (drr_op_input_tensors.size() != ir_op_input_value_size) { + return; + } + // check op output's size + const auto& drr_op_output_tensors = drr_op->outputs(); + auto ir_op_output_value_size = ir_op->num_results(); + if (drr_op_output_tensors.size() != ir_op_output_value_size) { + return; + } + // check producer op + for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) { + // case 1: drr_op_input_tensor is the input tensor of source pattern + if (drr_op_input_tensors[i]->producer() == nullptr) { + // dfs source pattern input tensor other child op + auto ir_input_tensor = ir_op->operand(i).source(); + for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) { + if (drr_visited_ops->count(drr_bro_op)) { + continue; + } + for (auto it = ir_input_tensor.use_begin(); + it != ir_input_tensor.use_end(); + ++it) { + auto* ir_bro_op = it.owner(); + if (drr_bro_op->name() == ir_bro_op->name()) { + drr_visited_ops->insert(drr_bro_op); + DfsVisitor(drr_bro_op, + ir_bro_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_bro_op); + } + } + } + continue; + } + // case 2: have producer op + const auto& drr_producer_op = drr_op_input_tensors[i]->producer(); + if (drr_visited_ops->count(drr_producer_op)) { + continue; + } + auto ir_operand_value = ir_op->operand(i).source(); + if (drr_op_input_tensors[i]->consumers().size() != + ir_operand_value.use_count()) { + return; + } + auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); + drr_visited_ops->insert(drr_producer_op); + DfsVisitor(drr_producer_op, + ir_producer_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_producer_op); + } + if (drr_output_op_set.count(drr_op)) { + (*output_op_bind_map)[drr_op].insert(ir_op); + return; + } + // check child ops + for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) { + const auto& drr_child_ops = drr_op_output_tensors[i]->consumers(); + auto ir_output_value = ir_op->result(i); + if (drr_child_ops.size() != ir_output_value.use_count()) { + return; + } + for (auto* drr_child_op : drr_child_ops) { + for (auto it = ir_output_value.use_begin(); + it != ir_output_value.use_end(); + ++it) { + auto* ir_child_op = it.owner(); + if (drr_child_op->name() == ir_child_op->name()) { + if (drr_visited_ops->count(drr_child_op)) { + continue; + } + drr_visited_ops->insert(drr_child_op); + DfsVisitor(drr_child_op, + ir_child_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_child_op); + } + } + } + } // check child ops + return; + } + + bool MatchFromOutputToInput( + std::unordered_map output_op_map, + const SourcePatternGraph& source_pattern_graph, + const std::shared_ptr& source_pattern_match_ctx) const { + VLOG(6) << "MatchFromOutputToInput Start"; + std::unordered_set drr_visited; + std::unordered_set ir_visited; + std::queue drr_q; + std::queue ir_q; + bool matched = true; + size_t step = 0; + for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { + VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" + << it->second << ") in source_pattern_graph "; + drr_q.push(it->first); + drr_visited.insert(it->first); + ir_q.push(it->second); + ir_visited.insert(it->second); + } + while (!drr_q.empty()) { + if (!matched) break; + auto* drr_node = drr_q.front(); + auto* ir_node = ir_q.front(); + drr_q.pop(); + ir_q.pop(); + if (drr_node->name() != ir_node->name()) { + matched = false; + break; + } + const auto& drr_input_tensors = drr_node->inputs(); + auto ir_input_value_size = ir_node->num_operands(); + if (drr_input_tensors.size() != ir_input_value_size) { + matched = false; + break; + } + if (drr_node->outputs().size() != ir_node->num_results()) { + matched = false; + break; + } + source_pattern_match_ctx->BindIrOperation( + drr_node, std::make_shared(ir_node)); + // binding input_tensor of current_op + for (size_t i = 0; i < drr_input_tensors.size(); ++i) { + source_pattern_match_ctx->BindIrValue( + drr_input_tensors[i]->name(), + std::make_shared(ir_node->operand(i).source())); + auto* drr_producer_op = drr_input_tensors[i]->producer(); + if (drr_producer_op == nullptr) { + continue; + } + auto* ir_producer_op = + ir_node->operand(i).source().dyn_cast().owner(); + if (drr_input_tensors[i]->consumers().size() != + ir_node->operand(i).source().use_count()) { + matched = false; + break; + } + // bfs producer_op of current_op + if (!drr_visited.count(drr_producer_op)) { + drr_q.push(drr_producer_op); + ir_q.push(ir_producer_op); + drr_visited.insert(drr_producer_op); + ir_visited.insert(ir_producer_op); + } + } + // binding output tensor of current_op + auto drr_op_output_tensor = drr_node->outputs(); + for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { + source_pattern_match_ctx->BindIrValue( + drr_op_output_tensor[j]->name(), + std::make_shared(ir_node->result(j))); + } + ++step; + } + + if (matched) { + IR_ENFORCE(step == source_pattern_graph.CountOfOpCalls()); + } else { + return matched; + } + + MatchContext match_context{source_pattern_match_ctx}; + for (const auto& constraint : constraints_) { + matched = constraint(match_context); + if (!matched) break; + } + + return matched; + } + + void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + VLOG(6) << "Create Operations in result_pattern_graph"; + MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + VLOG(6) << "Process Assign Tensor"; + RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx); + VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values " + "in result_pattern_graph"; + ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); + VLOG(6) << "Delete Operations in source_pattern_graph"; + DeleteSourcePatternOp(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + } + + private: + MatchContextImpl CreateOperations( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + MatchContextImpl res_match_ctx; + // add input tensors info for res_match_ctx + for (const auto& in_tensor : result_pattern_graph.input_tensors()) { + IR_ENFORCE(result_pattern_graph.id2owend_tensor().count(in_tensor), + "Drr input tensor [%s] must exists in result pattern graph.", + in_tensor); + if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + res_match_ctx.BindIrValue( + in_tensor, + std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + } + } + + if (result_pattern_graph.CountOfOpCalls() == 1) { + CreateOperation(*result_pattern_graph.owned_op_call()[0], + src_match_ctx, + rewriter, + &res_match_ctx); + return res_match_ctx; + } + + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; + for (Operation* op : *rewriter.block()) { + op_2_temp_program_index[op] = temp_program.size(); + temp_program.push_back({op}); + } + + // topo order visit result_pattern_graph + GraphTopo graph_topo_visit(&result_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { + // set insert point + size_t max_input_op_index = 0; + Operation* max_index_op = nullptr; + for (const Tensor* input : op_call.inputs()) { + if (input->is_none()) { + continue; + } + Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); + if (ir_val) { + Operation* ir_input_op = ir_val.dyn_cast().owner(); + if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { + max_input_op_index = op_2_temp_program_index[ir_input_op]; + max_index_op = ir_input_op; + } else if (max_input_op_index == + op_2_temp_program_index[ir_input_op]) { + const auto& ops_vec = temp_program[max_input_op_index]; + for (auto it = ops_vec.rbegin(); it != ops_vec.rend(); it++) { + if (*it == max_index_op) { + break; + } else if (*it == ir_input_op) { + max_index_op = ir_input_op; + break; + } else { + // do nothing + } + } + } else { + // do nothing + } + } + } + if (max_input_op_index == 0UL) { + VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; + Operation* source_patter_first_op = + src_match_ctx + .Operation(source_pattern_graph.owned_op_call()[0].get()) + .get(); + max_input_op_index = op_2_temp_program_index[source_patter_first_op]; + rewriter.SetInsertionPoint(source_patter_first_op); + } else { + rewriter.SetInsertionPointAfter(max_index_op); + } + + Operation* new_op = + CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); + op_2_temp_program_index[new_op] = max_input_op_index + 1; + temp_program[max_input_op_index + 1].push_back(new_op); + }); + + return res_match_ctx; + } + + void RebindIrTensorForAssignTensor( + const ResultPatternGraph& result_pattern_graph, + MatchContextImpl* res_match_ctx) const { + const auto& tensor_assign_map = result_pattern_graph.tensor_assign_map(); + for (const auto& kv : tensor_assign_map) { + const auto& src_tensor_name = kv.first; + const auto& dst_tensor_name = kv.second; + res_match_ctx->BindIrValue( + src_tensor_name, + std::make_shared( + res_match_ctx->GetIrValue(dst_tensor_name))); + } + } + + void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx, + const MatchContextImpl& res_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + for (const auto& output_name : result_pattern_graph_->output_tensors()) { + if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); + const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); + rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + } else { + LOG(WARNING) << "The output tensor (" << output_name + << ") in the result_pattern_graph is not the tensor" + " in source_pattern_graph."; + } + } + } + + void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + std::vector topo_order_ops; + GraphTopo graph_topo_visit(&source_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder( + [&topo_order_ops](const OpCall& op_call) { + topo_order_ops.push_back(&op_call); + }); + + // Filter the operations which are replaced by result pattern + // 1. Filter operations by forward walk + std::unordered_set forward_visited_tensor_set( + result_pattern_graph.input_tensors()); + std::unordered_set forward_deleted_ops; + std::for_each(topo_order_ops.begin(), + topo_order_ops.end(), + [&forward_deleted_ops, + &forward_visited_tensor_set](const OpCall* op_call) { + if (op_call->inputs().empty()) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + } + for (const auto* input : op_call->inputs()) { + if (forward_visited_tensor_set.count(input->name())) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + break; + } + } + }); + // 2. Filter operations by backward walk and merge the forward result + std::unordered_set backward_visited_tensor_set( + result_pattern_graph.output_tensors()); + std::vector deleted_ops; + std::unordered_set deleted_ops_set; + std::for_each(topo_order_ops.rbegin(), + topo_order_ops.rend(), + [&deleted_ops, + &deleted_ops_set, + &backward_visited_tensor_set, + &forward_deleted_ops](const OpCall* op_call) { + bool all_comsumer_deleted = true; + bool from_backward_visited_tensor = false; + for (const auto* output : op_call->outputs()) { + if (backward_visited_tensor_set.count(output->name())) { + from_backward_visited_tensor = true; + } else if (output->consumers().empty()) { + continue; + } else { + all_comsumer_deleted = false; + } + } + if (all_comsumer_deleted && from_backward_visited_tensor && + forward_deleted_ops.count(op_call)) { + deleted_ops_set.insert(op_call); + deleted_ops.push_back(op_call); + for (const auto* input : op_call->inputs()) { + backward_visited_tensor_set.insert(input->name()); + } + } + }); + + // Delete Operation with topo order from output tensors. + for (const auto* op_call : deleted_ops) { + IR_ENFORCE(src_match_ctx.operation_map().count(op_call), + "Drr OpCall [%s] must exists in match context.", + op_call->name()); + auto* op = src_match_ctx.operation_map().at(op_call)->get(); + VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op + << ") in source_pattern_graph "; + rewriter.EraseOp(op); + } + } + + private: + const std::shared_ptr source_pattern_graph_; + const std::vector constraints_; + const std::shared_ptr result_pattern_graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/ir_operation.h new file mode 100644 index 00000000000000..2764bc92454170 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/operation.h" + +namespace pir { +namespace drr { + +class IrOperation { + public: + explicit IrOperation(pir::Operation* op) : op_(op) {} + + pir::Operation* get() const { return op_; } + + private: + pir::Operation* op_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc new file mode 100644 index 00000000000000..5355a8977e8c53 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +void OperationFactory::RegisterManualOpCreator() { + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + attrs); + }); + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue_grad", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + inputs[3].dyn_cast(), + attrs); + }); + RegisterOperationCreator("builtin.combine", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build(inputs); + }); +} + +static pir::Attribute CreateIrAttribute(const std::any& obj) { + if (obj.type() == typeid(bool)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int32_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int64_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(float)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::string)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(const char*)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(phi::DataType)) { + return IrAttrbuteCreator()( + std::any_cast(obj)); + } else if (obj.type() == typeid(phi::Place)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Type error. CreateIrAttribute for type(%s) " + "is unimplemented CreateInCurrently.", + obj.type().name())); + } +} + +pir::AttributeMap CreateAttributeMap(const OpCall& op_call, + const MatchContextImpl& src_match_ctx) { + pir::AttributeMap attr_map; + for (const auto& kv : op_call.attributes()) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name()); + } + if constexpr (std::is_same_v, + ComputeAttribute>) { + MatchContext ctx(std::make_shared(src_match_ctx)); + attr_map[kv.first] = + CreateIrAttribute(arg.attr_compute_func()(ctx)); + } + }, + kv.second); + } + return attr_map; +} + +Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { + if (tensor.is_none()) { + return Value{}; + } + return res_match_ctx.GetIrValue(tensor.name()).get(); +} + +std::vector GetIrValuesByDrrTensors( + const std::vector& tensors, + const MatchContextImpl& res_match_ctx) { + std::vector ir_values; + ir_values.reserve(tensors.size()); + for (const auto* tensor : tensors) { + ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); + } + return ir_values; +} + +void BindIrOutputs(const OpCall& op_call, + pir::Operation* op, + MatchContextImpl* match_ctx) { + for (size_t i = 0; i < op_call.outputs().size(); ++i) { + std::shared_ptr ir_value = nullptr; + if (op->result(i)) { + ir_value = std::make_shared(op->result(i)); + } + match_ctx->BindIrValue(op_call.outputs()[i]->name(), ir_value); + } +} + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx) { + VLOG(6) << "Drr create [" << op_call.name() << "] op..."; + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + pir::Operation* op = OperationFactory::Instance().CreateOperation( + op_call.name(), + ir_values, + CreateAttributeMap(op_call, src_match_ctx), + rewriter); + BindIrOutputs(op_call, op, res_match_ctx); + VLOG(6) << "Drr create [" << op_call.name() << "] op done."; + return op; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h new file mode 100644 index 00000000000000..b38b5cd6a12b32 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +class OperationFactory { + public: + static OperationFactory& Instance() { + static OperationFactory operation_factory; + return operation_factory; + } + + using operation_create_fn = + std::function&, + const pir::AttributeMap&, + pir::PatternRewriter&)>; + + void RegisterOperationCreator(const std::string& op_name, + const operation_create_fn& create_fn) { + op_creator_map.emplace(op_name, create_fn); + } + + pir::Operation* CreateOperation( + const std::string& op_name, + const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) const { // NOLINT + auto iter = op_creator_map.find(op_name); + IR_ENFORCE(iter != op_creator_map.end(), + "The create function for op: (%s) is not found.", + op_name); + return iter->second(inputs, attrs, rewriter); + } + + private: + OperationFactory() { + RegisterGeneratedOpCreator(); + RegisterManualOpCreator(); + } + + void RegisterManualOpCreator(); + void RegisterGeneratedOpCreator(); + + std::unordered_map op_creator_map; +}; + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx); + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h new file mode 100644 index 00000000000000..907df9dfd24ebc --- /dev/null +++ b/paddle/fluid/pir/drr/ir_value.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +class IrShape { + public: + explicit IrShape(const phi::DDim& dims) : dims_(dims) {} + + bool operator==(const IrShape& other) const { return dims_ == other.dims_; } + + int size() const { return dims_.size(); } + + int64_t at(int idx) const { return dims_.at(idx); } + + private: + const phi::DDim dims_; +}; + +class IrDtype { + public: + explicit IrDtype(pir::Type dtype) : dtype_(dtype) {} + + bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } + + private: + const pir::Type dtype_; +}; + +class IrValue : public TensorInterface { + public: + explicit IrValue(const pir::Value& value) + : value_(value), + shape_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dims() + : phi::DDim{}), + dtype_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dtype() + : pir::Type{}) {} + + ShapeInterface Shape() const override { return ShapeInterface(&shape_); } + DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); } + + const Value& get() const { return value_; } + + private: + const Value value_; + const IrShape shape_; + const IrDtype dtype_; +}; + +class IrAttr; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h new file mode 100644 index 00000000000000..a04efbbfaf444b --- /dev/null +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -0,0 +1,124 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +class MatchContextImpl final { + public: + MatchContextImpl() = default; + ~MatchContextImpl() = default; + + const TensorInterface& Tensor(const std::string& tensor_name) const { + IR_ENFORCE(tensor_map_.count(tensor_name), + "Drr tensor [%s] must exists in pattern graph.", + tensor_name); + return *tensor_map_.at(tensor_name); + } + + const IrOperation& Operation(const OpCall* op_call) const { + IR_ENFORCE(operation_map_.count(op_call), + "Drr operation [%s] must exists in pattern graph.", + op_call->name()); + return *operation_map_.at(op_call); + } + + template + T Attr(const std::string& attr_name) const { + return IrAttrTypeCast::To(GetIrAttr(attr_name)); + } + + const IrValue& GetIrValue(const std::string& tensor_name) const { + auto iter = tensor_map_.find(tensor_name); + PADDLE_ENFORCE_NE( + iter, + tensor_map_.end(), + phi::errors::OutOfRange( + "the drr tensor(%s) is not found in the map to ir value.", + tensor_name)); + return *iter->second; + } + + pir::Attribute GetIrAttr(const std::string& attr_name) const { + auto iter = attr_map_.find(attr_name); + PADDLE_ENFORCE_NE( + iter, + attr_map_.end(), + phi::errors::OutOfRange( + "the drr attr(%s) is not found in the map to ir attribute.", + attr_name)); + return iter->second; + } + + const std::unordered_map>& + operation_map() const { + return operation_map_; + } + + const std::unordered_map& attr_map() const { + return attr_map_; + } + + const std::unordered_map>& tensor_map() + const { + return tensor_map_; + } + + void BindIrValue(const std::string& value_name, + const std::shared_ptr& value) { + tensor_map_.emplace(value_name, value); + } + + void BindIrOperation(const OpCall* op_call, + const std::shared_ptr& op) { + operation_map_.emplace(op_call, op); + const auto& attrs = op_call->attributes(); + for (const auto& kv : attrs) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + } + }, + kv.second); + } + } + + private: + void BindIrAttr(const std::string& attr_name, pir::Attribute attr) { + attr_map_.emplace(attr_name, attr); + } + + std::unordered_map> tensor_map_; + std::unordered_map> + operation_map_; + std::unordered_map attr_map_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc new file mode 100644 index 00000000000000..0b63f398a790bd --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -0,0 +1,223 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/pattern_graph.h" + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/pir/core/enforce.h" + +namespace pir { +namespace drr { + +const drr::OpCall &PatternGraph::AddOpCall( + const std::shared_ptr &op_call) { + owned_op_call_.push_back(op_call); + for (const auto *input : op_call->inputs()) { + const auto &tensor_name = input->name(); + IR_ENFORCE(id2owned_tensor_.count(tensor_name), + "intput tensor [%s] not exist.", + tensor_name); + id2owned_tensor_.at(tensor_name)->AddConsumer(op_call.get()); + + if (input->producer() == nullptr) { + input_tensors_.insert(tensor_name); + } + if (output_tensors_.find(tensor_name) != output_tensors_.end()) { + output_tensors_.erase(tensor_name); + } + } + for (auto &output : op_call->outputs()) { + const auto &out_tensor_name = output->name(); + IR_ENFORCE(id2owned_tensor_.count(out_tensor_name)); + id2owned_tensor_[output->name()]->set_producer(op_call.get()); + } + return *owned_op_call_.back(); +} + +drr::Tensor &PatternGraph::AddTensor( + const std::shared_ptr &tensor) { + if (id2owned_tensor_.find(tensor->name()) == id2owned_tensor_.end()) { + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + } + return *id2owned_tensor_[tensor->name()]; +} + +drr::Tensor &PatternGraph::AddTmpTensor( + const std::shared_ptr &tensor) { + IR_ENFORCE(id2owned_tensor_.count(tensor->name()) == 0); + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + return *id2owned_tensor_[tensor->name()]; +} + +void PatternGraph::UpdateTmpTensor(const std::string &tmp_tensor_name, + const std::string &new_tensor_name) { + if (input_tensors_.count(tmp_tensor_name)) { + input_tensors_.erase(tmp_tensor_name); + input_tensors_.insert(new_tensor_name); + } + + output_tensors_.erase(new_tensor_name); + if (output_tensors_.count(tmp_tensor_name)) { + output_tensors_.erase(tmp_tensor_name); + output_tensors_.insert(new_tensor_name); + } + + auto tmp_tensor = id2owned_tensor_[tmp_tensor_name]; + id2owned_tensor_.erase(tmp_tensor_name); + tmp_tensor->set_name(new_tensor_name); + id2owned_tensor_[new_tensor_name] = tmp_tensor; +} + +size_t PatternGraph::CountOfOpCalls() const { return owned_op_call_.size(); } + +OpCall *SourcePatternGraph::AnchorNode() const { + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + return output_op_candidate; + } + IR_THROW("Unable to find a valid anchor"); +} + +std::unordered_set SourcePatternGraph::OutputNodes() const { + std::unordered_set output_op_set; + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + output_op_set.insert(output_op_candidate); + } + return output_op_set; +} + +void ResultPatternGraph::AssignTensor(const Tensor &from, const Tensor &to) { + if (to.producer() == nullptr) { + input_tensors_.insert(to.name()); + } + output_tensors_.erase(to.name()); + IR_ENFORCE(output_tensors_.count(from.name()) == 1, + "The Tensor (%s) which be assigned must be the output of result " + "pattern graph.", + from.name()); + tensor_assign_map_[from.name()] = to.name(); +} + +void GraphTopo::WalkGraphNodesTopoOrder( + const std::function &VisitNode) const { + // graph data + const std::unordered_set &inputs_tensor = + graph_->input_tensors(); + const std::unordered_map> + &id2owned_tensor = graph_->id2owend_tensor(); + const std::vector> &owend_opcall = + graph_->owned_op_call(); + + std::queue opcall_queue; + std::unordered_map> + opcall_dependent; + + // init opcall_dependent + for (const std::shared_ptr &opcall_sptr : owend_opcall) { + if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty + opcall_queue.push(opcall_sptr.get()); + } else { + for (const auto &pre_depd_tensor : opcall_sptr.get()->inputs()) { + opcall_dependent[opcall_sptr.get()].insert(pre_depd_tensor->name()); + } + } + } + + // init queue + for (const auto &tensor_name : inputs_tensor) { + IR_ENFORCE(id2owned_tensor.count(tensor_name), + "Drr input tensor [%s] must exists in pattern graph.", + tensor_name); + for (const auto &tensor_comsumer : + id2owned_tensor.at(tensor_name).get()->consumers()) { + opcall_dependent[tensor_comsumer].erase(tensor_name); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + + while (!opcall_queue.empty()) { + const OpCall *opcall = opcall_queue.front(); + opcall_queue.pop(); + VisitNode(*opcall); + + // update opcall_dependent + for (const auto &output_tensor : opcall->outputs()) { + for (const auto &tensor_comsumer : output_tensor->consumers()) { + opcall_dependent[tensor_comsumer].erase(output_tensor->name()); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + } +} + +std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { + os << "\nAll Tensors:\n"; + for (const auto &kv : pattern_graph.id2owend_tensor()) { + os << " " << kv.first; + } + os << "\n\n"; + + os << "Input Tensors:\n"; + for (const auto &tensor_name : pattern_graph.input_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + os << "Output Tensors:\n"; + for (const auto &tensor_name : pattern_graph.output_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + for (const auto &op_call : pattern_graph.owned_op_call()) { + os << " " << op_call->name() << " : "; + os << "inputs[ "; + for (const auto *input : op_call->inputs()) { + os << input->name() << " "; + } + os << "], "; + + os << "outputs[ "; + for (const auto &output : op_call->outputs()) { + os << output->name() << " "; + } + os << "]\n"; + } + os << "\n"; + return os; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.h b/paddle/fluid/pir/drr/pattern_graph.h new file mode 100644 index 00000000000000..63bd60eadf17f3 --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.h @@ -0,0 +1,108 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pir { +namespace drr { + +class Constraint; +class MatchContext; +class OpCall; +class Tensor; + +class PatternGraph { + public: + virtual ~PatternGraph() {} + + const drr::OpCall& AddOpCall(const std::shared_ptr& op_call); + + drr::Tensor& AddTensor(const std::shared_ptr& tensor); + + drr::Tensor& AddTmpTensor(const std::shared_ptr& tensor); + + void UpdateTmpTensor(const std::string& tmp_tensor_name, + const std::string& new_tensor_name); + + const std::unordered_set& input_tensors() const { + return input_tensors_; + } + + const std::unordered_set& output_tensors() const { + return output_tensors_; + } + + size_t CountOfOpCalls() const; + + const std::vector>& owned_op_call() const { + return owned_op_call_; + } + + const std::unordered_map>& + id2owend_tensor() const { + return id2owned_tensor_; + } + + protected: + std::unordered_map> id2owned_tensor_; + std::vector> owned_op_call_; + std::unordered_set input_tensors_; + std::unordered_set output_tensors_; +}; + +std::ostream& operator<<(std::ostream& os, const PatternGraph& pattern_graph); + +class SourcePatternGraph : public PatternGraph { + public: + OpCall* AnchorNode() const; + + std::unordered_set OutputNodes() const; + + private: + friend class DrrPatternContext; +}; + +class ResultPatternGraph : public PatternGraph { + public: + void AssignTensor(const Tensor& from, const Tensor& to); + + const std::unordered_map& tensor_assign_map() + const { + return tensor_assign_map_; + } + + private: + std::unordered_map tensor_assign_map_; +}; + +class GraphTopo { + public: + explicit GraphTopo(const PatternGraph* graph) : graph_(graph) {} + + void WalkGraphNodesTopoOrder( + const std::function& VisitNode) const; + + private: + const PatternGraph* graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/pir/pass/ir_printing.cc b/paddle/pir/pass/ir_printing.cc index 6171b71c090fcf..901c8bdd89da78 100644 --- a/paddle/pir/pass/ir_printing.cc +++ b/paddle/pir/pass/ir_printing.cc @@ -31,12 +31,8 @@ void PrintIR(Operation *op, bool print_module, std::ostream &os) { return; } - // Find the top-level operation. - auto *top_op = op; - while (auto *parent_op = top_op->GetParentOp()) { - top_op = parent_op; - } - top_op->Print(os); + auto *program = op->GetParentProgram(); + program->Print(os); } } // namespace diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index 00d6cb2f4d3064..ff75f86d6da55a 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,6 +131,7 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter { for (uint32_t i = 0; i < op->num_operands(); ++i) { AddOperandToWorklist(op->operand_source(i)); } + if (op->num_regions() == 0) { RemoveFromWorklist(op); } else { diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 7edd32531be34d..a4c4de5419928d 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -8,3 +8,44 @@ endif() cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ${PATTERN_REWRITE_TEST_DEPS}) + +cc_test_old( + drr_test + SRCS + drr_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_fuse_linear_test + SRCS + drr_fuse_linear_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_same_type_binding_test + SRCS + drr_same_type_binding_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_attention_fuse_test + SRCS + drr_attention_fuse_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) + +set_tests_properties( + pattern_rewrite_test PROPERTIES ENVIRONMENT + "FLAGS_enable_new_ir_in_executor=true") diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc new file mode 100644 index 00000000000000..22252e52beb394 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -0,0 +1,380 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +class MultiHeadMatmulFusePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + pir::drr::SourcePattern src = ctx->SourcePattern(); + // The first path to matmul with scale (q). + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = + matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); + const auto &full_int_array_1 = + src.Op("pd_op.full_int_array", + {{"value", src.Attr("full_int_array_1_value")}}); + const auto &reshape_1 = src.Op("pd_op.reshape"); + reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, + {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); + const auto &transpose_1 = src.Op("pd_op.transpose"); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); + const auto &full_1 = + src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); + const auto &scale = src.Op("pd_op.scale"); + src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); + + // The second path to matmul (k). + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_transpose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); + const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); + const auto &reshape_2 = src.Op("pd_op.reshape"); + reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, + {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); + const auto &transpose_2 = src.Op("pd_op.transpose"); + src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); + + // The third path to matmul (v). + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); + const auto &add_3 = src.Op("pd_op.add"); + src.Tensor("add_3_out") = + add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); + const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); + const auto &reshape_3 = src.Op("pd_op.reshape"); + reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, + {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); + const auto &transpose_3 = src.Op("pd_op.transpose"); + src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); + + // softmax(qk)v + const auto &matmul_4 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_4_transpose_x")}, + {"transpose_y", src.Attr("matmul_4_transpose_y")}}); + src.Tensor("matmul_4_out") = + matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); + const auto &add_4 = src.Op("pd_op.add"); + src.Tensor("add_4_out") = + add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + const auto &matmul_5 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_5_transpose_x")}, + {"transpose_y", src.Attr("matmul_5_transpose_y")}}); + src.Tensor("matmul_5_out") = + matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); + const auto &transpose_4 = src.Op("pd_op.transpose"); + src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); + const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); + const auto &reshape_4 = src.Op("pd_op.reshape"); + reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, + {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); + + // + // Constraints. + // + src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool matmul_1_transpose_x = match_ctx.Attr("matmul_1_transpose_x"); + bool matmul_1_transpose_y = match_ctx.Attr("matmul_1_transpose_y"); + if (matmul_1_transpose_x || matmul_1_transpose_y) return false; + + bool matmul_2_transpose_x = match_ctx.Attr("matmul_2_transpose_x"); + bool matmul_2_transpose_y = match_ctx.Attr("matmul_2_transpose_y"); + if (matmul_2_transpose_x || matmul_2_transpose_y) return false; + + bool matmul_3_transpose_x = match_ctx.Attr("matmul_3_transpose_x"); + bool matmul_3_transpose_y = match_ctx.Attr("matmul_3_transpose_y"); + if (matmul_3_transpose_x || matmul_3_transpose_y) return false; + + bool matmul_4_transpose_x = match_ctx.Attr("matmul_4_transpose_x"); + bool matmul_4_transpose_y = match_ctx.Attr("matmul_4_transpose_y"); + if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; + + bool matmul_5_transpose_x = match_ctx.Attr("matmul_5_transpose_x"); + bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); + if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + + return true; + }); + + // + // Result Pattern. + // + pir::drr::ResultPattern res = src.ResultPattern(); + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("matmul_1_in_2"), + &res.Tensor("matmul_2_in_2"), + &res.Tensor("matmul_3_in_2")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_axis = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + const auto &reshape_5_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + return {-1, 3, matmul_1_in_2.at(1)}; + }); + const auto &reshape_5 = + res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); + reshape_5({&res.Tensor("concat_1_out")}, + {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); + + // Bias combine. + const auto &combine_2 = res.Op("builtin.combine"); + combine_2({&res.Tensor("add_1_in_2"), + &res.Tensor("add_2_in_2"), + &res.Tensor("add_3_in_2")}, + {&res.Tensor("combine_2_out")}); + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); + const auto &reshape_6_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + return {3, -1}; + }); + const auto &reshape_6 = + res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); + reshape_6({&res.Tensor("concat_2_out")}, + {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + + const auto &head_number = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + const auto &full_int_array_1_value = + match_ctx.Attr>("full_int_array_1_value"); + return full_int_array_1_value.at(2); + }); + const auto &alpha = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr("full_1_value"); + }); + const auto &multihead_matmul = res.Op( + "pd_op.multihead_matmul", + {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); + multihead_matmul({&res.Tensor("matmul_1_in_1"), + &res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("add_4_in_2")}, + {&res.Tensor("reshape_4_out")}); + } +}; + +class AttentionFusePass : public pir::Pass { + public: + AttentionFusePass() : pir::Pass("AttentionFusePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(MultiHeadMatmulFusePattern().Build(context)); + // Add other attention variant fuse pattern. + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +namespace pir { +std::unique_ptr CreateAttentionFusePass() { + return std::make_unique(); +} +} // namespace pir + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp matmul_1_in_1 = + builder.Build(std::vector{1, 300, 256}, + 0.9, + phi::DataType::FLOAT32, + phi::CPUPlace()); + // The first path to matmul with scale (q). + paddle::dialect::FullOp matmul_1_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_1 = builder.Build( + matmul_1_in_1.out(), matmul_1_in_2.out(), false, false); + + paddle::dialect::FullOp add_1_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_1 = + builder.Build(matmul_1.out(), add_1_in_2.out()); + + paddle::dialect::ReshapeOp reshape_1 = + builder.Build( + add_1.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_1 = + builder.Build(reshape_1.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ScaleOp scale_op = builder.Build( + transpose_1.out(), 0.1767766922712326, 0.0, true); + + // The second path to matmul (k). + paddle::dialect::FullOp matmul_2_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_2 = builder.Build( + matmul_1_in_1.out(), matmul_2_in_2.out(), false, false); + + paddle::dialect::FullOp add_2_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::AddOp add_op2 = + builder.Build(matmul_2.out(), add_2_in_2.out()); + + paddle::dialect::ReshapeOp reshape_2 = + builder.Build( + add_op2.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_2 = + builder.Build(reshape_2.out(), + std::vector{0, 2, 1, 3}); + + // The third path to matmul (v). + paddle::dialect::FullOp matmul_3_in_2 = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::MatmulOp matmul_3 = builder.Build( + matmul_1_in_1.out(), matmul_3_in_2.out(), false, false); + + paddle::dialect::FullOp add_3_in_2 = builder.Build( + std::vector{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_3 = + builder.Build(matmul_3.out(), add_3_in_2.out()); + + paddle::dialect::ReshapeOp reshape_3 = + builder.Build( + add_3.out(), std::vector{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_3 = + builder.Build(reshape_3.out(), + std::vector{0, 2, 1, 3}); + + // softmax(qk)v + paddle::dialect::MatmulOp matmul_4 = builder.Build( + scale_op.out(), transpose_2.out(), false, true); + + paddle::dialect::FullOp add_4_in_2 = builder.Build( + std::vector{1, 8, 300, 300}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_4 = + builder.Build(matmul_4.out(), add_4_in_2.out()); + + paddle::dialect::SoftmaxOp softmax_op = + builder.Build(add_4.out(), -1); + paddle::dialect::MatmulOp matmul_5 = builder.Build( + softmax_op.out(), transpose_3.out(), false, false); + + paddle::dialect::TransposeOp transpose_4 = + builder.Build(matmul_5.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ReshapeOp reshape_4 = + builder.Build( + transpose_4.out(), std::vector{0, 0, 256}); + + builder.Build(reshape_4.out(), "out", 0); +} + +TEST(DrrTest, AttentionFuse) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + EXPECT_EQ(program.block()->size(), 33u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateAttentionFusePass()); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 20u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc new file mode 100644 index 00000000000000..ac28f535785958 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc @@ -0,0 +1,399 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +class FusedLinearPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op("pd_op.matmul", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op("pd_op.add"); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + } +}; + +class FusedLinearGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op("pd_op.matmul", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &matmul_grad = pat.Op("pd_op.matmul_grad", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op("pd_op.add"); + const auto &add_grad = pat.Op("pd_op.add_grad"); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + add_grad({&pat.Tensor("tmp"), &pat.Tensor("bias"), &pat.Tensor("out_grad")}, + {&pat.Tensor("tmp_grad"), &pat.Tensor("bias_grad")}); + matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("tmp_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + const auto &fused_gemm_epilogue_grad = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation_grad", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + fused_gemm_epilogue_grad({&res.Tensor("x"), + &res.Tensor("w"), + &res.NoneTensor(), + &res.Tensor("out_grad")}, + {&res.Tensor("x_grad"), + &res.Tensor("w_grad"), + &res.Tensor("bias_grad")}); + } +}; + +class FusedLinearGeluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + // TODO(gst): don't have reserve_space + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = pat.Op("pd_op.gelu")(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("gelu_dx") = pat.Op("pd_op.gelu_grad")(pat.Tensor("fuse_out"), + pat.Tensor("x1_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act2") == "none"; + }); + + // Result patterns:要替换为的子图 + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu_grad"; + }); + const auto &fused_gemm_epilogue_grad_new = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space2")}); + fused_gemm_epilogue_grad_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space2"), + &res.Tensor("out_grad")}, + {&res.Tensor("gelu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedLinearReluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x3")}, + {"trans_y", pat.Attr("trans_y3")}, + {"activation_grad", pat.Attr("act3")}}}); + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = pat.Op("pd_op.relu")(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space2"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("relu_dx") = + pat.Op("pd_op.relu_grad")(pat.Tensor("x1"), pat.Tensor("x1_grad")); + fused_gemm_epilogue_grad({&pat.Tensor("x"), + &pat.Tensor("w"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("relu_dx")}, + {&pat.Tensor("x_grad"), + &pat.Tensor("w_grad"), + &pat.Tensor("bias_grad")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act3") == "none"; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu_grad"; + }); + const auto &fused_gemm_epilogue_grad1_new = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space3")}); + fused_gemm_epilogue_grad1_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space3"), + &res.Tensor("out_grad")}, + {&res.Tensor("relu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedLinearPass : public pir::Pass { + public: + FusedLinearPass() : pir::Pass("FusedLinearPass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedLinearGradPattern().Build(context)); + ps.Add(FusedLinearPattern().Build(context)); + ps.Add(FusedLinearGeluGradPattern().Build(context)); + ps.Add(FusedLinearReluGradPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{1, 512, 64}, + 1.5); + // linear 1 + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + // linear 2 + paddle::dialect::FullOp full_weight_op2 = + builder.Build(std::vector{64, 128}, + 1.5); + paddle::dialect::FullOp full_bias_op2 = + builder.Build(std::vector{128}, 1.0); + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(add_op1.out(), + full_weight_op2.out()); + paddle::dialect::AddOp add_op2 = builder.Build( + matmul_op2.out(), full_bias_op2.out()); + paddle::dialect::ReluOp relu_op = + builder.Build(add_op2.out()); + // linear 3 + paddle::dialect::FullOp full_weight_op3 = + builder.Build(std::vector{128, 64}, + 1.5); + paddle::dialect::FullOp full_bias_op3 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(relu_op.out(), + full_weight_op3.out()); + paddle::dialect::AddOp add_op3 = builder.Build( + matmul_op3.out(), full_bias_op3.out()); + paddle::dialect::GeluOp gelu_op1 = + builder.Build(add_op3.out()); + // linear 4 + paddle::dialect::FullOp full_weight_op4 = + builder.Build(std::vector{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op4 = + builder.Build(std::vector{64}, 1.0); + paddle::dialect::MatmulOp matmul_op4 = + builder.Build(gelu_op1.out(), + full_weight_op4.out()); + paddle::dialect::AddOp add_op4 = builder.Build( + matmul_op4.out(), full_bias_op4.out()); + paddle::dialect::GeluOp gelu_op2 = + builder.Build(add_op4.out()); + + // backward + paddle::dialect::FullOp full_grad_op = builder.Build( + std::vector{1, 512, 64}, 1.0); + + paddle::dialect::GeluGradOp gelu_op2_grad = + builder.Build( + add_op4.out(), full_grad_op.out(), false); + // backward linear 4 + paddle::dialect::AddGradOp add_op4_grad = + builder.Build( + matmul_op4.out(), full_bias_op4.out(), gelu_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op4_grad = + builder.Build( + gelu_op1.out(), full_weight_op4.out(), add_op4_grad.x_grad()); + + paddle::dialect::GeluGradOp gelu_op1_grad = + builder.Build( + add_op3.out(), matmul_op4_grad.x_grad(), false); + // backward linear 3 + paddle::dialect::AddGradOp add_op3_grad = + builder.Build( + matmul_op3.out(), full_bias_op3.out(), gelu_op1_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op3_grad = + builder.Build( + relu_op.out(), full_weight_op3.out(), add_op3_grad.x_grad()); + + paddle::dialect::ReluGradOp relu_op_grad = + builder.Build(relu_op.out(), + matmul_op3_grad.x_grad()); + // backward linear 2 + paddle::dialect::AddGradOp add_op2_grad = + builder.Build( + matmul_op2.out(), full_bias_op2.out(), relu_op_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op2_grad = + builder.Build( + add_op1.out(), full_weight_op2.out(), add_op2_grad.x_grad()); + // backward linear 1 + paddle::dialect::AddGradOp add_op1_grad = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), matmul_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op1_grad = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), add_op1_grad.x_grad()); + + builder.Build(gelu_op2.out(), "out", 0); + builder.Build(matmul_op1_grad.x_grad(), "dx", 1); +} + +TEST(DrrTest, FusedLinear) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 34u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + // pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 22u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc new file mode 100644 index 00000000000000..cb4c6e4b0b92f6 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" + +/* Source pattern: + input1 + / | \ \ \ + / | \ \ \ + full / | | \ \ full_tmp + / | transpos1 | trans2 trans3 \ / | + / | / | | | | \ / | + softmax1 | / | | | | \ / | + \ | / softmax2 | | | add1 | + \ | / \ | \ / | | + layernorm matmul2 matmul1 \ | + / | \ | | \ | + / | \ \ / \ | + / | \ matmul3 add2 + | | | / | \ | + | | | / | \ | + | | | / | \ | + | | | trans4 trans5 trans6 | + | | | | | | | + | | | relu1 softmax3 softmax4 relu2 + | | | | | | | + output0 output1 output2 output3 output4 output5 output6 +*/ + +class SameTypeBindingTestPattern + // This class is for test cases of the same type of OP. + // (without considering the computational logic between OPs, + // only focusing on the process of matching and replacing) + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern src = ctx->SourcePattern(); + + // path 1 + const auto &transpose_1 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_1")}}); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("input_1")); + const auto &softmax_2 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_2_axis")}}); + src.Tensor("softmax_2_out") = softmax_2(src.Tensor("transpose_1_out")); + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_tradnspose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("softmax_2_out"), src.Tensor("input_1")); + + // path 2 + const auto &full_1 = src.Op("pd_op.full", + {{"shape", src.Attr("shape_1")}, + {"value", src.Attr("value_1")}, + {"dtype", src.Attr("dtype_1")}, + {"place", src.Attr("place_1")}}); + src.Tensor("full_1_out") = full_1(); + const auto &softmax_1 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_1_axis")}}); + src.Tensor("softmax_1_out") = softmax_1(src.Tensor("full_1_out")); + const auto &layernorm_1 = + src.Op("pd_op.layer_norm", + {{"epsilon", src.Attr("layernorm_epsilon")}, + {"begin_norm_axis", src.Attr("layernorm_begin_norm_axis")}}); + layernorm_1({&src.Tensor("transpose_1_out"), + &src.Tensor("full_1_out"), + &src.Tensor("softmax_1_out")}, + {&src.Tensor("output0"), + &src.Tensor("output1"), + &src.Tensor("output2")}); + + // path 3 + const auto &transpose_2 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_2")}}); + const auto &transpose_3 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_3")}}); + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = matmul_1(transpose_2(src.Tensor("input_1")), + transpose_3(src.Tensor("input_1"))); + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_2_out"), src.Tensor("matmul_1_out")); + const auto &transpose_4 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + const auto &transpose_5 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + const auto &transpose_6 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_6")}}); + const auto &relu_1 = src.Op("pd_op.relu"); + const auto &softmax_3 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_3_axis")}}); + const auto &softmax_4 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_4_axis")}}); + src.Tensor("output3") = relu_1(transpose_4(src.Tensor("matmul_3_out"))); + src.Tensor("output4") = softmax_3(transpose_5(src.Tensor("matmul_3_out"))); + src.Tensor("output5") = softmax_4(transpose_6(src.Tensor("matmul_3_out"))); + + // path 4 + const auto &full_tmp = src.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + src.Tensor("full_tmp_out") = full_tmp(); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("input_1"), src.Tensor("full_tmp_out")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("add_1_out"), src.Tensor("full_tmp_out")); + const auto &relu_2 = src.Op("pd_op.relu"); + src.Tensor("output6") = relu_2(src.Tensor("add_2_out")); + + pir::drr::ResultPattern res = src.ResultPattern(); + const auto &transpose_7 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + res.Tensor("output0") = transpose_7(res.Tensor("input_1")); + const auto &transpose_8 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + res.Tensor("output1") = transpose_8(res.Tensor("input_1")); + const auto &full_2 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_3 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_4 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_5 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_6 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + res.Tensor("output2") = full_2(); + res.Tensor("output3") = full_3(); + res.Tensor("output4") = full_4(); + res.Tensor("output5") = full_5(); + res.Tensor("output6") = full_6(); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + // path 1 + paddle::dialect::TransposeOp transpose_op1 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op2 = + builder.Build(transpose_op1.out(), -1); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(softmax_op2.out(), + full_input_op1.out()); + + // path 2 + paddle::dialect::FullOp full_op_scale = + builder.Build(std::vector{48}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::SoftmaxOp softmax_op_bias = + builder.Build(full_op_scale.out(), -1); + paddle::dialect::LayerNormOp layernorm_op1 = + builder.Build( + transpose_op1.out(), full_op_scale.out(), softmax_op_bias.out()); + + // path 3 + paddle::dialect::TransposeOp transpose_op2 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::TransposeOp transpose_op3 = + builder.Build(full_input_op1.out(), + std::vector{0, 1, 2}); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(transpose_op2.out(), + transpose_op3.out()); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(matmul_op2.out(), + matmul_op1.out()); + + paddle::dialect::TransposeOp transpose_op4 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::ReluOp relu_op1 = + builder.Build(transpose_op4.out()); + + paddle::dialect::TransposeOp transpose_op5 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op3 = + builder.Build(transpose_op5.out(), -1); + + paddle::dialect::TransposeOp transpose_op6 = + builder.Build(matmul_op3.out(), + std::vector{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op4 = + builder.Build(transpose_op6.out(), -1); + + // path 4 + paddle::dialect::FullOp full_input_op2 = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_op1 = builder.Build( + full_input_op1.out(), full_input_op2.out()); + + paddle::dialect::AddOp add_op2 = builder.Build( + add_op1.out(), full_input_op2.out()); + + paddle::dialect::ReluOp relu_op2 = + builder.Build(add_op2.out()); + + // tail + paddle::dialect::MatmulOp matmul_op4 = + builder.Build(layernorm_op1.variance(), + layernorm_op1.mean()); + + paddle::dialect::MatmulOp matmul_op5 = + builder.Build(relu_op1.out(), + softmax_op3.out()); + + paddle::dialect::MatmulOp matmul_op6 = + builder.Build(softmax_op4.out(), + relu_op2.out()); + + builder.Build(matmul_op4.out(), "out1", 0); + builder.Build(matmul_op5.out(), "out2", 1); + builder.Build(matmul_op6.out(), "out3", 2); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(SameTypeBindingTestPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 27u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 13u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc new file mode 100644 index 00000000000000..f607fa5a083260 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -0,0 +1,232 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" + +class RemoveRedundentReshapePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source patterns + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &reshape1 = pat.Op("pd_op.reshape"); + const auto &reshape2 = pat.Op("pd_op.reshape"); + + reshape1({&pat.Tensor("arg0"), &pat.Tensor("shape0")}, + {&pat.Tensor("out1"), &pat.Tensor("xshape_0")}); + reshape2({&pat.Tensor("out1"), &pat.Tensor("shape1")}, + {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, + {&res.Tensor("ret"), &res.Tensor("xshape_1")}); + } +}; + +class FoldExpandToConstantPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full1 = pat.Op("pd_op.full", + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op("pd_op.full_int_array", + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &expand = pat.Op("pd_op.expand"); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &full2 = res.Op("pd_op.full", + {{"shape", pat.Attr("expand_shape_value")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + res.Tensor("ret") = full2(); + } +}; + +class RemoveRedundentTransposePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &transpose1 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); + const auto &transpose2 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_2")}}); + + pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + const auto &perm1 = match_ctx.Attr>("perm_1"); + const auto &perm2 = match_ctx.Attr>("perm_2"); + std::vector new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); + const auto &tranpose_continuous = + res.Op("pd_op.transpose", {{"perm", new_perm_attr}}); + + res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); + } +}; + +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase { + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("tmp") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); + pat.Tensor("ret") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + auto res = pat.ResultPattern(); + res.Tensor("ret") = res.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; + +class RemoveUselessCastPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); + pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); + auto res = pat.ResultPattern(); + res.Tensor("ret").Assign(res.Tensor("arg0")); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build( + std::vector{4, 3, 16, 16}, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::ExpandOp expand_op = + builder.Build(full_input_op.out(), + full_int_array_op.out()); + + paddle::dialect::ReshapeOp reshape_op1 = + builder.Build( + expand_op.out(), std::vector{16, 3, 4, 16}); + + paddle::dialect::ReshapeOp reshape_op2 = + builder.Build( + reshape_op1.out(), std::vector{16, 3, 4, 16}); + + paddle::dialect::ReluOp relu_op = + builder.Build(reshape_op2.out()); + + paddle::dialect::CastOp cast_op1 = builder.Build( + relu_op.out(), phi::DataType::FLOAT64); + + paddle::dialect::CastOp cast_op2 = builder.Build( + cast_op1.out(), phi::DataType::FLOAT32); + + paddle::dialect::TransposeOp transpose_op1 = + builder.Build(cast_op2.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::TransposeOp transpose_op2 = + builder.Build(transpose_op1.out(), + std::vector{1, 0, 2, 3}); + + paddle::dialect::ReluOp relu_op_second = + builder.Build(transpose_op2.out()); + + builder.Build(relu_op_second.out(), "out", 0); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(RemoveRedundentReshapePattern().Build(context)); + ps.Add(RemoveRedundentTransposePattern().Build(context)); + ps.Add(RemoveRedundentCastPattern().Build(context)); + ps.Add(RemoveUselessCastPattern().Build(context)); + ps.Add(FoldExpandToConstantPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 14u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 7u); +} diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 18644c08e21b7a..1499ba161bb09d 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -1111,9 +1111,12 @@ void BuildProgram(pir::Builder &builder) { // NOLINT // TODO(wilber): Add a normal test. TEST(pattern_rewrite, Patterns) { pir::IrContext *ctx = pir::IrContext::Instance(); + + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); auto *test_dialect = ctx->GetOrRegisterDialect(); test_dialect->RegisterOp(); - ctx->GetOrRegisterDialect(); + pir::Program program(ctx); pir::Builder builder = pir::Builder(ctx, program.block()); BuildProgram(builder); @@ -1122,7 +1125,7 @@ TEST(pattern_rewrite, Patterns) { pir::PassManager pm(ctx); pm.AddPass(std::make_unique()); - // pm.AddPass(ir::CreateConstantFoldingPass()); + // pm.AddPass(pir::CreateConstantFoldingPass()); pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.AddPass(pir::CreateReorderBlockOpsPass()); pm.EnablePassTiming();