Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#2 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
Add trt
  • Loading branch information
vivienfanghuagood committed Jul 11, 2024
2 parents 3d0ae20 + 03b777b commit 6034d99
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 64 deletions.
123 changes: 59 additions & 64 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,19 @@ namespace {

inline auto kCanRunTrtAttr = paddle::dialect::kCanRunTrtAttr;

#define DEFINE_GENERAL_PATTERN(OpName, OpType) \
class OpName##OpPattern : public pir::OpRewritePattern<OpType> { \
public: \
using pir::OpRewritePattern<OpType>::OpRewritePattern; \
bool MatchAndRewrite(OpType op, \
pir::PatternRewriter &rewriter) const override { \
if (op->HasAttribute(kCanRunTrtAttr) && \
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { \
VLOG(3) \
<< "Op " << #OpName \
<< "already has kCanRunTrtAttr set to true. Skipping rewrite."; \
return false; \
} \
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); \
return true; \
} \
#define DEFINE_GENERAL_PATTERN(OpName, OpType) \
class OpName##OpPattern : public pir::OpRewritePattern<OpType> { \
public: \
using pir::OpRewritePattern<OpType>::OpRewritePattern; \
bool MatchAndRewrite(OpType op, \
pir::PatternRewriter &rewriter) const override { \
if (op->HasAttribute(kCanRunTrtAttr) && \
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { \
return false; \
} \
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); \
return true; \
} \
};

DEFINE_GENERAL_PATTERN(Matmul, paddle::dialect::MatmulOp)
Expand All @@ -55,14 +52,16 @@ DEFINE_GENERAL_PATTERN(Relu, paddle::dialect::ReluOp)
DEFINE_GENERAL_PATTERN(FullIntArray, paddle::dialect::FullIntArrayOp)
DEFINE_GENERAL_PATTERN(Reshape, paddle::dialect::ReshapeOp)
DEFINE_GENERAL_PATTERN(Dropout, paddle::dialect::DropoutOp)
DEFINE_GENERAL_PATTERN(bmm, paddle::dialect::BmmOp)
DEFINE_GENERAL_PATTERN(concat, paddle::dialect::ConcatOp)
DEFINE_GENERAL_PATTERN(flatten, paddle::dialect::FlattenOp)
DEFINE_GENERAL_PATTERN(fused_gemm_epilogue, paddle::dialect::FusedGemmEpilogueOp)
DEFINE_GENERAL_PATTERN(layer_norm, paddle::dialect::LayerNormOp)
DEFINE_GENERAL_PATTERN(add, paddle::dialect::AddOp)
DEFINE_GENERAL_PATTERN(full, paddle::dialect::FullOp)
DEFINE_GENERAL_PATTERN(scale, paddle::dialect::ScaleOp)
DEFINE_GENERAL_PATTERN(Bmm, paddle::dialect::BmmOp)
DEFINE_GENERAL_PATTERN(Concat, paddle::dialect::ConcatOp)
DEFINE_GENERAL_PATTERN(Flatten, paddle::dialect::FlattenOp)
DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue,
paddle::dialect::FusedGemmEpilogueOp)
DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp)
DEFINE_GENERAL_PATTERN(Add, paddle::dialect::AddOp)
DEFINE_GENERAL_PATTERN(Full, paddle::dialect::FullOp)
DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)

#undef DEFINE_GENERAL_PATTERN

