Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle TensorRT]拆分单测文件 #8

Open
wants to merge 10 commits into
base: pir_trt_2
Choose a base branch
from
166 changes: 166 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)
DEFINE_GENERAL_PATTERN(Conv2d, paddle::dialect::Conv2dOp)
DEFINE_GENERAL_PATTERN(FusedConv2dAddAct, paddle::dialect::FusedConv2dAddActOp)
DEFINE_GENERAL_PATTERN(DepthwiseConv2d, paddle::dialect::DepthwiseConv2dOp)
DEFINE_GENERAL_PATTERN(Shape, paddle::dialect::ShapeOp)
DEFINE_GENERAL_PATTERN(Expand, paddle::dialect::ExpandOp)
DEFINE_GENERAL_PATTERN(Sigmoid, paddle::dialect::SigmoidOp)

#undef DEFINE_GENERAL_PATTERN
Expand Down Expand Up @@ -919,6 +921,161 @@ class MultiplyOpPattern
}
};

class SubtractOpPattern : public pir::OpRewritePattern<paddle::dialect::SubtractOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SubtractOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SubtractOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_sub do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class DivideOpPattern : public pir::OpRewritePattern<paddle::dialect::DivideOp> {
public:
using pir::OpRewritePattern<paddle::dialect::DivideOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::DivideOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_div do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class ElementwisePowOpPattern : public pir::OpRewritePattern<paddle::dialect::ElementwisePowOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ElementwisePowOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::ElementwisePowOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()||x_dtype.isa<pir::Int32Type>()|| y_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::Int32Type>()){
VLOG(3) << "elementwise_pow do not support boolean datatype and int32 datatype.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class MinimumOpPattern : public pir::OpRewritePattern<paddle::dialect::MinimumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MinimumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::MinimumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_min do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class MaximumOpPattern : public pir::OpRewritePattern<paddle::dialect::MaximumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MaximumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::MaximumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_max do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class FloorDivideOpPattern : public pir::OpRewritePattern<paddle::dialect::FloorDivideOp> {
public:
using pir::OpRewritePattern<paddle::dialect::FloorDivideOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::FloorDivideOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_floordiv do not support boolean datatype.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class RemainderOpPattern : public pir::OpRewritePattern<paddle::dialect::RemainderOp> {
public:
using pir::OpRewritePattern<paddle::dialect::RemainderOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::RemainderOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()|| y_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_mod do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -948,6 +1105,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(DepthwiseConv2d)
ADD_PATTERN(Nonzero)
ADD_PATTERN(Gelu)
ADD_PATTERN(Shape)
ADD_PATTERN(Expand)
ADD_PATTERN(Sigmoid)

#undef ADD_PATTERN
Expand All @@ -974,6 +1133,13 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
ps.Add(std::make_unique<MultiplyOpPattern>(context));
ps.Add(std::make_unique<SubtractOpPattern>(context));
ps.Add(std::make_unique<DivideOpPattern>(context));
ps.Add(std::make_unique<ElementwisePowOpPattern>(context));
ps.Add(std::make_unique<MinimumOpPattern>(context));
ps.Add(std::make_unique<MaximumOpPattern>(context));
ps.Add(std::make_unique<FloorDivideOpPattern>(context));
ps.Add(std::make_unique<RemainderOpPattern>(context));
return ps;
}
};
Expand Down
58 changes: 58 additions & 0 deletions test/tensorrt/test_trt_marker_divide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestDivideTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3], dtype='float32')
y = paddle.static.data(name='y', shape=[3], dtype='float32')
divide_out = paddle.divide(x, y)
out = paddle.assign(divide_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([2, 3, 4]).astype("float32"),
"y": np.array([1, 5, 2]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.divide"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
58 changes: 58 additions & 0 deletions test/tensorrt/test_trt_marker_elementwise_pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestElementWisePowTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3], dtype='float32')
y = paddle.static.data(name='y', shape=[1], dtype='float32')
pow_out = paddle.pow(x, y)
out = paddle.assign(pow_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([1, 2, 3]).astype("float32"),
"y": np.array([2]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.elementwise_pow"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
56 changes: 56 additions & 0 deletions test/tensorrt/test_trt_marker_expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core


class TestExpandTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name="x", shape=[3], dtype="float32")
expand_out = paddle.expand(x, shape=[2, 3])
out = paddle.assign(expand_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([[1, 2, 3]]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
self.trt_expected_ops = {"pd_op.expand"}

def test_check_output(self):
self.check_pass_correct()


if __name__ == '__main__':
unittest.main()
Loading