Skip to content

Commit

Permalink
Faiss GPU large transpose fix (#2075)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2075

This is a fix for #1996

namely, that large transposition jobs (where one of the dimensions is > 65535) will still work by performing a loop over the gridDim.y

Reviewed By: mdouze

Differential Revision: D31481494

fbshipit-source-id: af35af36cce27ce7d44128d95cc229dd5c4b4b56
  • Loading branch information
wickedfoo authored and facebook-github-bot committed Oct 11, 2021
1 parent 79e74fe commit 312d823
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions faiss/gpu/utils/Transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ __global__ void transposeOuter(
IndexT t1,
IndexT t2,
IndexT i1) {
IndexT gt1 = blockIdx.y;
IndexT gt2 = blockIdx.x;
for (IndexT gt1 = blockIdx.y; gt1 < t1; gt1 += gridDim.y) {
auto curIn = in + i1 * (gt1 * t2 + gt2);
auto curOut = out + i1 * (gt2 * t1 + gt1);

in += i1 * (gt1 * t2 + gt2);
out += i1 * (gt2 * t1 + gt1);

for (IndexT i = threadIdx.x; i < i1; i += blockDim.x) {
out[i] = in[i];
for (IndexT i = threadIdx.x; i < i1; i += blockDim.x) {
curOut[i] = curIn[i];
}
}
}

Expand Down Expand Up @@ -166,7 +166,10 @@ void runTransposeAny(
innerSize *= in.getSize(i);
}

auto grid = dim3(in.getSize(1), in.getSize(0));
// The grid y dimension is more limited; we do a grid loop if necessary
int maxGridY = getCurrentDeviceProperties().maxGridSize[1];
auto grid = dim3(in.getSize(1), std::min(in.getSize(0), maxGridY));

int block = (innerSize < maxThreads) ? innerSize : maxThreads;

if (totalSize <= (size_t)std::numeric_limits<int>::max()) {
Expand Down

0 comments on commit 312d823

Please sign in to comment.