Skip to content

Commit

Permalink
[Feature]: add modulated deformable conv TensorRT support (#1078)
Browse files Browse the repository at this point in the history
* add modulated dcn, better dcn plugin

* clangformat

* update documentation
  • Loading branch information
q.yao authored Jun 16, 2021
1 parent 1b59409 commit 004c006
Show file tree
Hide file tree
Showing 14 changed files with 746 additions and 60 deletions.
48 changes: 48 additions & 0 deletions docs/tensorrt_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
- [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d)
- [Description](#description-8)
- [Parameters](#parameters-8)
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)

<!-- TOC -->

Expand Down Expand Up @@ -345,3 +351,45 @@ y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance a
### Type Constraints

- T:tensor(float32, Linear)

## MMCVModulatedDeformConv2d

### Description

Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail.

### Parameters

| Type | Parameter | Description |
| -------------- | ------------------ | ------------------------------------------------------------------------------------- |
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
| `int` | `deformable_group` | Groups of deformable offset. |
| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |

### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[3]</tt>: T</dt>
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
<dt><tt>inputs[4]</tt>: T, optional</dt>
<dd>Input weight; 1-D tensor of shape (output_channel).</dd>
</dl>

### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
</dl>

### Type Constraints

- T:tensor(float32, Linear)
7 changes: 4 additions & 3 deletions docs/tensorrt_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | 1.3.5 |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | 1.3.5 |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | 1.3.5 |
| MMCVModulatedDeformConv2d | [MMCVModulatedDeformConv2d](./tensorrt_custom_ops.md#mmcvmodulateddeformconv2d) | master |

Notes

Expand Down
9 changes: 7 additions & 2 deletions mmcv/ops/csrc/modulated_deform_conv_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@
#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH

#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif
#endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT

template <typename T>
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
Expand Down
24 changes: 24 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <cublas_v2.h>

#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"
Expand Down Expand Up @@ -64,3 +66,25 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
int *permute, int src_dim,
cudaStream_t stream);

template <>
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const float *alpha, const float *A,
int lda, const float *B, int ldb,
const float *beta, float *C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
beta, C, ldc);
}

template <>
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const half *alpha, const half *A,
int lda, const half *B, int ldb,
const half *beta, half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
beta, C, ldc);
}
23 changes: 10 additions & 13 deletions mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(
mDilation(dilation),
mDeformableGroup(deformableGroup),
mGroup(group),
mIm2colStep(im2colStep) {
cublasCreate(&m_cublas_handle);
}
mIm2colStep(im2colStep) {}

DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
const void *data,
Expand All @@ -46,12 +44,8 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
deserialize_value(&data, &length, &mDeformableGroup);
deserialize_value(&data, &length, &mGroup);
deserialize_value(&data, &length, &mIm2colStep);
cublasCreate(&m_cublas_handle);
}
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {
// destroy cublas handle
cublasDestroy(m_cublas_handle);
}
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const {
DeformableConvPluginDynamic *plugin =
Expand Down Expand Up @@ -127,11 +121,6 @@ int DeformableConvPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
if (m_cuda_stream != stream) {
cublasSetStream(m_cublas_handle, stream);
m_cuda_stream = stream;
}

int batch_size = inputDesc[0].dims.d[0];
int inputChannel = inputDesc[0].dims.d[1];
int inputHeight = inputDesc[0].dims.d[2];
Expand Down Expand Up @@ -204,6 +193,14 @@ void DeformableConvPluginDynamic::destroy() {
delete this;
}

void DeformableConvPluginDynamic::attachToContext(
cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) {
m_cublas_handle = cublasContext;
}

void DeformableConvPluginDynamic::detachFromContext() {}

void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}
Expand Down
33 changes: 0 additions & 33 deletions mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <cublas_v2.h>
#include <cuda_fp16.h>

#include "common_cuda_helper.hpp"
Expand Down Expand Up @@ -32,38 +31,6 @@ void trt_deformable_im2col(const T* data_input, const T* data_offset,
cudaCheckError();
}

// used to switch gemm between fp32 and fp16
template <typename scalar_t>
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const scalar_t* alpha, const scalar_t* A, int lda,
const scalar_t* B, int ldb, const scalar_t* beta,
scalar_t* C, int ldc) {
return CUBLAS_STATUS_INTERNAL_ERROR;
}

template <>
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const float* alpha, const float* A,
int lda, const float* B, int ldb,
const float* beta, float* C, int ldc) {
cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
}

template <>
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const half* alpha, const half* A,
int lda, const half* B, int ldb,
const half* beta, half* C, int ldc) {
cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
}

template <typename scalar_t>
void DeformConvForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* weight, const scalar_t* offset,
Expand Down
Loading

0 comments on commit 004c006

Please sign in to comment.