From 808bf2b48b3624e2749b7c1e5c2ef8635e9d946e Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 26 Sep 2022 19:18:54 +0800 Subject: [PATCH] fix shard_index kernel (#46491) --- paddle/phi/kernels/gpu/shard_index_kernel.cu | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/shard_index_kernel.cu b/paddle/phi/kernels/gpu/shard_index_kernel.cu index d2497f56a0c76..96fd3911c0d45 100644 --- a/paddle/phi/kernels/gpu/shard_index_kernel.cu +++ b/paddle/phi/kernels/gpu/shard_index_kernel.cu @@ -33,7 +33,15 @@ __global__ void ShardIndexInner(const T* in_data, int shard_size = (index_num + nshards - 1) / nshards; int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numel) { - assert(in_data[idx] >= 0 && in_data[idx] < index_num); + PADDLE_ENFORCE(in_data[idx] >= 0, + "The input_index for Op(shard_index) must be " + "greater or equal to 0, but the value given is %d.", + in_data[idx]); + PADDLE_ENFORCE(in_data[idx] < index_num, + "The input_index for Op(shard_index) must be less " + "than index_num (%d), but the value given is %d.", + index_num, + in_data[idx]); if (in_data[idx] / shard_size == shard_id) { out_data[idx] = in_data[idx] % shard_size; } else {