class Pool2dOpPattern
Expand All @@ -73,11 +72,9 @@ class Pool2dOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "Pool2d already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
auto padding_attr = op->attribute<pir::ArrayAttribute>("padding");
auto padding_attr = op->attribute<pir::ArrayAttribute>("paddings");
std::vector<int32_t> paddings;
for (const auto &attr : padding_attr.AsVector()) {
paddings.push_back(attr.dyn_cast<pir::Int32Attribute>().data());
Expand Down Expand Up @@ -153,8 +150,6 @@ class Conv2dOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "conv2d already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(7000)
Expand Down Expand Up @@ -192,8 +187,6 @@ class Conv2dTransposeOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "conv2d_transpose already has kCanRunTrtAttr set to true. "
"Skipping rewrite.";
return false;
}
if (!op->HasAttribute("dilations")) {
Expand Down Expand Up @@ -226,8 +219,6 @@ class FusedConv2dAddActOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "fused_conv2d_add_act already has kCanRunTrtAttr set to true. "
"Skipping rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(7000)
Expand Down Expand Up @@ -265,8 +256,6 @@ class DepthwiseConv2dOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "depthwise_conv2d already has kCanRunTrtAttr set to true. "
"Skipping rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(7000)
Expand Down Expand Up @@ -305,8 +294,6 @@ class DepthwiseConv2dTransposeOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "depthwise_conv2d_transpose already has kCanRunTrtAttr set to "
"true. Skipping rewrite.";
return false;
}
if (!op->HasAttribute("dilations")) {
Expand Down Expand Up @@ -341,8 +328,6 @@ class DeformableConvOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "deformable_conv already has kCanRunTrtAttr set to "
"true. Skipping rewrite.";
return false;
}
if (!op->HasAttribute("groups") || !op->HasAttribute("strides") ||
Expand Down Expand Up @@ -402,8 +387,6 @@ class ArangeOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "arange already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(8400)
Expand All @@ -427,8 +410,6 @@ class SignOpPattern : public pir::OpRewritePattern<paddle::dialect::SignOp> {
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "sign already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(8200)
Expand All @@ -448,8 +429,6 @@ class LogicalNotOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "logical_not already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(8400)
Expand All @@ -470,18 +449,16 @@ class GroupNormOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "group_norm already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
if (!op->HasAttribute("epsilon") || !op->HasAttribute("groups") ||
!op->HasAttribute("data_layout")) {
VLOG(3) << "In group_norm, epsilon or groups or data_layout attributes "
!op->HasAttribute("data_format")) {
VLOG(3) << "In group_norm, epsilon or groups or data_format attributes "
"do not exist";
return false;
}
std::string layout_str =
op->attribute<pir::StrAttribute>("data_layout").AsString();
op->attribute<pir::StrAttribute>("data_format").AsString();
if (layout_str != "NCHW") {
VLOG(3) << "Group norm trt plugin only support NCHW layout, but got "
<< layout_str;
Expand All @@ -500,8 +477,6 @@ class TransposeOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "transpose already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
pir::Value x = op.operand_source(0);
Expand Down Expand Up @@ -542,8 +517,6 @@ class GatherOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "gather already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
pir::Value axis = op.operand_source(2);
Expand Down Expand Up @@ -573,8 +546,6 @@ class GatherNdOpPattern
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
VLOG(3) << "gather_nd already has kCanRunTrtAttr set to true. Skipping "
"rewrite.";
return false;
}
#if IS_TRT_VERSION_LT(8200)
Expand Down Expand Up @@ -602,6 +573,29 @@ class GatherNdOpPattern
}
};

class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ScaleOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::ScaleOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
auto x_dtype = pir::GetDataTypeFromValue(x);
if (!(x_dtype.isa<pir::Float32Type>() || x_dtype.isa<pir::Float64Type>() ||
x_dtype.isa<pir::Float16Type>() || x_dtype.isa<pir::Int32Type>() ||
x_dtype.isa<pir::Int64Type>())) {
VLOG(3) << "At present, ScaleOp only support float32 or float16 or "
"float64 or int32 or int64 into trt.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand All @@ -619,14 +613,14 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(FullIntArray)
ADD_PATTERN(Reshape)
ADD_PATTERN(Dropout)
ADD_PATTERN(bmm)
ADD_PATTERN(concat)
ADD_PATTERN(flatten)
ADD_PATTERN(full)
ADD_PATTERN(fused_gemm_epilogue)
ADD_PATTERN(add)
ADD_PATTERN(layer_norm)

ADD_PATTERN(Bmm)
ADD_PATTERN(Concat)
ADD_PATTERN(Flatten)
ADD_PATTERN(Full)
ADD_PATTERN(Fused_gemm_epilogue)
ADD_PATTERN(Add)
ADD_PATTERN(Layer_norm)
ADD_PATTERN(Silu)

#undef ADD_PATTERN
ps.Add(std::make_unique<Pool2dOpPattern>(context));
Expand All @@ -643,6 +637,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<TransposeOpPattern>(context));
ps.Add(std::make_unique<GatherOpPattern>(context));
ps.Add(std::make_unique<GatherNdOpPattern>(context));
ps.Add(std::make_unique<ScaleOpPattern>(context));
return ps;
}
};
Expand Down
Loading

0 comments on commit 6034d99

Please sign in to comment.