Skip to content

Commit

Permalink
【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#46111)

support fp16 for deformable conv
  • Loading branch information
Rayman96 authored Oct 10, 2022
1 parent a7e1b9d commit 5e0614a
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 137 deletions.
34 changes: 17 additions & 17 deletions paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ inline void ModulatedDeformableCol2imCPUKernel(
width);

*(grad_im + cur_bottom_grad_pos) =
*(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad;
*(grad_im + cur_bottom_grad_pos) + (weight * cur_top_grad);
}
}
}
}
}

template <typename T, typename Context>
template <typename T, typename MT, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
Expand All @@ -116,7 +116,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& stride,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im) {
MT* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];

Expand Down Expand Up @@ -222,22 +222,22 @@ void ModulatedDeformableCol2imCoordCPUKernel(
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
} else {
mval += data_col_ptr[col_pos] *
funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
mval += data_col_ptr[col_pos] * funcs::DmcnIm2colBilinear<T, T>(
data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
}
const T weight =
DmcnGetCoordinateWeight(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
DmcnGetCoordinateWeight<T, T>(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
if (data_mask_ptr) {
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/funcs/deformable_conv_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -82,8 +82,8 @@ inline void ModulatedDeformableIm2colCPUKernel(
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
val = DmcnIm2colBilinear<T, T>(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
Expand Down
41 changes: 30 additions & 11 deletions paddle/phi/kernels/funcs/deformable_conv_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/device_context.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -51,6 +54,8 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;

using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
Expand Down Expand Up @@ -85,22 +90,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;

const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]);
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]);
MT val = static_cast<MT>(0);
const MT h_im = h_in + i * dilation_h + offset_h;
const MT w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
val = DmcnIm2colBilinear<T, MT>(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
*data_col_ptr *= mask;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]);
val *= mask;
}
*data_col_ptr = static_cast<T>(val);
data_col_ptr += batch_size * height_col * width_col;
}
}
Expand Down Expand Up @@ -164,6 +169,20 @@ template void ModulatedDeformableIm2col(
const int deformable_groups,
float* data_col);

template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const phi::dtype::float16* data_im,
const phi::dtype::float16* data_offset,
const phi::dtype::float16* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
phi::dtype::float16* data_col);

template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const double* data_im,
Expand Down
55 changes: 29 additions & 26 deletions paddle/phi/kernels/funcs/deformable_conv_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,47 @@

#pragma once

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {

template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
T h,
T w) {
template <typename T, typename MT>
HOSTDEVICE MT DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
MT h,
MT w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;

T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
MT lh = h - h_low;
MT lw = w - w_low;
MT hh = 1 - lh;
MT hw = 1 - lw;

T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
MT v1 = (h_low >= 0 && w_low >= 0)
? static_cast<MT>(bottom_data[h_low * data_width + w_low])
: 0;
MT v2 = (h_low >= 0 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_low * data_width + w_high])
: 0;
MT v3 = (h_high <= height - 1 && w_low >= 0)
? static_cast<MT>(bottom_data[h_high * data_width + w_low])
: 0;
MT v4 = (h_high <= height - 1 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_high * data_width + w_high])
: 0;

T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
MT w1 = hh * hw;
MT w2 = hh * lw;
MT w3 = lh * hw;
MT w4 = lh * lw;

return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
Expand Down
Loading

0 comments on commit 5e0614a

Please sign in to comment.