Skip to content

Commit

Permalink
add index initialization in the block loop for index_sample kernel wh…
Browse files Browse the repository at this point in the history
…en dealing with a input tensor whose shape is larger than block_dim * grid_dim (#39736)

* add block and grid loop for index_sample kernel to deal with a large-shape tensor

* fix code format

* limit grid dim

* fix the omissive initialization of index_i in the second cycle for index_sample kernel

* fix conflicts
  • Loading branch information
FlyingQianMM authored Feb 20, 2022
1 parent 553afc0 commit c6950ab
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data,
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
Expand All @@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;

for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
Expand Down

0 comments on commit c6950ab

Please sign in to comment.