Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.40】为 Paddle 优化 kthvalue op 在 GPU 上的计算性能 #51835

Merged
merged 22 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c8ae296
untracked files
thunder95 Feb 20, 2023
6aa02f0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 23, 2023
d599110
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 23, 2023
264894d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 25, 2023
98d1e1c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 25, 2023
b958122
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Feb 25, 2023
760e099
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 26, 2023
e16076d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Feb 26, 2023
085c7a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 2, 2023
b1edf68
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Mar 2, 2023
f2887e5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 2, 2023
6a62308
Merge branch 'develop' of https://github.com/thunder95/Paddle into de…
thunder95 Mar 2, 2023
6620e88
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 8, 2023
0cb7cc4
kthvalue perf
thunder95 Mar 19, 2023
f44b58d
fix conflict
thunder95 Mar 19, 2023
ec62479
remove unused files
thunder95 Mar 19, 2023
4dcf15e
fix isnan
thunder95 Mar 20, 2023
a8ffbf3
fix isnan2
thunder95 Mar 21, 2023
bde6c3a
fix bug
thunder95 Mar 21, 2023
0475b44
try to fix rocm error
thunder95 Mar 22, 2023
9736b0b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Mar 22, 2023
7c57924
Merge branch 'kthvalue_perf' of https://github.com/thunder95/Paddle i…
thunder95 Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions paddle/phi/kernels/funcs/top_k_function_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ limitations under the License. */
#include "paddle/phi/kernels/primitive/functor_primitives.h"

#define FINAL_MASK 0xffffffff
#define WARP_SIZE 32
#define MAX_NUM_THREADS 1024

inline static size_t divide_round_up(size_t n, size_t q) {
return n % q == 0 ? n / q : n / q + 1;
}

inline static size_t round_up(size_t n, size_t q) {
return divide_round_up(n, q) * q;
}

#ifdef __HIPCC__
namespace rocprim {
namespace detail {
Expand Down Expand Up @@ -808,6 +819,61 @@ __device__ void RadixSearch(
*kth_value = RadixTypeConfig<T>::Deconvert(desired);
}

template <typename T>
__global__ void GatherKthValue(const T* input,
const int k,
const int64_t num_rows,
const int64_t num_cols,
T* output,
int64_t* indices) {
__shared__ int shared_mem[32];
int row =
blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x;
const T* cur_input = input + row * num_cols;

// 1. Find the k-th value
T kth_value = static_cast<T>(0);
RadixSearch<T, RadixTypeConfig<T>::RadixType, false>(
cur_input, k, num_cols, shared_mem, &kth_value);
const auto converted_kth_value = RadixTypeConfig<T>::Convert(kth_value);

// 2. find the k-th index
int64_t kth_index = 0;
bool foundKValue = false;
for (int64_t i = threadIdx.x; i < num_cols; i += blockDim.x) {
bool inRange = (i < num_cols);
T v = inRange ? cur_input[i] : static_cast<T>(0);
bool isKValue =
inRange && ((v == kth_value) || (isnan(static_cast<float>(v)) &&
isnan(static_cast<float>(kth_value))));
if (isKValue) {
kth_index = i;
foundKValue = true;
break;
}
}

if (foundKValue) {
output[row] = kth_value;
indices[row] = kth_index;
}
}

template <typename T>
void LaunchGatherKthValue(const phi::GPUContext& dev_ctx,
const T* input_data,
const int64_t num_cols,
const int64_t num_rows,
const int k,
T* out_data,
int64_t* indices_data) {
int num_threads = std::min(
static_cast<int>(round_up(static_cast<int>(num_cols), WARP_SIZE)),
MAX_NUM_THREADS);
GatherKthValue<T><<<num_rows, num_threads, 0, dev_ctx.stream()>>>(
input_data, k, num_rows, num_cols, out_data, indices_data);
}

template <typename T, bool Largest>
__global__ void RadixTopK(const T* input,
int k,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
34 changes: 29 additions & 5 deletions paddle/phi/kernels/gpu/kthvalue_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ void KthvalueKernel(const Context& dev_ctx,
const auto& in_dims = x.dims();
if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims();
const T* input_data = x.data<T>();
T* output_data = dev_ctx.template Alloc<T>(output);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);

Expand All @@ -180,15 +179,28 @@ void KthvalueKernel(const Context& dev_ctx,
phi::funcs::set_constant(dev_ctx, indices, 0);
return;
}

if (axis == in_dims.size() - 1) {
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
const T* input_data = x.data<T>();
funcs::LaunchGatherKthValue<T>(dev_ctx,
input_data,
input_width,
input_height,
k,
output_data,
indices_data);
#else
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(
dev_ctx, &x, input_width, input_height, k, output, indices),
true,
phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif

return;
} else {
std::vector<int> trans;
Expand Down Expand Up @@ -222,18 +234,28 @@ void KthvalueKernel(const Context& dev_ctx,
trans_out_dims[in_dims.size() - 1] = 1;
DenseTensor trans_input;
trans_input.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_input);
T* tran_input_data = dev_ctx.template Alloc<T>(&trans_input);
int ndims = trans.size();
funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, x, &trans_input, trans);
DenseTensor trans_ind, trans_out;
trans_ind.Resize(trans_out_dims);
trans_out.Resize(trans_out_dims);
dev_ctx.template Alloc<int64_t>(&trans_ind);
dev_ctx.template Alloc<T>(&trans_out);
int64_t* tran_indices_data = dev_ctx.template Alloc<int64_t>(&trans_ind);
T* tran_output_data = dev_ctx.template Alloc<T>(&trans_out);
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
funcs::LaunchGatherKthValue<T>(dev_ctx,
tran_input_data,
input_width,
input_height,
k,
tran_output_data,
tran_indices_data);
#else
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(dev_ctx,
&trans_input,
Expand All @@ -244,6 +266,7 @@ void KthvalueKernel(const Context& dev_ctx,
&trans_ind),
true,
phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif
funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
funcs::TransCompute<phi::GPUContext, T>(
Expand All @@ -263,6 +286,7 @@ PD_REGISTER_KERNEL(kthvalue,
float,
double,
int,
int64_t) {
int64_t,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}