Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add index initialization in the block loop for index_sample kernel when dealing with a input tensor whose shape is larger than block_dim * grid_dim #39736

Merged
merged 7 commits into from
Feb 20, 2022
2 changes: 2 additions & 0 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data,
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
Expand All @@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;

for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
Expand Down