Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nearest_interp forward #38528

Merged
merged 38 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
c7b68c8
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 26, 2021
0fd630e
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Aug 16, 2021
4bbb33b
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Sep 28, 2021
30a1a89
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Nov 22, 2021
ce3deec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 21, 2021
8c3620b
init commit
AshburnLee Dec 28, 2021
5719490
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 28, 2021
cb7cc51
remove comments
AshburnLee Dec 28, 2021
26a2aa8
Merge branches 'develop' and 'develop' of https://github.com/PaddlePa…
AshburnLee Dec 28, 2021
76bed9b
remove nchw branch
AshburnLee Dec 28, 2021
b7fd119
optimize code
AshburnLee Dec 28, 2021
3899e0b
apply fast div mod in 1D kernel, rm 3D kernel
AshburnLee Jan 5, 2022
88ff573
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 5, 2022
5844cc5
move init of FastDivMode to CPU
AshburnLee Jan 7, 2022
42c4038
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 7, 2022
cee38bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
35086d6
3D kernel for nchw, FastDiv for 1D kernel
AshburnLee Jan 12, 2022
5a79f12
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
5e08a97
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
0fd2b3f
debug done. process boundary
AshburnLee Jan 18, 2022
214b4a0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 18, 2022
86dd9e1
2^n
AshburnLee Jan 18, 2022
4a07e34
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 18, 2022
b2b85dd
optimize
AshburnLee Jan 19, 2022
45716d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 19, 2022
a39400c
optimize
AshburnLee Jan 20, 2022
a0d9431
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 20, 2022
efa4297
change code & optimize code
AshburnLee Jan 21, 2022
14f6927
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,35 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"

namespace paddle {
namespace operators {

using framework::Tensor;
using DataLayout = framework::DataLayout;

struct FastDivModForInterpolate {
public:
platform::FastDivMod channels_;
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
platform::FastDivMod output_w_;
platform::FastDivMod outimg_w_;
platform::FastDivMod out_size_;
platform::FastDivMod outimgw_mul_chann_;

explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
const int outimg_w,
const int out_size,
const int outimgw_mul_chann) {
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
channels_ = platform::FastDivMod(channels);
output_w_ = platform::FastDivMod(output_w);
outimg_w_ = platform::FastDivMod(outimg_w);
out_size_ = platform::FastDivMod(out_size);
outimgw_mul_chann_ = platform::FastDivMod(outimgw_mul_chann);
}
};

template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
Expand All @@ -33,29 +55,37 @@ __global__ void KeNearestNeighborInterpFw(
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;

for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int outimgw_mul_channels = out_img_w * num_channels;
FastDivModForInterpolate divmods(num_channels, output_w, out_img_w,
out_img_size, outimgw_mul_channels);
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved

auto out_id_divmod = divmods.output_w_.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];

int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
auto channel_divmod = divmods.out_size_.Divmod(out_id_w);
channel_id = channel_divmod.val[0];
out_img_idy = divmods.outimg_w_.Divmod(channel_divmod.val[1]).val[0];
out_img_idx = divmods.outimg_w_.Divmod(tid).val[1];
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
channel_id = divmods.channels_.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.outimgw_mul_chann_.Divmod(out_id_w);
out_img_idy = outimg_id_divmod.val[0];
out_img_idx = divmods.channels_.Divmod(outimg_id_divmod.val[1]).val[0];
}
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved

int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idx = ratio_w * out_img_idx;
int in_img_idy = ratio_h * out_img_idy;
if (align_corners) {
in_img_idx += 0.5;
in_img_idy += 0.5;
}

if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/platform/device/gpu/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
return config;
}

// TODO(wangchaochaohu): 3D will add later

} // namespace platform
} // namespace paddle

Expand Down