Skip to content

Commit

Permalink
General Plugin Mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
weishengying committed Sep 14, 2022
1 parent 1349584 commit a3a20a4
Show file tree
Hide file tree
Showing 23 changed files with 2,581 additions and 389 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,8 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/inference/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ else()
SRCS engine.cc trt_int8_calibrator.cc
DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
endif()
nv_library(
tensorrt_dynamic_shape_infermeta_factory
SRCS dynamic_shape_infermeta.cc
DEPS framework_proto)
nv_library(
tensorrt_plugin_arg_mapping_context
SRCS plugin_arg_mapping_context.cc
DEPS framework_proto)
nv_library(
tensorrt_op_teller
SRCS op_teller.cc
DEPS framework_proto device_context)
DEPS framework_proto device_context tensorrt_dynamic_shape_infermeta_factory)
nv_test(
test_tensorrt
SRCS test_tensorrt.cc
Expand All @@ -24,6 +32,10 @@ nv_test(
test_tensorrt_engine
SRCS test_engine.cc test_dynamic_engine.cc
DEPS dynload_cuda tensorrt_engine tensorrt_plugin)
nv_test(
test_arg_mapping_context
SRCS test_arg_mapping_context.cc
DEPS framework_proto tensorrt_plugin_arg_mapping_context)

if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
Expand Down
21 changes: 19 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ list(
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc)
layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc)

if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
Expand All @@ -85,7 +86,12 @@ endif()
nv_library(
tensorrt_converter
SRCS ${CONVERT_FILES}
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
DEPS tensorrt_engine
tensorrt_plugin
operator
scope
framework_proto
tensorrt_op_teller
op_registry)

nv_test(
Expand All @@ -94,6 +100,17 @@ nv_test(
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine
tensorrt_converter)

nv_test(
test_custom_plugin_creater
SRCS test_custom_plugin_creater.cc
DEPS paddle_framework
${GLOB_OPERATOR_DEPS}
tensorrt_engine
tensorrt_plugin
tensorrt_converter
op_meta_info
custom_operator)

if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
# be build only in CI, so suppose the generator in Windows is Ninja.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"

namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class CustomPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert " << op_desc.Type() << " op to custom pluign layer";

std::string plugin_name;

if (engine_->with_dynamic_shape()) {
plugin_name = op_desc.Type() + "_paddle_trt_dynamic_plugin";
} else {
plugin_name = op_desc.Type() + "_paddle_trt_plugin";
}

nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;

auto &op_meta_info_map = OpMetaInfoMap::Instance();
const auto &meta_info_map = op_meta_info_map.GetMap();
auto &op_info = meta_info_map.at(op_desc.Type()).front();

// set inputs
auto &op_input_names = framework::OpMetaInfoHelper::GetInputs(op_info);
for (auto &param_name : op_input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}
inputs.push_back(engine_->GetITensor(arg_name));
}
}
auto creator =
GetPluginRegistry()->getPluginCreator(plugin_name.c_str(), "1");
CHECK(creator);

// set attrs
std::vector<nvinfer1::PluginField> plugindatas;
auto &op_attrs_names = framework::OpMetaInfoHelper::GetAttrs(op_info);
auto &attrs = op_desc.GetAttrMap();

std::list<int> int_attrs;
std::list<float> float_attrs;
std::list<double> bool_attrs;
std::list<std::string> string_attrs;
std::list<std::vector<int>> ints_attrs;
std::list<std::vector<float>> floats_attrs;

for (auto &attr_name : op_attrs_names) {
nvinfer1::PluginField plugindata;
plugindata.name = attr_name.c_str();
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) {
int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOAT) {
float_attrs.push_back(PADDLE_GET_CONST(float, attrs.at(attr_name)));
plugindata.data = &float_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEAN) {
int_attrs.push_back(PADDLE_GET_CONST(bool, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::STRING) {
string_attrs.push_back(
PADDLE_GET_CONST(std::string, attrs.at(attr_name)));
plugindata.data = string_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kCHAR;
plugindata.length =
string_attrs.back().size() + 1; // string ends with ‘\0’
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::INTS) {
ints_attrs.push_back(
PADDLE_GET_CONST(std::vector<int>, attrs.at(attr_name)));
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOATS) {
floats_attrs.push_back(
PADDLE_GET_CONST(std::vector<float>, attrs.at(attr_name)));
plugindata.data = floats_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = floats_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEANS) {
auto bools_attr =
PADDLE_GET_CONST(std::vector<bool>, attrs.at(attr_name));
std::vector<int> convert_to_ints_attr;
for (bool i : bools_attr) convert_to_ints_attr.push_back(i);
ints_attrs.push_back(convert_to_ints_attr);
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else {
CHECK(false) << "UNKNOWN PluginFieldType.";
}
plugindatas.push_back(plugindata);
}

nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugindatas.size(),
plugindatas.data()};

auto *plugin = creator->createPlugin(op_desc.Type().c_str(), &plugin_fc);
CHECK(plugin);

if (engine_->with_dynamic_shape()) {
layer =
engine_->AddDynamicPlugin(inputs.data(),
inputs.size(),
(plugin::DynamicPluginTensorRT *)plugin);
} else {
layer = engine_->AddPlugin(
inputs.data(), inputs.size(), (plugin::PluginTensorRT *)plugin);
}

CHECK(layer);

// set outputs
auto &op_output_names = framework::OpMetaInfoHelper::GetOutputs(op_info);
std::vector<std::string> output_names;
for (auto &param_name : op_output_names) {
for (auto &arg_name : op_desc.Output(param_name))
output_names.push_back(arg_name);
}

RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};

class GenericPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
CHECK(block_);
const framework::BlockDesc block_desc(
nullptr, const_cast<framework::proto::BlockDesc *>(block_));

nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;

phi::KernelSignature phi_kernel_signature;
if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_desc.Type())) {
const phi::ArgumentMappingFn *argument_mapping_func =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_desc.Type());
PluginArgumentMappingContext argument_mapping_context(&op_desc);
phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context);
} else {
phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
}

plugin::GenericPlugin::InputOutPutVarInfo in_out_info;

for (auto &param_name : phi_kernel_signature.input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}

inputs.push_back(engine_->GetITensor(arg_name));
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.inputs_data_type.push_back(var->GetDataType());
}
}

std::vector<std::string> output_names;
for (auto &param_name : phi_kernel_signature.output_names) {
for (auto &arg_name : op_desc.Output(param_name)) {
output_names.push_back(arg_name);
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.outputs_data_type.push_back(var->GetDataType());
}
}
plugin::GenericPlugin *plugin = new plugin::GenericPlugin(op, in_out_info);
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);

RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(custom_plugin_creater, CustomPluginCreater);
REGISTER_TRT_OP_CONVERTER(generic_plugin_creater, GenericPluginCreater);
Loading

0 comments on commit a3a20a4

Please sign in to comment.