Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 4 No.40】为 Paddle 优化 kthvalue op 在 GPU 上的计算性能 (#…
Browse files Browse the repository at this point in the history
…51835)

* untracked files

* kthvalue perf

* remove unused files

* fix isnan

* fix isnan2

* fix bug

* try to fix rocm error
  • Loading branch information
thunder95 authored Mar 24, 2023
1 parent 7415b10 commit e18f533
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 6 deletions.
66 changes: 66 additions & 0 deletions paddle/phi/kernels/funcs/top_k_function_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ limitations under the License. */
#include "paddle/phi/kernels/primitive/functor_primitives.h"

#define FINAL_MASK 0xffffffff
#define WARP_SIZE 32
#define MAX_NUM_THREADS 1024

inline static size_t divide_round_up(size_t n, size_t q) {
return n % q == 0 ? n / q : n / q + 1;
}

inline static size_t round_up(size_t n, size_t q) {
return divide_round_up(n, q) * q;
}

#ifdef __HIPCC__
namespace rocprim {
namespace detail {
Expand Down Expand Up @@ -808,6 +819,61 @@ __device__ void RadixSearch(
*kth_value = RadixTypeConfig<T>::Deconvert(desired);
}

template <typename T>
__global__ void GatherKthValue(const T* input,
const int k,
const int64_t num_rows,
const int64_t num_cols,
T* output,
int64_t* indices) {
__shared__ int shared_mem[32];
int row =
blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x;
const T* cur_input = input + row * num_cols;

// 1. Find the k-th value
T kth_value = static_cast<T>(0);
RadixSearch<T, RadixTypeConfig<T>::RadixType, false>(
cur_input, k, num_cols, shared_mem, &kth_value);
const auto converted_kth_value = RadixTypeConfig<T>::Convert(kth_value);

// 2. find the k-th index
int64_t kth_index = 0;
bool foundKValue = false;
for (int64_t i = threadIdx.x; i < num_cols; i += blockDim.x) {
bool inRange = (i < num_cols);
T v = inRange ? cur_input[i] : static_cast<T>(0);
bool isKValue =
inRange && ((v == kth_value) || (isnan(static_cast<float>(v)) &&
isnan(static_cast<float>(kth_value))));
if (isKValue) {
kth_index = i;
foundKValue = true;
break;
}
}

if (foundKValue) {
output[row] = kth_value;
indices[row] = kth_index;
}
}

template <typename T>
void LaunchGatherKthValue(const phi::GPUContext& dev_ctx,
const T* input_data,
const int64_t num_cols,
const int64_t num_rows,
const int k,
T* out_data,
int64_t* indices_data) {
int num_threads = std::min(
static_cast<int>(round_up(static_cast<int>(num_cols), WARP_SIZE)),
MAX_NUM_THREADS);
GatherKthValue<T><<<num_rows, num_threads, 0, dev_ctx.stream()>>>(
input_data, k, num_rows, num_cols, out_data, indices_data);
}

template <typename T, bool Largest>
__global__ void RadixTopK(const T* input,
int k,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
34 changes: 29 additions & 5 deletions paddle/phi/kernels/gpu/kthvalue_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ void KthvalueKernel(const Context& dev_ctx,
const auto& in_dims = x.dims();
if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims();
const T* input_data = x.data<T>();
T* output_data = dev_ctx.template Alloc<T>(output);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);

Expand All @@ -180,15 +179,28 @@ void KthvalueKernel(const Context& dev_ctx,
phi::funcs::set_constant(dev_ctx, indices, 0);
return;
}

if (axis == in_dims.size() - 1) {
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
const T* input_data = x.data<T>();
funcs::LaunchGatherKthValue<T>(dev_ctx,
input_data,
input_width,
input_height,
k,
output_data,
indices_data);
#else
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(
dev_ctx, &x, input_width, input_height, k, output, indices),
true,
phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif

return;
} else {
std::vector<int> trans;
Expand Down Expand Up @@ -222,18 +234,28 @@ void KthvalueKernel(const Context& dev_ctx,
trans_out_dims[in_dims.size() - 1] = 1;
DenseTensor trans_input;
trans_input.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_input);
T* tran_input_data = dev_ctx.template Alloc<T>(&trans_input);
int ndims = trans.size();
funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, x, &trans_input, trans);
DenseTensor trans_ind, trans_out;
trans_ind.Resize(trans_out_dims);
trans_out.Resize(trans_out_dims);
dev_ctx.template Alloc<int64_t>(&trans_ind);
dev_ctx.template Alloc<T>(&trans_out);
int64_t* tran_indices_data = dev_ctx.template Alloc<int64_t>(&trans_ind);
T* tran_output_data = dev_ctx.template Alloc<T>(&trans_out);
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
funcs::LaunchGatherKthValue<T>(dev_ctx,
tran_input_data,
input_width,
input_height,
k,
tran_output_data,
tran_indices_data);
#else
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(dev_ctx,
&trans_input,
Expand All @@ -244,6 +266,7 @@ void KthvalueKernel(const Context& dev_ctx,
&trans_ind),
true,
phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif
funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
funcs::TransCompute<phi::GPUContext, T>(
Expand All @@ -263,6 +286,7 @@ PD_REGISTER_KERNEL(kthvalue,
float,
double,
int,
int64_t) {
int64_t,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}

0 comments on commit e18f533

Please sign in to comment.