Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support 5d for nearest interp #38868

Merged
merged 4 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions paddle/fluid/operators/interpolate_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

PADDLE_ENFORCE_EQ(
"trilinear", interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" or "
"\"nearest\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));

Expand Down
126 changes: 126 additions & 0 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,61 @@ __global__ void KeNearestNeighborInterpFw(
}
}

template <typename T>
__global__ void KeNearestNeighbor3DInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w; // ncdhw
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;

int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}

int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);

int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);

if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
out[tid] = in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
}

template <typename T>
__global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
Expand Down Expand Up @@ -114,6 +169,63 @@ __global__ void KeNearestNeighborInterpBw(
}
}

template <typename T>
__global__ void KeNearestNeighbor3DInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;

int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}

int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);

T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
in_pos = &in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}

template <typename T>
__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w,
const size_t input_w, T* out,
Expand Down Expand Up @@ -1376,6 +1488,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
}
}

Expand Down Expand Up @@ -1801,6 +1920,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
}
}

Expand Down
77 changes: 77 additions & 0 deletions paddle/fluid/operators/interpolate_v2_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,39 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
}
}

template <typename T>
static void NearestNeighbor3DInterpolate(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int n, const int c,
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
for (int d = 0; d < out_d; d++) { // loop for images
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) {
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);

for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);

for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l);
} else { // NDHWC
output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j);
}
}
}
}
}
}
}

template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_w, const int in_w,
Expand Down Expand Up @@ -584,6 +617,42 @@ static void NearestNeighborInterpolateGrad(
}
}

template <typename T>
static void NearestNeighbor3DInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int n, const int c,
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);

for (int d = 0; d < out_d; d++) {
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);

for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);

for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_d, in_k, in_l) +=
output_grad_t(i, j, d, k, l);
} else {
input_grad_t(i, in_d, in_k, in_l, j) +=
output_grad_t(i, d, k, l, j);
}
}
}
}
}
}
}

template <typename T>
static void BilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
Expand Down Expand Up @@ -1137,6 +1206,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolate<T>(input, output, ratio_d, ratio_h, ratio_w, n,
c, out_d, out_h, out_w, align_corners,
data_layout);
}
}

Expand Down Expand Up @@ -1489,6 +1562,10 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
TrilinearInterpolationGrad<T>(
output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolateGrad<T>(output_grad, input_grad, ratio_d,
ratio_h, ratio_w, n, c, out_d, out_h,
out_w, align_corners, data_layout);
}
}

Expand Down
Loading