Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General plugin mechanism #45355

Merged
merged 1 commit into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,8 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/inference/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ else()
SRCS engine.cc trt_int8_calibrator.cc
DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
endif()
nv_library(
tensorrt_dynamic_shape_infermeta_factory
SRCS dynamic_shape_infermeta.cc
DEPS framework_proto)
nv_library(
tensorrt_plugin_arg_mapping_context
SRCS plugin_arg_mapping_context.cc
DEPS framework_proto)
nv_library(
tensorrt_op_teller
SRCS op_teller.cc
DEPS framework_proto device_context)
DEPS framework_proto device_context tensorrt_dynamic_shape_infermeta_factory)
nv_test(
test_tensorrt
SRCS test_tensorrt.cc
Expand All @@ -24,6 +32,10 @@ nv_test(
test_tensorrt_engine
SRCS test_engine.cc test_dynamic_engine.cc
DEPS dynload_cuda tensorrt_engine tensorrt_plugin)
nv_test(
test_arg_mapping_context
SRCS test_arg_mapping_context.cc
DEPS framework_proto tensorrt_plugin_arg_mapping_context)

if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
Expand Down
21 changes: 19 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ list(
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc)
layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc)

if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
Expand All @@ -85,7 +86,12 @@ endif()
nv_library(
tensorrt_converter
SRCS ${CONVERT_FILES}
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
DEPS tensorrt_engine
tensorrt_plugin
operator
scope
framework_proto
tensorrt_op_teller
op_registry)

nv_test(
Expand All @@ -94,6 +100,17 @@ nv_test(
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine
tensorrt_converter)

nv_test(
test_custom_plugin_creater
SRCS test_custom_plugin_creater.cc
DEPS paddle_framework
${GLOB_OPERATOR_DEPS}
tensorrt_engine
tensorrt_plugin
tensorrt_converter
op_meta_info
custom_operator)

if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
# be build only in CI, so suppose the generator in Windows is Ninja.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"

namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class CustomPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert " << op_desc.Type() << " op to custom pluign layer";

std::string plugin_name;

if (engine_->with_dynamic_shape()) {
plugin_name = op_desc.Type() + "_paddle_trt_dynamic_plugin";
} else {
plugin_name = op_desc.Type() + "_paddle_trt_plugin";
}

nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;

auto &op_meta_info_map = OpMetaInfoMap::Instance();
const auto &meta_info_map = op_meta_info_map.GetMap();
auto &op_info = meta_info_map.at(op_desc.Type()).front();

// set inputs
auto &op_input_names = framework::OpMetaInfoHelper::GetInputs(op_info);
for (auto &param_name : op_input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}
inputs.push_back(engine_->GetITensor(arg_name));
}
}
auto creator =
GetPluginRegistry()->getPluginCreator(plugin_name.c_str(), "1");
CHECK(creator);

// set attrs
std::vector<nvinfer1::PluginField> plugindatas;
auto &op_attrs_names = framework::OpMetaInfoHelper::GetAttrs(op_info);
auto &attrs = op_desc.GetAttrMap();

std::list<int> int_attrs;
std::list<float> float_attrs;
std::list<double> bool_attrs;
std::list<std::string> string_attrs;
std::list<std::vector<int>> ints_attrs;
std::list<std::vector<float>> floats_attrs;

for (auto &attr_name : op_attrs_names) {
nvinfer1::PluginField plugindata;
plugindata.name = attr_name.c_str();
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

貌似可以用宏折叠起来。其他switch case 同

int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOAT) {
float_attrs.push_back(PADDLE_GET_CONST(float, attrs.at(attr_name)));
plugindata.data = &float_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEAN) {
int_attrs.push_back(PADDLE_GET_CONST(bool, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::STRING) {
string_attrs.push_back(
PADDLE_GET_CONST(std::string, attrs.at(attr_name)));
plugindata.data = string_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kCHAR;
plugindata.length =
string_attrs.back().size() + 1; // string ends with ‘\0’
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::INTS) {
ints_attrs.push_back(
PADDLE_GET_CONST(std::vector<int>, attrs.at(attr_name)));
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOATS) {
floats_attrs.push_back(
PADDLE_GET_CONST(std::vector<float>, attrs.at(attr_name)));
plugindata.data = floats_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = floats_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEANS) {
auto bools_attr =
PADDLE_GET_CONST(std::vector<bool>, attrs.at(attr_name));
std::vector<int> convert_to_ints_attr;
for (bool i : bools_attr) convert_to_ints_attr.push_back(i);
ints_attrs.push_back(convert_to_ints_attr);
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else {
CHECK(false) << "UNKNOWN PluginFieldType.";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not support float, and other dtype ?

Copy link
Contributor Author

@weishengying weishengying Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充完整

plugindatas.push_back(plugindata);
}

nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugindatas.size(),
plugindatas.data()};

auto *plugin = creator->createPlugin(op_desc.Type().c_str(), &plugin_fc);
CHECK(plugin);

if (engine_->with_dynamic_shape()) {
layer =
engine_->AddDynamicPlugin(inputs.data(),
inputs.size(),
(plugin::DynamicPluginTensorRT *)plugin);
} else {
layer = engine_->AddPlugin(
inputs.data(), inputs.size(), (plugin::PluginTensorRT *)plugin);
}

CHECK(layer);

// set outputs
auto &op_output_names = framework::OpMetaInfoHelper::GetOutputs(op_info);
std::vector<std::string> output_names;
for (auto &param_name : op_output_names) {
for (auto &arg_name : op_desc.Output(param_name))
output_names.push_back(arg_name);
}

RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};

class GenericPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
CHECK(block_);
const framework::BlockDesc block_desc(
nullptr, const_cast<framework::proto::BlockDesc *>(block_));

nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;

phi::KernelSignature phi_kernel_signature;
if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_desc.Type())) {
const phi::ArgumentMappingFn *argument_mapping_func =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_desc.Type());
PluginArgumentMappingContext argument_mapping_context(&op_desc);
phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context);
} else {
phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
}

plugin::GenericPlugin::InputOutPutVarInfo in_out_info;

for (auto &param_name : phi_kernel_signature.input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}

inputs.push_back(engine_->GetITensor(arg_name));
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.inputs_data_type.push_back(var->GetDataType());
}
}

std::vector<std::string> output_names;
for (auto &param_name : phi_kernel_signature.output_names) {
for (auto &arg_name : op_desc.Output(param_name)) {
output_names.push_back(arg_name);
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.outputs_data_type.push_back(var->GetDataType());
}
}
plugin::GenericPlugin *plugin = new plugin::GenericPlugin(op, in_out_info);
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);

RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(custom_plugin_creater, CustomPluginCreater);
REGISTER_TRT_OP_CONVERTER(generic_plugin_creater, GenericPluginCreater);
Loading