From cfdc3ffac03282bb10f0f379bd3099b9b9d83a84 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Fri, 12 May 2023 13:42:23 +0800 Subject: [PATCH] fix add_n kernel of large shape --- paddle/phi/kernels/gpu/add_n_kernel.cu | 28 ++++++-------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index 3c224d1e246c8..a30d25b018800 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -21,34 +21,20 @@ namespace phi { #define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) -template -__global__ void Sum2CUDAKernel(const T *in_0, - const T *in_1, - T *out, - int64_t N) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - while (id < N) { - out[id] = in_0[id] + in_1[id]; - id += blockDim.x * gridDim.x; - } -} - template __global__ void SumArrayCUDAKernel( T **in, T *out, int64_t N, size_t in_size, bool read_dst) { using MPType = typename phi::dtype::MPTypeTrait::Type; - int id = blockIdx.x * blockDim.x + threadIdx.x; - while (id < N) { - MPType total(read_dst ? static_cast(out[id]) + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { + MPType total(read_dst ? static_cast(out[idx]) : static_cast(0)); for (int i = 0; i < in_size; ++i) { const T *tmp = in[i]; if (tmp) { - total += static_cast(tmp[id]); + total += static_cast(tmp[idx]); } } - out[id] = static_cast(total); - id += blockDim.x * gridDim.x; + out[idx] = static_cast(total); } } @@ -56,16 +42,14 @@ template __global__ void SumSelectedRowsCUDAKernel(T **sr_in_out, int64_t N, size_t rows) { - int id = blockIdx.x * blockDim.x + threadIdx.x; - while (id < N) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { for (int i = 0; i < 2 * rows; i += 2) { const T *tmp = sr_in_out[i]; T *tmp_out = sr_in_out[i + 1]; if (tmp && tmp_out) { - tmp_out[id] += tmp[id]; + tmp_out[idx] += tmp[idx]; } } - id += blockDim.x * gridDim.x; } }