Skip to content

Commit

Permalink
[Backend] Add TensorRT FP16 support for AdaptivePool2d (#1116)
Browse files Browse the repository at this point in the history
* add fp16 cuda kernel

* fix code bug

* update code
  • Loading branch information
yeliang2258 authored Jan 13, 2023
1 parent d00df3d commit 829fe07
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
54 changes: 40 additions & 14 deletions fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "adaptive_pool2d_kernel.h"

namespace fastdeploy {

__global__ void CudaCastKernel(const float* in, float* out, int edge,
template <typename T1, typename T2>
__global__ void CudaCastKernel(const T1* in, T2* out, int edge,
int out_bc_offset, int in_bc_offset, int ih,
int iw, int oh, int ow, bool is_avg) {
int position = blockDim.x * blockIdx.x + threadIdx.x;
Expand All @@ -32,29 +32,37 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge,
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
int wstart = floorf(static_cast<float>(w * iw) / ow);
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
float ele_val = 0.0;
if (is_avg) {
out[position] = 0.0;
ele_val = 0.0;
} else {
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
ele_val =
static_cast<float>(in[offset * in_bc_offset + hstart * iw + wstart]);
}
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = h * iw + w;
if (is_avg) {
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
ele_val =
ele_val + static_cast<float>(in[offset * in_bc_offset + input_idx]);
} else {
out[position] =
max(out[position], in[offset * in_bc_offset + input_idx]);
ele_val =
(ele_val >
static_cast<float>(in[offset * in_bc_offset + input_idx]))
? ele_val
: static_cast<float>(in[offset * in_bc_offset + input_idx]);
}
}
}
out[position] = out[position] / ((hend - hstart) * (wend - wstart));
out[position] = static_cast<T2>(
ele_val / static_cast<float>(((hend - hstart) * (wend - wstart))));
}

void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& output_dims, float* output,
const float* input, void* compute_stream,
const std::string& pooling_type) {
const std::vector<int64_t>& output_dims, void* output,
const void* input, void* compute_stream,
const std::string& pooling_type, const std::string& dtype,
const std::string& out_dtype) {
auto casted_compute_stream = reinterpret_cast<cudaStream_t>(compute_stream);
int out_bc_offset = output_dims[2] * output_dims[3];
int in_bc_offset = input_dims[2] * input_dims[3];
Expand All @@ -65,9 +73,27 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
bool is_avg = pooling_type == "avg";
int threads = 256;
int blocks = ceil(jobs / static_cast<float>(threads));
CudaCastKernel<<<blocks, threads, 0, casted_compute_stream>>>(
input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]),
int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg);
if (dtype == "float") {
CudaCastKernel<float, float><<<blocks, threads, 0, casted_compute_stream>>>(
static_cast<const float*>(input), static_cast<float*>(output), jobs,
out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]),
int(output_dims[2]), int(output_dims[3]), is_avg);
} else if (dtype == "half") {
if (out_dtype == "half") {
CudaCastKernel<half, half><<<blocks, threads, 0, casted_compute_stream>>>(
static_cast<const half*>(input), static_cast<half*>(output), jobs,
out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]),
int(output_dims[2]), int(output_dims[3]), is_avg);
}
if (out_dtype == "float") {
CudaCastKernel<half, float>
<<<blocks, threads, 0, casted_compute_stream>>>(
static_cast<const half*>(input), static_cast<float*>(output),
jobs, out_bc_offset, in_bc_offset, int(input_dims[2]),
int(input_dims[3]), int(output_dims[2]), int(output_dims[3]),
is_avg);
}
}
}
} // namespace fastdeploy
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#pragma once

#include <cuda_fp16.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
Expand All @@ -25,8 +26,10 @@
namespace fastdeploy {

void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& output_dims, float* output,
const float* input, void* compute_stream,
const std::string& pooling_type);
const std::vector<int64_t>& output_dims, void* output,
const void* input, void* compute_stream,
const std::string& pooling_type,
const std::string& dtype = "float",
const std::string& out_dtype = "float");

} // namespace fastdeploy
26 changes: 18 additions & 8 deletions fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,25 @@ int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) noexcept {
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
return -1;
}
auto const* data = static_cast<float const*>(inputs[0]);
auto* result = static_cast<float*>(outputs[0]);
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] *
outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3];
std::vector<int64_t> input_size, output_size;
for (int i = 0; i < 4; i++) {
input_size.push_back(inputDesc[0].dims.d[i]);
output_size.push_back(outputDesc[0].dims.d[i]);
}
CudaAdaptivePool(input_size, output_size, result, data, stream,
pooling_type_);
if (inputDesc[0].type == nvinfer1::DataType::kHALF) {
if (outputDesc[0].type == nvinfer1::DataType::kHALF) {
CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream,
pooling_type_, "half", "half");
} else if (outputDesc[0].type == nvinfer1::DataType::kFLOAT) {
CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream,
pooling_type_, "half", "float");
}
} else if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) {
CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream,
pooling_type_, "float", "float");
}
return cudaPeekAtLastError();
}

Expand Down Expand Up @@ -106,7 +111,12 @@ nvinfer1::DataType AdaptivePool2d::getOutputDataType(
bool AdaptivePool2d::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) noexcept {
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
if ((inOut[pos].format == nvinfer1::PluginFormat::kLINEAR) &&
(inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kHALF)) {
return true;
}
return false;
}

int AdaptivePool2d::initialize() noexcept { return 0; }
Expand Down

0 comments on commit 829fe07

Please sign in to comment.