From c775eaf484701b0d6b074719dd16974e8aa46409 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 11 Jul 2024 03:04:53 +0000 Subject: [PATCH 1/2] =?UTF-8?q?conv2d+relu+poold2d+add+softmax+reshape?= =?UTF-8?q?=E5=8D=95=E4=BE=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fused_pass/test_pir_trt_op_marker_pass.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py diff --git a/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py new file mode 100644 index 0000000000000..e43daed7ccdb9 --- /dev/null +++ b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle +from paddle.base import core +from paddle.pir.core import create_parameter + + +class TestTRTPattern(PassTest): + def is_program_valid(self, program): + return True + + def build_ir_program(self): + for bias_shape in [[1, 32, 1, 1], [32, 1, 1], [32]]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[3, 1, 28, 28], dtype='float32' + ) + conv2d = paddle.nn.Conv2D( + in_channels=1, + out_channels=32, + kernel_size=3, + padding=1, + data_format='NCHW', + bias_attr=False, + ) + + y = create_parameter( + name="y", + shape=bias_shape, + dtype='float32', + initializer=paddle.nn.initializer.Assign( + np.random.random(bias_shape).astype("float32") + ), + ) + act_op = paddle.nn.ReLU() + act_out = act_op(paddle.add(conv2d(x), y)) + pool2d = paddle.nn.MaxPool2D( + kernel_size=2, stride=2, padding=0 + ) + padding_out = pool2d(act_out) + softmax = paddle.nn.Softmax() + softmax_out = softmax(padding_out) + reshaped_out = paddle.reshape( + softmax_out, [softmax_out.shape[0], -1] + ) + out = paddle.assign(reshaped_out) + self.pass_attr_list = [{'trt_op_marker_pass': {}}] + self.feeds = { + "x": np.random.random((3, 1, 28, 28)).astype("float32"), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.fused_conv2d_add_act": 0, + } + return [main_prog, start_prog] + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def sample_program(self): + yield self.build_ir_program(), False + + def test_check_output(self): + self.check_pass_correct() + + +if __name__ == "__main__": + unittest.main() From 03b777b5742dbfd1a2d1dbb74cd9fb57f9a4fc46 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Thu, 11 Jul 2024 07:04:09 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=E5=8A=A0=E4=BA=86silu,=E4=BF=AE=E6=94=B9Gr?= =?UTF-8?q?oupNrom,=E5=8A=A0=E4=BA=86Matmul,Scale=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transforms/tensorrt/trt_op_marker_pass.cc | 123 ++++++++-------- .../fused_pass/test_pir_trt_op_marker_pass.py | 138 ++++++++++++++++++ 2 files changed, 197 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 9002ce4257f7f..c1c67255395af 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -29,22 +29,19 @@ namespace { inline auto kCanRunTrtAttr = paddle::dialect::kCanRunTrtAttr; -#define DEFINE_GENERAL_PATTERN(OpName, OpType) \ - class OpName##OpPattern : public pir::OpRewritePattern { \ - public: \ - using pir::OpRewritePattern::OpRewritePattern; \ - bool MatchAndRewrite(OpType op, \ - pir::PatternRewriter &rewriter) const override { \ - if (op->HasAttribute(kCanRunTrtAttr) && \ - op->attribute(kCanRunTrtAttr).data()) { \ - VLOG(3) \ - << "Op " << #OpName \ - << "already has kCanRunTrtAttr set to true. Skipping rewrite."; \ - return false; \ - } \ - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); \ - return true; \ - } \ +#define DEFINE_GENERAL_PATTERN(OpName, OpType) \ + class OpName##OpPattern : public pir::OpRewritePattern { \ + public: \ + using pir::OpRewritePattern::OpRewritePattern; \ + bool MatchAndRewrite(OpType op, \ + pir::PatternRewriter &rewriter) const override { \ + if (op->HasAttribute(kCanRunTrtAttr) && \ + op->attribute(kCanRunTrtAttr).data()) { \ + return false; \ + } \ + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); \ + return true; \ + } \ }; DEFINE_GENERAL_PATTERN(Matmul, paddle::dialect::MatmulOp) @@ -55,14 +52,16 @@ DEFINE_GENERAL_PATTERN(Relu, paddle::dialect::ReluOp) DEFINE_GENERAL_PATTERN(FullIntArray, paddle::dialect::FullIntArrayOp) DEFINE_GENERAL_PATTERN(Reshape, paddle::dialect::ReshapeOp) DEFINE_GENERAL_PATTERN(Dropout, paddle::dialect::DropoutOp) -DEFINE_GENERAL_PATTERN(bmm, paddle::dialect::BmmOp) -DEFINE_GENERAL_PATTERN(concat, paddle::dialect::ConcatOp) -DEFINE_GENERAL_PATTERN(flatten, paddle::dialect::FlattenOp) -DEFINE_GENERAL_PATTERN(fused_gemm_epilogue, paddle::dialect::FusedGemmEpilogueOp) -DEFINE_GENERAL_PATTERN(layer_norm, paddle::dialect::LayerNormOp) -DEFINE_GENERAL_PATTERN(add, paddle::dialect::AddOp) -DEFINE_GENERAL_PATTERN(full, paddle::dialect::FullOp) -DEFINE_GENERAL_PATTERN(scale, paddle::dialect::ScaleOp) +DEFINE_GENERAL_PATTERN(Bmm, paddle::dialect::BmmOp) +DEFINE_GENERAL_PATTERN(Concat, paddle::dialect::ConcatOp) +DEFINE_GENERAL_PATTERN(Flatten, paddle::dialect::FlattenOp) +DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue, + paddle::dialect::FusedGemmEpilogueOp) +DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp) +DEFINE_GENERAL_PATTERN(Add, paddle::dialect::AddOp) +DEFINE_GENERAL_PATTERN(Full, paddle::dialect::FullOp) +DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp) + #undef DEFINE_GENERAL_PATTERN class Pool2dOpPattern @@ -73,11 +72,9 @@ class Pool2dOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "Pool2d already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } - auto padding_attr = op->attribute("padding"); + auto padding_attr = op->attribute("paddings"); std::vector paddings; for (const auto &attr : padding_attr.AsVector()) { paddings.push_back(attr.dyn_cast().data()); @@ -153,8 +150,6 @@ class Conv2dOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "conv2d already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } #if IS_TRT_VERSION_LT(7000) @@ -192,8 +187,6 @@ class Conv2dTransposeOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "conv2d_transpose already has kCanRunTrtAttr set to true. " - "Skipping rewrite."; return false; } if (!op->HasAttribute("dilations")) { @@ -226,8 +219,6 @@ class FusedConv2dAddActOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "fused_conv2d_add_act already has kCanRunTrtAttr set to true. " - "Skipping rewrite."; return false; } #if IS_TRT_VERSION_LT(7000) @@ -265,8 +256,6 @@ class DepthwiseConv2dOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "depthwise_conv2d already has kCanRunTrtAttr set to true. " - "Skipping rewrite."; return false; } #if IS_TRT_VERSION_LT(7000) @@ -305,8 +294,6 @@ class DepthwiseConv2dTransposeOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "depthwise_conv2d_transpose already has kCanRunTrtAttr set to " - "true. Skipping rewrite."; return false; } if (!op->HasAttribute("dilations")) { @@ -341,8 +328,6 @@ class DeformableConvOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "deformable_conv already has kCanRunTrtAttr set to " - "true. Skipping rewrite."; return false; } if (!op->HasAttribute("groups") || !op->HasAttribute("strides") || @@ -402,8 +387,6 @@ class ArangeOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "arange already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } #if IS_TRT_VERSION_LT(8400) @@ -427,8 +410,6 @@ class SignOpPattern : public pir::OpRewritePattern { pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "sign already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } #if IS_TRT_VERSION_LT(8200) @@ -448,8 +429,6 @@ class LogicalNotOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "logical_not already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } #if IS_TRT_VERSION_LT(8400) @@ -470,18 +449,16 @@ class GroupNormOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "group_norm already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } if (!op->HasAttribute("epsilon") || !op->HasAttribute("groups") || - !op->HasAttribute("data_layout")) { - VLOG(3) << "In group_norm, epsilon or groups or data_layout attributes " + !op->HasAttribute("data_format")) { + VLOG(3) << "In group_norm, epsilon or groups or data_format attributes " "do not exist"; return false; } std::string layout_str = - op->attribute("data_layout").AsString(); + op->attribute("data_format").AsString(); if (layout_str != "NCHW") { VLOG(3) << "Group norm trt plugin only support NCHW layout, but got " << layout_str; @@ -500,8 +477,6 @@ class TransposeOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "transpose already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } pir::Value x = op.operand_source(0); @@ -542,8 +517,6 @@ class GatherOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "gather already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } pir::Value axis = op.operand_source(2); @@ -573,8 +546,6 @@ class GatherNdOpPattern pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op->attribute(kCanRunTrtAttr).data()) { - VLOG(3) << "gather_nd already has kCanRunTrtAttr set to true. Skipping " - "rewrite."; return false; } #if IS_TRT_VERSION_LT(8200) @@ -602,6 +573,29 @@ class GatherNdOpPattern } }; +class ScaleOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::ScaleOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } + pir::Value x = op.operand_source(0); + auto x_dtype = pir::GetDataTypeFromValue(x); + if (!(x_dtype.isa() || x_dtype.isa() || + x_dtype.isa() || x_dtype.isa() || + x_dtype.isa())) { + VLOG(3) << "At present, ScaleOp only support float32 or float16 or " + "float64 or int32 or int64 into trt."; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class TrtOpMarkerPass : public pir::PatternRewritePass { public: TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {} @@ -619,14 +613,14 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ADD_PATTERN(FullIntArray) ADD_PATTERN(Reshape) ADD_PATTERN(Dropout) - ADD_PATTERN(bmm) - ADD_PATTERN(concat) - ADD_PATTERN(flatten) - ADD_PATTERN(full) - ADD_PATTERN(fused_gemm_epilogue) - ADD_PATTERN(add) - ADD_PATTERN(layer_norm) - + ADD_PATTERN(Bmm) + ADD_PATTERN(Concat) + ADD_PATTERN(Flatten) + ADD_PATTERN(Full) + ADD_PATTERN(Fused_gemm_epilogue) + ADD_PATTERN(Add) + ADD_PATTERN(Layer_norm) + ADD_PATTERN(Silu) #undef ADD_PATTERN ps.Add(std::make_unique(context)); @@ -643,6 +637,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); return ps; } }; diff --git a/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py index e43daed7ccdb9..19419544e7a06 100644 --- a/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py +++ b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py @@ -85,5 +85,143 @@ def test_check_output(self): self.check_pass_correct() +class TestMatmulScaleTRTPattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for x_shape in [[3, 2]]: + for w_shape in [[2, 3]]: + for scale_bias in [1e-7]: + for scale_value in [2.0]: + for bias_after_scale in [True]: + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard( + main_prog, start_prog + ): + x = paddle.static.data( + name='x', shape=x_shape, dtype='float32' + ) + w = paddle.static.data( + name='w', shape=w_shape, dtype='float32' + ) + out = paddle.scale( + paddle.matmul(x, w), + scale=scale_value, + bias=scale_bias, + bias_after_scale=bias_after_scale, + ) + out = paddle.assign(out) + self.pass_attr_list = [ + {'trt_op_marker_pass': {}} + ] + self.feeds = { + "x": np.random.random(x_shape).astype( + "float32" + ), + "w": np.random.random(w_shape).astype( + "float32" + ), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.conv2d": 0, + } + yield [main_prog, start_prog], False + + def setUp(self): + self.places.append(paddle.CPUPlace()) + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct() + + +class TestGroupNormSiluTRTPattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for x_shape in [[2, 6, 4, 2]]: + dtype = None + if core.is_compiled_with_xpu(): + dtype = 'float32' + elif core.is_compiled_with_cuda(): + dtype = 'float16' + for epilson in [1e-5]: + for groups in [2]: + rand_value = ( + 0.001 + * paddle.rand(shape=[x_shape[1]], dtype=dtype).numpy() + ) + with paddle.pir_utils.IrGuard(): + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.pir.core.program_guard( + main_prog, start_prog + ): + x = paddle.static.data( + name='x', shape=x_shape, dtype=dtype + ) + w = create_parameter( + shape=[x_shape[1]], + dtype=dtype, + initializer=paddle.nn.initializer.Assign( + rand_value + ), + ) + b = create_parameter( + shape=[x_shape[1]], + dtype=dtype, + initializer=paddle.nn.initializer.Assign( + rand_value + ), + ) + group_norm_out = paddle.nn.functional.group_norm( + x, + num_groups=groups, + epsilon=epilson, + weight=w, + bias=b, + data_format="NCHW", + ) + out = paddle.nn.functional.silu(group_norm_out) + out = paddle.assign(out) + if core.is_compiled_with_xpu(): + self.pass_attr_list = [ + {'trt_op_marker_pass': {}} + ] + elif core.is_compiled_with_cuda(): + self.pass_attr_list = [ + {'trt_op_marker_pass': {}} + ] + self.feeds = { + "x": np.random.random(x_shape).astype(dtype), + } + self.fetch_list = [out] + if core.is_compiled_with_xpu(): + self.valid_op_map = { + "pd_op.group_norm_silu_xpu": 0, + } + elif core.is_compiled_with_cuda(): + self.valid_op_map = { + "pd_op.add_group_norm_silu": 0, + } + + yield [main_prog, start_prog], False + + def setUp(self): + if core.is_compiled_with_xpu(): + self.places.append(paddle.XPUPlace(0)) + elif core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct() + + if __name__ == "__main__": unittest.main()