From 004c00675f912b9ac59ba07655295c86a39e29d1 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 16 Jun 2021 21:38:07 +0800 Subject: [PATCH] [Feature]: add modulated deformable conv TensorRT support (#1078) * add modulated dcn, better dcn plugin * clangformat * update documentation --- docs/tensorrt_custom_ops.md | 48 +++ docs/tensorrt_plugin.md | 7 +- .../modulated_deform_conv_cuda_kernel.cuh | 9 +- .../csrc/tensorrt/plugins/trt_cuda_helper.cu | 24 ++ .../csrc/tensorrt/plugins/trt_deform_conv.cpp | 23 +- .../plugins/trt_deform_conv_kernel.cu | 33 -- .../plugins/trt_modulated_deform_conv.cpp | 307 ++++++++++++++++++ .../trt_modulated_deform_conv_kernel.cu | 133 ++++++++ mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp | 2 + mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh | 14 +- mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp | 4 +- .../tensorrt/trt_modulated_deform_conv.hpp | 120 +++++++ mmcv/ops/modulated_deform_conv.py | 11 +- tests/test_ops/test_tensorrt.py | 71 ++++ 14 files changed, 746 insertions(+), 60 deletions(-) create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu create mode 100644 mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp diff --git a/docs/tensorrt_custom_ops.md b/docs/tensorrt_custom_ops.md index 0b5b1b83a7..1ef48ece06 100644 --- a/docs/tensorrt_custom_ops.md +++ b/docs/tensorrt_custom_ops.md @@ -51,6 +51,12 @@ - [Inputs](#inputs-7) - [Outputs](#outputs-7) - [Type Constraints](#type-constraints-7) + - [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d) + - [Description](#description-8) + - [Parameters](#parameters-8) + - [Inputs](#inputs-8) + - [Outputs](#outputs-8) + - [Type Constraints](#type-constraints-8) @@ -345,3 +351,45 @@ y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance a ### Type Constraints - T:tensor(float32, Linear) + +## MMCVModulatedDeformConv2d + +### Description + +Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail. + +### Parameters + +| Type | Parameter | Description | +| -------------- | ------------------ | ------------------------------------------------------------------------------------- | +| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) | +| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) | +| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) | +| `int` | `deformable_group` | Groups of deformable offset. | +| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. | + +### Inputs + +
+
inputs[0]: T
+
Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.
+
inputs[1]: T
+
Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.
+
inputs[2]: T
+
Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.
+
inputs[3]: T
+
Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).
+
inputs[4]: T, optional
+
Input weight; 1-D tensor of shape (output_channel).
+
+ +### Outputs + +
+
outputs[0]: T
+
Output feature; 4-D tensor of shape (N, output_channel, outH, outW).
+
+ +### Type Constraints + +- T:tensor(float32, Linear) diff --git a/docs/tensorrt_plugin.md b/docs/tensorrt_plugin.md index 0da0d1b23e..325c79762e 100644 --- a/docs/tensorrt_plugin.md +++ b/docs/tensorrt_plugin.md @@ -31,9 +31,10 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u | NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | | MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | | grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 | -| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | -| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | -| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master | +| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | 1.3.5 | +| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | 1.3.5 | +| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | 1.3.5 | +| MMCVModulatedDeformConv2d | [MMCVModulatedDeformConv2d](./tensorrt_custom_ops.md#mmcvmodulateddeformconv2d) | master | Notes diff --git a/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh b/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh index 04bf5c308d..ca0e91a252 100644 --- a/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh +++ b/mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh @@ -66,11 +66,16 @@ #ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH #define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH +#include +#ifdef MMCV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // MMCV_WITH_TRT #ifdef MMCV_USE_PARROTS #include "parrots_cuda_helper.hpp" -#else +#else // MMCV_USE_PARROTS #include "pytorch_cuda_helper.hpp" -#endif +#endif // MMCV_USE_PARROTS +#endif // MMCV_WITH_TRT template __device__ T dmcn_im2col_bilinear(const T *input, const int data_width, diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu index 5b85a4e567..8ddcca9703 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu @@ -1,3 +1,5 @@ +#include + #include "common_cuda_helper.hpp" #include "trt_cuda_helper.cuh" #include "trt_plugin_helper.hpp" @@ -64,3 +66,25 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, template void memcpyPermute(float *dst, const float *src, int *src_size, int *permute, int src_dim, cudaStream_t stream); + +template <> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, int m, int n, + int k, const float *alpha, const float *A, + int lda, const float *B, int ldb, + const float *beta, float *C, int ldc) { + return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, int m, int n, + int k, const half *alpha, const half *A, + int lda, const half *B, int ldb, + const half *beta, half *C, int ldc) { + return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, + beta, C, ldc); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp index 988e9bc46e..fa008e4190 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp @@ -32,9 +32,7 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic( mDilation(dilation), mDeformableGroup(deformableGroup), mGroup(group), - mIm2colStep(im2colStep) { - cublasCreate(&m_cublas_handle); -} + mIm2colStep(im2colStep) {} DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void *data, @@ -46,12 +44,8 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, deserialize_value(&data, &length, &mDeformableGroup); deserialize_value(&data, &length, &mGroup); deserialize_value(&data, &length, &mIm2colStep); - cublasCreate(&m_cublas_handle); -} -DeformableConvPluginDynamic::~DeformableConvPluginDynamic() { - // destroy cublas handle - cublasDestroy(m_cublas_handle); } +DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const { DeformableConvPluginDynamic *plugin = @@ -127,11 +121,6 @@ int DeformableConvPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workSpace, cudaStream_t stream) { - if (m_cuda_stream != stream) { - cublasSetStream(m_cublas_handle, stream); - m_cuda_stream = stream; - } - int batch_size = inputDesc[0].dims.d[0]; int inputChannel = inputDesc[0].dims.d[1]; int inputHeight = inputDesc[0].dims.d[2]; @@ -204,6 +193,14 @@ void DeformableConvPluginDynamic::destroy() { delete this; } +void DeformableConvPluginDynamic::attachToContext( + cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) { + m_cublas_handle = cublasContext; +} + +void DeformableConvPluginDynamic::detachFromContext() {} + void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) { mNamespace = libNamespace; } diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu index 36a63dea9d..b5eefa6e71 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu @@ -1,4 +1,3 @@ -#include #include #include "common_cuda_helper.hpp" @@ -32,38 +31,6 @@ void trt_deformable_im2col(const T* data_input, const T* data_offset, cudaCheckError(); } -// used to switch gemm between fp32 and fp16 -template -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const scalar_t* alpha, const scalar_t* A, int lda, - const scalar_t* B, int ldb, const scalar_t* beta, - scalar_t* C, int ldc) { - return CUBLAS_STATUS_INTERNAL_ERROR; -} - -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, int m, int n, - int k, const float* alpha, const float* A, - int lda, const float* B, int ldb, - const float* beta, float* C, int ldc) { - cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, - ldc); -} - -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, int m, int n, - int k, const half* alpha, const half* A, - int lda, const half* B, int ldb, - const half* beta, half* C, int ldc) { - cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, - ldc); -} - template void DeformConvForwardCUDAKernelLauncher( const scalar_t* input, const scalar_t* weight, const scalar_t* offset, diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp new file mode 100644 index 0000000000..dc5f960524 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp @@ -0,0 +1,307 @@ +#include "trt_modulated_deform_conv.hpp" + +#include + +#include + +#include "trt_serialize.hpp" + +void ModulatedDeformConvForwardCUDAKernelLauncher_float( + const float *input, const float *weight, const float *bias, + const float *offset, const float *mask, float *output, void *workspace, + int batch, int channels, int height, int width, int channels_out, + int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, + int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); + +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"}; +} // namespace + +nvinfer1::PluginFieldCollection + ModulatedDeformableConvPluginDynamicCreator::mFC{}; +std::vector + ModulatedDeformableConvPluginDynamicCreator::mPluginAttributes; + +ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic( + const std::string &name, const nvinfer1::Dims stride, + const nvinfer1::Dims padding, const nvinfer1::Dims dilation, + const int deformableGroup, const int group) + : mLayerName(name), + mStride(stride), + mPadding(padding), + mDilation(dilation), + mDeformableGroup(deformableGroup), + mGroup(group) { + mWithBias = false; +} + +ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic( + const std::string name, const void *data, size_t length) + : mLayerName(name) { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + mWithBias = false; +} +ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone() + const { + ModulatedDeformableConvPluginDynamic *plugin = + new ModulatedDeformableConvPluginDynamic( + mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[2].d[0]; + + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + + return ret; +} + +bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) { + if (pos == 0) { + return (inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + + } else { + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + } +} + +void ModulatedDeformableConvPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) { + if (nbInputs == 5) { + mWithBias = true; + } +} + +size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { + int sizeof_dtype = mmcv::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[3].dims.d[2]; + int kH = inputs[3].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = mmcv::getAlignedSize(nInputPlane * kW * kH * outputHeight * + outputWidth * sizeof_dtype); + + return col_size; +} + +int ModulatedDeformableConvPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workSpace, cudaStream_t stream) { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[3].dims.d[2]; + int kernel_w = inputDesc[3].dims.d[3]; + + const void *x = inputs[0]; + const void *offset = inputs[1]; + const void *mask = inputs[2]; + const void *weight = inputs[3]; + const void *bias = mWithBias ? inputs[4] : nullptr; + void *output = outputs[0]; + int im2col_step = std::min(batch, 32); + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + ModulatedDeformConvForwardCUDAKernelLauncher_float( + (float *)x, (float *)weight, (float *)bias, (float *)offset, + (float *)mask, (float *)output, workSpace, batch, channels, height, + width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1], + mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, + mDeformableGroup, im2col_step, m_cublas_handle, stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *ModulatedDeformableConvPluginDynamic::getPluginType() const { + return PLUGIN_NAME; +} + +const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +int ModulatedDeformableConvPluginDynamic::getNbOutputs() const { return 1; } + +int ModulatedDeformableConvPluginDynamic::initialize() { return 0; } + +void ModulatedDeformableConvPluginDynamic::terminate() {} + +size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const { + return sizeof(mStride) + sizeof(mPadding) + sizeof(mDilation) + + sizeof(mDeformableGroup) + sizeof(mGroup); +} + +void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); +} + +void ModulatedDeformableConvPluginDynamic::destroy() { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void ModulatedDeformableConvPluginDynamic::attachToContext( + cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) { + m_cublas_handle = cublasContext; +} + +void ModulatedDeformableConvPluginDynamic::detachFromContext() {} + +void ModulatedDeformableConvPluginDynamic::setPluginNamespace( + const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *ModulatedDeformableConvPluginDynamic::getPluginNamespace() const { + return mNamespace.c_str(); +} + +////////////////////// creator ///////////////////////////// + +ModulatedDeformableConvPluginDynamicCreator:: + ModulatedDeformableConvPluginDynamicCreator() { + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const { + return PLUGIN_NAME; +} + +const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion() + const { + return PLUGIN_VERSION; +} + +const nvinfer1::PluginFieldCollection * +ModulatedDeformableConvPluginDynamicCreator::getFieldNames() { + return &mFC; +} + +nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deformable_group") == 0) { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("group") == 0) { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + ModulatedDeformableConvPluginDynamic *plugin = + new ModulatedDeformableConvPluginDynamic(name, stride, padding, dilation, + deformableGroup, group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 * +ModulatedDeformableConvPluginDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) { + auto plugin = + new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +void ModulatedDeformableConvPluginDynamicCreator::setPluginNamespace( + const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *ModulatedDeformableConvPluginDynamicCreator::getPluginNamespace() + const { + return mNamespace.c_str(); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu new file mode 100644 index 0000000000..258ae783f6 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu @@ -0,0 +1,133 @@ +#include +#include + +#include "common_cuda_helper.hpp" +#include "modulated_deform_conv_cuda_kernel.cuh" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +template +void trt_modulated_deformable_im2col( + const T* data_im_, const T* data_offset_, const T* data_mask_, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, T* data_col_, + cudaStream_t stream) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, + kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h, + dilation_w, channel_per_deformable_group, batch_size, channels, + deformable_group, height_col, width_col, data_col_); + + cudaCheckError(); +} + +template +__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias, + size_t step_batch, size_t step_channel, + size_t n) { + CUDA_1D_KERNEL_LOOP(index, n) { + output[index] += bias[(index % step_batch) / step_channel]; + } +} + +template +static void output_add_bias(scalar_t* output, const scalar_t* bias, + size_t batch, size_t channel, size_t height, + size_t width, cudaStream_t stream) { + size_t step_channel = height * width; + size_t step_batch = step_channel * channel; + size_t n = step_batch * batch; + output_add_bias_kernel<<>>( + output, bias, step_batch, step_channel, n); +} + +template +void ModulatedDeformConvForwardCUDAKernelLauncher( + const scalar_t* input, const scalar_t* weight, const scalar_t* bias, + const scalar_t* offset, const scalar_t* mask, scalar_t* output, + void* workspace, int batch, int channels, int height, int width, + int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, + int pad_w, int pad_h, int dilation_w, int dilation_h, int group, + int deformable_group, int im2col_step, cublasHandle_t cublas_handle, + cudaStream_t stream) { + size_t sizeof_dtype = sizeof(scalar_t); + bool with_bias = (bias != nullptr); + + im2col_step = std::min(int(batch), im2col_step); + assert(batch % im2col_step == 0); + const int channels_kernel = channels / group; + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + scalar_t* columns = (scalar_t*)workspace; + + const size_t input_step = channels * height * width; + const size_t offset_step = + deformable_group * kernel_h * kernel_w * 2 * height * width; + const size_t mask_step = + deformable_group * kernel_h * kernel_w * height * width; + const size_t out_step = channels_out * height_out * width_out; + const size_t out_group_step = out_step / group; + const size_t col_g_step = + channels * kernel_w * kernel_h / group * height_out * width_out; + const size_t weight_g_step = + channels_out / group * channels / group * kernel_h * kernel_w; + + const int m = channels_out / group; + const int n = height_out * width_out; + const int k = channels / group * kernel_h * kernel_w; + scalar_t alpha = 1.; + scalar_t beta = 0.; + + for (int b = 0; b < batch; b++) { + const scalar_t* input_start = input + b * input_step; + const scalar_t* offset_start = offset + b * offset_step; + const scalar_t* mask_start = mask + b * mask_step; + trt_modulated_deformable_im2col( + input_start, offset_start, mask_start, 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, columns, stream); + + for (int g = 0; g < group; g++) { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output + b * out_step + g * out_group_step; + + // cudaMemsetAsync(out_buffer_start, 0, 1, stream); + cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, + &alpha, col_start, n, weight_start, k, &beta, + out_buffer_start, n); + cudaCheckError(); + } + } + + if (with_bias) { + output_add_bias(output, bias, batch, channels_out, height_out, + width_out, stream); + } +} + +void ModulatedDeformConvForwardCUDAKernelLauncher_float( + const float* input, const float* weight, const float* bias, + const float* offset, const float* mask, float* output, void* workspace, + int batch, int channels, int height, int width, int channels_out, + int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, + int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, + int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) { + ModulatedDeformConvForwardCUDAKernelLauncher( + input, weight, bias, offset, mask, output, workspace, batch, channels, + height, width, channels_out, kernel_w, kernel_h, stride_w, stride_h, + pad_w, pad_h, dilation_w, dilation_h, group, deformable_group, + im2col_step, cublas_handle, stream); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index 81f724f162..c7b946b5dd 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -4,6 +4,7 @@ #include "trt_deform_conv.hpp" #include "trt_grid_sampler.hpp" #include "trt_instance_norm.hpp" +#include "trt_modulated_deform_conv.hpp" #include "trt_nms.hpp" #include "trt_roi_align.hpp" #include "trt_scatternd.hpp" @@ -12,6 +13,7 @@ REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator); REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); +REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); diff --git a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh index a4635dcdd5..db42dae9e1 100644 --- a/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh +++ b/mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh @@ -1,5 +1,6 @@ #ifndef TRT_CUDA_HELPER_HPP #define TRT_CUDA_HELPER_HPP +#include #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) @@ -24,7 +25,16 @@ * @param[in] stream cuda stream handle */ template -void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, - int *permute, int src_dim, cudaStream_t stream = 0); +void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, + int* permute, int src_dim, cudaStream_t stream = 0); + +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const scalar_t* alpha, const scalar_t* A, int lda, + const scalar_t* B, int ldb, const scalar_t* beta, + scalar_t* C, int ldc) { + return CUBLAS_STATUS_INTERNAL_ERROR; +} #endif // TRT_CUDA_HELPER_HPP diff --git a/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp index b8762f7868..fc48ac5dd9 100644 --- a/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp +++ b/mmcv/ops/csrc/tensorrt/trt_deform_conv.hpp @@ -44,6 +44,9 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt { const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) override; + void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) override; + void detachFromContext() override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType(int index, @@ -74,7 +77,6 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt { int mIm2colStep; cublasHandle_t m_cublas_handle; - cudaStream_t m_cuda_stream; protected: // To prevent compiler warnings. diff --git a/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp b/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp new file mode 100644 index 0000000000..0907e7ea85 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/trt_modulated_deform_conv.hpp @@ -0,0 +1,120 @@ +#ifndef TRT_MODULATED_DEFORM_CONV_HPP +#define TRT_MODULATED_DEFORM_CONV_HPP +#include + +#include +#include +#include + +#include "trt_plugin_helper.hpp" + +class ModulatedDeformableConvPluginDynamic + : public nvinfer1::IPluginV2DynamicExt { + public: + ModulatedDeformableConvPluginDynamic(const std::string &name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group); + + ModulatedDeformableConvPluginDynamic(const std::string name, const void *data, + size_t length); + + ModulatedDeformableConvPluginDynamic() = delete; + + ~ModulatedDeformableConvPluginDynamic(); + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) override; + void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) override; + void detachFromContext() override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const override; + + // IPluginV2 Methods + const char *getPluginType() const override; + const char *getPluginVersion() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void *buffer) const override; + void destroy() override; + void setPluginNamespace(const char *pluginNamespace) override; + const char *getPluginNamespace() const override; + + private: + const std::string mLayerName; + std::string mNamespace; + + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + bool mWithBias; + + cublasHandle_t m_cublas_handle; + + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; +}; + +class ModulatedDeformableConvPluginDynamicCreator + : public nvinfer1::IPluginCreator { + public: + ModulatedDeformableConvPluginDynamicCreator(); + + const char *getPluginName() const override; + + const char *getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection *getFieldNames() override; + + nvinfer1::IPluginV2 *createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) override; + + void setPluginNamespace(const char *pluginNamespace) override; + + const char *getPluginNamespace() const override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; +#endif // TRT_MODULATED_DEFORM_CONV_HPP diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index b3dfd0b003..d26f61a0a1 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -20,13 +20,12 @@ class ModulatedDeformConv2dFunction(Function): @staticmethod def symbolic(g, input, offset, mask, weight, bias, stride, padding, dilation, groups, deform_groups): + input_tensors = [input, offset, mask, weight] + if bias is not None: + input_tensors.append(bias) return g.op( - 'MMCVModulatedDeformConv2d', - input, - offset, - mask, - weight, - bias, + 'mmcv::MMCVModulatedDeformConv2d', + *input_tensors, stride_i=stride, padding_i=padding, dilation_i=dilation, diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index 362a403430..d65308ba8a 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -406,6 +406,77 @@ def test_deform_conv(): assert torch.allclose(pytorch_results, trt_results) +@pytest.mark.parametrize('with_bias', [True, False]) +def test_modulated_deform_conv(with_bias): + try: + from mmcv.ops import ModulatedDeformConv2dPack + except (ImportError, ModuleNotFoundError): + pytest.skip('test requires compilation') + + input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] + + x = torch.Tensor(input).cuda() + model = ModulatedDeformConv2dPack( + 1, + 1, + kernel_size=(2, 2), + stride=1, + padding=1, + deform_groups=1, + bias=with_bias) + model.weight.data.fill_(1.) + model.type(torch.float32) + model = model.cuda().eval() + + input_names = ['input'] + output_names = ['output'] + + with torch.no_grad(): + torch.onnx.export( + model, (x.clone(), ), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=input_names, + output_names=output_names, + opset_version=11) + + onnx_model = onnx.load(onnx_file) + + # create trt engine and wraper + opt_shape_dict = { + 'input': [list(x.shape), list(x.shape), + list(x.shape)], + } + # trt config + fp16_mode = False + max_workspace_size = 1 << 30 + + trt_engine = onnx2trt( + onnx_model, + opt_shape_dict, + fp16_mode=fp16_mode, + max_workspace_size=max_workspace_size) + + save_trt_engine(trt_engine, trt_file) + trt_model = TRTWrapper(trt_file, input_names, output_names) + + with torch.no_grad(): + trt_outputs = trt_model({'input': x.clone()}) + trt_results = trt_outputs['output'] + + # compute pytorch_output + with torch.no_grad(): + pytorch_results = model(x.clone()) + + # allclose + if os.path.exists(onnx_file): + os.remove(onnx_file) + if os.path.exists(trt_file): + os.remove(trt_file) + torch.testing.assert_allclose(pytorch_results, trt_results) + + @pytest.mark.parametrize('mode', ['bilinear', 'nearest']) @pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection']) @pytest.mark.parametrize('align_corners', [True, False])