diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 8261fee1517f6..a8c73c3218398 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -112,7 +112,6 @@ #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" #include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" @@ -800,7 +799,6 @@ bool AnalysisPredictor::PrepareExecutor() { gpu_pm.AddPass(::pir::CreateConv2dBnFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddActFusePass()); gpu_pm.AddPass(::pir::CreateConv2dAddFusePass()); - gpu_pm.AddPass(::pir::CreateFcWithSpecialOpFusePass()); gpu_pm.AddPass(::pir::CreateFcFusePass()); gpu_pm.AddPass(::pir::CreateFcElementwiseLayerNormFusePass()); gpu_pm.AddPass(::pir::CreateMatmulScaleFusePass()); diff --git a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc deleted file mode 100644 index 74dd21a0828fe..0000000000000 --- a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h" - -#include "paddle/common/ddim.h" - -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" - -#include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" -#include "paddle/pir/pass/pass.h" -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" - -namespace { - -class SqueezeFcFusePattern - : public paddle::drr::DrrPatternBase { - public: - void operator()(paddle::drr::DrrPatternContext *ctx) const override { - paddle::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &squeeze_op = pat.Op(paddle::dialect::SqueezeOp::name()); - const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), - {{"transpose_x", pat.Attr("transpose_x")}, - {"transpose_y", pat.Attr("transpose_y")}}); - const auto &add = pat.Op(paddle::dialect::AddOp::name()); - squeeze_op({&pat.Tensor("x"), &pat.Tensor("axis")}, - {&pat.Tensor("squeeze_out"), &pat.Tensor("xshape")}); - matmul({&pat.Tensor("squeeze_out"), &pat.Tensor("w")}, - {&pat.Tensor("matmul_out")}); - pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); - // Constrains the activation is none - pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - auto axis_type = match_ctx.Tensor("axis").Dtype().get(); - if (axis_type.isa() && - axis_type.dyn_cast().size() != 2) { - return false; - } - - if (!axis_type.isa() && - match_ctx.Tensor("axis").Shape().size() > 0 && - match_ctx.Tensor("axis").Shape().at(0) != 2) { - return false; - } - - if (match_ctx.Tensor("x").Shape().size() != 4 || - match_ctx.Tensor("x").Shape().at(2) != 1 || - match_ctx.Tensor("x").Shape().at(3) != 1 || - match_ctx.Attr("transpose_x") == true || - match_ctx.Attr("transpose_y") == true) { - return false; - } - - if (match_ctx.Tensor("w").Shape().size() != 2 || - match_ctx.Tensor("squeeze_out").Shape().size() != 2) { - return false; - } - if (match_ctx.Tensor("squeeze_out").Shape().at(1) != - match_ctx.Tensor("w").Shape().at(0)) { - return false; - } - if (match_ctx.Tensor("bias").Shape().size() == 1) { - return match_ctx.Tensor("bias").Shape().at(0) == - match_ctx.Tensor("w").Shape().at(1); - } - if (match_ctx.Tensor("bias").Shape().size() == 2) { - return match_ctx.Tensor("bias").Shape().at(0) == 1 && - match_ctx.Tensor("bias").Shape().at(1) == - match_ctx.Tensor("w").Shape().at(1); - } - return false; - }); - - paddle::drr::ResultPattern res = pat.ResultPattern(); - - const auto &in_num_col_dims_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return 1; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fc = - res.Op(paddle::dialect::FcOp::name(), - {{ - {"in_num_col_dims", in_num_col_dims_attr}, - {"activation_type", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return ""; })}, - {"padding_weights", false_attr}, - }}); - fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, - {&res.Tensor("add_out")}); - } -}; - -class ReshapeFcFusePattern - : public paddle::drr::DrrPatternBase { - public: - void operator()(paddle::drr::DrrPatternContext *ctx) const override { - paddle::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &reshape_op = pat.Op(paddle::dialect::ReshapeOp::name()); - const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), - {{"transpose_x", pat.Attr("transpose_x")}, - {"transpose_y", pat.Attr("transpose_y")}}); - const auto &add = pat.Op(paddle::dialect::AddOp::name()); - reshape_op({&pat.Tensor("x"), &pat.Tensor("shape")}, - {&pat.Tensor("reshape_out"), &pat.Tensor("xshape")}); - matmul({&pat.Tensor("reshape_out"), &pat.Tensor("w")}, - {&pat.Tensor("matmul_out")}); - add({&pat.Tensor("matmul_out"), &pat.Tensor("bias")}, - {&pat.Tensor("add_out")}); - pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - if (match_ctx.Tensor("w").Shape().size() != 2 || - match_ctx.Attr("transpose_x") == true || - match_ctx.Attr("transpose_y") == true) { - return false; - } - if (match_ctx.Tensor("reshape_out").Shape().size() < 2 || - (match_ctx.Tensor("reshape_out").Shape().size() > 0 && - match_ctx.Tensor("reshape_out") - .Shape() - .at(match_ctx.Tensor("reshape_out").Shape().size() - 1) != - match_ctx.Tensor("w").Shape().at(0))) { - return false; - } - - if (match_ctx.Tensor("bias").Shape().size() == 1 && - match_ctx.Tensor("bias").Shape().at(0) != - match_ctx.Tensor("w").Shape().at(1)) { - return false; - } - if (match_ctx.Tensor("bias").Shape().size() == 2 && - (match_ctx.Tensor("bias").Shape().at(0) != 1 || - match_ctx.Tensor("bias").Shape().at(1) != - match_ctx.Tensor("w").Shape().at(1))) { - return false; - } - - if (match_ctx.Tensor("x").Shape().size() < - match_ctx.Tensor("reshape_out").Shape().size()) { - return false; - } - int i = match_ctx.Tensor("x").Shape().size() - 1; - int j = match_ctx.Tensor("reshape_out").Shape().size() - 1; - int target = match_ctx.Tensor("reshape_out").Shape().at(j); - int mul = match_ctx.Tensor("x").Shape().at(i); - - if (mul > target) { - return false; - } - /* - reshape_in:[2,12,12,128] - reshape_out:[2,144,128] - shape:[2,144,128] - n = len(reshape_in) - m = len(reshape_out) - - request: - 1. reshape_in[i:] = reshape_in[i]*reshape_in[i+1]...*reshape_in[n-1] - reshape_in[i:]=reshape_out=[j] - 2.reshape_in[:i]=reshape_out[:j] - e.g.: - 288(reshape_out[:2]) - i shape[2,144,128] | | j(invariable) - [2,12,12,128]------------------->[2,144,128] - | | | | - | | |reshape_in[i]=reshape_out[j] | - \ / - 288(reshape_in[:3]) - - - */ - while (target != mul) { - if (mul <= 0 || mul > target) { - return false; - } - i--; - mul *= match_ctx.Tensor("x").Shape().at(i); - } - - int mul1 = 1; - int mul2 = 1; - i--; - j--; - while (i >= 0 || j >= 0) { - if (i >= 0) { - mul1 *= match_ctx.Tensor("x").Shape().at(i); - i--; - } - if (j >= 0) { - mul2 *= match_ctx.Tensor("x").Shape().at(j); - j--; - } - } - if (mul1 != mul2) { - return false; - } - return true; - }); - paddle::drr::ResultPattern res = pat.ResultPattern(); - - const auto &in_num_col_dims_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - int i = match_ctx.Tensor("x").Shape().size() - 1; - int target = - match_ctx.Tensor("reshape_out") - .Shape() - .at(match_ctx.Tensor("reshape_out").Shape().size() - 1); - int mul = match_ctx.Tensor("x").Shape().at(i); - while (target != mul) { - i--; - mul *= match_ctx.Tensor("x").Shape().at(i); - } - return i; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fc = - res.Op(paddle::dialect::FcOp::name(), - {{ - {"in_num_col_dims", in_num_col_dims_attr}, - {"activation_type", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return ""; })}, - {"padding_weights", false_attr}, - }}); - fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, - {&res.Tensor("add_out")}); - } -}; - -class FlattenFcFusePattern - : public paddle::drr::DrrPatternBase { - public: - void operator()(paddle::drr::DrrPatternContext *ctx) const override { - paddle::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &flatten_op = pat.Op(paddle::dialect::FlattenOp::name(), - {{"start_axis", pat.Attr("start_axis")}, - {"stop_axis", pat.Attr("stop_axis")}}); - const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), - {{"transpose_x", pat.Attr("transpose_x")}, - {"transpose_y", pat.Attr("transpose_y")}}); - const auto &add = pat.Op(paddle::dialect::AddOp::name()); - flatten_op({&pat.Tensor("x")}, - {&pat.Tensor("flatten_out"), &pat.Tensor("xshape")}); - matmul({&pat.Tensor("flatten_out"), &pat.Tensor("w")}, - {&pat.Tensor("matmul_out")}); - pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); - // Constrains the activation is none - pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - bool flatten_flag = false; - - if (match_ctx.Tensor("x").Shape().size() == 4 && - match_ctx.Tensor("flatten_out").Shape().size() == 2 && - match_ctx.Attr("start_axis") == 1 && - match_ctx.Attr("stop_axis") == 3 && - match_ctx.Attr("transpose_x") == false && - match_ctx.Attr("transpose_y") == false) { - flatten_flag = true; - } - - if (match_ctx.Tensor("w").Shape().size() != 2 || - match_ctx.Tensor("flatten_out").Shape().size() != 2) { - return false; - } - if (match_ctx.Tensor("flatten_out").Shape().at(1) != - match_ctx.Tensor("w").Shape().at(0)) { - return false; - } - if (match_ctx.Tensor("bias").Shape().size() == 1) { - return flatten_flag && match_ctx.Tensor("bias").Shape().at(0) == - match_ctx.Tensor("w").Shape().at(1); - } - if (match_ctx.Tensor("bias").Shape().size() == 2) { - return flatten_flag && match_ctx.Tensor("bias").Shape().at(0) == 1 && - match_ctx.Tensor("bias").Shape().at(1) == - match_ctx.Tensor("w").Shape().at(1); - } - return false; - }); - - paddle::drr::ResultPattern res = pat.ResultPattern(); - - const auto &in_num_col_dims_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return 1; - }); - const auto &false_attr = - res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return false; - }); - - const auto &fc = - res.Op(paddle::dialect::FcOp::name(), - {{ - {"in_num_col_dims", in_num_col_dims_attr}, - {"activation_type", - res.Attr([](const paddle::drr::MatchContext &match_ctx) - -> std::string { return ""; })}, - {"padding_weights", false_attr}, - }}); - fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, - {&res.Tensor("add_out")}); - } -}; - -class FcWithSpecialOpFusePass : public pir::PatternRewritePass { - public: - FcWithSpecialOpFusePass() - : pir::PatternRewritePass("fc_with_special_op_fuse_pass", 2) {} - - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { - pir::RewritePatternSet ps(context); - ps.Add(SqueezeFcFusePattern().Build(context)); - ps.Add(ReshapeFcFusePattern().Build(context)); - ps.Add(FlattenFcFusePattern().Build(context)); - return ps; - } -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateFcWithSpecialOpFusePass() { - return std::make_unique(); -} - -} // namespace pir - -REGISTER_IR_PASS(fc_with_special_op_fuse_pass, FcWithSpecialOpFusePass); diff --git a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h deleted file mode 100644 index ff5b851e3623f..0000000000000 --- a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/pir/core/dll_decl.h" - -namespace pir { - -class Pass; - -IR_API std::unique_ptr CreateFcWithSpecialOpFusePass(); - -} // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 300ff3c7f7ccc..2ca3f1b703cbd 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -49,7 +49,6 @@ #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" #include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" @@ -122,7 +121,6 @@ USE_PIR_PASS(replace_fetch_with_shadow_output_pass); USE_PIR_PASS(identity_op_clean_pass); USE_PIR_PASS(matmul_scale_fuse_pass); USE_PIR_PASS(fc_fuse_pass); -USE_PIR_PASS(fc_with_special_op_fuse_pass); USE_PIR_PASS(fc_elementwise_layernorm_fuse_pass); USE_PIR_PASS(conv2d_bn_fuse_pass); USE_PIR_PASS(conv2d_add_fuse_pass); diff --git a/test/ir/pir/fused_pass/test_fc_with_special_op_fuse_pass.py b/test/ir/pir/fused_pass/test_fc_with_special_op_fuse_pass.py deleted file mode 100644 index f2c8a40f2d565..0000000000000 --- a/test/ir/pir/fused_pass/test_fc_with_special_op_fuse_pass.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from pass_test import PassTest - -import paddle -from paddle.base import core - -paddle.enable_static() - - -class TestSqueezeFcFusePattern(PassTest): - r""" - squeeze - \ - - Matmul Y - \ / - Add - | - Relu - """ - - def is_program_valid(self, program=None): - return True - - def sample_program(self): - for y_shape in [[128], [1, 128]]: - for w_shape in [[128, 128]]: - with paddle.pir_utils.IrGuard(): - start_prog = paddle.static.Program() - main_prog = paddle.static.Program() - with paddle.pir.core.program_guard(main_prog, start_prog): - x = paddle.static.data( - name='x', shape=[3, 128, 1, 1], dtype='float32' - ) - w = paddle.static.data( - name='w', shape=w_shape, dtype='float32' - ) - y = paddle.static.data( - name='y', shape=y_shape, dtype='float32' - ) - - out = paddle.add( - paddle.matmul(paddle.squeeze(x, [2, 3]), w), y - ) - out = paddle.assign(out) - self.pass_list = ['fc_with_special_op_fuse_pass'] - self.feeds = { - "x": np.random.random([3, 128, 1, 1]).astype( - "float32" - ), - "y": np.random.random(y_shape).astype("float32"), - "w": np.random.random(w_shape).astype("float32"), - } - self.fetch_list = [out] - self.valid_op_map = { - "pd_op.add": 0, - "pd_op.squeeze": 0, - "pd_op.matmul": 0, - "pd_op.fc": 1, - } - - yield [main_prog, start_prog], False - - def setUp(self): - self.places.append(paddle.CPUPlace()) - if core.is_compiled_with_cuda(): - self.places.append(paddle.CUDAPlace(0)) - - def test_check_output(self): - self.check_pass_correct() - - -class TestReshapeFcFusePattern(PassTest): - r""" - reshape - \ - - Matmul Y - \ / - Add - | - Relu - """ - - def is_program_valid(self, program=None): - return True - - def sample_program(self): - for y_shape in [[192], [1, 192]]: - with paddle.pir_utils.IrGuard(): - start_prog = paddle.static.Program() - main_prog = paddle.static.Program() - with paddle.pir.core.program_guard(main_prog, start_prog): - x = paddle.static.data( - name='x', shape=[3, 144, 6, 32], dtype='float32' - ) - w = paddle.static.data( - name='w', shape=[192, 192], dtype='float32' - ) - y = paddle.static.data( - name='y', shape=y_shape, dtype='float32' - ) - - out = paddle.add( - paddle.matmul(paddle.reshape(x, [3, 144, -1]), w), y - ) - out = paddle.assign(out) - self.pass_list = ['fc_with_special_op_fuse_pass'] - self.feeds = { - "x": np.random.random([3, 144, 6, 32]).astype( - "float32" - ), - "w": np.random.random([192, 192]).astype("float32"), - "y": np.random.random(y_shape).astype("float32"), - } - self.fetch_list = [out] - self.valid_op_map = { - "pd_op.add": 0, - "pd_op.reshape": 0, - "pd_op.matmul": 0, - "pd_op.fc": 1, - } - - yield [main_prog, start_prog], False - - def setUp(self): - self.places.append(paddle.CPUPlace()) - if core.is_compiled_with_cuda(): - self.places.append(paddle.CUDAPlace(0)) - - def test_check_output(self): - self.check_pass_correct() - - -class TestFlattenFcFusePattern(PassTest): - r""" - flatten - \ - - Matmul Y - \ / - Add - | - Relu - """ - - def is_program_valid(self, program=None): - return True - - def sample_program(self): - for y_shape in [[128], [1, 128]]: - with paddle.pir_utils.IrGuard(): - start_prog = paddle.static.Program() - main_prog = paddle.static.Program() - with paddle.pir.core.program_guard(main_prog, start_prog): - x = paddle.static.data( - name='x', shape=[3, 255, 1, 1], dtype='float32' - ) - w = paddle.static.data( - name='w', shape=[255, 128], dtype='float32' - ) - y = paddle.static.data( - name='y', shape=y_shape, dtype='float32' - ) - - out = paddle.add( - paddle.matmul(paddle.flatten(x, start_axis=1), w), y - ) - out = paddle.assign(out) - self.pass_list = ['fc_with_special_op_fuse_pass'] - self.feeds = { - "x": np.random.random([3, 255, 1, 1]).astype("float32"), - "w": np.random.random([255, 128]).astype("float32"), - "y": np.random.random(y_shape).astype("float32"), - } - self.fetch_list = [out] - self.valid_op_map = { - "pd_op.add": 0, - "pd_op.flatten": 0, - "pd_op.matmul": 0, - "pd_op.fc": 1, - } - - yield [main_prog, start_prog], False - - def setUp(self): - self.places.append(paddle.CPUPlace()) - if core.is_compiled_with_cuda(): - self.places.append(paddle.CUDAPlace(0)) - - def test_check_output(self): - self.check_pass_correct() - - -if __name__ == "__main__": - unittest.main()