Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon NO.75] 为 Paddle-TRT 添加 expend_as_v2 算子 #51028

Merged
merged 7 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2553,6 +2553,7 @@ USE_TRT_CONVERTER(tanh_shrink)
USE_TRT_CONVERTER(logsigmoid)
USE_TRT_CONVERTER(lookup_table)
USE_TRT_CONVERTER(expand_v2)
USE_TRT_CONVERTER(expand_as_v2)
USE_TRT_CONVERTER(take_along_axis)
USE_TRT_CONVERTER(skip_groupnorm_act)
USE_TRT_CONVERTER(preln_groupnorm_act)
Expand Down
88 changes: 53 additions & 35 deletions paddle/fluid/inference/tensorrt/convert/expand_v2_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,26 +14,16 @@ limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {

class ExpandV2OpConverter : public OpConverter {
class ExpandOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a paddle expand_v2 op to trt expand layer.";
VLOG(3) << "convert a paddle " << op_type_ << " op to trt expand layer.";
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto inputs = op_desc.Inputs();
Expand All @@ -43,25 +33,40 @@ class ExpandV2OpConverter : public OpConverter {

nvinfer1::ITensor* shape_tensor = nullptr;
int32_t shape_rank = 0;
if (inputs.find("Shape") != inputs.end() &&
op_desc.Input("Shape").size() >= 1) {
shape_tensor = engine_->GetITensor(op_desc.Input("Shape")[0]);
shape_rank = shape_tensor->getDimensions().d[0];
} else if (inputs.find("expand_shapes_tensor") != inputs.end() &&
op_desc.Input("expand_shapes_tensor").size() >= 1) {
int shape_size = op_desc.Input("expand_shapes_tensor").size();
std::vector<nvinfer1::ITensor*> shape_tensors;
for (int i = 0; i < shape_size; ++i) {
shape_tensors.push_back(
engine_->GetITensor(op_desc.Input("expand_shapes_tensor")[i]));

if (op_type_ == "expand_v2") {
if (inputs.find("Shape") != inputs.end() &&
op_desc.Input("Shape").size() >= 1) {
shape_tensor = engine_->GetITensor(op_desc.Input("Shape")[0]);
shape_rank = shape_tensor->getDimensions().nbDims;
} else if (inputs.find("expand_shapes_tensor") != inputs.end() &&
op_desc.Input("expand_shapes_tensor").size() >= 1) {
int shape_size = op_desc.Input("expand_shapes_tensor").size();
std::vector<nvinfer1::ITensor*> shape_tensors;
for (int i = 0; i < shape_size; ++i) {
shape_tensors.push_back(
engine_->GetITensor(op_desc.Input("expand_shapes_tensor")[i]));
}
shape_tensor = Concat(shape_tensors);
shape_rank = shape_size;
} else {
std::vector<int32_t> shape =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("shape"));
shape_tensor =
Add1DConstantLayer(shape, output_name + "_shape_tensor_");
shape_rank = shape.size();
}
} else if (op_type_ == "expand_as_v2") {
if (inputs.find("Y") != inputs.end()) {
shape_tensor = engine_->GetITensor(op_desc.Input("Y")[0]);
shape_rank = shape_tensor->getDimensions().nbDims;
} else {
std::vector<int32_t> shape = PADDLE_GET_CONST(
std::vector<int32_t>, op_desc.GetAttr("target_shape"));
shape_tensor =
Add1DConstantLayer(shape, output_name + "_target_shape_tensor_");
shape_rank = shape.size();
}
shape_tensor = Concat(shape_tensors);
shape_rank = shape_size;
} else {
std::vector<int32_t> shape =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("shape"));
shape_tensor = Add1DConstantLayer(shape, output_name + "_shape_tensor_");
shape_rank = shape.size();
}

nvinfer1::ITensor* input_shape_tensor;
Expand All @@ -78,8 +83,7 @@ class ExpandV2OpConverter : public OpConverter {
input_shape_tensor = Shape(input);
}

auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
shuffle->setInput(1, *input_shape_tensor);
auto* newInputTensor = Reshape(input, input_shape_tensor);

std::vector<int32_t> start_vec(shape_rank, 0);
nvinfer1::Dims start;
Expand All @@ -101,17 +105,31 @@ class ExpandV2OpConverter : public OpConverter {
auto strides_tensor = Min(one_tensor, input_sub_tensor);

auto layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shuffle->getOutput(0), start, size, stride);
engine_, Slice, *newInputTensor, start, size, stride);
layer->setInput(1, *starts_tensor);
layer->setInput(2, *sizes_tensor);
layer->setInput(3, *strides_tensor);

RreplenishLayerAndOutput(layer, "expand_v2", {output_name}, test_mode);
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
}

protected:
std::string op_type_;
};

class ExpandV2OpConverter : public ExpandOpConverter {
public:
ExpandV2OpConverter() { op_type_ = "expand_v2"; }
};

class ExpandAsV2OpConverter : public ExpandOpConverter {
public:
ExpandAsV2OpConverter() { op_type_ = "expand_as_v2"; }
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(expand_v2, ExpandV2OpConverter);
REGISTER_TRT_OP_CONVERTER(expand_as_v2, ExpandAsV2OpConverter);
11 changes: 11 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ class OpConverter {
return TRT_ENGINE_ADD_LAYER(engine_, Shape, *input)->getOutput(0);
}

nvinfer1::ITensor* Reshape(nvinfer1::ITensor* input,
nvinfer1::ITensor* newShape) {
nvinfer1::ITensor* oldShape = Shape(input);
if (oldShape == newShape) {
return input;
}
auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
shuffle->setInput(1, *newShape);
return shuffle->getOutput(0);
}

// Concat not make rank changed
nvinfer1::ITensor* Concat(const std::vector<nvinfer1::ITensor*>& inputs,
int axis = 0) {
Expand Down
30 changes: 28 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2533,11 +2533,35 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "expand_v2") {
if (op_type == "expand_as_v2" || op_type == "expand_v2") {
if (!with_dynamic_shape) {
VLOG(3) << "the " << op_type
<< "does not support "
"static shape yet";
return false;
}
if (!desc.HasAttr("shape")) {

auto inputs = desc.Inputs();
if (op_type == "expand_as_v2") {
if (!desc.HasAttr("target_shape") && inputs.find("Y") == inputs.end()) {
VLOG(3)
<< "expand_as_v2 op need have input(Y) or attr(target_shape). ";
return false;
}
} else if (op_type == "expand_v2") {
if (!desc.HasAttr("shape") && inputs.find("Shape") == inputs.end() &&
inputs.find("expand_shapes_tensor") == inputs.end()) {
VLOG(3) << "expand_v2 op need have input(Shape) or "
"input(expand_shapes_tensor) or attr(shape) . ";
return false;
}
}

auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
}
Expand Down Expand Up @@ -2699,6 +2723,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"skip_merge_layernorm",
"lookup_table_v2",
"expand_v2",
"expand_as_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
Expand Down Expand Up @@ -2851,6 +2876,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"lookup_table",
"lookup_table_v2",
"expand_v2",
"expand_as_v2",
"fuse_eleadd_transpose",
"skip_groupnorm_act",
"preln_groupnorm_act"};
Expand Down
Loading