diff --git a/CMakeLists.txt b/CMakeLists.txt index 7743f0c024..b7bf8c7dd8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,13 @@ cmake_minimum_required(VERSION 3.24.0 FATAL_ERROR) set(FAISS_LANGUAGES CXX) if(FAISS_ENABLE_GPU) - list(APPEND FAISS_LANGUAGES CUDA) + # if ROCm install detected, assume ROCm/HIP is GPU device + if (EXISTS /opt/rocm) + set(USE_ROCM TRUE) + list(APPEND FAISS_LANGUAGES HIP) + else() + list(APPEND FAISS_LANGUAGES CUDA) + endif() endif() if(FAISS_ENABLE_RAFT) @@ -58,8 +64,17 @@ option(FAISS_ENABLE_PYTHON "Build Python extension." ON) option(FAISS_ENABLE_C_API "Build C API." OFF) if(FAISS_ENABLE_GPU) - set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) - enable_language(CUDA) + if(USE_ROCM) + enable_language(HIP) + add_definitions(-DUSE_ROCM) + find_package(HIP REQUIRED) + find_package(hipBLAS REQUIRED) + set(GPU_EXT_PREFIX "hip") + else () + set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) + enable_language(CUDA) + set(GPU_EXT_PREFIX "cu") + endif() endif() if(FAISS_ENABLE_RAFT AND NOT TARGET raft::raft) @@ -69,7 +84,11 @@ if(FAISS_ENABLE_RAFT AND NOT TARGET raft::raft) add_subdirectory(faiss) if(FAISS_ENABLE_GPU) - add_subdirectory(faiss/gpu) + if(USE_ROCM) + add_subdirectory(faiss/gpu-rocm) + else() + add_subdirectory(faiss/gpu) + endif() endif() if(FAISS_ENABLE_PYTHON) @@ -90,6 +109,10 @@ if(BUILD_TESTING) add_subdirectory(tests) if(FAISS_ENABLE_GPU) - add_subdirectory(faiss/gpu/test) + if(USE_ROCM) + add_subdirectory(faiss/gpu-rocm/test) + else() + add_subdirectory(faiss/gpu/test) + endif() endif() endif() diff --git a/c_api/CMakeLists.txt b/c_api/CMakeLists.txt index 60b9f23a68..06d85c6aef 100644 --- a/c_api/CMakeLists.txt +++ b/c_api/CMakeLists.txt @@ -56,5 +56,9 @@ add_executable(example_c EXCLUDE_FROM_ALL example_c.c) target_link_libraries(example_c PRIVATE faiss_c) if(FAISS_ENABLE_GPU) - add_subdirectory(gpu) + if(USE_ROCM) + add_subdirectory(gpu-rocm) + else () + add_subdirectory(gpu) + endif() endif() diff --git a/c_api/gpu/CMakeLists.txt b/c_api/gpu/CMakeLists.txt index 4ec926439d..32b374ece9 100644 --- a/c_api/gpu/CMakeLists.txt +++ b/c_api/gpu/CMakeLists.txt @@ -15,8 +15,14 @@ target_sources(faiss_c PRIVATE file(GLOB FAISS_C_API_GPU_HEADERS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.h") faiss_install_headers("${FAISS_C_API_GPU_HEADERS}" c_api/gpu) +if (USE_ROCM) +find_package(HIP REQUIRED) +find_package(hipBLAS REQUIRED) +target_link_libraries(faiss_c PUBLIC hip::host roc::hipblas) +else() find_package(CUDAToolkit REQUIRED) target_link_libraries(faiss_c PUBLIC CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) +endif() add_executable(example_gpu_c EXCLUDE_FROM_ALL example_gpu_c.c) target_link_libraries(example_gpu_c PRIVATE faiss_c) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 5b29987379..c97bae7032 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -197,6 +197,10 @@ function(generate_ivf_interleaved_code) "64|2048|8" ) + if (USE_ROCM) + list(TRANSFORM FAISS_GPU_SRC REPLACE cu$ hip) + endif() + # Traverse through the Cartesian product of X and Y foreach(sub_codec ${SUB_CODEC_TYPE}) foreach(metric_type ${SUB_METRIC_TYPE}) @@ -210,10 +214,10 @@ function(generate_ivf_interleaved_code) set(filename "template_${sub_codec}_${metric_type}_${sub_threads}_${sub_num_warp_q}_${sub_num_thread_q}") # Remove illegal characters from filename string(REGEX REPLACE "[^A-Za-z0-9_]" "" filename ${filename}) - set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.cu") + set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.${GPU_EXT_PREFIX}") # Read the template file - file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.cu" template_content) + file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.${GPU_EXT_PREFIX}" template_content) # Replace the placeholders string(REPLACE "SUB_CODEC_TYPE" "${sub_codec}" template_content "${template_content}") @@ -290,6 +294,10 @@ if(FAISS_ENABLE_RAFT) target_compile_definitions(faiss_gpu PUBLIC USE_NVIDIA_RAFT=1) endif() +if (USE_ROCM) + list(TRANSFORM FAISS_GPU_SRC REPLACE cu$ hip) +endif() + # Export FAISS_GPU_HEADERS variable to parent scope. set(FAISS_GPU_HEADERS ${FAISS_GPU_HEADERS} PARENT_SCOPE) @@ -305,21 +313,26 @@ foreach(header ${FAISS_GPU_HEADERS}) ) endforeach() -# Prepares a host linker script and enables host linker to support -# very large device object files. -# This is what CUDA 11.5+ `nvcc -hls=gen-lcs -aug-hls` would generate -file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld" -[=[ -SECTIONS -{ -.nvFatBinSegment : { *(.nvFatBinSegment) } -__nv_relfatbin : { *(__nv_relfatbin) } -.nv_fatbin : { *(.nv_fatbin) } -} -]=] -) -target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") +if (USE_ROCM) + target_link_libraries(faiss_gpu PRIVATE $<$:hip::host> $<$:roc::hipblas>) + target_compile_options(faiss_gpu PRIVATE) +else() + # Prepares a host linker script and enables host linker to support + # very large device object files. + # This is what CUDA 11.5+ `nvcc -hls=gen-lcs -aug-hls` would generate + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld" + [=[ + SECTIONS + { + .nvFatBinSegment : { *(.nvFatBinSegment) } + __nv_relfatbin : { *(__nv_relfatbin) } + .nv_fatbin : { *(.nv_fatbin) } + } + ]=] + ) + target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") -find_package(CUDAToolkit REQUIRED) -target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> $<$:OpenMP::OpenMP_CXX>) -target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr $<$:-Xcompiler=${OpenMP_CXX_FLAGS}>>) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> $<$:OpenMP::OpenMP_CXX>) + target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr $<$:-Xcompiler=${OpenMP_CXX_FLAGS}>>) +endif() diff --git a/faiss/gpu/GpuFaissAssert.h b/faiss/gpu/GpuFaissAssert.h index 2f03a8c278..7d36fbd8b5 100644 --- a/faiss/gpu/GpuFaissAssert.h +++ b/faiss/gpu/GpuFaissAssert.h @@ -15,7 +15,7 @@ /// Assertions /// -#ifdef __CUDA_ARCH__ +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) #define GPU_FAISS_ASSERT(X) assert(X) #define GPU_FAISS_ASSERT_MSG(X, MSG) assert(X) #define GPU_FAISS_ASSERT_FMT(X, FMT, ...) assert(X) diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index 78336b4994..ae3c8793e7 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -363,11 +363,20 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) { prop.major, prop.minor); +#if USE_ROCM + // Our code is pre-built with and expects warpSize == 32 or 64, validate + // that + FAISS_ASSERT_FMT( + prop.warpSize == 32 || prop.warpSize == 64, + "Device id %d does not have expected warpSize of 32 or 64", + device); +#else // Our code is pre-built with and expects warpSize == 32, validate that FAISS_ASSERT_FMT( prop.warpSize == 32, "Device id %d does not have expected warpSize of 32", device); +#endif // Create streams cudaStream_t defaultStream = nullptr; diff --git a/faiss/gpu/hipify.sh b/faiss/gpu/hipify.sh new file mode 100755 index 0000000000..09d466545e --- /dev/null +++ b/faiss/gpu/hipify.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +# go one level up from faiss/gpu +top=$(dirname "${BASH_SOURCE[0]}")/.. +echo "top=$top" +cd $top +echo "pwd=`pwd`" + +# create all destination directories for hipified files into sibling 'gpu-rocm' directory +for src in $(find ./gpu -type d) +do + dst=$(echo $src | sed 's/gpu/gpu-rocm/') + echo "Creating $dst" + mkdir -p $dst +done + +# run hipify-perl against all *.cu *.cuh *.h *.cpp files, no renaming +# run all files in parallel to speed up +for ext in cu cuh h cpp +do + for src in $(find ./gpu -name "*.$ext") + do + dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') + hipify-perl -o=$dst.tmp $src & + done +done +wait + +# rename all hipified *.cu files to *.hip +for src in $(find ./gpu-rocm -name "*.cu.tmp") +do + dst=${src%.cu.tmp}.hip.tmp + mv $src $dst +done + +# replace header include statements "@#include @' $src + sed -i 's@#include @#include @' $src + done +done + +# hipify was run in parallel above +# don't copy the tmp file if it is unchanged +for ext in hip cuh h cpp +do + for src in $(find ./gpu-rocm -name "*.$ext.tmp") + do + dst=${src%.tmp} + if test -f $dst + then + if diff -q $src $dst >& /dev/null + then + echo "$dst [unchanged]" + rm $src + else + echo "$dst" + mv $src $dst + fi + else + echo "$dst" + mv $src $dst + fi + done +done + +# copy over CMakeLists.txt +for src in $(find ./gpu -name "CMakeLists.txt") +do + dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') + if test -f $dst + then + if diff -q $src $dst >& /dev/null + then + echo "$dst [unchanged]" + else + echo "$dst" + cp $src $dst + fi + else + echo "$dst" + cp $src $dst + fi +done + +# Copy over other files +for ext in py +do + for src in $(find ./gpu -name "*.$ext") + do + dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') + if test -f $dst + then + if diff -q $src $dst >& /dev/null + then + echo "$dst [unchanged]" + else + echo "$dst" + cp $src $dst + fi + else + echo "$dst" + cp $src $dst + fi + done +done + + +################################################################################### +# C_API Support +################################################################################### + +# Now get the c_api dir +# This points to the faiss/c_api dir +top_c_api=$(dirname "${BASH_SOURCE[0]}")/../../c_api +echo "top=$top_c_api" +cd ../$top_c_api +echo "pwd=`pwd`" + + +# create all destination directories for hipified files into sibling 'gpu-rocm' directory +for src in $(find ./gpu -type d) +do + dst=$(echo $src | sed 's/gpu/gpu-rocm/') + echo "Creating $dst" + mkdir -p $dst +done + +# run hipify-perl against all *.cu *.cuh *.h *.cpp files, no renaming +# run all files in parallel to speed up +for ext in cu cuh h cpp c +do + for src in $(find ./gpu -name "*.$ext") + do + dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') + hipify-perl -o=$dst.tmp $src & + done +done +wait + +# rename all hipified *.cu files to *.hip +for src in $(find ./gpu-rocm -name "*.cu.tmp") +do + dst=${src%.cu.tmp}.hip.tmp + mv $src $dst +done + +# replace header include statements "@#include @' $src + sed -i 's@#include @#include @' $src + done +done + +# hipify was run in parallel above +# don't copy the tmp file if it is unchanged +for ext in hip cuh h cpp c +do + for src in $(find ./gpu-rocm -name "*.$ext.tmp") + do + dst=${src%.tmp} + if test -f $dst + then + if diff -q $src $dst >& /dev/null + then + echo "$dst [unchanged]" + rm $src + else + echo "$dst" + mv $src $dst + fi + else + echo "$dst" + mv $src $dst + fi + done +done + +# copy over CMakeLists.txt +for src in $(find ./gpu -name "CMakeLists.txt") +do + dst=$(echo $src | sed 's@./gpu@./gpu-rocm@') + if test -f $dst + then + if diff -q $src $dst >& /dev/null + then + echo "$dst [unchanged]" + else + echo "$dst" + cp $src $dst + fi + else + echo "$dst" + cp $src $dst + fi +done diff --git a/faiss/gpu/impl/BinaryDistance.cu b/faiss/gpu/impl/BinaryDistance.cu index 0183c24d37..4ea5004442 100644 --- a/faiss/gpu/impl/BinaryDistance.cu +++ b/faiss/gpu/impl/BinaryDistance.cu @@ -29,78 +29,85 @@ __launch_bounds__(kWarps* kLanes) __global__ void binaryDistanceAnySize( Tensor outK, Tensor outV, int k) { - // A matrix tile (query, k) - __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict - - // B matrix tile (vec, k) - __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict - - WarpSelect< - int, - idx_t, - false, - Comparator, - NumWarpQ, - NumThreadQ, - kWarps * kLanes> - heap(kMaxDistance, -1, k); - - int warpId = threadIdx.y; - int laneId = threadIdx.x; - - // Each warp handles a single query - idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId; - bool queryInBounds = warpQuery < query.getSize(0); - - // Each warp loops through the entire chunk of vectors - for (idx_t blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) { - int threadDistance = 0; + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + // A matrix tile (query, k) + __shared__ BinaryType + queryTile[kWarps][kLanes + 1]; // avoid bank conflict + + // B matrix tile (vec, k) + __shared__ BinaryType + vecTile[kLanes][kLanes + 1]; // avoid bank conflict + + WarpSelect< + int, + idx_t, + false, + Comparator, + NumWarpQ, + NumThreadQ, + kWarps * kLanes> + heap(kMaxDistance, -1, k); + + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + // Each warp handles a single query + idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId; + bool queryInBounds = warpQuery < query.getSize(0); + + // Each warp loops through the entire chunk of vectors + for (idx_t blockVec = 0; blockVec < vecs.getSize(0); + blockVec += kLanes) { + int threadDistance = 0; + + // Reduction dimension + for (idx_t blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) { + idx_t laneK = blockK + laneId; + bool kInBounds = laneK < vecs.getSize(1); + + queryTile[warpId][laneId] = queryInBounds && kInBounds + ? query[warpQuery][laneK] + : 0; + + // kWarps warps are responsible for loading 32 vecs +#pragma unroll + for (int i = 0; i < kLanes / kWarps; ++i) { + int warpVec = i * kWarps + warpId; + idx_t vec = blockVec + warpVec; + bool vecInBounds = vec < vecs.getSize(0); - // Reduction dimension - for (idx_t blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) { - idx_t laneK = blockK + laneId; - bool kInBounds = laneK < vecs.getSize(1); + vecTile[warpVec][laneId] = + vecInBounds && kInBounds ? vecs[vec][laneK] : 0; + } - queryTile[warpId][laneId] = - queryInBounds && kInBounds ? query[warpQuery][laneK] : 0; + __syncthreads(); - // kWarps warps are responsible for loading 32 vecs + // Compare distances #pragma unroll - for (int i = 0; i < kLanes / kWarps; ++i) { - int warpVec = i * kWarps + warpId; - idx_t vec = blockVec + warpVec; - bool vecInBounds = vec < vecs.getSize(0); + for (int i = 0; i < kLanes; ++i) { + threadDistance += + __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); + } - vecTile[warpVec][laneId] = - vecInBounds && kInBounds ? vecs[vec][laneK] : 0; + __syncthreads(); } - __syncthreads(); + // Lanes within a warp are different vec results against the same + // query Only submit distances which represent real (query, vec) + // pairs + bool valInBounds = + queryInBounds && (blockVec + laneId < vecs.getSize(0)); + threadDistance = valInBounds ? threadDistance : kMaxDistance; + idx_t id = valInBounds ? blockVec + laneId : idx_t(-1); - // Compare distances -#pragma unroll - for (int i = 0; i < kLanes; ++i) { - threadDistance += - __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); - } - - __syncthreads(); + heap.add(threadDistance, id); } - // Lanes within a warp are different vec results against the same query - // Only submit distances which represent real (query, vec) pairs - bool valInBounds = - queryInBounds && (blockVec + laneId < vecs.getSize(0)); - threadDistance = valInBounds ? threadDistance : kMaxDistance; - idx_t id = valInBounds ? blockVec + laneId : idx_t(-1); - - heap.add(threadDistance, id); - } + heap.reduce(); - heap.reduce(); - - if (warpQuery < query.getSize(0)) { - heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k); + if (warpQuery < query.getSize(0)) { + heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k); + } } } @@ -117,73 +124,80 @@ __global__ void __launch_bounds__(kWarps* kLanes) binaryDistanceLimitSize( Tensor outK, Tensor outV, int k) { - // A matrix tile (query, k) - __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict - - // B matrix tile (vec, k) - __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict - - WarpSelect< - int, - idx_t, - false, - Comparator, - NumWarpQ, - NumThreadQ, - kWarps * kLanes> - heap(kMaxDistance, -1, k); - - int warpId = threadIdx.y; - int laneId = threadIdx.x; - - // Each warp handles a single query - int laneK = laneId; - idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId; - bool kInBounds = laneK < vecs.getSize(1); - bool queryInBounds = warpQuery < query.getSize(0); - - queryTile[warpId][laneId] = - queryInBounds && kInBounds ? query[warpQuery][laneK] : 0; - - // Each warp loops through the entire chunk of vectors - for (idx_t blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) { - int threadDistance = 0; - - // kWarps warps are responsible for loading 32 vecs + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + // A matrix tile (query, k) + __shared__ BinaryType + queryTile[kWarps][kLanes + 1]; // avoid bank conflict + + // B matrix tile (vec, k) + __shared__ BinaryType + vecTile[kLanes][kLanes + 1]; // avoid bank conflict + + WarpSelect< + int, + idx_t, + false, + Comparator, + NumWarpQ, + NumThreadQ, + kWarps * kLanes> + heap(kMaxDistance, -1, k); + + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + // Each warp handles a single query + int laneK = laneId; + idx_t warpQuery = idx_t(blockIdx.x) * kWarps + warpId; + bool kInBounds = laneK < vecs.getSize(1); + bool queryInBounds = warpQuery < query.getSize(0); + + queryTile[warpId][laneId] = + queryInBounds && kInBounds ? query[warpQuery][laneK] : 0; + + // Each warp loops through the entire chunk of vectors + for (idx_t blockVec = 0; blockVec < vecs.getSize(0); + blockVec += kLanes) { + int threadDistance = 0; + + // kWarps warps are responsible for loading 32 vecs #pragma unroll - for (int i = 0; i < kLanes / kWarps; ++i) { - int warpVec = i * kWarps + warpId; - idx_t vec = blockVec + warpVec; - bool vecInBounds = vec < vecs.getSize(0); + for (int i = 0; i < kLanes / kWarps; ++i) { + int warpVec = i * kWarps + warpId; + idx_t vec = blockVec + warpVec; + bool vecInBounds = vec < vecs.getSize(0); - vecTile[warpVec][laneId] = - vecInBounds && kInBounds ? vecs[vec][laneK] : 0; - } + vecTile[warpVec][laneId] = + vecInBounds && kInBounds ? vecs[vec][laneK] : 0; + } - __syncthreads(); + __syncthreads(); - // Compare distances + // Compare distances #pragma unroll - for (int i = 0; i < ReductionLimit; ++i) { - threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); - } + for (int i = 0; i < ReductionLimit; ++i) { + threadDistance += + __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); + } - __syncthreads(); + __syncthreads(); - // Lanes within a warp are different vec results against the same query - // Only submit distances which represent real (query, vec) pairs - bool valInBounds = - queryInBounds && (blockVec + laneId < vecs.getSize(0)); - threadDistance = valInBounds ? threadDistance : kMaxDistance; - idx_t id = valInBounds ? blockVec + laneId : idx_t(-1); + // Lanes within a warp are different vec results against the same + // query Only submit distances which represent real (query, vec) + // pairs + bool valInBounds = + queryInBounds && (blockVec + laneId < vecs.getSize(0)); + threadDistance = valInBounds ? threadDistance : kMaxDistance; + idx_t id = valInBounds ? blockVec + laneId : idx_t(-1); - heap.add(threadDistance, id); - } + heap.add(threadDistance, id); + } - heap.reduce(); + heap.reduce(); - if (warpQuery < query.getSize(0)) { - heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k); + if (warpQuery < query.getSize(0)) { + heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k); + } } } @@ -196,12 +210,12 @@ void runBinaryDistanceAnySize( int k, cudaStream_t stream) { dim3 grid(utils::divUp(query.getSize(0), kWarps)); - dim3 block(kLanes, kWarps); + dim3 block(getWarpSizeCurrentDevice(), kWarps); if (k == 1) { binaryDistanceAnySize<1, 1, BinaryType> <<>>(vecs, query, outK, outV, k); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { binaryDistanceAnySize<32, 2, BinaryType> <<>>(vecs, query, outK, outV, k); } else if (k <= 64) { @@ -237,12 +251,12 @@ void runBinaryDistanceLimitSize( int k, cudaStream_t stream) { dim3 grid(utils::divUp(query.getSize(0), kWarps)); - dim3 block(kLanes, kWarps); + dim3 block(getWarpSizeCurrentDevice(), kWarps); if (k == 1) { binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit> <<>>(vecs, query, outK, outV, k); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit> <<>>(vecs, query, outK, outV, k); } else if (k <= 64) { diff --git a/faiss/gpu/impl/BroadcastSum.cuh b/faiss/gpu/impl/BroadcastSum.cuh index 6b62a4b913..87aeb2b367 100644 --- a/faiss/gpu/impl/BroadcastSum.cuh +++ b/faiss/gpu/impl/BroadcastSum.cuh @@ -7,6 +7,7 @@ #pragma once +#include #include namespace faiss { diff --git a/faiss/gpu/impl/GeneralDistance.cuh b/faiss/gpu/impl/GeneralDistance.cuh index 8b751f7e02..cf2c82b896 100644 --- a/faiss/gpu/impl/GeneralDistance.cuh +++ b/faiss/gpu/impl/GeneralDistance.cuh @@ -30,6 +30,12 @@ namespace faiss { namespace gpu { +// Initially kWarpSize was used for the x and y tile shape. +// This works when kWarpSize is 32 but for kWarpSize 64, +// this results in an invalid launch configuration of 64x64 block size. +// 32 is a reasonable tile size for both kWarpSize options. +constexpr int TILE_SIZE = 32; + // Reduction tree operator template struct ReduceDistanceOp { @@ -56,8 +62,8 @@ struct ReduceDistanceOp { template inline __device__ DistanceOp reduce(const DistanceOp& in, - const T queryTile[kWarpSize][DimMultiple * kWarpSize + 1], - const T vecTile[kWarpSize][DimMultiple * kWarpSize + 1]) { + const T queryTile[TILE_SIZE][DimMultiple * TILE_SIZE + 1], + const T vecTile[TILE_SIZE][DimMultiple * TILE_SIZE + 1]) { DistanceOp accs[Unroll]; #pragma unroll for (int i = 0; i < Unroll; ++i) { @@ -70,8 +76,8 @@ reduce(const DistanceOp& in, #pragma unroll for (int i = 0; i < Unroll; ++i) { #pragma unroll - for (int j = 0; j < (kWarpSize * DimMultiple / Unroll); ++j) { - int idx = i * (kWarpSize * DimMultiple / Unroll) + j; + for (int j = 0; j < (TILE_SIZE * DimMultiple / Unroll); ++j) { + int idx = i * (TILE_SIZE * DimMultiple / Unroll) + j; accs[i].handle( ConvertTo::to(queryTileBase[idx]), ConvertTo::to(vecTileBase[idx])); @@ -83,23 +89,23 @@ reduce(const DistanceOp& in, // Our general distance matrix "multiplication" kernel template -__launch_bounds__(kWarpSize* kWarpSize) __global__ void generalDistance( +__launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance( Tensor query, // m x k Tensor vec, // n x k DistanceOp op, Tensor out) { // m x n constexpr int kDimMultiple = 1; - __shared__ T queryTile[kWarpSize][kWarpSize * kDimMultiple + 1]; - __shared__ T vecTile[kWarpSize][kWarpSize * kDimMultiple + 1]; + __shared__ T queryTile[TILE_SIZE][TILE_SIZE * kDimMultiple + 1]; + __shared__ T vecTile[TILE_SIZE][TILE_SIZE * kDimMultiple + 1]; // block y -> query // block x -> vector - idx_t queryBlock = idx_t(blockIdx.y) * kWarpSize; + idx_t queryBlock = idx_t(blockIdx.y) * TILE_SIZE; idx_t queryThread = queryBlock + threadIdx.y; - idx_t vecBlock = idx_t(blockIdx.x) * kWarpSize; + idx_t vecBlock = idx_t(blockIdx.x) * TILE_SIZE; idx_t vecThreadLoad = vecBlock + threadIdx.y; idx_t vecThreadSave = vecBlock + threadIdx.x; @@ -116,16 +122,16 @@ __launch_bounds__(kWarpSize* kWarpSize) __global__ void generalDistance( // Interior tile // idx_t limit = - utils::roundDown(query.getSize(1), kWarpSize * kDimMultiple); + utils::roundDown(query.getSize(1), TILE_SIZE * kDimMultiple); - for (idx_t k = threadIdx.x; k < limit; k += kWarpSize * kDimMultiple) { + for (idx_t k = threadIdx.x; k < limit; k += TILE_SIZE * kDimMultiple) { // Load query tile #pragma unroll for (int i = 0; i < kDimMultiple; ++i) { - queryTileBase[threadIdx.x + i * kWarpSize] = - queryBase[k + i * kWarpSize]; - vecTileBase[threadIdx.x + i * kWarpSize] = - vecBase[k + i * kWarpSize]; + queryTileBase[threadIdx.x + i * TILE_SIZE] = + queryBase[k + i * TILE_SIZE]; + vecTileBase[threadIdx.x + i * TILE_SIZE] = + vecBase[k + i * TILE_SIZE]; } __syncthreads(); @@ -141,13 +147,13 @@ __launch_bounds__(kWarpSize* kWarpSize) __global__ void generalDistance( if (limit < query.getSize(1)) { #pragma unroll for (int i = 0; i < kDimMultiple; ++i) { - idx_t k = limit + threadIdx.x + i * kWarpSize; + idx_t k = limit + threadIdx.x + i * TILE_SIZE; bool kInBounds = k < query.getSize(1); - queryTileBase[threadIdx.x + i * kWarpSize] = + queryTileBase[threadIdx.x + i * TILE_SIZE] = kInBounds ? queryBase[k] : ConvertTo::to(0); - vecTileBase[threadIdx.x + i * kWarpSize] = + vecTileBase[threadIdx.x + i * TILE_SIZE] = kInBounds ? vecBase[k] : ConvertTo::to(0); } @@ -174,9 +180,9 @@ __launch_bounds__(kWarpSize* kWarpSize) __global__ void generalDistance( bool queryThreadInBounds = queryThread < query.getSize(0); bool vecThreadInBoundsLoad = vecThreadLoad < vec.getSize(0); bool vecThreadInBoundsSave = vecThreadSave < vec.getSize(0); - idx_t limit = utils::roundDown(query.getSize(1), kWarpSize); + idx_t limit = utils::roundDown(query.getSize(1), TILE_SIZE); - for (idx_t k = threadIdx.x; k < limit; k += kWarpSize) { + for (idx_t k = threadIdx.x; k < limit; k += TILE_SIZE) { // Load query tile queryTileBase[threadIdx.x] = queryThreadInBounds ? queryBase[k] : ConvertTo::to(0); @@ -188,7 +194,7 @@ __launch_bounds__(kWarpSize* kWarpSize) __global__ void generalDistance( // thread (y, x) does (query y, vec x) #pragma unroll - for (int i = 0; i < kWarpSize; ++i) { + for (int i = 0; i < TILE_SIZE; ++i) { acc.handle( ConvertTo::to(queryTileBase[i]), ConvertTo::to(vecTile[threadIdx.x][i])); @@ -242,10 +248,10 @@ void runGeneralDistanceKernel( FAISS_ASSERT(out.getSize(1) == vecs.getSize(0)); dim3 grid( - utils::divUp(vecs.getSize(0), kWarpSize), - utils::divUp(query.getSize(0), kWarpSize)); + utils::divUp(vecs.getSize(0), TILE_SIZE), + utils::divUp(query.getSize(0), TILE_SIZE)); FAISS_ASSERT(grid.y <= getMaxGridCurrentDevice().y); - dim3 block(kWarpSize, kWarpSize); + dim3 block(TILE_SIZE, TILE_SIZE); generalDistance<<>>(query, vecs, op, out); } diff --git a/faiss/gpu/impl/IVFAppend.cu b/faiss/gpu/impl/IVFAppend.cu index e0bdc6fc52..8ee85eaed8 100644 --- a/faiss/gpu/impl/IVFAppend.cu +++ b/faiss/gpu/impl/IVFAppend.cu @@ -342,7 +342,7 @@ void runSQEncode( } // Handles appending encoded vectors (one per EncodeT word) packed into -// EncodeBits interleaved by 32 vectors. +// EncodeBits interleaved by kWarpSize vectors. // This is used by Flat, SQ and PQ code for the interleaved format. template __global__ void ivfInterleavedAppend( @@ -391,30 +391,30 @@ __global__ void ivfInterleavedAppend( // These are the actual vec IDs that we are adding (in vecs) auto listVecIds = vectorsByUniqueList[vecIdStart].data(); - // All data is written by groups of 32 vectors (to mirror the warp). + // All data is written by groups of kWarpSize vectors (to mirror the warp). // listVecStart could be in the middle of this, or even, for sub-byte // encodings, mean that the first vector piece of data that we need to // update is in the high part of a byte. // // WarpPackedBits allows writing of arbitrary bit packed data in groups of - // 32, but we ensure that it only operates on the group of 32 vectors. In - // order to do this we need to actually start updating vectors at the next - // lower multiple of 32 from listVecStart. - auto alignedListVecStart = utils::roundDown(listVecStart, 32); + // kWarpSize, but we ensure that it only operates on the group of kWarpSize + // vectors. In order to do this we need to actually start updating vectors + // at the next lower multiple of kWarpSize from listVecStart. + auto alignedListVecStart = utils::roundDown(listVecStart, kWarpSize); - // Each block of 32 vectors fully encodes into this many bytes - constexpr int bytesPerVectorBlockDim = EncodeBits * 32 / 8; + // Each block of kWarpSize vectors fully encodes into this many bytes + constexpr int bytesPerVectorBlockDim = EncodeBits * kWarpSize / 8; constexpr int wordsPerVectorBlockDim = bytesPerVectorBlockDim / sizeof(EncodeT); auto wordsPerVectorBlock = wordsPerVectorBlockDim * encodedVecs.getSize(1); EncodeT* listStart = ((EncodeT*)listData[listId]); - // Each warp within the block handles a different chunk of 32 - auto warpVec = alignedListVecStart + warpId * 32; + // Each warp within the block handles a different chunk of kWarpSize + auto warpVec = alignedListVecStart + warpId * kWarpSize; // The warp data starts here - EncodeT* warpData = listStart + (warpVec / 32) * wordsPerVectorBlock; + EncodeT* warpData = listStart + (warpVec / kWarpSize) * wordsPerVectorBlock; // Each warp encodes a single block for (; warpVec < listVecStart + numVecsAdding; diff --git a/faiss/gpu/impl/IVFBase.cuh b/faiss/gpu/impl/IVFBase.cuh index 04af9a906e..26411aa4e9 100644 --- a/faiss/gpu/impl/IVFBase.cuh +++ b/faiss/gpu/impl/IVFBase.cuh @@ -220,15 +220,15 @@ class IVFBase { /// Coarse quantizer centroids available on GPU DeviceTensor ivfCentroids_; - /// Whether or not our index uses an interleaved by 32 layout: + /// Whether or not our index uses an interleaved by kWarpSize layout: /// The default memory layout is [vector][PQ/SQ component]: /// (v0 d0) (v0 d1) ... (v0 dD-1) (v1 d0) (v1 d1) ... /// - /// The interleaved by 32 memory layout is: - /// [vector / 32][PQ/SQ component][vector % 32] with padding: + /// The interleaved by kWarpSize memory layout is: + /// [vector / kWarpSize][PQ/SQ component][vector % kWarpSize] with padding: /// (v0 d0) (v1 d0) ... (v31 d0) (v0 d1) (v1 d1) ... (v31 dD-1) (v32 d0) /// (v33 d0) ... so the list length is always a multiple of num quantizers * - /// 32 + /// kWarpSize bool interleavedLayout_; /// How are user indices stored on the GPU? diff --git a/faiss/gpu/impl/IVFFlat.cu b/faiss/gpu/impl/IVFFlat.cu index e0ecfd82cf..9a7fd87d34 100644 --- a/faiss/gpu/impl/IVFFlat.cu +++ b/faiss/gpu/impl/IVFFlat.cu @@ -57,14 +57,16 @@ size_t IVFFlat::getGpuVectorsEncodingSize_(idx_t numVecs) const { // bits per scalar code idx_t bits = scalarQ_ ? scalarQ_->bits : 32 /* float */; - // bytes to encode a block of 32 vectors (single dimension) - idx_t bytesPerDimBlock = bits * 32 / 8; + int warpSize = getWarpSizeCurrentDevice(); - // bytes to fully encode 32 vectors + // bytes to encode a block of warpSize vectors (single dimension) + idx_t bytesPerDimBlock = bits * warpSize / 8; + + // bytes to fully encode warpSize vectors idx_t bytesPerBlock = bytesPerDimBlock * dim_; - // number of blocks of 32 vectors we have - idx_t numBlocks = utils::divUp(numVecs, 32); + // number of blocks of warpSize vectors we have + idx_t numBlocks = utils::divUp(numVecs, warpSize); // total size to encode numVecs return bytesPerBlock * numBlocks; @@ -289,6 +291,7 @@ void IVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) { return; } + int warpSize = getWarpSizeCurrentDevice(); auto stream = resources_->getDefaultStreamCurrentDevice(); for (idx_t list_no = 0; list_no < numLists_; list_no++) { @@ -313,15 +316,15 @@ void IVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) { // where vectors are chunked into groups of 32, and each dimension // for each of the 32 vectors is contiguous - auto vectorChunk = offset / 32; - auto vectorWithinChunk = offset % 32; + auto vectorChunk = offset / warpSize; + auto vectorWithinChunk = offset % warpSize; auto listDataPtr = (float*)deviceListData_[list_no]->data.data(); - listDataPtr += vectorChunk * 32 * dim_ + vectorWithinChunk; + listDataPtr += vectorChunk * warpSize * dim_ + vectorWithinChunk; for (int d = 0; d < dim_; ++d) { fromDevice( - listDataPtr + 32 * d, + listDataPtr + warpSize * d, out + (id - i0) * dim_ + d, 1, stream); diff --git a/faiss/gpu/impl/IVFFlatScan.cu b/faiss/gpu/impl/IVFFlatScan.cu index 3acee58d71..66e9a396b6 100644 --- a/faiss/gpu/impl/IVFFlatScan.cu +++ b/faiss/gpu/impl/IVFFlatScan.cu @@ -211,8 +211,9 @@ void runIVFFlatScanTile( runCalcListOffsets( res, listIds, listLengths, prefixSumOffsets, thrustMem, stream); + int warpSize = getWarpSizeCurrentDevice(); auto grid = dim3(listIds.getSize(1), listIds.getSize(0)); - auto block = dim3(kWarpSize * kIVFFlatScanWarps); + auto block = dim3(warpSize * kIVFFlatScanWarps); #define RUN_IVF_FLAT \ do { \ diff --git a/faiss/gpu/impl/IVFInterleaved.cu b/faiss/gpu/impl/IVFInterleaved.cu index 9e595b4a59..9a98473f7b 100644 --- a/faiss/gpu/impl/IVFInterleaved.cu +++ b/faiss/gpu/impl/IVFInterleaved.cu @@ -26,103 +26,107 @@ __global__ void ivfInterleavedScan2( bool dir, Tensor distanceOut, Tensor indicesOut) { - int queryId = blockIdx.x; - - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ float smemK[kNumWarps * NumWarpQ]; - // The BlockSelect value type is uint32_t, as we pack together which probe - // (up to nprobe - 1) and which k (up to k - 1) from each individual list - // together, and both nprobe and k are limited to GPU_MAX_SELECTION_K. - __shared__ uint32_t smemV[kNumWarps * NumWarpQ]; - - // To avoid creating excessive specializations, we combine direction - // kernels, selecting for the smallest element. If `dir` is true, we negate - // all values being selected (so that we are selecting the largest element). - BlockSelect< - float, - uint32_t, - false, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(kFloatMax, kMaxUInt32, smemK, smemV, k); - - // nprobe x k - idx_t num = distanceIn.getSize(1) * distanceIn.getSize(2); - - const float* distanceBase = distanceIn[queryId].data(); - idx_t limit = utils::roundDown(num, kWarpSize); - - // This will keep our negation factor - float adj = dir ? -1 : 1; - - idx_t i = threadIdx.x; - for (; i < limit; i += blockDim.x) { - // We represent the index as (probe id)(k) - // Right now, both are limited to a maximum of 2048, but we will - // dedicate each to the high and low words of a uint32_t - static_assert(GPU_MAX_SELECTION_K <= 65536, ""); - - uint32_t curProbe = i / k; - uint32_t curK = i % k; - // Since nprobe and k are limited, we can pack both of these together - // into a uint32_t - uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff); - - // The IDs reported from the list may be -1, if a particular IVF list - // doesn't even have k entries in it - if (listIds[queryId][curProbe] != -1) { - // Adjust the value we are selecting based on the sorting order - heap.addThreadQ(distanceBase[i] * adj, index); - } - - heap.checkThreadQ(); - } - - // Handle warp divergence separately - if (i < num) { - uint32_t curProbe = i / k; - uint32_t curK = i % k; - uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff); + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + int queryId = blockIdx.x; + + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ float smemK[kNumWarps * NumWarpQ]; + // The BlockSelect value type is uint32_t, as we pack together which + // probe (up to nprobe - 1) and which k (up to k - 1) from each + // individual list together, and both nprobe and k are limited to + // GPU_MAX_SELECTION_K. + __shared__ uint32_t smemV[kNumWarps * NumWarpQ]; + + // To avoid creating excessive specializations, we combine direction + // kernels, selecting for the smallest element. If `dir` is true, we + // negate all values being selected (so that we are selecting the + // largest element). + BlockSelect< + float, + uint32_t, + false, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(kFloatMax, kMaxUInt32, smemK, smemV, k); + + // nprobe x k + idx_t num = distanceIn.getSize(1) * distanceIn.getSize(2); + + const float* distanceBase = distanceIn[queryId].data(); + idx_t limit = utils::roundDown(num, kWarpSize); + + // This will keep our negation factor + float adj = dir ? -1 : 1; + + idx_t i = threadIdx.x; + for (; i < limit; i += blockDim.x) { + // We represent the index as (probe id)(k) + // Right now, both are limited to a maximum of 2048, but we will + // dedicate each to the high and low words of a uint32_t + static_assert(GPU_MAX_SELECTION_K <= 65536, ""); + + uint32_t curProbe = i / k; + uint32_t curK = i % k; + // Since nprobe and k are limited, we can pack both of these + // together into a uint32_t + uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff); + + // The IDs reported from the list may be -1, if a particular IVF + // list doesn't even have k entries in it + if (listIds[queryId][curProbe] != -1) { + // Adjust the value we are selecting based on the sorting order + heap.addThreadQ(distanceBase[i] * adj, index); + } - idx_t listId = listIds[queryId][curProbe]; - if (listId != -1) { - heap.addThreadQ(distanceBase[i] * adj, index); + heap.checkThreadQ(); } - } - // Merge all final results - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += blockDim.x) { - // Re-adjust the value we are selecting based on the sorting order - distanceOut[queryId][i] = smemK[i] * adj; - auto packedIndex = smemV[i]; - - // We need to remap to the user-provided indices - idx_t index = -1; - - // We may not have at least k values to return; in this function, max - // uint32 is our sentinel value - if (packedIndex != kMaxUInt32) { - uint32_t curProbe = packedIndex >> 16; - uint32_t curK = packedIndex & 0xffff; + // Handle warp divergence separately + if (i < num) { + uint32_t curProbe = i / k; + uint32_t curK = i % k; + uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff); idx_t listId = listIds[queryId][curProbe]; - idx_t listOffset = indicesIn[queryId][curProbe][curK]; - - if (opt == INDICES_32_BIT) { - index = (idx_t)((int*)listIndices[listId])[listOffset]; - } else if (opt == INDICES_64_BIT) { - index = ((idx_t*)listIndices[listId])[listOffset]; - } else { - index = (listId << 32 | (idx_t)listOffset); + if (listId != -1) { + heap.addThreadQ(distanceBase[i] * adj, index); } } - indicesOut[queryId][i] = index; + // Merge all final results + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += blockDim.x) { + // Re-adjust the value we are selecting based on the sorting order + distanceOut[queryId][i] = smemK[i] * adj; + auto packedIndex = smemV[i]; + + // We need to remap to the user-provided indices + idx_t index = -1; + + // We may not have at least k values to return; in this function, + // max uint32 is our sentinel value + if (packedIndex != kMaxUInt32) { + uint32_t curProbe = packedIndex >> 16; + uint32_t curK = packedIndex & 0xffff; + + idx_t listId = listIds[queryId][curProbe]; + idx_t listOffset = indicesIn[queryId][curProbe][curK]; + + if (opt == INDICES_32_BIT) { + index = (idx_t)((int*)listIndices[listId])[listOffset]; + } else if (opt == INDICES_64_BIT) { + index = ((idx_t*)listIndices[listId])[listOffset]; + } else { + index = (listId << 32 | (idx_t)listOffset); + } + } + + indicesOut[queryId][i] = index; + } } } @@ -152,7 +156,7 @@ void runIVFInterleavedScan2( if (k == 1) { IVF_SCAN_2(128, 1, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { IVF_SCAN_2(128, 32, 2); } else if (k <= 64) { IVF_SCAN_2(128, 64, 3); @@ -211,7 +215,7 @@ void runIVFInterleavedScan( if (k == 1) { ivf_interleaved_call(ivfInterleavedScanImpl<128, 1, 1>); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { ivf_interleaved_call(ivfInterleavedScanImpl<128, 32, 2>); } else if (k <= 64) { ivf_interleaved_call(ivfInterleavedScanImpl<128, 64, 3>); diff --git a/faiss/gpu/impl/IVFInterleaved.cuh b/faiss/gpu/impl/IVFInterleaved.cuh index 5f92c366e3..a3f218a212 100644 --- a/faiss/gpu/impl/IVFInterleaved.cuh +++ b/faiss/gpu/impl/IVFInterleaved.cuh @@ -49,166 +49,176 @@ __global__ void ivfInterleavedScan( Tensor distanceOut, Tensor indicesOut, const bool Residual) { - extern __shared__ float smem[]; + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + extern __shared__ float smem[]; - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - for (idx_t queryId = blockIdx.y; queryId < queries.getSize(0); - queryId += gridDim.y) { - int probeId = blockIdx.x; - idx_t listId = listIds[queryId][probeId]; + for (idx_t queryId = blockIdx.y; queryId < queries.getSize(0); + queryId += gridDim.y) { + int probeId = blockIdx.x; + idx_t listId = listIds[queryId][probeId]; - // Safety guard in case NaNs in input cause no list ID to be generated, - // or we have more nprobe than nlist - if (listId == -1) { - return; - } + // Safety guard in case NaNs in input cause no list ID to be + // generated, or we have more nprobe than nlist + if (listId == -1) { + return; + } - // Vector dimension is currently limited to 32 bit - int dim = queries.getSize(1); + // Vector dimension is currently limited to 32 bit + int dim = queries.getSize(1); - // FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs? - int laneId = threadIdx.x % kWarpSize; - int warpId = threadIdx.x / kWarpSize; + // FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs? + int laneId = threadIdx.x % kWarpSize; + int warpId = threadIdx.x / kWarpSize; - using EncodeT = typename Codec::EncodeT; + using EncodeT = typename Codec::EncodeT; - auto query = queries[queryId].data(); - auto vecsBase = (EncodeT*)allListData[listId]; - int numVecs = listLengths[listId]; - auto residualBaseSlice = residualBase[queryId][probeId].data(); + auto query = queries[queryId].data(); + auto vecsBase = (EncodeT*)allListData[listId]; + int numVecs = listLengths[listId]; + auto residualBaseSlice = residualBase[queryId][probeId].data(); - constexpr auto kInit = Metric::kDirection ? kFloatMin : kFloatMax; + constexpr auto kInit = Metric::kDirection ? kFloatMin : kFloatMax; - __shared__ float smemK[kNumWarps * NumWarpQ]; - __shared__ idx_t smemV[kNumWarps * NumWarpQ]; + __shared__ float smemK[kNumWarps * NumWarpQ]; + __shared__ idx_t smemV[kNumWarps * NumWarpQ]; - BlockSelect< - float, - idx_t, - Metric::kDirection, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(kInit, -1, smemK, smemV, k); + BlockSelect< + float, + idx_t, + Metric::kDirection, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(kInit, -1, smemK, smemV, k); - // The codec might be dependent upon data that we need to reference or - // store in shared memory - codec.initKernel(smem, dim); - __syncthreads(); + // The codec might be dependent upon data that we need to reference + // or store in shared memory + codec.initKernel(smem, dim); + __syncthreads(); - // How many vector blocks of 32 are in this list? - idx_t numBlocks = utils::divUp(numVecs, (idx_t)32); + // How many vector blocks of kWarpSize are in this list? + idx_t numBlocks = utils::divUp(numVecs, (idx_t)kWarpSize); - // Number of EncodeT words per each dimension of block of 32 vecs - constexpr int bytesPerVectorBlockDim = Codec::kEncodeBits * 32 / 8; - constexpr int wordsPerVectorBlockDim = - bytesPerVectorBlockDim / sizeof(EncodeT); - int wordsPerVectorBlock = wordsPerVectorBlockDim * dim; + // Number of EncodeT words per each dimension of block of kWarpSize + // vecs + constexpr int bytesPerVectorBlockDim = + Codec::kEncodeBits * kWarpSize / 8; + constexpr int wordsPerVectorBlockDim = + bytesPerVectorBlockDim / sizeof(EncodeT); + int wordsPerVectorBlock = wordsPerVectorBlockDim * dim; - int dimBlocks = utils::roundDown(dim, kWarpSize); + int dimBlocks = utils::roundDown(dim, kWarpSize); - for (idx_t block = warpId; block < numBlocks; block += kNumWarps) { - // We're handling a new vector - Metric dist = metric.zero(); + for (idx_t block = warpId; block < numBlocks; block += kNumWarps) { + // We're handling a new vector + Metric dist = metric.zero(); - // This is the vector a given lane/thread handles - idx_t vec = block * kWarpSize + laneId; - bool valid = vec < numVecs; + // This is the vector a given lane/thread handles + idx_t vec = block * kWarpSize + laneId; + bool valid = vec < numVecs; - // This is where this warp begins reading data - EncodeT* data = vecsBase + block * wordsPerVectorBlock; + // This is where this warp begins reading data + EncodeT* data = vecsBase + block * wordsPerVectorBlock; - // whole blocks - for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) { - const int loadDim = dBase + laneId; - const float queryReg = query[loadDim]; - const float residualReg = - Residual ? residualBaseSlice[loadDim] : 0; + // whole blocks + for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) { + const int loadDim = dBase + laneId; + const float queryReg = query[loadDim]; + const float residualReg = + Residual ? residualBaseSlice[loadDim] : 0; - constexpr int kUnroll = 4; + constexpr int kUnroll = 4; #pragma unroll - for (int i = 0; i < kWarpSize / kUnroll; - ++i, data += kUnroll * wordsPerVectorBlockDim) { - EncodeT encV[kUnroll]; + for (int i = 0; i < kWarpSize / kUnroll; + ++i, data += kUnroll * wordsPerVectorBlockDim) { + EncodeT encV[kUnroll]; #pragma unroll - for (int j = 0; j < kUnroll; ++j) { - encV[j] = WarpPackedBits:: - read(laneId, data + j * wordsPerVectorBlockDim); - } + for (int j = 0; j < kUnroll; ++j) { + encV[j] = WarpPackedBits< + EncodeT, + Codec::kEncodeBits>:: + read(laneId, + data + j * wordsPerVectorBlockDim); + } #pragma unroll - for (int j = 0; j < kUnroll; ++j) { - encV[j] = WarpPackedBits:: - postRead(laneId, encV[j]); - } + for (int j = 0; j < kUnroll; ++j) { + encV[j] = WarpPackedBits< + EncodeT, + Codec::kEncodeBits>:: + postRead(laneId, encV[j]); + } - float decV[kUnroll]; + float decV[kUnroll]; #pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int d = i * kUnroll + j; - decV[j] = codec.decodeNew(dBase + d, encV[j]); - } + for (int j = 0; j < kUnroll; ++j) { + int d = i * kUnroll + j; + decV[j] = codec.decodeNew(dBase + d, encV[j]); + } + + if (Residual) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int d = i * kUnroll + j; + decV[j] += SHFL_SYNC(residualReg, d, kWarpSize); + } + } - if (Residual) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { int d = i * kUnroll + j; - decV[j] += SHFL_SYNC(residualReg, d, kWarpSize); + float q = SHFL_SYNC(queryReg, d, kWarpSize); + dist.handle(q, decV[j]); } } + } -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int d = i * kUnroll + j; - float q = SHFL_SYNC(queryReg, d, kWarpSize); - dist.handle(q, decV[j]); + // remainder + const int loadDim = dimBlocks + laneId; + const bool loadDimInBounds = loadDim < dim; + + const float queryReg = loadDimInBounds ? query[loadDim] : 0; + const float residualReg = Residual && loadDimInBounds + ? residualBaseSlice[loadDim] + : 0; + + for (int d = 0; d < dim - dimBlocks; + ++d, data += wordsPerVectorBlockDim) { + float q = SHFL_SYNC(queryReg, d, kWarpSize); + + EncodeT enc = + WarpPackedBits::read( + laneId, data); + enc = WarpPackedBits::postRead( + laneId, enc); + float dec = codec.decodeNew(dimBlocks + d, enc); + if (Residual) { + dec += SHFL_SYNC(residualReg, d, kWarpSize); } - } - } - // remainder - const int loadDim = dimBlocks + laneId; - const bool loadDimInBounds = loadDim < dim; - - const float queryReg = loadDimInBounds ? query[loadDim] : 0; - const float residualReg = Residual && loadDimInBounds - ? residualBaseSlice[loadDim] - : 0; - - for (int d = 0; d < dim - dimBlocks; - ++d, data += wordsPerVectorBlockDim) { - float q = SHFL_SYNC(queryReg, d, kWarpSize); - - EncodeT enc = WarpPackedBits::read( - laneId, data); - enc = WarpPackedBits::postRead( - laneId, enc); - float dec = codec.decodeNew(dimBlocks + d, enc); - if (Residual) { - dec += SHFL_SYNC(residualReg, d, kWarpSize); + dist.handle(q, dec); } - dist.handle(q, dec); - } + if (valid) { + heap.addThreadQ(dist.reduce(), vec); + } - if (valid) { - heap.addThreadQ(dist.reduce(), vec); + heap.checkThreadQ(); } - heap.checkThreadQ(); - } - - heap.reduce(); + heap.reduce(); - auto distanceOutBase = distanceOut[queryId][probeId].data(); - auto indicesOutBase = indicesOut[queryId][probeId].data(); + auto distanceOutBase = distanceOut[queryId][probeId].data(); + auto indicesOutBase = indicesOut[queryId][probeId].data(); - for (int i = threadIdx.x; i < k; i += blockDim.x) { - distanceOutBase[i] = smemK[i]; - indicesOutBase[i] = smemV[i]; + for (int i = threadIdx.x; i < k; i += blockDim.x) { + distanceOutBase[i] = smemK[i]; + indicesOutBase[i] = smemV[i]; + } } } } @@ -218,7 +228,7 @@ __global__ void ivfInterleavedScan( // compile time using these macros to define the function body // -// Top-level IVF scan function for the interleaved by 32 layout +// Top-level IVF scan function for the interleaved by kWarpSize layout // with all implementations void runIVFInterleavedScan( Tensor& queries, diff --git a/faiss/gpu/impl/IVFPQ.cu b/faiss/gpu/impl/IVFPQ.cu index f83a764db9..a19f465b93 100644 --- a/faiss/gpu/impl/IVFPQ.cu +++ b/faiss/gpu/impl/IVFPQ.cu @@ -261,14 +261,16 @@ size_t IVFPQ::getGpuVectorsEncodingSize_(idx_t numVecs) const { // bits per PQ code idx_t bits = bitsPerSubQuantizer_; - // bytes to encode a block of 32 vectors (single PQ code) - idx_t bytesPerDimBlock = bits * 32 / 8; + int warpSize = getWarpSizeCurrentDevice(); - // bytes to fully encode 32 vectors + // bytes to encode a block of warpSize vectors (single PQ code) + idx_t bytesPerDimBlock = bits * warpSize / 8; + + // bytes to fully encode warpSize vectors idx_t bytesPerBlock = bytesPerDimBlock * numSubQuantizers_; - // number of blocks of 32 vectors we have - idx_t numBlocks = utils::divUp(numVecs, idx_t(32)); + // number of blocks of warpSize vectors we have + idx_t numBlocks = utils::divUp(numVecs, idx_t(warpSize)); // total size to encode numVecs return bytesPerBlock * numBlocks; diff --git a/faiss/gpu/impl/IVFUtilsSelect1.cu b/faiss/gpu/impl/IVFUtilsSelect1.cu index 92cae1b1b6..b778190bbb 100644 --- a/faiss/gpu/impl/IVFUtilsSelect1.cu +++ b/faiss/gpu/impl/IVFUtilsSelect1.cu @@ -34,62 +34,66 @@ __global__ void pass1SelectLists( int k, Tensor heapDistances, Tensor heapIndices) { - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ float smemK[kNumWarps * NumWarpQ]; - __shared__ IndexT smemV[kNumWarps * NumWarpQ]; - - for (IndexT queryId = blockIdx.y; queryId < prefixSumOffsets.getSize(0); - queryId += gridDim.y) { - constexpr auto kInit = Dir ? kFloatMin : kFloatMax; - BlockSelect< - float, - IndexT, - Dir, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(kInit, -1, smemK, smemV, k); - - auto sliceId = blockIdx.x; - auto numSlices = gridDim.x; - - IndexT sliceSize = (nprobe / numSlices); - IndexT sliceStart = sliceSize * sliceId; - IndexT sliceEnd = - sliceId == (numSlices - 1) ? nprobe : sliceStart + sliceSize; - auto offsets = prefixSumOffsets[queryId].data(); - - // We ensure that before the array (at offset -1), there is a 0 value - auto start = *(&offsets[sliceStart] - 1); - auto end = offsets[sliceEnd - 1]; - - auto num = end - start; - auto limit = utils::roundDown(num, (IndexT)kWarpSize); - - IndexT i = threadIdx.x; - auto distanceStart = distance[start].data(); - - // BlockSelect add cannot be used in a warp divergent circumstance; we - // handle the remainder warp below - for (; i < limit; i += blockDim.x) { - heap.add(distanceStart[i], IndexT(start + i)); - } - - // Handle the remainder if any separately (warp is divergent) - if (i < num) { - heap.addThreadQ(distanceStart[i], IndexT(start + i)); - } - - // Merge all final results - heap.reduce(); - - // Write out the final k-selected values; they should be all - // together - for (int i = threadIdx.x; i < k; i += blockDim.x) { - heapDistances[queryId][sliceId][i] = smemK[i]; - heapIndices[queryId][sliceId][i] = idx_t(smemV[i]); + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ float smemK[kNumWarps * NumWarpQ]; + __shared__ IndexT smemV[kNumWarps * NumWarpQ]; + + for (IndexT queryId = blockIdx.y; queryId < prefixSumOffsets.getSize(0); + queryId += gridDim.y) { + constexpr auto kInit = Dir ? kFloatMin : kFloatMax; + BlockSelect< + float, + IndexT, + Dir, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(kInit, -1, smemK, smemV, k); + + auto sliceId = blockIdx.x; + auto numSlices = gridDim.x; + + IndexT sliceSize = (nprobe / numSlices); + IndexT sliceStart = sliceSize * sliceId; + IndexT sliceEnd = sliceId == (numSlices - 1) + ? nprobe + : sliceStart + sliceSize; + auto offsets = prefixSumOffsets[queryId].data(); + + // We ensure that before the array (at offset -1), there is a 0 + // value + auto start = *(&offsets[sliceStart] - 1); + auto end = offsets[sliceEnd - 1]; + + auto num = end - start; + auto limit = utils::roundDown(num, (IndexT)kWarpSize); + + IndexT i = threadIdx.x; + auto distanceStart = distance[start].data(); + + // BlockSelect add cannot be used in a warp divergent circumstance; + // we handle the remainder warp below + for (; i < limit; i += blockDim.x) { + heap.add(distanceStart[i], IndexT(start + i)); + } + + // Handle the remainder if any separately (warp is divergent) + if (i < num) { + heap.addThreadQ(distanceStart[i], IndexT(start + i)); + } + + // Merge all final results + heap.reduce(); + + // Write out the final k-selected values; they should be all + // together + for (int i = threadIdx.x; i < k; i += blockDim.x) { + heapDistances[queryId][sliceId][i] = smemK[i]; + heapIndices[queryId][sliceId][i] = idx_t(smemV[i]); + } } } } @@ -129,46 +133,46 @@ void runPass1SelectLists( #if GPU_MAX_SELECTION_K >= 2048 // block size 128 for k <= 1024, 64 for k = 2048 -#define RUN_PASS_DIR(INDEX_T, DIR) \ - do { \ - if (k == 1) { \ - RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ - } else if (k <= 32) { \ - RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ - } else if (k <= 64) { \ - RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ - } else if (k <= 128) { \ - RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ - } else if (k <= 256) { \ - RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ - } else if (k <= 512) { \ - RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ - } else if (k <= 1024) { \ - RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ - } else if (k <= 2048) { \ - RUN_PASS(INDEX_T, 64, 2048, 8, DIR); \ - } \ +#define RUN_PASS_DIR(INDEX_T, DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { \ + RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ + } else if (k <= 2048) { \ + RUN_PASS(INDEX_T, 64, 2048, 8, DIR); \ + } \ } while (0) #else -#define RUN_PASS_DIR(INDEX_T, DIR) \ - do { \ - if (k == 1) { \ - RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ - } else if (k <= 32) { \ - RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ - } else if (k <= 64) { \ - RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ - } else if (k <= 128) { \ - RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ - } else if (k <= 256) { \ - RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ - } else if (k <= 512) { \ - RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ - } else if (k <= 1024) { \ - RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ - } \ +#define RUN_PASS_DIR(INDEX_T, DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { \ + RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ + } \ } while (0) #endif // GPU_MAX_SELECTION_K diff --git a/faiss/gpu/impl/IVFUtilsSelect2.cu b/faiss/gpu/impl/IVFUtilsSelect2.cu index 6587069a71..9141a19959 100644 --- a/faiss/gpu/impl/IVFUtilsSelect2.cu +++ b/faiss/gpu/impl/IVFUtilsSelect2.cu @@ -62,92 +62,95 @@ __global__ void pass2SelectLists( IndicesOptions opt, Tensor outDistances, Tensor outIndices) { - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ float smemK[kNumWarps * NumWarpQ]; - __shared__ IndexT smemV[kNumWarps * NumWarpQ]; - - constexpr auto kInit = Dir ? kFloatMin : kFloatMax; - BlockSelect< - float, - IndexT, - Dir, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(kInit, -1, smemK, smemV, k); - - auto queryId = blockIdx.x; - idx_t num = heapDistances.getSize(1); - idx_t limit = utils::roundDown(num, kWarpSize); - - idx_t i = threadIdx.x; - auto heapDistanceStart = heapDistances[queryId]; - - // BlockSelect add cannot be used in a warp divergent circumstance; we - // handle the remainder warp below - for (; i < limit; i += blockDim.x) { - heap.add(heapDistanceStart[i], IndexT(i)); - } + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ float smemK[kNumWarps * NumWarpQ]; + __shared__ IndexT smemV[kNumWarps * NumWarpQ]; + + constexpr auto kInit = Dir ? kFloatMin : kFloatMax; + BlockSelect< + float, + IndexT, + Dir, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(kInit, -1, smemK, smemV, k); + + auto queryId = blockIdx.x; + idx_t num = heapDistances.getSize(1); + idx_t limit = utils::roundDown(num, kWarpSize); + + idx_t i = threadIdx.x; + auto heapDistanceStart = heapDistances[queryId]; + + // BlockSelect add cannot be used in a warp divergent circumstance; we + // handle the remainder warp below + for (; i < limit; i += blockDim.x) { + heap.add(heapDistanceStart[i], IndexT(i)); + } - // Handle warp divergence separately - if (i < num) { - heap.addThreadQ(heapDistanceStart[i], IndexT(i)); - } + // Handle warp divergence separately + if (i < num) { + heap.addThreadQ(heapDistanceStart[i], IndexT(i)); + } - // Merge all final results - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += blockDim.x) { - outDistances[queryId][i] = smemK[i]; - - // `v` is the index in `heapIndices` - // We need to translate this into an original user index. The - // reason why we don't maintain intermediate results in terms of - // user indices is to substantially reduce temporary memory - // requirements and global memory write traffic for the list - // scanning. - // This code is highly divergent, but it's probably ok, since this - // is the very last step and it is happening a small number of - // times (#queries x k). - idx_t v = smemV[i]; - idx_t index = -1; - - if (v != -1) { - // `offset` is the offset of the intermediate result, as - // calculated by the original scan. - idx_t offset = heapIndices[queryId][v]; - - // In order to determine the actual user index, we need to first - // determine what list it was in. - // We do this by binary search in the prefix sum list. - idx_t probe = binarySearchForBucket( - prefixSumOffsets[queryId].data(), - prefixSumOffsets.getSize(1), - offset); - - // This is then the probe for the query; we can find the actual - // list ID from this - idx_t listId = ivfListIds[queryId][probe]; - - // Now, we need to know the offset within the list - // We ensure that before the array (at offset -1), there is a 0 - // value - idx_t listStart = *(prefixSumOffsets[queryId][probe].data() - 1); - idx_t listOffset = offset - listStart; - - // This gives us our final index - if (opt == INDICES_32_BIT) { - index = (idx_t)((int*)listIndices[listId])[listOffset]; - } else if (opt == INDICES_64_BIT) { - index = ((idx_t*)listIndices[listId])[listOffset]; - } else { - index = (listId << 32 | (idx_t)listOffset); + // Merge all final results + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[queryId][i] = smemK[i]; + + // `v` is the index in `heapIndices` + // We need to translate this into an original user index. The + // reason why we don't maintain intermediate results in terms of + // user indices is to substantially reduce temporary memory + // requirements and global memory write traffic for the list + // scanning. + // This code is highly divergent, but it's probably ok, since this + // is the very last step and it is happening a small number of + // times (#queries x k). + idx_t v = smemV[i]; + idx_t index = -1; + + if (v != -1) { + // `offset` is the offset of the intermediate result, as + // calculated by the original scan. + idx_t offset = heapIndices[queryId][v]; + + // In order to determine the actual user index, we need to first + // determine what list it was in. + // We do this by binary search in the prefix sum list. + idx_t probe = binarySearchForBucket( + prefixSumOffsets[queryId].data(), + prefixSumOffsets.getSize(1), + offset); + + // This is then the probe for the query; we can find the actual + // list ID from this + idx_t listId = ivfListIds[queryId][probe]; + + // Now, we need to know the offset within the list + // We ensure that before the array (at offset -1), there is a 0 + // value + idx_t listStart = + *(prefixSumOffsets[queryId][probe].data() - 1); + idx_t listOffset = offset - listStart; + + // This gives us our final index + if (opt == INDICES_32_BIT) { + index = (idx_t)((int*)listIndices[listId])[listOffset]; + } else if (opt == INDICES_64_BIT) { + index = ((idx_t*)listIndices[listId])[listOffset]; + } else { + index = (listId << 32 | (idx_t)listOffset); + } } - } - outIndices[queryId][i] = index; + outIndices[queryId][i] = index; + } } } @@ -187,46 +190,46 @@ void runPass2SelectLists( #if GPU_MAX_SELECTION_K >= 2048 // block size 128 for k <= 1024, 64 for k = 2048 -#define RUN_PASS_DIR(INDEX_T, DIR) \ - do { \ - if (k == 1) { \ - RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ - } else if (k <= 32) { \ - RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ - } else if (k <= 64) { \ - RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ - } else if (k <= 128) { \ - RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ - } else if (k <= 256) { \ - RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ - } else if (k <= 512) { \ - RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ - } else if (k <= 1024) { \ - RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ - } else if (k <= 2048) { \ - RUN_PASS(INDEX_T, 64, 2048, 8, DIR); \ - } \ +#define RUN_PASS_DIR(INDEX_T, DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { \ + RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ + } else if (k <= 2048) { \ + RUN_PASS(INDEX_T, 64, 2048, 8, DIR); \ + } \ } while (0) #else -#define RUN_PASS_DIR(INDEX_T, DIR) \ - do { \ - if (k == 1) { \ - RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ - } else if (k <= 32) { \ - RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ - } else if (k <= 64) { \ - RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ - } else if (k <= 128) { \ - RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ - } else if (k <= 256) { \ - RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ - } else if (k <= 512) { \ - RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ - } else if (k <= 1024) { \ - RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ - } \ +#define RUN_PASS_DIR(INDEX_T, DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(INDEX_T, 128, 1, 1, DIR); \ + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { \ + RUN_PASS(INDEX_T, 128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(INDEX_T, 128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(INDEX_T, 128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(INDEX_T, 128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(INDEX_T, 128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(INDEX_T, 128, 1024, 8, DIR); \ + } \ } while (0) #endif // GPU_MAX_SELECTION_K diff --git a/faiss/gpu/impl/IcmEncoder.cu b/faiss/gpu/impl/IcmEncoder.cu index 81f234f2b6..13cb297282 100644 --- a/faiss/gpu/impl/IcmEncoder.cu +++ b/faiss/gpu/impl/IcmEncoder.cu @@ -313,9 +313,10 @@ void IcmEncoderImpl::encode( res.get(), makeTempAlloc(AllocType::Other, stream), {n}); // compute how much shared memory we need - const int evaluateSmem = sizeof(float) * (dims + kWarpSize - 1) / kWarpSize; + int warpSize = getWarpSizeCurrentDevice(); + const int evaluateSmem = sizeof(float) * (dims + warpSize - 1) / warpSize; const int encodeSmem = - sizeof(Pair) * (K + kWarpSize - 1) / kWarpSize; + sizeof(Pair) * (K + warpSize - 1) / warpSize; // compute the reconstruction error for each vector runEvaluation<<>>( diff --git a/faiss/gpu/impl/InterleavedCodes.cpp b/faiss/gpu/impl/InterleavedCodes.cpp index 801ff72180..bd9464d5c8 100644 --- a/faiss/gpu/impl/InterleavedCodes.cpp +++ b/faiss/gpu/impl/InterleavedCodes.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -166,15 +167,16 @@ void unpackInterleavedWord( int numVecs, int dims, int bitsPerCode) { - int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T)); + int warpSize = getWarpSizeCurrentDevice(); + int wordsPerDimBlock = warpSize * bitsPerCode / (8 * sizeof(T)); int wordsPerBlock = wordsPerDimBlock * dims; - int numBlocks = utils::divUp(numVecs, 32); + int numBlocks = utils::divUp(numVecs, warpSize); #pragma omp parallel for for (int i = 0; i < numVecs; ++i) { - int block = i / 32; + int block = i / warpSize; FAISS_ASSERT(block < numBlocks); - int lane = i % 32; + int lane = i % warpSize; for (int j = 0; j < dims; ++j) { int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane; @@ -188,9 +190,10 @@ std::vector unpackInterleaved( int numVecs, int dims, int bitsPerCode) { - int bytesPerDimBlock = 32 * bitsPerCode / 8; + int warpSize = getWarpSizeCurrentDevice(); + int bytesPerDimBlock = warpSize * bitsPerCode / 8; int bytesPerBlock = bytesPerDimBlock * dims; - int numBlocks = utils::divUp(numVecs, 32); + int numBlocks = utils::divUp(numVecs, warpSize); size_t totalSize = (size_t)bytesPerBlock * numBlocks; FAISS_ASSERT(data.size() == totalSize); @@ -217,8 +220,8 @@ std::vector unpackInterleaved( } else if (bitsPerCode == 4) { #pragma omp parallel for for (int i = 0; i < numVecs; ++i) { - int block = i / 32; - int lane = i % 32; + int block = i / warpSize; + int lane = i % warpSize; int word = lane / 2; int subWord = lane % 2; @@ -235,8 +238,8 @@ std::vector unpackInterleaved( } else if (bitsPerCode == 5) { #pragma omp parallel for for (int i = 0; i < numVecs; ++i) { - int block = i / 32; - int blockVector = i % 32; + int block = i / warpSize; + int blockVector = i % warpSize; for (int j = 0; j < dims; ++j) { uint8_t* dimBlock = @@ -257,8 +260,8 @@ std::vector unpackInterleaved( } else if (bitsPerCode == 6) { #pragma omp parallel for for (int i = 0; i < numVecs; ++i) { - int block = i / 32; - int blockVector = i % 32; + int block = i / warpSize; + int blockVector = i % warpSize; for (int j = 0; j < dims; ++j) { uint8_t* dimBlock = @@ -442,17 +445,18 @@ void packInterleavedWord( int numVecs, int dims, int bitsPerCode) { - int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T)); + int warpSize = getWarpSizeCurrentDevice(); + int wordsPerDimBlock = warpSize * bitsPerCode / (8 * sizeof(T)); int wordsPerBlock = wordsPerDimBlock * dims; - int numBlocks = utils::divUp(numVecs, 32); + int numBlocks = utils::divUp(numVecs, warpSize); // We're guaranteed that all other slots not filled by the vectors present // are initialized to zero (from the vector constructor in packInterleaved) #pragma omp parallel for for (int i = 0; i < numVecs; ++i) { - int block = i / 32; + int block = i / warpSize; FAISS_ASSERT(block < numBlocks); - int lane = i % 32; + int lane = i % warpSize; for (int j = 0; j < dims; ++j) { int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane; @@ -466,9 +470,10 @@ std::vector packInterleaved( int numVecs, int dims, int bitsPerCode) { - int bytesPerDimBlock = 32 * bitsPerCode / 8; + int warpSize = getWarpSizeCurrentDevice(); + int bytesPerDimBlock = warpSize * bitsPerCode / 8; int bytesPerBlock = bytesPerDimBlock * dims; - int numBlocks = utils::divUp(numVecs, 32); + int numBlocks = utils::divUp(numVecs, warpSize); size_t totalSize = (size_t)bytesPerBlock * numBlocks; // bit codes padded to whole bytes @@ -499,7 +504,7 @@ std::vector packInterleaved( for (int i = 0; i < numBlocks; ++i) { for (int j = 0; j < dims; ++j) { for (int k = 0; k < bytesPerDimBlock; ++k) { - int loVec = i * 32 + k * 2; + int loVec = i * warpSize + k * 2; int hiVec = loVec + 1; uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0; @@ -516,7 +521,7 @@ std::vector packInterleaved( for (int j = 0; j < dims; ++j) { for (int k = 0; k < bytesPerDimBlock; ++k) { // What input vectors we are pulling from - int loVec = i * 32 + (k * 8) / 5; + int loVec = i * warpSize + (k * 8) / 5; int hiVec = loVec + 1; int hiVec2 = hiVec + 1; @@ -536,7 +541,7 @@ std::vector packInterleaved( for (int j = 0; j < dims; ++j) { for (int k = 0; k < bytesPerDimBlock; ++k) { // What input vectors we are pulling from - int loVec = i * 32 + (k * 8) / 6; + int loVec = i * warpSize + (k * 8) / 6; int hiVec = loVec + 1; uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0; diff --git a/faiss/gpu/impl/L2Norm.cu b/faiss/gpu/impl/L2Norm.cu index e0db8e2b69..1834048d23 100644 --- a/faiss/gpu/impl/L2Norm.cu +++ b/faiss/gpu/impl/L2Norm.cu @@ -199,6 +199,7 @@ void runL2Norm( // Row-major kernel /// + int warpSize = getWarpSizeCurrentDevice(); if (input.template canCastResize()) { // Can load using the vectorized type auto inputV = input.template castResize(); @@ -212,7 +213,7 @@ void runL2Norm( auto block = dim3(numThreads); auto smem = sizeof(float) * rowTileSize * - utils::divUp(numThreads, kWarpSize); + utils::divUp(numThreads, warpSize); RUN_L2_ROW_MAJOR(T, TVec, inputV); } else { @@ -227,7 +228,7 @@ void runL2Norm( auto block = dim3(numThreads); auto smem = sizeof(float) * rowTileSize * - utils::divUp(numThreads, kWarpSize); + utils::divUp(numThreads, warpSize); RUN_L2_ROW_MAJOR(T, T, input); } diff --git a/faiss/gpu/impl/L2Norm.cuh b/faiss/gpu/impl/L2Norm.cuh index db4b3dd148..c0447b91c1 100644 --- a/faiss/gpu/impl/L2Norm.cuh +++ b/faiss/gpu/impl/L2Norm.cuh @@ -7,6 +7,7 @@ #pragma once +#include #include namespace faiss { diff --git a/faiss/gpu/impl/L2Select.cu b/faiss/gpu/impl/L2Select.cu index 91c9ee1fda..660c513933 100644 --- a/faiss/gpu/impl/L2Select.cu +++ b/faiss/gpu/impl/L2Select.cu @@ -141,45 +141,47 @@ __global__ void l2SelectMinK( Tensor outIndices, int k, T initK) { - // Each block handles a single row of the distances (results) - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ T smemK[kNumWarps * NumWarpQ]; - __shared__ IndexT smemV[kNumWarps * NumWarpQ]; - - BlockSelect< - T, - IndexT, - false, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(initK, -1, smemK, smemV, k); - - IndexT row = blockIdx.x; - - // Whole warps must participate in the selection - IndexT limit = utils::roundDown(productDistances.getSize(1), kWarpSize); - IndexT i = threadIdx.x; - - for (; i < limit; i += blockDim.x) { - T v = Math::add(centroidDistances[i], productDistances[row][i]); - heap.add(v, IndexT(i)); - } + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + // Each block handles a single row of the distances (results) + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - // Handle the remainder if any separately (warp is divergent) - if (i < productDistances.getSize(1)) { - T v = Math::add(centroidDistances[i], productDistances[row][i]); - heap.addThreadQ(v, IndexT(i)); - } + __shared__ T smemK[kNumWarps * NumWarpQ]; + __shared__ IndexT smemV[kNumWarps * NumWarpQ]; + + BlockSelect< + T, + IndexT, + false, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(initK, -1, smemK, smemV, k); + + IndexT row = blockIdx.x; - // Merge all final results - heap.reduce(); + // Whole warps must participate in the selection + IndexT limit = utils::roundDown(productDistances.getSize(1), kWarpSize); + IndexT i = threadIdx.x; - for (int i = threadIdx.x; i < k; i += blockDim.x) { - outDistances[row][i] = smemK[i]; - outIndices[row][i] = idx_t(smemV[i]); + for (; i < limit; i += blockDim.x) { + T v = Math::add(centroidDistances[i], productDistances[row][i]); + heap.add(v, IndexT(i)); + } + + // Handle the remainder if any separately (warp is divergent) + if (i < productDistances.getSize(1)) { + T v = Math::add(centroidDistances[i], productDistances[row][i]); + heap.addThreadQ(v, IndexT(i)); + } + + // Merge all final results + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[row][i] = smemK[i]; + outIndices[row][i] = idx_t(smemV[i]); + } } } @@ -241,7 +243,7 @@ void runL2SelectMin( } while (false) // block size 128 for everything <= 1024 - if (k <= 32) { + if (k <= 32 && getWarpSizeCurrentDevice() == 32) { RUN_L2_SELECT(128, 32, 2); } else if (k <= 64) { RUN_L2_SELECT(128, 64, 3); diff --git a/faiss/gpu/impl/PQCodeDistances-inl.cuh b/faiss/gpu/impl/PQCodeDistances-inl.cuh index a05143e9d0..4306426e2d 100644 --- a/faiss/gpu/impl/PQCodeDistances-inl.cuh +++ b/faiss/gpu/impl/PQCodeDistances-inl.cuh @@ -20,6 +20,12 @@ namespace faiss { namespace gpu { +#if defined(USE_ROCM) && __AMDGCN_WAVEFRONT_SIZE == 64u +#define LAUNCH_BOUND 320 +#else +#define LAUNCH_BOUND 288 +#endif + // Kernel responsible for calculating distance from residual vector to // each product quantizer code centroid template < @@ -27,7 +33,7 @@ template < typename CentroidT, int DimsPerSubQuantizer, bool L2Distance> -__global__ void __launch_bounds__(288, 3) pqCodeDistances( +__global__ void __launch_bounds__(LAUNCH_BOUND, 3) pqCodeDistances( Tensor queries, int queriesPerBlock, Tensor coarseCentroids, @@ -632,7 +638,8 @@ void runPQCodeDistances( // Reserve one block of threads for double buffering // FIXME: probably impractical for large # of dims? - auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize); + int warpSize = getWarpSizeCurrentDevice(); + auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, warpSize); auto block = dim3(codesPerSubQuantizer + loadingThreads); auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + diff --git a/faiss/gpu/impl/PQCodeLoad.cuh b/faiss/gpu/impl/PQCodeLoad.cuh index e9b8fc9c92..a37e908a1d 100644 --- a/faiss/gpu/impl/PQCodeLoad.cuh +++ b/faiss/gpu/impl/PQCodeLoad.cuh @@ -47,6 +47,237 @@ inline __device__ unsigned int getByte(uint64_t v, int pos, int width) { return getBitfield(v, pos, width); } +#ifdef USE_ROCM + +template +struct LoadCode32 {}; + +template <> +struct LoadCode32<1> { + static inline __device__ void load( + unsigned int code32[1], + uint8_t* p, + int offset) { + p += offset * 1; + // using T = uint8_t __attribute__((ext_vector_type(1))); + // T* t = reinterpret_cast(p); + uint8_t* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(p); + } +}; + +template <> +struct LoadCode32<2> { + static inline __device__ void load( + unsigned int code32[1], + uint8_t* p, + int offset) { + p += offset * 2; + using T = uint8_t __attribute__((ext_vector_type(2))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<3> { + static inline __device__ void load( + unsigned int code32[1], + uint8_t* p, + int offset) { + p += offset * 3; + using T = uint8_t __attribute__((ext_vector_type(3))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[1] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<4> { + static inline __device__ void load( + unsigned int code32[1], + uint8_t* p, + int offset) { + p += offset * 4; + using T = uint32_t __attribute__((ext_vector_type(1))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<8> { + static inline __device__ void load( + unsigned int code32[2], + uint8_t* p, + int offset) { + p += offset * 8; + using T = uint32_t __attribute__((ext_vector_type(2))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<12> { + static inline __device__ void load( + unsigned int code32[3], + uint8_t* p, + int offset) { + p += offset * 12; + using T = uint32_t __attribute__((ext_vector_type(3))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<16> { + static inline __device__ void load( + unsigned int code32[4], + uint8_t* p, + int offset) { + p += offset * 16; + using T = uint32_t __attribute__((ext_vector_type(4))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<20> { + static inline __device__ void load( + unsigned int code32[5], + uint8_t* p, + int offset) { + p += offset * 20; + using T = uint32_t __attribute__((ext_vector_type(5))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<24> { + static inline __device__ void load( + unsigned int code32[6], + uint8_t* p, + int offset) { + p += offset * 24; + using T = uint32_t __attribute__((ext_vector_type(6))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<28> { + static inline __device__ void load( + unsigned int code32[7], + uint8_t* p, + int offset) { + p += offset * 28; + using T = uint32_t __attribute__((ext_vector_type(7))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<32> { + static inline __device__ void load( + unsigned int code32[8], + uint8_t* p, + int offset) { + p += offset * 32; + using T = uint32_t __attribute__((ext_vector_type(8))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<40> { + static inline __device__ void load( + unsigned int code32[10], + uint8_t* p, + int offset) { + p += offset * 40; + using T = uint32_t __attribute__((ext_vector_type(10))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<48> { + static inline __device__ void load( + unsigned int code32[12], + uint8_t* p, + int offset) { + p += offset * 48; + using T = uint32_t __attribute__((ext_vector_type(12))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<56> { + static inline __device__ void load( + unsigned int code32[14], + uint8_t* p, + int offset) { + p += offset * 56; + using T = uint32_t __attribute__((ext_vector_type(14))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<64> { + static inline __device__ void load( + unsigned int code32[16], + uint8_t* p, + int offset) { + p += offset * 64; + using T = uint32_t __attribute__((ext_vector_type(16))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +template <> +struct LoadCode32<96> { + static inline __device__ void load( + unsigned int code32[24], + uint8_t* p, + int offset) { + p += offset * 96; + using T = uint32_t __attribute__((ext_vector_type(24))); + T* t = reinterpret_cast(p); + T* u = reinterpret_cast(code32); + u[0] = __builtin_nontemporal_load(t); + } +}; + +#else // USE_ROCM + template struct LoadCode32 {}; @@ -378,5 +609,7 @@ struct LoadCode32<96> { } }; +#endif // USE_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh b/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh index a162beed85..8a1fd82c32 100644 --- a/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh +++ b/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh @@ -61,11 +61,11 @@ __global__ void pqScanInterleaved( auto vecsBase = (EncodeT*)listCodes[listId]; auto numVecs = listLengths[listId]; - // How many vector blocks of 32 are in this list? - idx_t numBlocks = utils::divUp(numVecs, (idx_t)32); + // How many vector blocks of kWarpSize are in this list? + idx_t numBlocks = utils::divUp(numVecs, (idx_t)kWarpSize); - // Number of EncodeT words per each dimension of block of 32 vecs - constexpr int bytesPerVectorBlockDim = EncodeBits * 32 / 8; + // Number of EncodeT words per each dimension of block of kWarpSize vecs + constexpr int bytesPerVectorBlockDim = EncodeBits * kWarpSize / 8; constexpr int wordsPerVectorBlockDim = bytesPerVectorBlockDim / sizeof(EncodeT); int wordsPerVectorBlock = wordsPerVectorBlockDim * numSubQuantizers; diff --git a/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu b/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu index 2f31ed9fc2..f50388a08c 100644 --- a/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu +++ b/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu @@ -66,11 +66,11 @@ __global__ void pqScanPrecomputedInterleaved( auto vecsBase = (EncodeT*)listCodes[listId]; idx_t numVecs = listLengths[listId]; - // How many vector blocks of 32 are in this list? - idx_t numBlocks = utils::divUp(numVecs, idx_t(32)); + // How many vector blocks of kWarpSize are in this list? + idx_t numBlocks = utils::divUp(numVecs, idx_t(kWarpSize)); - // Number of EncodeT words per each dimension of block of 32 vecs - constexpr idx_t bytesPerVectorBlockDim = EncodeBits * 32 / 8; + // Number of EncodeT words per each dimension of block of kWarpSize vecs + constexpr idx_t bytesPerVectorBlockDim = EncodeBits * kWarpSize / 8; constexpr idx_t wordsPerVectorBlockDim = bytesPerVectorBlockDim / sizeof(EncodeT); idx_t wordsPerVectorBlock = wordsPerVectorBlockDim * numSubQuantizers; diff --git a/faiss/gpu/impl/VectorResidual.cu b/faiss/gpu/impl/VectorResidual.cu index f0cba9a019..ed24a69e1f 100644 --- a/faiss/gpu/impl/VectorResidual.cu +++ b/faiss/gpu/impl/VectorResidual.cu @@ -8,7 +8,11 @@ #include #include #include +#ifdef USE_ROCM +#define CUDART_NAN_F __int_as_float(0x7fffffff) +#else #include // in CUDA SDK, for CUDART_NAN_F +#endif #include #include #include diff --git a/faiss/gpu/impl/VectorResidual.cuh b/faiss/gpu/impl/VectorResidual.cuh index 2803d188ef..6979f00e3c 100644 --- a/faiss/gpu/impl/VectorResidual.cuh +++ b/faiss/gpu/impl/VectorResidual.cuh @@ -10,6 +10,8 @@ #include #include +#include + namespace faiss { namespace gpu { diff --git a/faiss/gpu/test/CMakeLists.txt b/faiss/gpu/test/CMakeLists.txt index 60f78ef74f..2983ddc219 100644 --- a/faiss/gpu/test/CMakeLists.txt +++ b/faiss/gpu/test/CMakeLists.txt @@ -17,12 +17,15 @@ # the License. # ============================================================================= -find_package(CUDAToolkit REQUIRED) - # Defines `gtest_discover_tests()`. include(GoogleTest) add_library(faiss_gpu_test_helper TestUtils.cpp) -target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart $<$:raft::raft> $<$:raft::compiled>) +if(USE_ROCM) + target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest $<$:hip::host>) +else() + find_package(CUDAToolkit REQUIRED) + target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart $<$:raft::raft> $<$:raft::compiled>) +endif() macro(faiss_gpu_test file) get_filename_component(test_name ${file} NAME_WE) @@ -39,8 +42,9 @@ faiss_gpu_test(TestGpuIndexBinaryFlat.cpp) faiss_gpu_test(TestGpuMemoryException.cpp) faiss_gpu_test(TestGpuIndexIVFPQ.cpp) faiss_gpu_test(TestGpuIndexIVFScalarQuantizer.cpp) -faiss_gpu_test(TestGpuDistance.cu) -faiss_gpu_test(TestGpuSelect.cu) +faiss_gpu_test(TestGpuDistance.${GPU_EXT_PREFIX}) +faiss_gpu_test(TestGpuSelect.${GPU_EXT_PREFIX}) + if(FAISS_ENABLE_RAFT) faiss_gpu_test(TestGpuIndexCagra.cu) endif() @@ -48,5 +52,10 @@ endif() add_executable(demo_ivfpq_indexing_gpu EXCLUDE_FROM_ALL demo_ivfpq_indexing_gpu.cpp) -target_link_libraries(demo_ivfpq_indexing_gpu - PRIVATE faiss gtest_main CUDA::cudart) +if (USE_ROCM) + target_link_libraries(demo_ivfpq_indexing_gpu + PRIVATE faiss gtest_main $<$:hip::host>) +else() + target_link_libraries(demo_ivfpq_indexing_gpu + PRIVATE faiss gtest_main CUDA::cudart) +endif() diff --git a/faiss/gpu/test/TestCodePacking.cpp b/faiss/gpu/test/TestCodePacking.cpp index 44720aebe6..5d80150b20 100644 --- a/faiss/gpu/test/TestCodePacking.cpp +++ b/faiss/gpu/test/TestCodePacking.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -119,8 +120,9 @@ TEST(TestCodePacking, InterleavedCodes_UnpackPack) { std::cout << bitsPerCode << " " << dims << " " << numVecs << "\n"; - int blocks = utils::divUp(numVecs, 32); - int bytesPerDimBlock = 32 * bitsPerCode / 8; + int warpSize = getWarpSizeCurrentDevice(); + int blocks = utils::divUp(numVecs, warpSize); + int bytesPerDimBlock = warpSize * bitsPerCode / 8; int bytesPerBlock = bytesPerDimBlock * dims; int size = blocks * bytesPerBlock; @@ -132,9 +134,9 @@ TEST(TestCodePacking, InterleavedCodes_UnpackPack) { for (int i = 0; i < blocks; ++i) { for (int j = 0; j < dims; ++j) { - for (int k = 0; k < 32; ++k) { + for (int k = 0; k < warpSize; ++k) { for (int l = 0; l < bytesPerCode; ++l) { - int vec = i * 32 + k; + int vec = i * warpSize + k; if (vec < numVecs) { data[i * bytesPerBlock + j * bytesPerDimBlock + @@ -148,7 +150,8 @@ TEST(TestCodePacking, InterleavedCodes_UnpackPack) { for (int i = 0; i < blocks; ++i) { for (int j = 0; j < dims; ++j) { for (int k = 0; k < bytesPerDimBlock; ++k) { - int loVec = i * 32 + (k * 8) / bitsPerCode; + int loVec = + i * warpSize + (k * 8) / bitsPerCode; int hiVec = loVec + 1; int hiVec2 = hiVec + 1; diff --git a/faiss/gpu/utils/BlockSelectFloat.cu b/faiss/gpu/utils/BlockSelectFloat.cu index d234cd509e..f76fd6421d 100644 --- a/faiss/gpu/utils/BlockSelectFloat.cu +++ b/faiss/gpu/utils/BlockSelectFloat.cu @@ -55,7 +55,7 @@ void runBlockSelect( if (dir) { if (k == 1) { BLOCK_SELECT_CALL(float, true, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { BLOCK_SELECT_CALL(float, true, 32); } else if (k <= 64) { BLOCK_SELECT_CALL(float, true, 64); @@ -75,7 +75,7 @@ void runBlockSelect( } else { if (k == 1) { BLOCK_SELECT_CALL(float, false, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { BLOCK_SELECT_CALL(float, false, 32); } else if (k <= 64) { BLOCK_SELECT_CALL(float, false, 64); @@ -108,7 +108,7 @@ void runBlockSelectPair( if (dir) { if (k == 1) { BLOCK_SELECT_PAIR_CALL(float, true, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { BLOCK_SELECT_PAIR_CALL(float, true, 32); } else if (k <= 64) { BLOCK_SELECT_PAIR_CALL(float, true, 64); @@ -128,7 +128,7 @@ void runBlockSelectPair( } else { if (k == 1) { BLOCK_SELECT_PAIR_CALL(float, false, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { BLOCK_SELECT_PAIR_CALL(float, false, 32); } else if (k <= 64) { BLOCK_SELECT_PAIR_CALL(float, false, 64); diff --git a/faiss/gpu/utils/BlockSelectKernel.cuh b/faiss/gpu/utils/BlockSelectKernel.cuh index 79f8aa0c1c..e4dc8e20d3 100644 --- a/faiss/gpu/utils/BlockSelectKernel.cuh +++ b/faiss/gpu/utils/BlockSelectKernel.cuh @@ -26,45 +26,47 @@ __global__ void blockSelect( K initK, IndexType initV, int k) { - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ K smemK[kNumWarps * NumWarpQ]; - __shared__ IndexType smemV[kNumWarps * NumWarpQ]; - - BlockSelect< - K, - IndexType, - Dir, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - idx_t row = blockIdx.x; - - idx_t i = threadIdx.x; - K* inStart = in[row][i].data(); - - // Whole warps must participate in the selection - idx_t limit = utils::roundDown(in.getSize(1), kWarpSize); - - for (; i < limit; i += ThreadsPerBlock) { - heap.add(*inStart, (IndexType)i); - inStart += ThreadsPerBlock; - } - - // Handle last remainder fraction of a warp of elements - if (i < in.getSize(1)) { - heap.addThreadQ(*inStart, (IndexType)i); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { - outK[row][i] = smemK[i]; - outV[row][i] = smemV[i]; + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect< + K, + IndexType, + Dir, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + idx_t row = blockIdx.x; + + idx_t i = threadIdx.x; + K* inStart = in[row][i].data(); + + // Whole warps must participate in the selection + idx_t limit = utils::roundDown(in.getSize(1), kWarpSize); + + for (; i < limit; i += ThreadsPerBlock) { + heap.add(*inStart, (IndexType)i); + inStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < in.getSize(1)) { + heap.addThreadQ(*inStart, (IndexType)i); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } } } @@ -83,47 +85,49 @@ __global__ void blockSelectPair( K initK, IndexType initV, int k) { - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - - __shared__ K smemK[kNumWarps * NumWarpQ]; - __shared__ IndexType smemV[kNumWarps * NumWarpQ]; - - BlockSelect< - K, - IndexType, - Dir, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - idx_t row = blockIdx.x; - - idx_t i = threadIdx.x; - K* inKStart = inK[row][i].data(); - IndexType* inVStart = inV[row][i].data(); - - // Whole warps must participate in the selection - idx_t limit = utils::roundDown(inK.getSize(1), (idx_t)kWarpSize); - - for (; i < limit; i += ThreadsPerBlock) { - heap.add(*inKStart, *inVStart); - inKStart += ThreadsPerBlock; - inVStart += ThreadsPerBlock; - } - - // Handle last remainder fraction of a warp of elements - if (i < inK.getSize(1)) { - heap.addThreadQ(*inKStart, *inVStart); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { - outK[row][i] = smemK[i]; - outV[row][i] = smemV[i]; + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect< + K, + IndexType, + Dir, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + idx_t row = blockIdx.x; + + idx_t i = threadIdx.x; + K* inKStart = inK[row][i].data(); + IndexType* inVStart = inV[row][i].data(); + + // Whole warps must participate in the selection + idx_t limit = utils::roundDown(inK.getSize(1), (idx_t)kWarpSize); + + for (; i < limit; i += ThreadsPerBlock) { + heap.add(*inKStart, *inVStart); + inKStart += ThreadsPerBlock; + inVStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < inK.getSize(1)) { + heap.addThreadQ(*inKStart, *inVStart); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } } } diff --git a/faiss/gpu/utils/DeviceDefs.cuh b/faiss/gpu/utils/DeviceDefs.cuh index bdba8d720b..521f104265 100644 --- a/faiss/gpu/utils/DeviceDefs.cuh +++ b/faiss/gpu/utils/DeviceDefs.cuh @@ -12,6 +12,23 @@ namespace faiss { namespace gpu { +#ifdef USE_ROCM + +#if __AMDGCN_WAVEFRONT_SIZE == 32u +constexpr int kWarpSize = 32; +#else +constexpr int kWarpSize = 64; +#endif + +// This is a memory barrier for intra-warp writes to shared memory. +__forceinline__ __device__ void warpFence() { + __threadfence_block(); +} + +#define GPU_MAX_SELECTION_K 2048 + +#else // USE_ROCM + // We require at least CUDA 8.0 for compilation #if CUDA_VERSION < 8000 #error "CUDA >= 8.0 is required" @@ -39,5 +56,7 @@ __forceinline__ __device__ void warpFence() { #define GPU_MAX_SELECTION_K 1024 #endif +#endif // USE_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/utils/DeviceUtils.cu b/faiss/gpu/utils/DeviceUtils.cu index 56f5a25a42..1664e218da 100644 --- a/faiss/gpu/utils/DeviceUtils.cu +++ b/faiss/gpu/utils/DeviceUtils.cu @@ -124,6 +124,14 @@ int getDeviceForAddress(const void* p) { return -1; } +#if USE_ROCM + if (att.type != hipMemoryTypeHost && + att.type != hipMemoryTypeUnregistered) { + return att.device; + } else { + return -1; + } +#else // memoryType is deprecated for CUDA 10.0+ #if CUDA_VERSION < 10000 if (att.memoryType == cudaMemoryTypeHost) { @@ -139,6 +147,7 @@ int getDeviceForAddress(const void* p) { return -1; } #endif +#endif } bool getFullUnifiedMemSupport(int device) { @@ -159,6 +168,15 @@ bool getTensorCoreSupportCurrentDevice() { return getTensorCoreSupport(getCurrentDevice()); } +int getWarpSize(int device) { + const auto& prop = getDeviceProperties(device); + return prop.warpSize; +} + +int getWarpSizeCurrentDevice() { + return getWarpSize(getCurrentDevice()); +} + size_t getFreeMemory(int device) { DeviceScope scope(device); diff --git a/faiss/gpu/utils/DeviceUtils.h b/faiss/gpu/utils/DeviceUtils.h index 0bbfca8206..dee13079c1 100644 --- a/faiss/gpu/utils/DeviceUtils.h +++ b/faiss/gpu/utils/DeviceUtils.h @@ -76,6 +76,12 @@ bool getTensorCoreSupport(int device); /// Equivalent to getTensorCoreSupport(getCurrentDevice()) bool getTensorCoreSupportCurrentDevice(); +/// Returns the warp size of the given GPU device +int getWarpSize(int device); + +/// Equivalent to getWarpSize(getCurrentDevice()) +int getWarpSizeCurrentDevice(); + /// Returns the amount of currently available memory on the given device size_t getFreeMemory(int device); diff --git a/faiss/gpu/utils/Float16.cuh b/faiss/gpu/utils/Float16.cuh index 42fb9878b9..3a676538c3 100644 --- a/faiss/gpu/utils/Float16.cuh +++ b/faiss/gpu/utils/Float16.cuh @@ -12,7 +12,7 @@ #include // Some compute capabilities have full float16 ALUs. -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 || defined(USE_ROCM) #define FAISS_USE_FULL_FLOAT16 1 #endif // __CUDA_ARCH__ types diff --git a/faiss/gpu/utils/Limits.cuh b/faiss/gpu/utils/Limits.cuh index c136375474..7a65fbf1b6 100644 --- a/faiss/gpu/utils/Limits.cuh +++ b/faiss/gpu/utils/Limits.cuh @@ -33,7 +33,7 @@ struct Limits { }; inline __device__ __host__ half kGetHalf(unsigned short v) { -#if CUDA_VERSION >= 9000 +#if CUDA_VERSION >= 9000 || defined(USE_ROCM) __half_raw h; h.x = v; return __half(h); diff --git a/faiss/gpu/utils/LoadStoreOperators.cuh b/faiss/gpu/utils/LoadStoreOperators.cuh index a342c40075..e00c5d85df 100644 --- a/faiss/gpu/utils/LoadStoreOperators.cuh +++ b/faiss/gpu/utils/LoadStoreOperators.cuh @@ -23,6 +23,51 @@ namespace faiss { namespace gpu { +#ifdef USE_ROCM + +template +struct LoadStore { + static inline __device__ T load(void* p) { + return *((T*)p); + } + + static inline __device__ void store(void* p, const T& v) { + *((T*)p) = v; + } +}; + +template <> +struct LoadStore { + static inline __device__ Half4 load(void* p) { + Half4 out; + Half4* t = reinterpret_cast(p); + out = *t; + return out; + } + + static inline __device__ void store(void* p, Half4& v) { + Half4* t = reinterpret_cast(p); + *t = v; + } +}; + +template <> +struct LoadStore { + static inline __device__ Half8 load(void* p) { + Half8 out; + Half8* t = reinterpret_cast(p); + out = *t; + return out; + } + + static inline __device__ void store(void* p, Half8& v) { + Half8* t = reinterpret_cast(p); + *t = v; + } +}; + +#else // USE_ROCM + template struct LoadStore { static inline __device__ T load(void* p) { @@ -97,5 +142,7 @@ struct LoadStore { } }; +#endif // USE_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/utils/MathOperators.cuh b/faiss/gpu/utils/MathOperators.cuh index 2b2bf7d64b..87f779feba 100644 --- a/faiss/gpu/utils/MathOperators.cuh +++ b/faiss/gpu/utils/MathOperators.cuh @@ -282,7 +282,7 @@ struct Math { } static inline __device__ half zero() { -#if CUDA_VERSION >= 9000 +#if CUDA_VERSION >= 9000 || defined(USE_ROCM) return 0; #else half h; diff --git a/faiss/gpu/utils/MatrixMult-inl.cuh b/faiss/gpu/utils/MatrixMult-inl.cuh index da4b201f0d..ce9922e071 100644 --- a/faiss/gpu/utils/MatrixMult-inl.cuh +++ b/faiss/gpu/utils/MatrixMult-inl.cuh @@ -20,6 +20,17 @@ namespace gpu { template struct GetCudaType; +#ifdef USE_ROCM +template <> +struct GetCudaType { + static constexpr hipblasDatatype_t Type = HIPBLAS_R_32F; +}; + +template <> +struct GetCudaType { + static constexpr hipblasDatatype_t Type = HIPBLAS_R_16F; +}; +#else template <> struct GetCudaType { static constexpr cudaDataType_t Type = CUDA_R_32F; @@ -29,6 +40,7 @@ template <> struct GetCudaType { static constexpr cudaDataType_t Type = CUDA_R_16F; }; +#endif template cublasStatus_t rawGemm( @@ -49,6 +61,29 @@ cublasStatus_t rawGemm( auto cAT = GetCudaType::Type; auto cBT = GetCudaType::Type; +#ifdef USE_ROCM + return hipblasGemmEx( + handle, + transa, + transb, + m, + n, + k, + &fAlpha, + A, + cAT, + lda, + B, + cBT, + ldb, + &fBeta, + C, + HIPBLAS_R_32F, + ldc, + HIPBLAS_R_32F, + HIPBLAS_GEMM_DEFAULT); +#else + // FIXME: some weird CUDA 11 bug? where cublasSgemmEx on // f16 (8, 64) x f16 (64, 64)' = f32 (8, 64) returns "not supported". // cublasGemmEx using CUBLAS_COMPUTE_32F also fails, but @@ -100,6 +135,7 @@ cublasStatus_t rawGemm( C, CUDA_R_32F, ldc); +#endif // USE_ROCM } template @@ -126,6 +162,32 @@ cublasStatus_t rawBatchGemm( auto cBT = GetCudaType::Type; // Always accumulate in f32 +#ifdef USE_ROCM + return hipblasGemmStridedBatchedEx( + handle, + transa, + transb, + m, + n, + k, + &fAlpha, + A, + cAT, + lda, + strideA, + B, + cBT, + ldb, + strideB, + &fBeta, + C, + HIPBLAS_R_32F, + ldc, + strideC, + batchCount, + HIPBLAS_R_32F, + HIPBLAS_GEMM_DEFAULT); +#else return cublasGemmStridedBatchedEx( handle, transa, @@ -150,6 +212,7 @@ cublasStatus_t rawBatchGemm( batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); +#endif } template diff --git a/faiss/gpu/utils/MergeNetworkWarp.cuh b/faiss/gpu/utils/MergeNetworkWarp.cuh index fcd1f648d8..4f64d490c0 100644 --- a/faiss/gpu/utils/MergeNetworkWarp.cuh +++ b/faiss/gpu/utils/MergeNetworkWarp.cuh @@ -164,7 +164,11 @@ template struct BitonicMergeStep { static inline __device__ void merge(K k[1], V v[1]) { // Use warp shuffles - warpBitonicMergeLE16(k[0], v[0]); + if constexpr (kWarpSize == 32) { + warpBitonicMergeLE16(k[0], v[0]); + } else { + warpBitonicMergeLE16(k[0], v[0]); + } } }; @@ -529,13 +533,17 @@ struct BitonicSortStep { static inline __device__ void sort(K k[1], V v[1]) { // Update this code if this changes // should go from 1 -> kWarpSize in multiples of 2 - static_assert(kWarpSize == 32, "unexpected warp size"); + static_assert( + kWarpSize == 32 || kWarpSize == 64, "unexpected warp size"); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); + if constexpr (kWarpSize == 64) { + warpBitonicMergeLE16(k[0], v[0]); + } } }; diff --git a/faiss/gpu/utils/PtxUtils.cuh b/faiss/gpu/utils/PtxUtils.cuh index 35e23dd632..a6617aa0b6 100644 --- a/faiss/gpu/utils/PtxUtils.cuh +++ b/faiss/gpu/utils/PtxUtils.cuh @@ -8,10 +8,51 @@ #pragma once #include +#ifdef USE_ROCM +#include +#endif namespace faiss { namespace gpu { +#ifdef USE_ROCM + +#define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \ + do { \ + OUT = getBitfield((uint32_t)VAL, POS, LEN); \ + } while (0) + +#define GET_BITFIELD_U64(OUT, VAL, POS, LEN) \ + do { \ + OUT = getBitfield((uint64_t)VAL, POS, LEN); \ + } while (0) + +__device__ __forceinline__ uint32_t +getBitfield(uint32_t val, int pos, int len) { + return __bitextract_u32(val, pos, len); +} + +__device__ __forceinline__ uint64_t +getBitfield(uint64_t val, int pos, int len) { + return __bitextract_u64(val, pos, len); +} + +__device__ __forceinline__ unsigned int setBitfield( + unsigned int val, + unsigned int toInsert, + int pos, + int len) { + unsigned int ret{0}; + printf("Runtime Error of %s: Unimplemented\n", __PRETTY_FUNCTION__); + return ret; +} + +__device__ __forceinline__ int getLaneId() { + return ::__lane_id(); +} + +#else // USE_ROCM + // defines to simplify the SASS assembly structure file/line in the profiler #define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \ asm("bfe.u32 %0, %1, %2, %3;" : "=r"(OUT) : "r"(VAL), "r"(POS), "r"(LEN)); @@ -88,5 +129,7 @@ __device__ __forceinline__ void namedBarrierArrived(int name, int numThreads) { : "memory"); } +#endif // USE_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/utils/Select.cuh b/faiss/gpu/utils/Select.cuh index 4ad89239bb..f4d05a6cc8 100644 --- a/faiss/gpu/utils/Select.cuh +++ b/faiss/gpu/utils/Select.cuh @@ -207,10 +207,10 @@ struct BlockSelect { __device__ inline void checkThreadQ() { bool needSort = (numVals == NumThreadQ); -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else +#if CUDA_VERSION < 9000 || defined(USE_ROCM) needSort = __any(needSort); +#else + needSort = __any_sync(0xffffffff, needSort); #endif if (!needSort) { @@ -484,10 +484,10 @@ struct WarpSelect { __device__ inline void checkThreadQ() { bool needSort = (numVals == NumThreadQ); -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else +#if CUDA_VERSION < 9000 || defined(USE_ROCM) needSort = __any(needSort); +#else + needSort = __any_sync(0xffffffff, needSort); #endif if (!needSort) { diff --git a/faiss/gpu/utils/Tensor.cuh b/faiss/gpu/utils/Tensor.cuh index 0fbb2417b3..5cfb19c02c 100644 --- a/faiss/gpu/utils/Tensor.cuh +++ b/faiss/gpu/utils/Tensor.cuh @@ -469,7 +469,7 @@ class SubTensor { /// Use the texture cache for reads __device__ inline typename TensorType::DataType ldg() const { -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) return __ldg(data_); #else return *data_; @@ -479,7 +479,7 @@ class SubTensor { /// Use the texture cache for reads; cast as a particular type template __device__ inline T ldgAs() const { -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) return __ldg(dataAs()); #else return as(); @@ -605,7 +605,7 @@ class SubTensor { /// Use the texture cache for reads __device__ inline typename TensorType::DataType ldg() const { -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) return __ldg(data_); #else return *data_; @@ -615,7 +615,7 @@ class SubTensor { /// Use the texture cache for reads; cast as a particular type template __device__ inline T ldgAs() const { -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) return __ldg(dataAs()); #else return as(); diff --git a/faiss/gpu/utils/Transpose.cuh b/faiss/gpu/utils/Transpose.cuh index 2d0303106c..f5aaa8551e 100644 --- a/faiss/gpu/utils/Transpose.cuh +++ b/faiss/gpu/utils/Transpose.cuh @@ -84,7 +84,7 @@ __global__ void transposeAny( auto inputOffset = TensorInfoOffset::get(input, i); auto outputOffset = TensorInfoOffset::get(output, i); -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) output.data[outputOffset] = __ldg(&input.data[inputOffset]); #else output.data[outputOffset] = input.data[inputOffset]; diff --git a/faiss/gpu/utils/WarpPackedBits.cuh b/faiss/gpu/utils/WarpPackedBits.cuh index c079874ad9..ebb7806a75 100644 --- a/faiss/gpu/utils/WarpPackedBits.cuh +++ b/faiss/gpu/utils/WarpPackedBits.cuh @@ -45,7 +45,7 @@ struct WarpPackedBits { uint8_t v, bool valid, uint8_t* out) { - // Lower 24 lanes wwrite out packed data + // Lower kWarpSize*3/4 lanes (24 or 48) write out packed data int laneFrom = (laneId * 8) / 6; v = valid ? v : 0; @@ -80,7 +80,7 @@ struct WarpPackedBits { break; } - if (laneId < 24) { + if (laneId < kWarpSize * 3 / 4) { // There could be prior data out[laneId] |= vOut; } @@ -89,7 +89,7 @@ struct WarpPackedBits { static inline __device__ uint8_t read(int laneId, uint8_t* in) { uint8_t v = 0; - if (laneId < 24) { + if (laneId < kWarpSize * 3 / 4) { v = in[laneId]; } @@ -242,7 +242,7 @@ struct WarpPackedBits { uint8_t v, bool valid, uint8_t* out) { - // Lower 16 lanes write out packed data + // Lower kWarpSize/2 (16 or 32) lanes write out packed data int laneFrom = laneId * 2; v = valid ? v : 0; @@ -254,7 +254,7 @@ struct WarpPackedBits { uint8_t vOut = (vLower & 0xf) | (vUpper << 4); - if (laneId < 16) { + if (laneId < kWarpSize / 2) { // There could be prior data out[laneId] |= vOut; } @@ -263,7 +263,7 @@ struct WarpPackedBits { static inline __device__ uint8_t read(int laneId, uint8_t* in) { uint8_t v = 0; - if (laneId < 16) { + if (laneId < kWarpSize / 2) { v = in[laneId]; } diff --git a/faiss/gpu/utils/WarpSelectFloat.cu b/faiss/gpu/utils/WarpSelectFloat.cu index 89aef1f01d..221dad9703 100644 --- a/faiss/gpu/utils/WarpSelectFloat.cu +++ b/faiss/gpu/utils/WarpSelectFloat.cu @@ -55,7 +55,7 @@ void runWarpSelect( if (dir) { if (k == 1) { WARP_SELECT_CALL(float, true, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { WARP_SELECT_CALL(float, true, 32); } else if (k <= 64) { WARP_SELECT_CALL(float, true, 64); @@ -75,7 +75,7 @@ void runWarpSelect( } else { if (k == 1) { WARP_SELECT_CALL(float, false, 1); - } else if (k <= 32) { + } else if (k <= 32 && getWarpSizeCurrentDevice() == 32) { WARP_SELECT_CALL(float, false, 32); } else if (k <= 64) { WARP_SELECT_CALL(float, false, 64); diff --git a/faiss/gpu/utils/WarpSelectKernel.cuh b/faiss/gpu/utils/WarpSelectKernel.cuh index 55f74a608b..fa7a34109d 100644 --- a/faiss/gpu/utils/WarpSelectKernel.cuh +++ b/faiss/gpu/utils/WarpSelectKernel.cuh @@ -26,43 +26,45 @@ __global__ void warpSelect( K initK, IndexType initV, int k) { - constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + if constexpr ((NumWarpQ == 1 && NumThreadQ == 1) || NumWarpQ >= kWarpSize) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; - WarpSelect< - K, - IndexType, - Dir, - Comparator, - NumWarpQ, - NumThreadQ, - ThreadsPerBlock> - heap(initK, initV, k); + WarpSelect< + K, + IndexType, + Dir, + Comparator, + NumWarpQ, + NumThreadQ, + ThreadsPerBlock> + heap(initK, initV, k); - int warpId = threadIdx.x / kWarpSize; - idx_t row = idx_t(blockIdx.x) * kNumWarps + warpId; + int warpId = threadIdx.x / kWarpSize; + idx_t row = idx_t(blockIdx.x) * kNumWarps + warpId; - if (row >= in.getSize(0)) { - return; - } + if (row >= in.getSize(0)) { + return; + } - idx_t i = getLaneId(); - K* inStart = in[row][i].data(); + idx_t i = getLaneId(); + K* inStart = in[row][i].data(); - // Whole warps must participate in the selection - idx_t limit = utils::roundDown(in.getSize(1), kWarpSize); + // Whole warps must participate in the selection + idx_t limit = utils::roundDown(in.getSize(1), kWarpSize); - for (; i < limit; i += kWarpSize) { - heap.add(*inStart, (IndexType)i); - inStart += kWarpSize; - } + for (; i < limit; i += kWarpSize) { + heap.add(*inStart, (IndexType)i); + inStart += kWarpSize; + } - // Handle non-warp multiple remainder - if (i < in.getSize(1)) { - heap.addThreadQ(*inStart, (IndexType)i); - } + // Handle non-warp multiple remainder + if (i < in.getSize(1)) { + heap.addThreadQ(*inStart, (IndexType)i); + } - heap.reduce(); - heap.writeOut(outK[row].data(), outV[row].data(), k); + heap.reduce(); + heap.writeOut(outK[row].data(), outV[row].data(), k); + } } void runWarpSelect( diff --git a/faiss/gpu/utils/WarpShuffles.cuh b/faiss/gpu/utils/WarpShuffles.cuh index 23dbb437fa..5af6d71ae7 100644 --- a/faiss/gpu/utils/WarpShuffles.cuh +++ b/faiss/gpu/utils/WarpShuffles.cuh @@ -102,6 +102,22 @@ inline __device__ T* shfl_xor( return (T*)shfl_xor(v, laneMask, width); } +#ifdef USE_ROCM + +inline __device__ half shfl(half v, int srcLane, int width = kWarpSize) { + unsigned int vu = __half2uint_rn(v); + vu = __shfl(vu, srcLane, width); + return __uint2half_rn(vu); +} + +inline __device__ half shfl_xor(half v, int laneMask, int width = kWarpSize) { + unsigned int vu = __half2uint_rn(v); + vu = __shfl_xor(vu, laneMask, width); + return __uint2half_rn(vu); +} + +#else + // CUDA 9.0+ has half shuffle #if CUDA_VERSION < 9000 inline __device__ half shfl(half v, int srcLane, int width = kWarpSize) { @@ -123,5 +139,7 @@ inline __device__ half shfl_xor(half v, int laneMask, int width = kWarpSize) { } #endif // CUDA_VERSION +#endif // USE_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh b/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh index 8362ffefb5..eb49816a62 100644 --- a/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh +++ b/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh @@ -25,9 +25,10 @@ bool dir, \ int k, \ cudaStream_t stream) { \ + int warpSize = getWarpSizeCurrentDevice(); \ constexpr int kWarpSelectNumThreads = 128; \ auto grid = dim3(utils::divUp( \ - in.getSize(0), (kWarpSelectNumThreads / kWarpSize))); \ + in.getSize(0), (kWarpSelectNumThreads / warpSize))); \ auto block = dim3(kWarpSelectNumThreads); \ \ FAISS_ASSERT(k <= WARP_Q); \ diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index d8d933497e..84bc331421 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -38,6 +38,12 @@ macro(configure_swigfaiss source) set_source_files_properties(${source} PROPERTIES COMPILE_DEFINITIONS GPU_WRAPPER ) + if (USE_ROCM) + message(USE_ROCM="${USE_ROCM}") + set_source_files_properties(${source} PROPERTIES + COMPILE_DEFINITIONS USE_ROCM + ) + endif() if (FAISS_ENABLE_RAFT) set_property(SOURCE ${source} APPEND PROPERTY COMPILE_DEFINITIONS FAISS_ENABLE_RAFT @@ -66,12 +72,20 @@ if(TARGET faiss) list(APPEND SWIG_MODULE_swigfaiss_avx512_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/${h}") list(APPEND SWIG_MODULE_swigfaiss_sve_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/${h}") endforeach() - foreach(h ${FAISS_GPU_HEADERS}) - list(APPEND SWIG_MODULE_swigfaiss_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") - list(APPEND SWIG_MODULE_swigfaiss_avx2_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") - list(APPEND SWIG_MODULE_swigfaiss_avx512_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") - list(APPEND SWIG_MODULE_swigfaiss_sve_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") - endforeach() + if(USE_ROCM) + foreach(h ${FAISS_GPU_HEADERS}) + list(APPEND SWIG_MODULE_swigfaiss_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu-rocm/${h}") + list(APPEND SWIG_MODULE_swigfaiss_avx2_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu-rocm/${h}") + list(APPEND SWIG_MODULE_swigfaiss_avx512_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu-rocm/${h}") + endforeach() + else() + foreach(h ${FAISS_GPU_HEADERS}) + list(APPEND SWIG_MODULE_swigfaiss_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") + list(APPEND SWIG_MODULE_swigfaiss_avx2_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") + list(APPEND SWIG_MODULE_swigfaiss_avx512_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") + list(APPEND SWIG_MODULE_swigfaiss_sve_EXTRA_DEPS "${faiss_SOURCE_DIR}/faiss/gpu/${h}") + endforeach() + endif() else() find_package(faiss REQUIRED) endif() @@ -143,14 +157,21 @@ else() endif() if(FAISS_ENABLE_GPU) - find_package(CUDAToolkit REQUIRED) - if(FAISS_ENABLE_RAFT) - find_package(raft COMPONENTS compiled distributed) + if(USE_ROCM) + find_package(HIP REQUIRED) + target_link_libraries(swigfaiss PRIVATE $<$:hip::host>) + target_link_libraries(swigfaiss_avx2 PRIVATE $<$:hip::host>) + target_link_libraries(swigfaiss_avx512 PRIVATE $<$:hip::host>) + else() + find_package(CUDAToolkit REQUIRED) + if(FAISS_ENABLE_RAFT) + find_package(raft COMPONENTS compiled distributed) + endif() + target_link_libraries(swigfaiss PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) + target_link_libraries(swigfaiss_avx2 PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) + target_link_libraries(swigfaiss_avx512 PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) + target_link_libraries(swigfaiss_sve PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) endif() - target_link_libraries(swigfaiss PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) - target_link_libraries(swigfaiss_avx2 PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) - target_link_libraries(swigfaiss_avx512 PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) - target_link_libraries(swigfaiss_sve PRIVATE CUDA::cudart $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) endif() find_package(OpenMP REQUIRED) diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index f63d76dc0e..b507843f3c 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -295,6 +295,77 @@ void gpu_profiler_stop(); void gpu_sync_all_devices(); #ifdef GPU_WRAPPER +#ifdef USE_ROCM + +%shared_ptr(faiss::gpu::GpuResources); +%shared_ptr(faiss::gpu::StandardGpuResourcesImpl); + +%{ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int get_num_gpus() +{ + return faiss::gpu::getNumDevices(); +} + +void gpu_profiler_start() +{ + return faiss::gpu::profilerStart(); +} + +void gpu_profiler_stop() +{ + return faiss::gpu::profilerStop(); +} + +void gpu_sync_all_devices() +{ + return faiss::gpu::synchronizeAllDevices(); +} + +%} + +%template() std::pair; +%template() std::map >; +%template() std::map > >; + +// causes weird wrapper bug +%ignore *::allocMemoryHandle; +%ignore faiss::gpu::GpuMemoryReservation; +%ignore faiss::gpu::GpuMemoryReservation::operator=(GpuMemoryReservation&&); +%ignore faiss::gpu::AllocType; + +%include +%include + +%inline %{ + +// interop between pytorch exposed hipStream_t and faiss +hipStream_t cast_integer_to_cudastream_t(int64_t x) { + return (hipStream_t) x; +} + +int64_t cast_cudastream_t_to_integer(hipStream_t x) { + return (int64_t) x; +} + +%} + +#else // USE_ROCM %shared_ptr(faiss::gpu::GpuResources); %shared_ptr(faiss::gpu::StandardGpuResourcesImpl); @@ -367,7 +438,8 @@ int64_t cast_cudastream_t_to_integer(cudaStream_t x) { %} -#else +#endif // USE_ROCM +#else // GPU_WRAPPER %{ int get_num_gpus() @@ -389,7 +461,7 @@ void gpu_sync_all_devices() %} -#endif +#endif // GPU_WRAPPER // order matters because includes are not recursive @@ -559,6 +631,25 @@ struct faiss::simd16uint16 {}; #ifdef GPU_WRAPPER +#ifdef USE_ROCM + +// quiet SWIG warnings +%ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF; + +%include +%include +%include +%include +%include +%include +%include +%include +%include +%include +%include + +#else // USE_ROCM + // quiet SWIG warnings %ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF; @@ -577,7 +668,7 @@ struct faiss::simd16uint16 {}; %include %include - +#endif // USE_ROCM #endif @@ -585,8 +676,6 @@ struct faiss::simd16uint16 {}; - - /******************************************************************* * downcast return of some functions so that the sub-class is used * instead of the generic upper-class. @@ -815,6 +904,16 @@ faiss::Quantizer * downcast_Quantizer (faiss::Quantizer *aq) #ifdef GPU_WRAPPER +#ifdef USE_ROCM +%include + +%newobject index_gpu_to_cpu; +%newobject index_cpu_to_gpu; +%newobject index_cpu_to_gpu_multiple; + +%include + +#else // USE_ROCM %include %newobject index_gpu_to_cpu; @@ -823,6 +922,7 @@ faiss::Quantizer * downcast_Quantizer (faiss::Quantizer *aq) %include +#endif // USE_ROCM #endif diff --git a/faiss/utils/hamming_distance/generic-inl.h b/faiss/utils/hamming_distance/generic-inl.h index e0907a1586..a006814877 100644 --- a/faiss/utils/hamming_distance/generic-inl.h +++ b/faiss/utils/hamming_distance/generic-inl.h @@ -166,9 +166,12 @@ struct HammingComputer20 { void set(const uint8_t* a8, int code_size) { assert(code_size == 20); const uint64_t* a = (uint64_t*)a8; + const uint32_t* b = (uint32_t*)a8; a0 = a[0]; a1 = a[1]; - a2 = a[2]; + // can't read a[2] since it is uint64_t, not uint32_t + // results in AddressSanitizer failure reading past end of array + a2 = b[4]; } inline int hamming(const uint8_t* b8) const {