diff --git a/.gitmodules b/.gitmodules index c031c2fd5ad38..f488047609c6b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -128,3 +128,7 @@ path = third_party/cpp-httplib url = https://github.com/yhirose/cpp-httplib.git branch = v0.15.3 +[submodule "third_party/composable_kernel"] + path = third_party/composable_kernel + url = https://github.com/ROCm/composable_kernel.git + branch = develop diff --git a/CMakeLists.txt b/CMakeLists.txt index c4cd4b2c2a98e..a67a274540775 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,10 @@ cmake_policy(SET CMP0069 NEW) # and it's possible on our Windows configs. cmake_policy(SET CMP0092 NEW) +include(CMakePrintHelpers) + + + # Prohibit in-source builds if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) message(FATAL_ERROR "In-source build are not supported") @@ -773,7 +777,7 @@ set(CAFFE2_ALLOWLIST if(NOT CMAKE_BUILD_TYPE) message(STATUS "Build type not set - defaulting to Release") set(CMAKE_BUILD_TYPE - "Release" + "Debug" CACHE STRING "Choose the type of build from: Debug Release RelWithDebInfo MinSizeRel Coverage." @@ -851,6 +855,8 @@ endif() # aotriton build decision later. include(cmake/Dependencies.cmake) +message("BEFORE USE_FLASH ATTENTION IS SUPPOSEDLY CREATED") +cmake_print_variables(USE_FLASH_ATTENTION) cmake_dependent_option( USE_FLASH_ATTENTION @@ -860,6 +866,93 @@ cmake_dependent_option( "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) +message("AFTER USE_FLASH_ATTENTION IS CREATED") + + +if(USE_FLASH_ATTENTION) + message("MADE IT HERE") + cmake_print_variables(Python3_EXECUTABLE) + execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/fwd_blob_list.txt --receipt 3 + ) + execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --list_blobs ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/bwd_blob_list.txt --receipt 3 + ) + message("MADE IT PAST EXECUTE_PROCESS") + # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory + # as current cmake list, otherwise will not figure out the dependency properly + file(STRINGS ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) + file(STRINGS ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) + + #add_custom_command( + # OUTPUT ${FMHA_FWD_GEN_BLOBS} + # COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + # --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_LIST_DIR} + #) + + #add_custom_command( + # OUTPUT ${FMHA_BWD_GEN_BLOBS} + # COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + # --api bwd --output_dir ${CMAKE_CURRENT_LIST_DIR} + #) + execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/ --receipt 3 + ) + execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --output_dir ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/ --receipt 3 + ) + + execute_process(COMMAND ls ${CMAKE_CURRENT_LIST_DIR}) + + + set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") + # not using add_example_executable() to add this target, since we don't want this to have + # to be included in "make all/install/check" + message("adding example ${EXAMPLE_FMHA_FWD}") + add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/fmha_fwd.cpp) + target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/) + target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + + set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") + # not using add_example_executable() to add this target, since we don't want this to have + # to be included in "make all/install/check" + message("adding example ${EXAMPLE_FMHA_BWD}") + add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/fmha_bwd.cpp) + target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/) + target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + + # NOTE: this is dangerous since will change the whole kernel to flush denormals + # WIP with compiler team for an exp2 intrinsic..., then remove this + if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) + endif() + + set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) + + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + # ... because they are auto-generated + if(FMHA_FWD_FAST_EXP2) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + endif() + + # Allow comparing floating points directly in order to check sentinel values + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) + + target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) + target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + + # TODO: we have to turn off this global prop, otherwise the progress bar generated + # by cmake will print too many files, execvp: /bin/sh: Argument list too long + # however, this property may affect global + # TODO: consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) + + message("ANDY! WE MADE IT TO THE END OF OUR BLURB!") +endif() + + # We are currenlty not using alibi attention for Flash So we disable this # feature by default We dont currently document this feature because we don't # Suspect users building from source will need this @@ -871,7 +964,7 @@ cmake_dependent_option( USE_MEM_EFF_ATTENTION "Enable memory-efficient attention for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM" OFF) + "USE_CUDA" OFF) if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") diff --git a/aten/src/ATen/BlasBackend.h b/aten/src/ATen/BlasBackend.h index 7f8c321ad9fa2..521addefc5ee1 100644 --- a/aten/src/ATen/BlasBackend.h +++ b/aten/src/ATen/BlasBackend.h @@ -7,7 +7,7 @@ namespace at { -enum class BlasBackend : int8_t { Cublas, Cublaslt }; +enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck }; inline std::string BlasBackendToString(at::BlasBackend backend) { switch (backend) { @@ -15,6 +15,8 @@ inline std::string BlasBackendToString(at::BlasBackend backend) { return "at::BlasBackend::Cublas"; case BlasBackend::Cublaslt: return "at::BlasBackend::Cublaslt"; + case BlasBackend::Ck: + return "at::BlasBackend::Ck"; default: TORCH_CHECK(false, "Unknown blas backend"); } diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6d9152a4d07df..42006349dbb4b 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -169,6 +169,7 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") # flash_attention sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +file(GLOB flash_attention_hip_cpp "native/transformers/hip/flash_attn/*.cpp") #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") @@ -184,6 +185,14 @@ if(USE_FLASH_ATTENTION) list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip}) list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip}) + add_subdirectory(native/transformers/hip/flash_attn) + set_source_files_properties( + ${flash_attention_hip_cpp} + DIRECTORY "native/transformers/hip/flash_attn/" + PROPERTIES + COMPILE_FLAGS "-Wno-undefined-func-template" + ) + list(APPEND native_transformers_hip_cpp ${flash_attention_hip_cpp}) endif() if(USE_MEM_EFF_ATTENTION) @@ -309,6 +318,7 @@ if(USE_ROCM) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha/) list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 5997b9435c59e..218451bc0c96e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -284,6 +284,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #else TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(), "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt."); + TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(), + "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm."); if (b != at::BlasBackend::Cublas) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index d7ae18ed1a3b7..9213ad88e1d48 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -126,6 +126,9 @@ class TORCH_API Context { static bool hasCuBLASLt() { return detail::getCUDAHooks().hasCuBLASLt(); } + static bool hasROCM() { + return detail::getCUDAHooks().hasROCM(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index ce991a9bcad4e..b18d782dc1b57 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -18,6 +18,7 @@ // until hipblas has an API to accept flags, we must use rocblas here #include #include +#include #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) // needed to work around calling rocblas API instead of hipblas API @@ -792,6 +793,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { AT_ERROR("at::cuda::blas::gemm_internal_cublas: not implemented for ", typeid(Dtype).name()); } + template <> void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] @@ -1000,6 +1002,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(double)); } @@ -1011,6 +1018,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(float)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(float)); } @@ -1054,6 +1066,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1065,6 +1082,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::BFloat16)); } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 7d796c3d67e2b..02978a81c2aad 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -79,6 +79,7 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b transpose_tensor = tensor.is_contiguous(); return resolve_conj_if_indicated(tensor, true); } + IntArrayRef tensor_strides = tensor.strides(); IntArrayRef tensor_sizes = tensor.sizes(); if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { diff --git a/aten/src/ATen/native/hip/ck_gemm.h b/aten/src/ATen/native/hip/ck_gemm.h new file mode 100644 index 0000000000000..176cbabd5e01c --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +namespace at::native { + + +template +inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); + + + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip new file mode 100644 index 0000000000000..dd1503de89cb1 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -0,0 +1,479 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + } else { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +} + + + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip new file mode 100644 index 0000000000000..b8301a47981c6 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -0,0 +1,486 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) { + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + } else { + + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +} + + + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)) { + dispatch_float_gemm(CUDABLAS_GEMM_ARGS(float)); +} + +// temporarily put this here until we implement double support +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)) { + return; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip new file mode 100644 index 0000000000000..60b64ca275c54 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -0,0 +1,306 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include + +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { +#if 0 + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + + + } else { + + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + true>(CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2> + 1, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +#endif +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { + dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_template.h b/aten/src/ATen/native/hip/ck_gemm_template.h new file mode 100644 index 0000000000000..6d006e8d37d14 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_template.h @@ -0,0 +1,291 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include + + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +// Define commonly used types. +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace at::native { + +template +struct CkMathType { + using dtype = T; +}; + +template <> +struct CkMathType { + using dtype = ck::bhalf_t; +}; + +template <> +struct CkMathType { + using dtype = ck::half_t; +}; + + +template +struct CkTensorLayout { + // default goes to row-wise for now + using a_layout = Row; + using b_layout = Row; +}; + +// True denotes transpose is necessary. Default is Col, so return Row +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Col; +}; + + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Row; +}; + + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Row; +}; + + +// Elementwise Operators +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(C& c, const AB& ab) const; + + template<> + __host__ __device__ constexpr void operator() + (float& c, const float& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::bhalf_t& c, const ck::bhalf_t& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::half_t& c, const ck::half_t& ab) const + { + c = alpha_ * ab; + }; + + float alpha_; + // TODO: Leaving for now, will use later + float beta_; +}; + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int MPER_XDL, + int NPER_XDL, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_AK1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename BLOCK_CLUSTER_LENS, + typename CDE_SCALAR_VEC, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = AlphaBetaAdd; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmSpec, + BLOCK_SIZE, + MBLOCK, + NBLOCK, + KBLOCK, + AK1, + BK1, + MPER_XDL, + NPER_XDL, + MPER_WAVE, + NPER_WAVE, + ABLOCK_CLUSTER_LENS, + ABLOCK_CLUSTER_ORDER, + ABLOCK_SRC_ORDER, + ABLOCK_VECTOR_DIM, + ABLOCK_SCALAR_VEC, + ABLOCK_SCALAR_VEC_AK1, + ABLOCK_LDS_EXTRAM, + BBLOCK_CLUSTER_LENS, + BBLOCK_CLUSTER_ORDER, + BBLOCK_SRC_ORDER, + BBLOCK_VECTOR_DIM, + BBLOCK_SCALAR_VEC, + BBLOCK_SCALAR_VEC_AK1, + BBLOCK_LDS_EXTRAN, + CMPER_WAVE, + CNPER_WAVE, + BLOCK_CLUSTER_LENS, + CDE_SCALAR_VEC>; + + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{alpha, beta}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + DDataArray, + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h deleted file mode 100644 index 1c238c751a05c..0000000000000 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ /dev/null @@ -1,130 +0,0 @@ -#pragma once - -#ifdef USE_ROCM - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// -// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h -//////////////////////////////////////////////////////////////////////////////// - -#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK(TENSOR.is_contiguous()); - -#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK( \ - TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); - -#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ - TORCH_CHECK( \ - uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK( \ - B < std::numeric_limits::max(), #B " overflows"); \ - } - -namespace sdp { - -namespace aotriton_adapter { - -inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) -{ -#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname - CAST_TYPE(kByte, kUInt8); - CAST_TYPE(kUInt16, kUInt16); - CAST_TYPE(kUInt32, kUInt32); - CAST_TYPE(kUInt64, kUInt64); - CAST_TYPE(kChar, kInt8); - CAST_TYPE(kShort, kInt16); - CAST_TYPE(kInt, kInt32); - CAST_TYPE(kLong, kInt64); - CAST_TYPE(kHalf, kFloat16); - CAST_TYPE(kFloat, kFloat32); - CAST_TYPE(kBFloat16, kBFloat16); - return aotriton::DType::kUnknown; -#undef CAST_TYPE -} - -template -struct IntArrayRefCaster { - // std::array cast(IntArrayRef); -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ static_cast(ref.at(0)) }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)), - static_cast(ref.at(3)) - }}; - } -}; - - -template -aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) -{ - const auto strides = q.strides(); - int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name - TORCH_CHECK(false, - std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) - + " but is " + std::to_string(real_rank)); - } - return aotriton::TensorView(reinterpret_cast(q.data_ptr()), - IntArrayRefCaster::cast(q.sizes()), - IntArrayRefCaster::cast(strides), - cast_dtype(q.dtype())); -} - -} // namespace aotriton_adapter - -} // namespace sdp - -namespace at::native { - -inline int64_t ceil_div(int64_t numerator, int64_t denominator) { - return (numerator + (denominator - 1)) / denominator; -} - -} - -#endif // USE_ROCM diff --git a/aten/src/ATen/native/transformers/hip/attention.hip b/aten/src/ATen/native/transformers/hip/attention.hip new file mode 100644 index 0000000000000..b4dfee27409e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/attention.hip @@ -0,0 +1,1465 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif + +#include + +#include +#include +#include +#include +#include + +#ifdef USE_FLASH_ATTENTION +// FlashAttention Specific Imports +#include +#endif +#ifdef USE_MEM_EFF_ATTENTION +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA +#include +#include +#include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif +#endif + +namespace at { + +namespace native { + +namespace { + + +static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4; + +template +__global__ void transform_bias_rescale_qkv_kernel( + // [B, T, 3 * D] + const PackedTensorAccessor64 qkv, + // [3 * D] + const PackedTensorAccessor64 qkv_bias, + // [3, B, NH, T, DH] + PackedTensorAccessor64 q_k_v, + const scalar_t inv_sqrt_dim_per_head) { + // warp per DH. + // so launch B * NH * T warps. + auto NH = q_k_v.size(2); + auto T = q_k_v.size(3); + auto DH = q_k_v.size(4); + + auto t = blockIdx.x % T; + auto b = blockIdx.x / T; + + auto D = NH * DH; + + if (assume_aligned) { + constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC; + using LoadT = memory::aligned_vector; + for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) { + auto d = d_v * VEC; + auto nh = d / DH; + auto dh = d % DH; + scalar_t qkv_bias_q[VEC]; + scalar_t qkv_bias_k[VEC]; + scalar_t qkv_bias_v[VEC]; + scalar_t qkv_q[VEC]; + scalar_t qkv_k[VEC]; + scalar_t qkv_v[VEC]; + + // Here we require D % VEC == 0 for these vectorized loads. + *reinterpret_cast(&qkv_bias_q) = + *reinterpret_cast(&qkv_bias[d + 0 * D]); + *reinterpret_cast(&qkv_bias_k) = + *reinterpret_cast(&qkv_bias[d + 1 * D]); + *reinterpret_cast(&qkv_bias_v) = + *reinterpret_cast(&qkv_bias[d + 2 * D]); + + *reinterpret_cast(&qkv_q) = + *reinterpret_cast(&qkv[b][t][d + 0 * D]); + *reinterpret_cast(&qkv_k) = + *reinterpret_cast(&qkv[b][t][d + 1 * D]); + *reinterpret_cast(&qkv_v) = + *reinterpret_cast(&qkv[b][t][d + 2 * D]); + +#pragma unroll + // TODO: specialize for float2half2/half2float2? + for (auto ii = 0; ii < VEC; ++ii) { + qkv_q[ii] = static_cast( + (static_cast(qkv_q[ii]) + + static_cast(qkv_bias_q[ii])) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k[ii] = static_cast( + (static_cast(qkv_k[ii]) + + static_cast(qkv_bias_k[ii]))); + qkv_v[ii] = static_cast( + (static_cast(qkv_v[ii]) + + static_cast(qkv_bias_v[ii]))); + } + + // Here we require DH % VEC == 0 for these vectorized stores. + *reinterpret_cast(&q_k_v[0][b][nh][t][dh]) = + *reinterpret_cast(&qkv_q); + *reinterpret_cast(&q_k_v[1][b][nh][t][dh]) = + *reinterpret_cast(&qkv_k); + *reinterpret_cast(&q_k_v[2][b][nh][t][dh]) = + *reinterpret_cast(&qkv_v); + } + } else { + // Same as above, but we can't vectorize memory access. + for (int32_t d = threadIdx.x; d < D; d += blockDim.x) { + auto nh = d / DH; + auto dh = d % DH; + scalar_t qkv_bias_q = qkv_bias[d + 0 * D]; + scalar_t qkv_bias_k = qkv_bias[d + 1 * D]; + scalar_t qkv_bias_v = qkv_bias[d + 2 * D]; + scalar_t qkv_q = qkv[b][t][d + 0 * D]; + scalar_t qkv_k = qkv[b][t][d + 1 * D]; + scalar_t qkv_v = qkv[b][t][d + 2 * D]; + qkv_q = static_cast( + (static_cast(qkv_q) + + static_cast(qkv_bias_q)) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k = static_cast( + (static_cast(qkv_k) + + static_cast(qkv_bias_k))); + qkv_v = static_cast( + (static_cast(qkv_v) + + static_cast(qkv_bias_v))); + + q_k_v[0][b][nh][t][dh] = qkv_q; + q_k_v[1][b][nh][t][dh] = qkv_k; + q_k_v[2][b][nh][t][dh] = qkv_v; + } + } +} + +template +__global__ void transform_bias_rescale_qkv_add_padding_kernel( + // [B, T, 3 * D], but it's a NestedTensor buffer + const PackedTensorAccessor64 qkv, + // [3 * D] + const PackedTensorAccessor64 qkv_bias, + const int* offsets, + const int* input_sizes, + // [3, B, NH, T, DH] + PackedTensorAccessor64 q_k_v, + const scalar_t inv_sqrt_dim_per_head) { + // warp per DH. + // so launch B * NH * T warps. + const auto NH = q_k_v.size(2); + const auto T = q_k_v.size(3); + const auto DH = q_k_v.size(4); + + const auto t = blockIdx.x % T; + const auto b = blockIdx.x / T; + + const auto D = NH * DH; + const auto _3D = 3 * D; + + const auto offset_for_batch = offsets[b]; + const auto input_dim = 1; + const auto* sizes_i = input_sizes + b * input_dim; + if (assume_aligned) { + constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC; + using LoadT = memory::aligned_vector; + for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) { + auto d = d_v * VEC; + auto nh = d / DH; + auto dh = d % DH; + scalar_t qkv_bias_q[VEC]; + scalar_t qkv_bias_k[VEC]; + scalar_t qkv_bias_v[VEC]; + scalar_t qkv_q[VEC]; + scalar_t qkv_k[VEC]; + scalar_t qkv_v[VEC]; + + const auto first_item_offset = t * _3D + d; + const auto last_item_offset = first_item_offset + VEC - 1; + const bool first_item_in_bounds = first_item_offset < sizes_i[0]; + const bool entire_vec_in_bounds = last_item_offset < sizes_i[0]; + + // Here we require D % VEC == 0 for these vectorized loads. + *reinterpret_cast(&qkv_bias_q) = + *reinterpret_cast(&qkv_bias[d + 0 * D]); + *reinterpret_cast(&qkv_bias_k) = + *reinterpret_cast(&qkv_bias[d + 1 * D]); + *reinterpret_cast(&qkv_bias_v) = + *reinterpret_cast(&qkv_bias[d + 2 * D]); + + if (entire_vec_in_bounds) { + const auto offset = offset_for_batch + first_item_offset; + *reinterpret_cast(&qkv_q) = + *reinterpret_cast(&qkv[offset + 0 * D]); + *reinterpret_cast(&qkv_k) = + *reinterpret_cast(&qkv[offset + 1 * D]); + *reinterpret_cast(&qkv_v) = + *reinterpret_cast(&qkv[offset + 2 * D]); +#pragma unroll + // TODO: specialize for float2half2/half2float2? + for (auto ii = 0; ii < VEC; ++ii) { + qkv_q[ii] = static_cast( + (static_cast(qkv_q[ii]) + + static_cast(qkv_bias_q[ii])) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k[ii] = static_cast( + (static_cast(qkv_k[ii]) + + static_cast(qkv_bias_k[ii]))); + qkv_v[ii] = static_cast( + (static_cast(qkv_v[ii]) + + static_cast(qkv_bias_v[ii]))); + } + } else if (first_item_in_bounds) { + const auto offset = offset_for_batch + first_item_offset; + qkv_q[0] = qkv[offset + 0 * D]; + qkv_k[0] = qkv[offset + 1 * D]; + qkv_v[0] = qkv[offset + 2 * D]; + qkv_q[0] = static_cast( + (static_cast(qkv_q[0]) + + static_cast(qkv_bias_q[0])) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k[0] = static_cast( + (static_cast(qkv_k[0]) + + static_cast(qkv_bias_k[0]))); + qkv_v[0] = static_cast( + (static_cast(qkv_v[0]) + + static_cast(qkv_bias_v[0]))); +#pragma unroll + for (auto ii = 1; ii < VEC; ++ii) { + const auto loop_offset = offset + ii; + if (loop_offset < sizes_i[0]) { + qkv_q[ii] = qkv[loop_offset + 0 * D]; + qkv_k[ii] = qkv[loop_offset + 1 * D]; + qkv_v[ii] = qkv[loop_offset + 2 * D]; + qkv_q[ii] = static_cast( + (static_cast(qkv_q[ii]) + + static_cast(qkv_bias_q[ii])) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k[ii] = static_cast( + (static_cast(qkv_k[ii]) + + static_cast(qkv_bias_k[ii]))); + qkv_v[ii] = static_cast( + (static_cast(qkv_v[ii]) + + static_cast(qkv_bias_v[ii]))); + } else { + qkv_q[ii] = 0; + qkv_k[ii] = 0; + qkv_v[ii] = 0; + } + } + } else { +#pragma unroll + for (auto ii = 0; ii < VEC; ++ii) { + qkv_q[ii] = 0; + qkv_k[ii] = 0; + qkv_v[ii] = 0; + } + } + + // Here we require DH % VEC == 0 for these vectorized stores. + *reinterpret_cast(&q_k_v[0][b][nh][t][dh]) = + *reinterpret_cast(&qkv_q); + *reinterpret_cast(&q_k_v[1][b][nh][t][dh]) = + *reinterpret_cast(&qkv_k); + *reinterpret_cast(&q_k_v[2][b][nh][t][dh]) = + *reinterpret_cast(&qkv_v); + } + } else { + for (int32_t d = threadIdx.x; d < D; d += blockDim.x) { + auto nh = d / DH; + auto dh = d % DH; + scalar_t qkv_bias_q = qkv_bias[d + 0 * D]; + scalar_t qkv_bias_k = qkv_bias[d + 1 * D]; + scalar_t qkv_bias_v = qkv_bias[d + 2 * D]; + + const auto item_offset = t * _3D + d; + const bool in_bounds = item_offset < sizes_i[0]; + scalar_t qkv_q, qkv_k, qkv_v; + if (in_bounds) { + const auto qkv_offset = offset_for_batch + item_offset; + qkv_q = qkv[qkv_offset + 0 * D]; + qkv_k = qkv[qkv_offset + 1 * D]; + qkv_v = qkv[qkv_offset + 2 * D]; + qkv_q = static_cast( + (static_cast(qkv_q) + + static_cast(qkv_bias_q)) * + static_cast(inv_sqrt_dim_per_head)); + qkv_k = static_cast( + (static_cast(qkv_k) + + static_cast(qkv_bias_k))); + qkv_v = static_cast( + (static_cast(qkv_v) + + static_cast(qkv_bias_v))); + } else { + qkv_q = 0; + qkv_k = 0; + qkv_v = 0; + } + + q_k_v[0][b][nh][t][dh] = qkv_q; + q_k_v[1][b][nh][t][dh] = qkv_k; + q_k_v[2][b][nh][t][dh] = qkv_v; + } + } +} + +Tensor collapse_dims_1_and_2(const Tensor& sizes) { + auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1); + auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1); + + return (sizes_dim1 * sizes_dim2).contiguous(); +} + +} // namespace +// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias +__host__ std::tuple transform_bias_rescale_qkv_cuda( + const Tensor& qkv, + const Tensor& qkv_bias, + const int64_t num_head) { + auto B = qkv.is_nested() + ? get_nested_tensor_impl(qkv)->get_nested_sizes().size(0) + : qkv.size(0); + // TODO: calculate this without the std::vector -- NestedTensor_to_mask wants + // this too + auto T = qkv.is_nested() + ? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0] + : qkv.size(1); + if (qkv.is_nested()) { + // Don't mess with non-nested case for now since it's not set up to fiddle + // with mask size. + + // Round T up to next multiple of 8 so as to be able to utilize Tensor + // cores. Otherwise, sometimes with padding, *no* row will have the maximum + // sequence length and so we'll have a non-divisible-by-8 dimension even if + // the model author chose a multiple of 8. + T = T + (8 - (T % 8)) % 8; + } + auto _3D = qkv_bias.size(0); + auto D = _3D / 3; + TORCH_CHECK(D % num_head == 0); + const auto dim_per_head = D / num_head; + auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options()); +#define CALL_KERNEL(assume_aligned) \ + hipLaunchKernelGGL(( transform_bias_rescale_qkv_kernel) \ + , dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), \ + qkv.packed_accessor64(), \ + qkv_bias.packed_accessor64(), \ + q_k_v.packed_accessor64(), \ + 1.0 / std::sqrt(static_cast(dim_per_head))) +#define CALL_ADD_PADDING_KERNEL(assume_aligned) \ + hipLaunchKernelGGL(( transform_bias_rescale_qkv_add_padding_kernel< \ + scalar_t, \ + accscalar_t, \ + assume_aligned>) \ + , dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), \ + nt_qkv_buffer \ + .packed_accessor64(), \ + qkv_bias.packed_accessor64(), \ + offsets_ptr, \ + sizes_ptr, \ + q_k_v.packed_accessor64(), \ + 1.0 / std::sqrt(static_cast(dim_per_head))) + + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + qkv.scalar_type(), + "transform_bias_rescale_qkv", + [&] { + using accscalar_t = acc_type; + auto threads = ::max( + std::min(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1); + auto blocks = B * T; + const bool aligned = + ((dim_per_head % TRANSFORM_BIAS_RESCALE_VEC) == 0) && + ((reinterpret_cast(qkv_bias.data_ptr()) % + TRANSFORM_BIAS_RESCALE_VEC) == 0); + if (aligned) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + D % TRANSFORM_BIAS_RESCALE_VEC == 0, + "D = num_heads * dim_per_head, so we should have dim_per_head % " + "TRANSFORM_BIAS_RESCALE_VEC == 0 => " + "D % TRANSFORM_BIAS_RESCALE_VEC == 0"); + } + if (qkv.is_nested()) { + auto* nt_qkv = get_nested_tensor_impl(qkv); + const at::Tensor& nt_qkv_buffer = nt_qkv->get_buffer(); + auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_sizes()); + auto offsets = + NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel()); + at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel()) + .copy_(sizes.reshape({-1})); + auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true); + const auto offsets_ptr = metadata.data_ptr(); + const auto sizes_ptr = offsets_ptr + sizes.numel() + 1; + const auto input_dim = sizes.sizes()[1]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input_dim == 1); + if (aligned && + ((reinterpret_cast(qkv.data_ptr()) % + TRANSFORM_BIAS_RESCALE_VEC) == 0)) { + CALL_ADD_PADDING_KERNEL(true); + } else { + CALL_ADD_PADDING_KERNEL(false); + } + } else if (aligned) { + CALL_KERNEL(true); + } else { + CALL_KERNEL(false); + } + C10_HIP_KERNEL_LAUNCH_CHECK(); + }); +#undef CALL_ADD_PADDING_KERNEL +#undef CALL_KERNEL + auto q_k_v_s = + at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0); + return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]); +} + +std::tuple native_multi_head_attention_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const int64_t embed_dim, + const int64_t num_head, + const Tensor& qkv_weight, + const Tensor& qkv_bias, + const Tensor& proj_weight, + const Tensor& proj_bias, + const std::optional& mask, + bool need_weights, + bool average_attn_weights, + const std::optional mask_type) { + // query shape: [B, T, D] + // qkv_weight shape: [3 * D, D] + + TORCH_CHECK( + !mask || !query.is_nested(), + "NestedTensor with mask is not supported yet"); + const auto D = embed_dim; + TORCH_CHECK( + query.dim() == 3, + "expected 3-D `query`, got ", + query.dim(), + "-D tensor"); + TORCH_CHECK( + query.is_nested() || query.sizes()[2] == embed_dim, + "passed-in embed_dim ", + embed_dim, + " didn't match last dim of query ", + query.sizes()[2]); + TORCH_CHECK( + key.dim() == 3, + "expected 3-D `key`, got ", + key.dim(), + "-D tensor"); + TORCH_CHECK( + value.dim() == 3, + "expected 3-D `value`, got ", + value.dim(), + "-D tensor"); + TORCH_CHECK( + query.is_nested() || key.is_nested() || value.is_nested() || + (query.sizes() == key.sizes() && key.sizes() == value.sizes()), + "expected `query`/`key`/`value` shapes to match"); + TORCH_CHECK( + qkv_weight.dim() == 2, + "expected 2-D `qkv_weight`, got ", + qkv_weight.dim(), + "-D tensor"); + TORCH_CHECK( + D * 3 == qkv_weight.sizes()[0], + "expected `qkv_weight` first dim to be 3x embed_dim"); + TORCH_CHECK( + D == qkv_weight.sizes()[1], + "expected `qkv_weight` second dim to be embed_Dim"); + TORCH_CHECK( + qkv_bias.dim() == 1, + "expected 1-D `qkv_bias`, got ", + qkv_bias.dim(), + "-D tensor"); + TORCH_CHECK( + qkv_bias.sizes()[0] == 3 * D, + "expected `qkv_bias` first dim and first dim of query to be equal"); + TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`"); + +#ifndef NDEBUG + const auto B = query.is_nested() + ? get_nested_tensor_impl(query)->get_nested_sizes().size(0) + : query.sizes()[0]; + auto T = query.is_nested() ? 0 : query.sizes()[1]; + +#endif + const auto dim_per_head = D / num_head; + if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && !need_weights) { + + // We have not done linear projection yet but the input for SDP + // Is expected to be 4 dimensional. We "cheaply" create view tensors + // That will then be used for checking hot path conditions with select_sd_backend + auto q = query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + + sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false}; + auto backend = select_sdp_backend(kernel_params); + // strides from packed projection for nested tensors when seq_len is 1 will be + // and will trigger a contiguous call in the kernel, so we prevent this + bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true; + // The API for transformer_encoder is a mask of shape (Batch_Size, Seq_len_q) + // For mem-eff attention this will cause the expand call to error + // For now I am going to turn of that path not have to deal with all the annoying + // Mask type shape grossness + if (!mask.has_value() && no_seq_len_1_nested && + (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention || + backend == sdp::SDPBackend::cudnn_attention)) { + auto x = at::linear(query, qkv_weight, qkv_bias); + auto chunks = x.chunk(3, -1); + auto x_size_0 = x.size(0); + + chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + auto y = at::scaled_dot_product_attention( + chunks[0], chunks[1], chunks[2], mask, 0.0, false, c10::nullopt); + + auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim}); + return std::make_tuple( + at::linear(past_sdp, proj_weight, proj_bias), Tensor()); + } + // Returned math or error lets not use it + } + + // shape: [B, T, 3 x D] + auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight); + + if (!qkv.is_nested() && qkv.numel() == 0) { + if (query.is_nested()) { + return std::make_tuple(Tensor(), Tensor()); + } + return std::make_tuple(at::empty_like(query), Tensor()); + } + +#ifndef NDEBUG + if (!query.is_nested() || !qkv.is_nested()) { + if (query.is_nested()) { + T = qkv.size(1); + } + debug_assert_shape(__LINE__, qkv, {B, T, 3 * D}); + } +#endif + +#ifdef DEBUG_PRINT_EACH_STEP + if (!qkv.is_nested()) { + std::cerr << "qkv: " << qkv << std::endl; + } +#endif + // shape: 3 x [B, num_head, T, dim_per_head] + auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); + qkv = Tensor(); // Not used any more, allow free + auto& q = std::get<0>(q_k_v); + const auto& k = std::get<1>(q_k_v); + const auto& v = std::get<2>(q_k_v); +#ifndef NDEBUG + debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head}); + debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head}); + debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "q: " << q << std::endl; + std::cerr << "k: " << k << std::endl; + std::cerr << "v: " << v << std::endl; +#endif + + // shape: [B, num_head, T, T] + auto qkt = bmm_nt(q, k); + // q & k are dead but cannot be freed because they were packed with v +#ifndef NDEBUG + debug_assert_shape(__LINE__, qkt, {B, num_head, T, T}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "qkt: " << qkt << std::endl; +#endif + + // shape: [B, num_head, T, T] + // TODO: long-term, have a kernel that works with + // NestedTensor directly if there is no mask passed + qkt = masked_softmax(qkt, mask, query, mask_type); +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "qkt after softmax: " << qkt << std::endl; +#endif + + // shape: [B, num_head, T, dim_per_head] + // reuse storage for q; we're done with it + auto attn_ctx = bmm_nn(q, qkt, v); + // qkv is not dead; we just reused storage for q! + if (!need_weights) { + qkt = Tensor(); + } +#ifndef NDEBUG + debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "attn_ctx: " << attn_ctx << std::endl; +#endif + + // shape: [B, T, D] + // Fuse transform_0213 inside + auto proj = transform0213_gemm_nt_bias( + attn_ctx, proj_weight, proj_bias, query); +#ifndef NDEBUG + debug_assert_shape(__LINE__, proj, {B, T, D}); +#endif + if (need_weights && average_attn_weights) { + // weights are not needed for full transformer, so don't worry too + // much about performance -- we implement this just to make use + // cases that don't disable need_weights still get some speedup. + qkt = qkt.sum(1); + qkt /= num_head; + } + return std::make_tuple(std::move(proj), std::move(qkt)); +} +std::tuple _scaled_dot_product_flash_attention_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + // Used for tracking usage statistics + C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention"); + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + + const int64_t max_seqlen_batch_q = query.size(2); + const int64_t max_seqlen_batch_k = key.size(2); + const int64_t max_seqlen_batch_v = value.size(2); + TORCH_CHECK( + max_seqlen_batch_k == max_seqlen_batch_v, + "Key and Value must have the same sequence length"); + + // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) + Tensor q_t = query.transpose(1, 2); + Tensor k_t = key.transpose(1, 2); + Tensor v_t = value.transpose(1, 2); + + auto + [output, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask] = + at::_flash_attention_forward( + q_t, + k_t, + v_t, + c10::nullopt, + c10::nullopt, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + is_causal, + return_debug_mask, + scale, + c10::nullopt, + c10::nullopt); + // Reshape output to convert nnz to batch_size and seq_len + Tensor attention = output.transpose(1,2); + + return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask); +} + +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) { + if (arg.captured_) { + *seed_ptr = static_cast(*arg.seed_.ptr); + *offset_ptr = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + *seed_ptr = static_cast(arg.seed_.val); + *offset_ptr = static_cast(arg.offset_.val); + } +} + +std::tuple _scaled_dot_product_cudnn_attention_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const std::optional& attn_bias, + bool compute_logsumexp, + double dropout_p, + bool is_causal, + bool return_debug_mask, + c10::optional scale) { + // Used for tracking usage statistics + C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn"); + // TODO(eqy): debug mask support + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t max_seqlen_batch_q = query.size(2); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_batch_k = key.size(2); + const int64_t max_seqlen_batch_v = value.size(2); + TORCH_CHECK( + max_seqlen_batch_k == max_seqlen_batch_v, + "Key and Value must have the same sequence length"); + + Tensor attention, log_sumexp; + + at::Tensor cudnn_seed, cudnn_offset; + cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + + // See Note [Seed and Offset Device] in _efficient_attention_forward + at::PhiloxCudaState philox_state; + const bool in_capture_stream = + at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; + if (use_dropout) { + // Device + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount + philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k); + hipLaunchKernelGGL(( unpack_cudnn), dim3(1), dim3(1), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), + philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr())); + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + Tensor debugmask; + + run_cudnn_SDP_fprop(batch_size/*int64_t b*/, + num_heads/*int64_t h*/, + max_seqlen_batch_q/*int64_t s_q*/, + max_seqlen_batch_k/*int64_t s_kv*/, + head_dim_qk/*int64_t d_qk*/, + head_dim_v/*int64_t d_v*/, + softmax_scale/*float scaling_factor*/, + compute_logsumexp/* bool */, + is_causal/* bool */, + dropout_p/*double dropout_probability*/, + query/* Tensor q*/, + key/* Tensor k*/, + value/* Tensor v*/, + log_sumexp/*Tensor softmaxstats*/, + attention/*Tensor o*/, + cudnn_seed/*Tensor dropoutseed*/, + cudnn_offset/*Tensor dropoutoffset*/); + + // TODO(eqy): support debug_attn_mask + return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor()); +} + +std::tuple _scaled_dot_product_efficient_attention_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const std::optional& attn_bias, + bool compute_log_sumexp, + double dropout_p, + bool is_causal, + std::optional scale) { + // Used for tracking usage statistics + C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention"); + // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) + Tensor q_t = query.transpose(1, 2); + Tensor k_t = key.transpose(1, 2); + Tensor v_t = value.transpose(1, 2); + + sdp::CustomMaskType custom_mask_type = is_causal + ? sdp::CustomMaskType::CausalFromTopLeft + : sdp::CustomMaskType::NoCustomMask; + + auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward( + q_t, + k_t, + v_t, + attn_bias, + c10::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, + dropout_p, + static_cast(custom_mask_type), + compute_log_sumexp, + scale); + + attention = attention.transpose(1, 2); + return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset)); +} + +int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, + const std::optional& attn_mask_, double dropout_p, bool is_causal, std::optional scale){ + sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal}; + auto backend = select_sdp_backend(kernel_params); + if (backend == sdp::SDPBackend::error) { + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found. ", + "This is likely due to turning off both the math kernel and the fused kernels."); + } + return static_cast(backend); +} + +std::tuple +_flash_attention_forward( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const std::optional& cumulative_sequence_length_q, + const std::optional& cumulative_sequence_length_k, + int64_t max_seqlen_batch_q, + int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale, + std::optional window_size_left, + std::optional window_size_right, + const std::optional& _seqused_k, + const std::optional& _alibi_slopes + ) { +#if defined(USE_FLASH_ATTENTION) + const auto softmax_scale = + sdp::calculate_scale(query, scale).as_float_unchecked(); + std::optional out = c10::nullopt; + + std::optional seqused_k = _seqused_k; + std::optional alibi_slopes = _alibi_slopes; + + const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; + const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; + + // We are going to have two paths: + // 1. The standard MHA path for dense tensors + // 2. The Varseqlen path + TORCH_CHECK( + cumulative_sequence_length_q.has_value() == + cumulative_sequence_length_k.has_value(), + "cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set"); + Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape, + philox_seed, philox_offset, debug_attn_mask; + if (cumulative_sequence_length_q.has_value()) { + std::tie( + output, + q_padded, + k_padded, + v_padded, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask) = + pytorch_flash::mha_varlen_fwd( + query, + key, + value, + out, + cumulative_sequence_length_q.value(), + cumulative_sequence_length_k.value(), + seqused_k, /*seqused_k*/ + alibi_slopes, /*alibi_slopes*/ + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + softmax_scale, + false /*zero_tensors*/, + is_causal, + non_null_window_left, + non_null_window_right, + return_debug_mask, + c10::nullopt /*gen_*/); + } else { + std::tie( + output, + q_padded, + k_padded, + v_padded, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask) = + pytorch_flash::mha_fwd( + query, + key, + value, + out, + alibi_slopes, + dropout_p, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_debug_mask, /*return_softmax (this is used for testing)*/ + c10::nullopt); + } + debug_attn_mask = + return_debug_mask ? debug_attn_mask : at::empty({0}, query.options()); + return std::make_tuple( + std::move(output), + std::move(logsumexp), + std::move(philox_seed), + std::move(philox_offset), + std::move(debug_attn_mask)); + +#endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return std::make_tuple( + Tensor(), + Tensor(), + Tensor(), + Tensor(), + Tensor()); +} + +std::tuple _efficient_attention_forward( + const at::Tensor& query, // [b, seqlen, num_heads, K] + const at::Tensor& key, // [b, seqlen, num_heads, K] + const at::Tensor& value, // [b, seqlen, num_heads, Kv] + const std::optional& bias, // [b, num_heads, seqlen, seqlen] + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const std::optional& seqstart_q, + // (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the + // position of the first key token for batch $b + const std::optional& seqstart_k, + // (Mode 1MHK only) Maximum sequence length across batches + const std::optional max_seqlen_q_, + const std::optional max_seqlen_k_, + double dropout_p, // attention matrix dropout probability + int64_t custom_mask_type, + bool compute_logsumexp, + std::optional scale, + const std::optional& seqlen_k, + const std::optional window_size) { +#if defined(USE_MEM_EFF_ATTENTION) +// TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a +// machine that is >= 5.0. In practice, this is not a problem but since +// this would avoid runtime architecture checks, we should look into it + + TORCH_CHECK(query.dim() == 4); + TORCH_CHECK(key.dim() == 4); + TORCH_CHECK(value.dim() == 4); + + // Batch sizes + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // Sequence length + TORCH_CHECK(key.size(1) == value.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + + int64_t max_seqlen_q = 0, max_seqlen_k = 0; + TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value()); + if (seqstart_q.has_value()) { + TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_q)); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_k)); + TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0)); + TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); + TORCH_CHECK(max_seqlen_q_.has_value()); + max_seqlen_q = *max_seqlen_q_; + max_seqlen_k = 0; // TODO: is this actually being set inside the kernel anywhere? + // see https://github.com/pytorch/pytorch/issues/115590s + } else { + max_seqlen_q = query.size(1); + max_seqlen_k = key.size(1); + } + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::hip::HIPGuardMasqueradingAsCUDA device_guard(query.device()); + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t num_heads = query.size(-2); + int64_t K = query.size(-1); + int64_t Kv = value.size(-1); + + at::Tensor res; + at::Tensor logsumexp; + at::Tensor seed_t, offset_t; + + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + + // Note [Seed and Offset Device] + // If we are currently in graph capture mode, we need to create the seed and offset tensors on the device. + // This is necessary for CUDA graph-safe random number generation, which requires the seed and offset tensors + // to be single element tensors on device. During graph capture, when the seed and offset tensors are passed + // the pointers act as scratch space for storing the RNG state for the backwards pass. + // When calling backwards, we either construct a PhiloxState with the pointers or the actual values. + // For more information on CUDA graph-safe RNG states, see Note [CUDA Graph-safe RNG states]. + + at::PhiloxCudaState philox_state; + const bool in_capture_stream = + at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None; + auto device = in_capture_stream ? at::kCUDA : at::kCPU; + if (use_dropout) { + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + philox_state = gen->philox_cuda_state(B * num_heads * M * N); + + if (in_capture_stream) { + // The seed and offset will be populated by the kernel + seed_t = at::empty({}, at::dtype(at::kLong).device(device)); + offset_t = at::empty({}, at::dtype(at::kLong).device(device)); + } else { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor( + at::Scalar(static_cast(seed)), at::dtype(at::kLong)); + offset_t = at::scalar_tensor( + at::Scalar(static_cast(offset)), at::dtype(at::kLong)); + } + } else { + // Not using dropout + seed_t = at::empty({}, at::dtype(at::kLong).device(device)); + offset_t = at::empty({}, at::dtype(at::kLong).device(device)); + } + +#ifdef USE_ROCM + // ROCM Implementation + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + } + + // AOTriton may accept aligned on logsumexp tensor in the future for better + // performance, but for now it requires compact logsumexp tensor, even if + // compute_logsumexp is false + constexpr int kAlignLSE = 1; + res = at::empty({B, M, num_heads, Kv}, query.options()); + logsumexp = at::empty( + { B, num_heads, max_seqlen_q }, + query.options().dtype(at::ScalarType::Float)); + at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + at::Tensor q_t = query.transpose(1, 2); + at::Tensor k_t = key.transpose(1, 2); + at::Tensor v_t = value.transpose(1, 2); + at::Tensor output_t = res.transpose(1, 2); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + + using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); + at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); + hipError_t err; // TODO: Error handling + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + use_dropout ? *seed_t.data_ptr() : 0, + use_dropout ? *offset_t.data_ptr() : 0, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + if (!compute_logsumexp) { + // Set the tensor to empty when compute_logsumexp is false + logsumexp = at::empty( + { B * num_heads, max_seqlen_q, 0 }, + query.options().dtype(at::ScalarType::Float)); + } +#else + // CUDA Implementation + hipDeviceProp_t* p = at::cuda::getDeviceProperties(query.device().index()); + const int computeCapability = p->major * 10 + p->minor; + + bool kernel_launched = false; + const auto maxShmem = p->sharedMemPerBlockOptin; + + auto launchKernel = [&](auto _k, auto kernel_fn) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + if (kernel_launched) { + return; + } + // Check if this kernel is compatible + if (!Kernel::kSupportsDropout && use_dropout) { + return; + } + if (!Kernel::kSupportsBias && bias.has_value()) { + return; + } + + if (value.size(3) > Kernel::kMaxK || key.size(3) > Kernel::kMaxK) { + return; + } + // Alignment + if ((query.stride(2) % Kernel::kAlignmentQ) || + (key.stride(2) % Kernel::kAlignmentK) || + (value.stride(2) % Kernel::kAlignmentV)) { + return; + } + // Uses too much shmem + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + if (smem_bytes > maxShmem) { + return; + } + kernel_launched = true; + + res = at::empty( + {B, M, num_heads, Kv}, + query.options().dtype( + CutlassToAtenDtype::atScalarType())); + + // NOTE: Should be aligned (by padding) in case M is + // not a good number for loading during backward + constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; + logsumexp = at::empty( + {seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B, + num_heads, + compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, + query.options().dtype(at::ScalarType::Float)); + typename Kernel::Params p; + p.query_ptr = (const scalar_t*)query.const_data_ptr(); + p.key_ptr = (const scalar_t*)key.const_data_ptr(); + p.value_ptr = (const scalar_t*)value.const_data_ptr(); + p.logsumexp_ptr = compute_logsumexp + ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() + : nullptr; + at::Tensor output_accum; + if (Kernel::kNeedsOutputAccumulatorBuffer) { + output_accum = at::empty( + {B, M, num_heads, Kv}, + query.options().dtype( + CutlassToAtenDtype< + typename Kernel::output_accum_t>::atScalarType())); + p.output_accum_ptr = + (typename Kernel::output_accum_t*)output_accum.data_ptr(); + } else { + p.output_accum_ptr = nullptr; + } + p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); + + if (seqstart_q.has_value()) { + p.seqstart_q_ptr = (const int32_t*)seqstart_q->const_data_ptr(); + p.seqstart_k_ptr = (const int32_t*)seqstart_k->const_data_ptr(); + } + + p.num_heads = num_heads; + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = max_seqlen_q; + p.num_keys = max_seqlen_k; + p.num_batches = seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B; + p.custom_mask_type = custom_mask_type; + + p.seqlen_k_ptr = nullptr; + if (seqlen_k.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(seqlen_k.value()); + TORCH_CHECK(seqlen_k->scalar_type() == at::ScalarType::Int); + p.seqlen_k_ptr = (const int32_t*)seqlen_k->const_data_ptr(); + } + if (window_size.has_value()) { + p.window_size = *window_size; + } + p.scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.o_strideM, res.stride(1)); + + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + TORCH_CHECK( + bias->scalar_type() == CutlassToAtenDtype::atScalarType(), + "invalid dtype for bias - should match query's dtype"); + p.attn_bias_ptr = (const scalar_t*)bias->const_data_ptr(); + + TORCH_CHECK(bias->dim() == 4, "Bias expected in BMHK format"); + TORCH_CHECK( + bias->size(0) == query.size(0), + "attn_bias: wrong shape (batch dimension)"); + TORCH_CHECK( + bias->size(1) == query.size(2), + "attn_bias: wrong shape (head dimension)"); + TORCH_CHECK( + bias->size(2) == query.size(1), + "attn_bias: wrong shape (seqlenQ dimension)"); + TORCH_CHECK( + bias->size(3) == key.size(1), + "attn_bias: wrong shape (seqlenKV dimension)"); + ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias->stride(0)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(1)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(2)); + TORCH_CHECK( + bias->stride(3) == 1, + "attn_bias: wrong alignment (last dimension must be contiguous)"); + } + + p.use_dropout = use_dropout; + if (p.use_dropout) { + p.rng_engine_inputs = philox_state; + p.dropout_prob = dropout_p; + p.seed = seed_t.data_ptr(); + p.extragraph_offset = offset_t.data_ptr(); + } + + if (smem_bytes > 0xc000) { + auto err = hipFuncSetAttribute( + kernel_fn, hipFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + TORCH_CHECK( + err != hipErrorInvalidValue, + "This GPU does not have enough shared-memory (kernel requires ", + smem_bytes / 1024, + " kb)"); + AT_CUDA_CHECK(err); + } + auto blocks = p.getBlocksGrid(); + if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) { + res.zero_(); + return; + } + Kernel::check_supported(p); + hipLaunchKernelGGL(( kernel_fn), dim3(blocks), dim3(p.getThreadsGrid()), smem_bytes, stream, p); + }; + + // Dispatch to the right kernel + DISPATCH_TYPES(query, ([&]() { + dispatch_cutlassF(launchKernel, computeCapability); + })); + TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!"); + AT_CUDA_CHECK(hipGetLastError()); + +#endif // USE_ROCM + return std::make_tuple( + std::move(res), + std::move(logsumexp), + std::move(seed_t), + std::move(offset_t), + max_seqlen_q, + // TODO: why isn't this being set in the kernel? + max_seqlen_k_.has_value() ? max_seqlen_k_.value() : max_seqlen_k); +#endif + TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0); +} + +Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){ + TORCH_CHECK(false, "This operator should be overridden in python before use"); + return at::Tensor(); +} + +REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda); + +#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) +namespace { +/** + * simple kernel that populates a tensor with rand uniform values. + * currently only used for testing purposes, not much attention + * is paid to performance. + * + * problem is partitioned as follows: + * - (batch, head) is given by block coordinates + * - each thread handles a row for a given (batch, head) + */ +template +__global__ void rand_uniform_kernel( + int64_t n_heads, + int64_t n_queries, + int64_t n_keys, + float dropout_prob, + at::PhiloxCudaState rng_engine_inputs, + mask_t* mask_out, + int64_t mask_numel) { + const int64_t batch_id = blockIdx.x; + const int64_t head_id = blockIdx.y; + const int64_t query_idx = threadIdx.x; + + const auto seeds = at::cuda::philox::unpack(rng_engine_inputs); + + const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) + + head_id * (n_queries * n_keys); + const int64_t query_start_idx = query_idx * n_keys; + + hiprandStatePhilox4_32_10_t curand_state; + hiprand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + dropout_seq_start + query_start_idx, + &curand_state); + + for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) { + float4 rand_quad = hiprand_uniform4(&curand_state); + +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int64_t linear_idx = dropout_seq_start + query_start_idx + key_start_idx + i; + if (linear_idx < mask_numel) { + mask_out[linear_idx] = (&rand_quad.x)[i]; + } + } + } +} +} // namespace +#endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) +/** + * fill tensor with random uniform values. only used for testing, not much + * attention is paid to performance + */ +at::Tensor& _fill_mem_eff_dropout_mask_( + Tensor& self, + double dropout_p, + const int64_t seed, + const int64_t offset) { + TORCH_CHECK(self.is_contiguous()); + TORCH_CHECK(self.dtype() == at::ScalarType::Float); + const int64_t batch_sz = self.size(0); + const int64_t n_heads = self.size(1); + const int64_t n_queries = self.size(2); + const int64_t n_keys = self.size(3); +#if defined(USE_MEM_EFF_ATTENTION) + +#ifdef USE_ROCM + using aotriton::v2::flash::debug_fill_dropout_rng; + using sdp::aotriton_adapter::mk_aotensor; + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + hipError_t err; // TODO: Error handling + + err = debug_fill_dropout_rng(mk_aotensor(self, "r"), + static_cast(seed), + static_cast(offset), + stream); +#else + at::PhiloxCudaState rng_engine_inputs; + rng_engine_inputs = at::PhiloxCudaState(seed, offset); + at::hip::HIPGuardMasqueradingAsCUDA device_guard(self.device()); + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + hipLaunchKernelGGL(( rand_uniform_kernel), dim3(dim3(batch_sz, n_heads)), dim3(n_queries), 0, stream, + n_heads, + n_queries, + n_keys, + dropout_p, + rng_engine_inputs, + reinterpret_cast(self.data_ptr()), + self.numel()); +#endif + + return self; +#endif + TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") + return self; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/hip/attention_backward.hip b/aten/src/ATen/native/transformers/hip/attention_backward.hip new file mode 100644 index 0000000000000..c6f129c649d02 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/attention_backward.hip @@ -0,0 +1,826 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +#ifdef USE_FLASH_ATTENTION +// FlashAttention Specific Imports +#include +#endif +#ifdef USE_MEM_EFF_ATTENTION +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA +#include +#include +#include +#include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif +#endif + +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif + +namespace at::native { + +std::tuple _flash_attention_backward( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_k, + int64_t max_seqlen_batch_q, + int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + const Tensor& philox_seed, + const Tensor& philox_offset, + std::optional scale, + std::optional window_size_left, + std::optional window_size_right) { +#if defined(USE_FLASH_ATTENTION) + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + // CUDA code assumes that dout is contiguous + auto contiguous_grad_out = grad_out.contiguous(); + auto contiguous_out = out.contiguous(); + + const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; + const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; + + std::optional dq{c10::nullopt}; + std::optional dk{c10::nullopt}; + std::optional dv{c10::nullopt}; + + // The kernel computes irregardless we will drop for this functions return + Tensor grad_softmax; + + // Currently unused args: + std::optional alibi_slopes{c10::nullopt}; + + bool determinisitic{false}; + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "Flash Attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } else { + determinisitic = true; + } + } + + // We check the whether the cumulative_sequence_length_q is defined + // in order to determine whether we are using varlen or dense forward + if (cumulative_sequence_length_q.defined()) { + // Varlen forward + auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd( + contiguous_grad_out, + query, + key, + value, + contiguous_out, + logsumexp, + dq, + dk, + dv, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + alibi_slopes, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + softmax_scale, + false /*zero_tensors*/, + is_causal, + non_null_window_left, + non_null_window_right, + determinisitic, + philox_seed, + philox_offset); + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); + } else { + // Dense forward + auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd( + contiguous_grad_out, + query, + key, + value, + contiguous_out, + logsumexp, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + determinisitic, + philox_seed, + philox_offset); + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); + } +#endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build."); + return std::make_tuple(Tensor(), Tensor(), Tensor()); +} + +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + c10::optional scale) { + + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "cuDNN Attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } + } + + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(dq, dk, dv); +} + +std::tuple +_efficient_attention_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& kernel_bias, // additive attention bias + const at::Tensor& out, + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const std::optional& cu_seqlens_q_dummy, + // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the + // position of the first key token for batch $b + const std::optional& cu_seqlens_k_dummy, + // (Mode 1MHK only) Maximum sequence length across batches + int64_t max_seqlen_q, + // (Mode 1MHK only) Maximum sequence length across batches + int64_t max_seqlen_k, + const at::Tensor& logsumexp, + double dropout_p, // dropout probability + const at::Tensor& philox_seed, // seed using for generating random numbers for dropout + const at::Tensor& philox_offset, // offset into random number sequence + int64_t custom_mask_type, + const bool bias_requires_grad, + const std::optional scale, + std::optional num_splits_key, + const std::optional window_size, + const bool shared_storage_dqdkdv) { + #if defined(USE_MEM_EFF_ATTENTION) + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); + } + // This path is used when we directly call _efficient_attention_forward + // from python. + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional bias, cu_seqlens_q, cu_seqlens_k; + bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias; + cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy; + cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy; + + // ndim + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + TORCH_CHECK(query.dim() == 4); + + // batch size + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // seqlen + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + TORCH_CHECK(query.size(2) == grad_out_.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + TORCH_CHECK(value.size(3) == grad_out_.size(3)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); + TORCH_CHECK( + !(cu_seqlens_q.has_value() && bias.has_value()), + "cu seqlen + bias not supported"); + if (cu_seqlens_q.has_value()) { + TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); + TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); + TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); + TORCH_CHECK(max_seqlen_q > 0, "max_seqlen_q required with `cu_seqlens_q`"); + TORCH_CHECK(max_seqlen_k > 0, "max_seqlen_k required with `cu_seqlens_k`"); + TORCH_CHECK( + max_seqlen_k <= key.size(1), "Invalid max_seqlen_k:", max_seqlen_k); + TORCH_CHECK( + max_seqlen_q <= query.size(1), "Invalid max_seqlen_q:", max_seqlen_q); + } else { + max_seqlen_q = query.size(1); + max_seqlen_k = key.size(1); + } + + at::hip::HIPGuardMasqueradingAsCUDA device_guard(query.device()); + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t nH = query.size(2); + int64_t K = query.size(3); + int64_t Kv = value.size(3); + + at::Tensor grad_q, grad_k, grad_v, grad_bias; + if (shared_storage_dqdkdv) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + TORCH_CHECK( + query.size(1) == key.size(1), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same sequence length: got ", query.size(1), + " query tokens and ", key.size(1), " key/value tokens" + ); + TORCH_CHECK( + query.size(3) == key.size(3), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same embed dim: got ", query.size(3), + " for Q, and ", key.size(3), " for K" + ); + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty(query.sizes(), query.options()); + grad_k = at::empty(key.sizes(), key.options()); + grad_v = at::empty(value.sizes(), value.options()); + } + + if (bias_requires_grad) { + // force alignment for the last dim + std::vector sz = bias->sizes().vec(); + int64_t lastDim = sz[sz.size() - 1]; + int64_t alignTo = 16; + sz[sz.size() - 1] = alignTo * ((lastDim + alignTo - 1) / alignTo); + grad_bias = at::empty(sz, bias->options()) + .slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim); + } + + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + + // See Note [Seed and Offset Device] + at::PhiloxCudaState rng_engine_inputs; + if (use_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) { + rng_engine_inputs = at::PhiloxCudaState( + *philox_seed.data_ptr(), + *philox_offset.data_ptr()); + } else { // dropout + capture + rng_engine_inputs = at::PhiloxCudaState( + philox_seed.data_ptr(), + philox_offset.data_ptr(), + 0); + } + } + +#ifdef USE_ROCM + // ROCM Implementation + TORCH_CHECK(!num_splits_key.has_value(), + "ROCM does not support num_split_keys in _efficient_attention_forward"); + TORCH_CHECK(!window_size.has_value(), + "ROCM does not support window_size in _efficient_attention_forward"); + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now"); + } + at::Tensor q_t = query.permute({0,2,1,3}); + at::Tensor k_t = key.permute({0,2,1,3}); + at::Tensor v_t = value.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = grad_q.permute({0,2,1,3}); + at::Tensor dk_t = grad_k.permute({0,2,1,3}); + at::Tensor dv_t = grad_v.permute({0,2,1,3}); + at::Tensor dout_t = grad_out.permute({0,2,1,3}); + at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + + hipError_t err; + using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + rng_engine_inputs.seed_.val, + rng_engine_inputs.offset_.val, + is_causal, + stream); +#else + at::Tensor workspace; + hipDeviceProp_t* p = at::cuda::getDeviceProperties(query.device().index()); + const int computeCapability = p->major * 10 + p->minor; + + bool kernel_launched = false; + const auto maxK = ::max(query.size(3), value.size(3)); + const auto maxShmem = p->sharedMemPerBlockOptin; + + auto launchKernel = [&](auto _k, auto kernel_fn) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + if (kernel_launched) { + return; + } + // Check if this kernel is compatible + if (Kernel::kMaxK < maxK) { + return; + } + // Dropout must be supported if we need it + if (use_dropout && !Kernel::kApplyDropout) { + return; + } + if (Kernel::kKeysQueriesAlignedToBlockSize && + (cu_seqlens_q.has_value() || M % Kernel::kBlockSizeI || + N % Kernel::kBlockSizeJ)) { + return; + } + // Alignment + if ((query.stride(2) % Kernel::kMinimumAlignment) || + (key.stride(2) % Kernel::kMinimumAlignment) || + (value.stride(2) % Kernel::kMinimumAlignment)) { + return; + } + // Uses too much shmem + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + if (smem_bytes > maxShmem) { + return; + } + + kernel_launched = true; + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)) + .sum(-1) + .transpose(-2, -1) + .contiguous(); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == nH); + TORCH_INTERNAL_ASSERT(delta.size(2) == M); + + typename Kernel::Params p; + p.query_ptr = (const scalar_t*)query.const_data_ptr(); + p.key_ptr = (const scalar_t*)key.const_data_ptr(); + p.value_ptr = (const scalar_t*)value.const_data_ptr(); + p.logsumexp_ptr = (typename Kernel::lse_scalar_t const *)logsumexp.const_data_ptr(); + p.output_ptr = (const scalar_t*)out.const_data_ptr(); + p.grad_output_ptr = (const scalar_t*)grad_out.const_data_ptr(); + p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + p.delta_ptr = (float*)delta.data_ptr(); + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = max_seqlen_q; + p.num_keys = max_seqlen_k; + p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; + p.num_heads = nH; + p.custom_mask_type = custom_mask_type; + p.scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + if (cu_seqlens_q.has_value()) { + p.cu_seqlens_q_ptr = (const int32_t*)cu_seqlens_q->const_data_ptr(); + p.cu_seqlens_k_ptr = (const int32_t*)cu_seqlens_k->const_data_ptr(); + } + if (window_size.has_value()) { + p.window_size = *window_size; + } + + ASSIGN_CHECK_OVERFLOW(p.lse_strideB, logsumexp.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.lse_strideH, logsumexp.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); + p.gQKV_strideM_multiplier = shared_storage_dqdkdv ? 3 : 1; + TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); + TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); + TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.delta_strideB, delta.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.delta_strideH, delta.stride(1)); + + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + TORCH_CHECK( + bias->scalar_type() == CutlassToAtenDtype::atScalarType(), + "invalid dtype for bias - should match query's dtype"); + + p.bias_ptr = (scalar_t*)bias->data_ptr(); + + TORCH_CHECK(bias->dim() == 4, "Bias expected in BMHK format"); + TORCH_CHECK( + bias->size(0) == query.size(0), + "attn_bias: wrong shape (batch dimension)"); + TORCH_CHECK( + bias->size(1) == query.size(2), + "attn_bias: wrong shape (head dimension)"); + TORCH_CHECK( + bias->size(2) == query.size(1), + "attn_bias: wrong shape (seqlenQ dimension)"); + TORCH_CHECK( + bias->size(3) == key.size(1), + "attn_bias: wrong shape (seqlenKV dimension)"); + TORCH_CHECK( + bias->stride(3) == 1, + "attn_bias: wrong alignment (last dimension must be contiguous)"); + ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias->stride(0)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(1)); + ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(2)); + + if (bias_requires_grad) { + p.grad_bias_ptr = (scalar_t*)grad_bias.data_ptr(); + + ASSIGN_CHECK_OVERFLOW(p.gB_strideB, grad_bias.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias.stride(2)); + } + } + + if (use_dropout) { + p.rng_engine_inputs = rng_engine_inputs; + p.dropout_prob = dropout_p; + } + + // Heuristic for finding optimal number of splits + auto parallelism_without_split_key = + p.getBlocksGrid().x * p.getBlocksGrid().y * p.getBlocksGrid().z; + p.num_splits_key = cutlass::ceil_div(p.num_keys, Kernel::kBlockSizeJ); + if (num_splits_key.has_value()) { + p.num_splits_key = + std::min(p.num_splits_key, num_splits_key.value()); + } else { + // Keys splitting heuristic + + // If we already have enough parallelism, split-keys can help + // better use L2 cache. + // This is negligible when the seqlen is too small tho + if (parallelism_without_split_key >= 256 && + p.num_keys <= 2 * Kernel::kBlockSizeJ) { + p.num_splits_key = 1; + } + // Increasing `split_keys` leads to using more gmem for temporary storage + // when we need a staging area for gK/gV. let's avoid that + if (Kernel::kNeedsAccumGradK || Kernel::kNeedsAccumGradV) { + p.num_splits_key = ::min( + int(p.num_splits_key), 200 / (p.num_batches * p.num_heads)); + } + } + if (!Kernel::kEnableSplitKeys || p.num_splits_key < 1) { + p.num_splits_key = 1; + } + + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "Memory Efficient attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } else { + TORCH_CHECK( + num_splits_key.value_or(1) <= 1, + "Using `num_splits_key > 1` makes the algorithm non-deterministic, and pytorch's deterministic mode is enabled"); + p.num_splits_key = 1; + } + } + int64_t size_bytes = p.workspace_size(); + if (size_bytes) { + workspace = + at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte)); + p.workspace = (float*)workspace.data_ptr(); + if (p.should_zero_workspace()) { + workspace.zero_(); + } + } + + // Handle the edge-cases where some tensors are empty + if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 || + p.num_heads == 0) { + grad_k.zero_(); + grad_v.zero_(); + grad_q.zero_(); + return; + } + Kernel::check_supported(p); + + if (smem_bytes > 0xc000) { + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability + auto err = hipFuncSetAttribute( + kernel_fn, hipFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + TORCH_CHECK( + err != hipErrorInvalidValue, + "This GPU does not have enough shared-memory (kernel requires ", + smem_bytes / 1024, + " kb)"); + AT_CUDA_CHECK(err); + } + + // second syntax resulted in the error below on windows + // error C3495: 'kernel_fn': a simple capture must be a variable + // with automatic storage duration declared + // in the reaching scope of the lambda +#ifdef _WIN32 + hipFuncAttributes attr; + AT_CUDA_CHECK(hipFuncGetAttributes(&attr, kernel_fn)); + TORCH_INTERNAL_ASSERT( + attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, + "Something went wrong in the build process"); +#else + auto checkBinaryArchMatches = [&]() { + hipFuncAttributes attr; + AT_CUDA_CHECK(hipFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); +#endif + + hipLaunchKernelGGL(( kernel_fn), dim3(p.getBlocksGrid()), dim3(p.getThreadsGrid()), smem_bytes, stream, p); + }; + + DISPATCH_TYPES(query, ([&]() { + dispatch_cutlassB(launchKernel, computeCapability); + })); + TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!"); + AT_CUDA_CHECK(hipGetLastError()); +#endif // USE_ROCM + return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias)); + #endif // defined(USE_MEM_EFF_ATTENTION) + TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); +} + +std::tuple _scaled_dot_product_flash_attention_backward_cuda( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale){ + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + + Tensor q_t = query.transpose(1, 2); + Tensor k_t = key.transpose(1, 2); + Tensor v_t = value.transpose(1, 2); + + Tensor grad_out_t = grad_out_.transpose(1,2); + Tensor out_t = out.transpose(1,2); + + auto [grad_q, grad_k, grad_v] = at::_flash_attention_backward( + grad_out_t, + q_t, + k_t, + v_t, + out_t, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale); + + grad_q = grad_q.transpose(1,2); + grad_k = grad_k.transpose(1,2); + grad_v = grad_v.transpose(1,2); + + return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v)); +} + + +std::tuple _scaled_dot_product_efficient_attention_backward_cuda( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + double dropout_p, + std::array grad_input_mask, + bool causal, + std::optional scale) { + + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); + } + auto grad_out = grad_out_.transpose(1, 2); + auto out_t = out.transpose(1, 2); + auto q_t = query.transpose(1, 2); + auto k_t = key.transpose(1, 2); + auto v_t = value.transpose(1, 2); + + Tensor grad_q, grad_k, grad_v, grad_bias; + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional kernel_bias; + if (attn_bias.defined()) { + kernel_bias = attn_bias; + } + // Will add with signauter changes for dropout and bias + // We are only handling Dense inputs, but this should be passed + // from forward to backward + int64_t max_seqlen_q = q_t.size(1); + int64_t max_seqlen_k = k_t.size(1); + + sdp::CustomMaskType custom_mask_type = causal + ? sdp::CustomMaskType::CausalFromTopLeft + : sdp::CustomMaskType::NoCustomMask; + std::tie(grad_q, grad_k, grad_v, grad_bias) = + at::_efficient_attention_backward( + grad_out, + q_t, + k_t, + v_t, + kernel_bias, + out_t, + c10::nullopt, + c10::nullopt, + max_seqlen_q, + max_seqlen_k, + logsumexp, + dropout_p, + philox_seed, + philox_offset, + static_cast(custom_mask_type), + grad_input_mask[3], + scale, + c10::nullopt); // num_split_keys + return std::make_tuple( + grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000..605ae9cfe68a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,91 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include + +#include +#include + +namespace pytorch_flash { + +TORCH_API +std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + + +std::tuple +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip deleted file mode 100644 index 7af480a7ae495..0000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ /dev/null @@ -1,486 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Advanced Micro Devices, Inc. - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ -#include -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS - -#include -#include - -#include - -#ifdef USE_FLASH_ATTENTION -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#include -#include -#include -#include -#include -#include -#endif - -#include -#include - -#include -#include - -// AOTriton headers -#include -#include - -namespace pytorch_flash { - -namespace { - -void check_gpu_arch(hipStream_t stream) { - auto ret = aotriton::v2::flash::check_gpu(stream); - if (hipSuccess != ret) { - TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") - } -} - -} - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -std::tuple -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - - // FIXME: ROCM probably does not need this - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - int seqlen_q = sizes[1]; - int num_heads = sizes[2]; - const int head_size_og = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - if (is_causal) { window_size_right = 0; } - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - - at::Tensor q_padded, k_padded, v_padded; - q_padded = q; - k_padded = k; - v_padded = v; - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } - } else { - out = at::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; - - if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = batch_size * num_heads * 32; - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong)); - offset_t = at::empty({}, at::dtype(at::kLong)); - } - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); - } - } - - // Transpose tensors to meet AOTriton's Flash API - at::Tensor q_t = q_padded.permute({0,2,1,3}); - at::Tensor k_t = k_padded.permute({0,2,1,3}); - at::Tensor v_t = v_padded.permute({0,2,1,3}); - at::Tensor output_t = out.permute({0,2,1,3}); - - at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse - - at::Tensor softmax_fa_t; - if (return_softmax) { - softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); - } else { - softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); - } - - hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); - - return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; -} - -std::tuple -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, - const int max_seqlen_k, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - - TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); - - at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); - at::Tensor p = at::empty({}, at::dtype(at::kFloat)); - at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor out = at::empty({}, at::dtype(at::kFloat)); - - return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; -} - -std::tuple -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - bool is_dropout = p_dropout > 0.0; - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; - const int head_size_og = dout.size(3); - const int head_size = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - - if (is_causal){ - TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); - } - - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); - - at::Tensor dq, dk, dv; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); - } else { - dq = at::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dk = at::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dv = at::empty_like(k); - } - - // const at::Tensor& dout_padded = dout; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } - } - - at::Tensor q_t = q.permute({0,2,1,3}); - at::Tensor k_t = k.permute({0,2,1,3}); - at::Tensor v_t = v.permute({0,2,1,3}); - at::Tensor out_t = out.permute({0,2,1,3}); - at::Tensor dq_t = dq.permute({0,2,1,3}); - at::Tensor dk_t = dk.permute({0,2,1,3}); - at::Tensor dv_t = dv.permute({0,2,1,3}); - at::Tensor dout_t = dout.permute({0,2,1,3}); - - at::Tensor softmax_lse_cont = softmax_lse.contiguous(); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - - int d_head = head_size_og; - hipError_t err; // TODO: Error handling - { - using aotriton::v2::flash::attn_bwd; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, - is_causal, - stream); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } - return { dq, dk, dv, softmax_d }; -#undef CALL_BWD_DROPOUT -#undef CALL_BWD -} - -std::tuple -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp - c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); - - at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); - - return { q, k, v, softmax_d }; -} -} // namespace pytorch_fmha - -#endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp new file mode 100644 index 0000000000000..9e98fb2ee58d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp @@ -0,0 +1,57 @@ +// !!! This is a file automatically generated by hipify!!! +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +//#include +//#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +//#include +#include +#include +#include +#include +#include + + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace flash { +// Copy from PyTorch +// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 +static std::tuple unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/mha_bwd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/mha_bwd.hip new file mode 100644 index 0000000000000..78421cbef1b03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/mha_bwd.hip @@ -0,0 +1,421 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common_hip.hpp" + +#include "fmha_bwd.hpp" +#include "mask.hpp" + +namespace pytorch_flash { + +fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t nhead_stride_q = q.stride(2); + + // k: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(2); + + // v: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(2); + + // o: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_o = out.stride(0); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(2); + + // lse: (batch_size, nheads, seqlen_q) + ck_tile::index_t batch_stride_lse = softmax_lse.stride(0); + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); + + // do: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_do = dout.stride(0); + ck_tile::index_t stride_do = dout.stride(1); + ck_tile::index_t nhead_stride_do = dout.stride(2); + + // d: (batch_size, nheads, seqlen_q) + // CK assume d share the same stride with lse + + // dq: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = dq.stride(0); + ck_tile::index_t stride_dq = dq.stride(1); + ck_tile::index_t nhead_stride_dq = dq.stride(2); + + // dk_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = dk.stride(0); + ck_tile::index_t stride_dk = dk.stride(1); + ck_tile::index_t nhead_stride_dk = dk.stride(2); + + // dv_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = dv.stride(0); + ck_tile::index_t stride_dv = dv.stride(1); + ck_tile::index_t nhead_stride_dv = dv.stride(2); + + // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, // seqlen_k_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + {drop_seed, drop_offset}}; +} + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentHIPStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); // unpadded hdim + const int head_size_8x = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); + dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + + if (rng_state.has_value()) { + uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); + drop_seed = d[0]; + drop_offset = d[1]; + } else if(is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + if (seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_bwd_args( + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/mha_fwd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/mha_fwd.hip new file mode 100644 index 0000000000000..b0b9757c45cca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/mha_fwd.hip @@ -0,0 +1,379 @@ +// !!! This is a file automatically generated by hipify!!! +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common_hip.hpp" + +#include "fmha_fwd.hpp" +#include "mask.hpp" + +namespace pytorch_flash { + +fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (batch_size, seqlen_q, nheads, d) + // k: (batch_size, seqlen_k, nheads_k, d) + // v: (batch_size, seqlen_k, nheads_k, d) + // o: (batch_size, seqlen_q, nheads, d) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, seqlen_q) + // randval: (batch_size, nheads, seqlen_q, seqlen_k) + + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(2); + ck_tile::index_t nhead_stride_k = k.stride(2); + ck_tile::index_t nhead_stride_v = v.stride(2); + ck_tile::index_t nhead_stride_o = out.stride(2); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t batch_stride_o = out.stride(0); + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + {drop_seed, drop_offset}}; +} + +std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + + mask_info mask; + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + at::Tensor temp_q = q; + if (seqlenq_ngroups_swapped) { + temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = at::pad(temp_q, {0, 8 - head_size_og % 8}); + k_padded = at::pad(k, {0, 8 - head_size_og % 8}); + v_padded = at::pad(v, {0, 8 - head_size_og % 8}); + } + else { + q_padded = temp_q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + } + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } + else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte)); + } + + // uint64_t drop_seed = 1, drop_offset = 0; + // int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + // auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + // auto rng_state = at::empty({2}, options.dtype(at::kLong)); + + // if (p_dropout > 0.0) { + // auto gen = at::get_generator_or_default( + // gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // // See Note [Acquire lock when using random generators] + // std::lock_guard lock(gen->mutex_); + // auto philox_args = gen->philox_cuda_state(counter_offset); + // std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + // } + + // rng_state[0] = *(reinterpret_cast(&drop_seed)); + // rng_state[1] = *(reinterpret_cast(&drop_offset)); + + // We want to checkpoint and save the RNG state for backward if dropout + // We get the default generator and return the seed and offset which will + // be used in the backward function + uint64_t drop_seed = 1, drop_offset = 0; + at::Tensor seed_t, offset_t; + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + //auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); + if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); + drop_seed = seed; + drop_offset = offset; + } else { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + //params.seed = seed_t.data_ptr(); + //params.extragraph_offset = offset_t.data_ptr(); + } + // params.philox_args = philox_state; + } else { + if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } else { + seed_t = at::empty({}, at::dtype(at::kLong)); + offset_t = at::empty({}, at::dtype(at::kLong)); + } + + } + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_fwd(traits, args, stream_config); + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (head_size_og % 8 != 0) { + out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_bwd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_bwd.hip new file mode 100644 index 0000000000000..c729bb4e6691c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_bwd.hip @@ -0,0 +1,449 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common_hip.hpp" + +#include "fmha_bwd.hpp" +#include "mask.hpp" + +namespace pytorch_flash { + + +fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int max_seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + // q: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t nhead_stride_q = q.stride(1); + + // k: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t nhead_stride_k = k.stride(1); + + // v: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t nhead_stride_v = v.stride(1); + + // o: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_o = 0; + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t nhead_stride_o = out.stride(1); + + // lse: (nheads, total_q) + ck_tile::index_t batch_stride_lse = 0; + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0); + + // do: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_do = 0; + ck_tile::index_t stride_do = dout.stride(0); + ck_tile::index_t nhead_stride_do = dout.stride(1); + + // d: (batch_size, nheads, max_seqlen_q) + // CK assume d share the same stride with lse + + // dq: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = 0; + ck_tile::index_t stride_dq = dq.stride(0); + ck_tile::index_t nhead_stride_dq = dq.stride(1); + + + // dk_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = 0; + ck_tile::index_t stride_dk = dk.stride(0); + ck_tile::index_t nhead_stride_dk = dk.stride(1); + + // dv_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = 0; + ck_tile::index_t stride_dv = dv.stride(0); + ck_tile::index_t nhead_stride_dv = dv.stride(1); + + // dq_acc: (split, total_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = 0; + ck_tile::index_t stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_k_ptr + total_q, + total_k, + b, + max_seqlen_q, // max_seqlen_q + max_seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + {drop_seed, drop_offset}}; +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size_8x = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, total_q, num_heads, head_size_8x); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, total_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); + dq_accum = at::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if(zero_tensors) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + + if (rng_state.has_value()) { + uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); + drop_seed = d[0]; + drop_offset = d[1]; + } else if(is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + if (max_seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_varlen_bwd_args( + mask, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_fwd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_fwd.hip new file mode 100644 index 0000000000000..e42da770e2c90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/mha_varlen_fwd.hip @@ -0,0 +1,368 @@ +// !!! This is a file automatically generated by hipify!!! +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common_hip.hpp" + +#include "fmha_fwd.hpp" +#include "mask.hpp" + +namespace pytorch_flash { + +fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (total_q, nheads, d) + // k: (total_k, nheads_k, d) + // v: (total_k, nheads_k, d) + // o: (total_q, nheads, d) + + // alibi_slopes:(batch, nheads) or (nhead) + // lse: (batch, nheads, max_seqlen_q) + // randval: (nheads, total_q, max_seqlen_k) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(1); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t batch_stride_o = 0; + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_kpads + total_q, + total_k, + b, + max_seqlen_q, + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + {drop_seed, drop_offset}}; +} + +std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional & /*seqused_k*/, + // c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + // TODO - Support paged_KV + // const bool paged_KV = block_table_.has_value(); + // TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + const int max_num_blocks_per_seq = 0; + const int num_blocks = 0; + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + + // TODO + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + + const int total_q = q.size(0); + const int total_k = k.size(0); + + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = at::pad(q, {0, 8 - head_size_og % 8}); + k_padded = at::pad(k, {0, 8 - head_size_og % 8}); + v_padded = at::pad(v, {0, 8 - head_size_og % 8}); + } + else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } + else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(at::kByte)); + } + + if (zero_tensors) + { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_dropout_randval) {p.zero_();} + } + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + auto rng_state = at::empty({2}, options.dtype(at::kLong)); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + rng_state[0] = *(reinterpret_cast(&drop_seed)); + rng_state[1] = *(reinterpret_cast(&drop_offset)); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_fwd(traits, args, stream_config); + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + //return kludge -- TODO:: REMOVE + at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); + at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); + + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} diff --git a/aten/src/ATen/native/transformers/hip/mem_eff_attention/debug_utils.h b/aten/src/ATen/native/transformers/hip/mem_eff_attention/debug_utils.h new file mode 100644 index 0000000000000..9ac1137ca99b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/mem_eff_attention/debug_utils.h @@ -0,0 +1,212 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (int _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (int _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/aten/src/ATen/native/transformers/hip/mem_eff_attention/gemm_kernel_utils.h b/aten/src/ATen/native/transformers/hip/mem_eff_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000..5ba671288fc26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/mem_eff_attention/gemm_kernel_utils.h @@ -0,0 +1,210 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + TORCH_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/aten/src/ATen/native/transformers/hip/mem_eff_attention/pytorch_utils.h b/aten/src/ATen/native/transformers/hip/mem_eff_attention/pytorch_utils.h new file mode 100644 index 0000000000000..a18c1b9a124e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/mem_eff_attention/pytorch_utils.h @@ -0,0 +1,45 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +#include +#include + + +template +struct CutlassToAtenDtype; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::half_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Half; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::bfloat16_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::BFloat16; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = float; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Float; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/sdp_utils.cpp b/aten/src/ATen/native/transformers/hip/sdp_utils.cpp new file mode 100644 index 0000000000000..f4487ab28e091 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/sdp_utils.cpp @@ -0,0 +1,761 @@ +// !!! This is a file automatically generated by hipify!!! +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if AT_CUDNN_ENABLED() +#include +#endif + +#include +#include + +// #if USE_ROCM +// #include +// #endif + +/** +* Note [SDPA Runtime Dispatch] +* SDPA relies on a runtime dispatch mechanism to select the appropriate +* kernel. This file contains exposes this through the `select_sdp_backend` +* The basic structure of this function is to call `priority_order` to get a +* list of backends to try, and then iterate through them until one succeeds. +* Each backend defines a use_ function that returns true if the +* backend can be run with the given SDP parameters. The use_ function +* will iterate over a list of "filters" that check for specific properties of +* the SDP parameters. If all filters pass, the backend can be used and use_ +* returns true. If any filter fails, then use_ returns false. +* +* In order to aid in debugging, each filter takes sdp_params and a debug flag. +* If the debug flag is set, the filter will print a warning message if it fails. +* The behavior of select_sdp_backend is to return the first backend that +* succeeds. If no backend is viable then it will run each use_ function +* with debug=true and return SDPBackend::error. +*/ + +namespace sdp { +namespace { + +// TODO(eqy): more benchmarking to determine whether this should include sm86/89 +// Needs to be kept in-sync with test_fused_chocie in test_transformers.py +bool check_prefer_cudnn_attention() { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000 + auto dprops = at::cuda::getCurrentDeviceProperties(); + return dprops->major >= 9; +#else + return false; +#endif +} + +// flash_attention V2 is universally faster than efficient_attention and Math +std::array priority_order(sdp_params const& params) { + constexpr std::array default_order{ + SDPBackend::flash_attention, + SDPBackend::cudnn_attention, + SDPBackend::efficient_attention, + SDPBackend::math}; + constexpr std::array cudnn_order{ + SDPBackend::cudnn_attention, + SDPBackend::flash_attention, + SDPBackend::efficient_attention, + SDPBackend::math}; + static const bool prefer_cudnn = check_prefer_cudnn_attention(); + return prefer_cudnn ? cudnn_order : default_order; +} + +bool use_tensor_cores(sdp_params const& params, hipDeviceProp_t* dprops, bool is_half) { + if (dprops->major >= 8) { + return true; + } + if (dprops->major >= 7) { + return is_half; + } + return false; +} +int64_t minimum_gemm_alignment(sdp_params const& params) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + bool use_tc = use_tensor_cores(params, dprops, is_half); + int64_t matmul_alignment_mn = 1; + if (dprops->major >= 8) { + matmul_alignment_mn = 4; + } + int64_t bits_per_scalar = is_half ? 16 : 32; + if (use_tc) { + matmul_alignment_mn = std::max(matmul_alignment_mn, 128 / bits_per_scalar); + } + return matmul_alignment_mn; +} + +bool check_head_dim_size_flash(sdp_params const& params, bool debug) { + // All head_dim sizes must be equal and less than 256 + const auto max_size = c10::SymInt(256); + const auto query_size_last = params.query.sym_size(-1); + const auto key_size_last = params.key.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); + bool same_head_dim_size = + query_size_last == key_size_last && query_size_last == value_size_last; + if (!(same_head_dim_size && (query_size_last <= max_size))) { + if (debug) { + TORCH_WARN( + "Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.", + " Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + key_size_last, + ", Value.size(-1): ", + value_size_last, + " instead."); + } + return false; + } + return true; +} + +bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { + const auto max_size = c10::SymInt(256); + const auto query_size_last = params.query.sym_size(-1); + const auto key_size_last = params.key.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); + bool same_head_dim_size = + query_size_last == key_size_last && query_size_last == value_size_last; + if (!(same_head_dim_size && (query_size_last % 8 == 0) && + (query_size_last <= max_size))) { + if (debug) { + TORCH_WARN( + "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", + " Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + return true; +} + +bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) { + const auto query_size_last = params.query.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); + const int64_t alignment = minimum_gemm_alignment(params); + if (!(query_size_last == params.key.sym_size(-1) && + query_size_last % alignment == 0 && query_size_last > 0 && + value_size_last % alignment == 0 && value_size_last > 0)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention requires last dimension of inputs to be divisible by ", + alignment, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + return true; +} + +template +struct SMVersion { + static constexpr int major = Major; + static constexpr int minor = Minor; + constexpr SMVersion() = default; +}; + +/** + * Checks if the current CUDA device architecture is inclusively within the specified range. + * + * @param lower_bound The lower bound of the CUDA device architecture range. + * @param upper_bound The upper bound of the CUDA device architecture range. + * @param params The parameters for the current operation. + * @return True if the current CUDA device architecture is within the specified range, false otherwise. + */ +template +bool check_sm_version(hipDeviceProp_t * dprops) { + bool is_gte_lower_bound = dprops->major > lower_bound::major || + (dprops->major == lower_bound::major && + dprops->minor >= lower_bound::minor); + bool is_lte_upper_bound = dprops->major < upper_bound::major || + (dprops->major == upper_bound::major && + dprops->minor <= upper_bound::minor); + return is_gte_lower_bound && is_lte_upper_bound; +} + +bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) { + // Check that the gpu is capable of running flash attention + using sm80 = SMVersion<8, 0>; + using sm90 = SMVersion<9, 0>; +#if USE_ROCM + // auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + // if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (debug) { + // TORCH_WARN( + // "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + // } + // return false; + // } +#else + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (!check_sm_version(dprops)) { + if (debug) { + TORCH_WARN( + "Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ", + dprops->major, + ".", + dprops->minor, + " gpu."); + } + return false; + } +#endif + return true; +} + +bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) { + // Mem Efficient attention supports hardware in the range [sm_50, sm_90] + using sm50 = SMVersion<5, 0>; + using sm90 = SMVersion<9, 0>; +#if USE_ROCM + // auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + // if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // if (debug) { + // TORCH_WARN( + // "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + // } + // return false; + // } +#else + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (!check_sm_version(dprops)) { + if (debug) { + TORCH_WARN( + "Mem Efficient Attention only supports gpu architectures in the range [sm50, sm90]. Attempting to run on a sm ", + dprops->major, + ".", + dprops->minor, + " gpu."); + } + return false; + } +#endif + return true; +} + +bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89( + sdp_params const& params, + bool debug) { + // Flash Attention will raise an error in the backward pass if the head_dim + // size is greater than 192 And the device is between in the range [sm86, sm89] + using sm86 = SMVersion<8, 6>; + using sm89 = SMVersion<8, 9>; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm86_or_sm89 = check_sm_version(dprops); + bool is_head_dim_gt192 = params.query.sym_size(-1) > 192; + bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224; + bool is_dropout = params.dropout > 0.0; + // head_dim size in (192, 224] is not supported on sm86 and sm89 + bool cond1 = is_head_dim_gt192 && is_head_dim_lte224; + // head_dim size > 224 and is_dropout is not supported on sm86 and sm89 + bool cond2 = params.query.sym_size(-1) > 224 && is_dropout; + if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) { + if (debug) { + TORCH_WARN( + "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or " + "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].", + "Attempting to run with dropout set to: ", params.dropout, + "and head_dim: ", + params.query.sym_size(-1), " on a sm ", dprops->major, ".", + dprops->minor, " gpu."); + } + return false; + } + return true; +} + +bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) { + // FlashAttention 2 updated the default mask meaning for causal in this PR: + // 9e5e8bc91e it is now aligned to lower_right which would be a BC break + // for non-square masks. We will not support non-square masks for causal w/ FAV2 + if (params.is_causal && + !params.query.is_nested() && !params.key.is_nested() && + params.query.sym_size(-2) != params.key.sym_size(-2)) { + if (debug) { + TORCH_WARN( + "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k. ", + "Got seqlen_q: ", params.query.sym_size(-2), " seqlen_k: ", + params.key.sym_size(-2), ". If you would like to use causal attention with non-square masks, please see CausalAttnMask."); + } + return false; + } + return true; +} + +bool check_all_tensors_on_device(sdp_params const& params, bool debug) { + // Check that all tensors are on the GPU device + // This should be handled by the stub dispatch, but whe call can_use_*_attention + // directly from python we need to ensure that the tensors are on cuda + if (params.query.device().type() != at::DeviceType::CUDA) { + if (debug) { + TORCH_WARN( + "All tensors need to be on cuda device. Got query on device: ", + params.query.device(), + ", key on device: ", + params.key.device(), + ", value on device: ", + params.value.device()); + } + return false; + } + return true; +} + +bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { + const auto s_q = params.query.sym_size(2); + const auto s_k = params.key.sym_size(2); + const auto head_dim = params.query.sym_size(3); + long cudnn_version = at::detail::getCUDAHooks().versionCuDNN(); + if (cudnn_version >= 90000) { + if (head_dim % 8 != 0 || head_dim > 256) { + if (debug) { + TORCH_WARN("head_dim should be a multiple of 8 and no more than 256"); + } + return false; + } + } else { + if (head_dim % 8 != 0 || head_dim > 128) { + if (debug) { + TORCH_WARN("head_dim should be a multiple of 8 and no more than 128"); + } + return false; + } + } + if (cudnn_version < 8903) { + if (debug) { + TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher"); + } + return false; + } + if (params.dropout != 0.0 && cudnn_version < 8906) { + if (debug) { + TORCH_WARN("Dropout reference is only supported on 8.9.6 onwards."); + } + return false; + } + if (cudnn_version < 90000) { + if (s_q < 64) { + if (debug) { + TORCH_WARN("s_q less than 64 is not supported before cudnn 9.0.0"); + } + return false; + } + if ((s_q % 64 != 0 || s_k % 64 != 0) && params.dropout != 0.0) { + if (debug) { + TORCH_WARN( + "s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0"); + } + return false; + } + } + if (s_k % 64 != 0 && cudnn_version < 8906) { + if (debug) { + TORCH_WARN("not-multiple-of-64 seq_kv is not supported below 8.9.6"); + } + return false; + } + return true; +} + +bool check_cudnn_layout(sdp_params const& params, bool debug) { + const int64_t h = params.query.size(1); + const int64_t s_q = params.query.size(2); + const int64_t d = params.query.size(3); + const int64_t s_k = params.key.size(2); + const int64_t s_v = params.value.size(2); + // corresponds to cuDNN's "packed QKV" layout + const bool packed_query_layout_ok = (params.query.stride(0) == s_q * 3 * h * d) && + (params.query.stride(1) == d) && + (params.query.stride(2) == 3 * h * d) && + (params.query.stride(3) == 1); + const bool packed_key_layout_ok = (params.key.stride(0) == s_k * 3 * h * d) && + (params.key.stride(1) == d) && + (params.key.stride(2) == 3 * h * d) && + (params.key.stride(3) == 1); + const bool packed_value_layout_ok = (params.value.stride(0) == s_v * 3 * h * d) && + (params.value.stride(1) == d) && + (params.value.stride(2) == 3 * h * d) && + (params.value.stride(3) == 1); + + const bool packed_layout_ok = packed_query_layout_ok && packed_key_layout_ok && packed_value_layout_ok; + + const bool query_layout_ok = (params.query.stride(0) == s_q * h * d) && + (params.query.stride(1) == d) && + (params.query.stride(2) == h * d) && + (params.query.stride(3) == 1); + const bool key_layout_ok = (params.key.stride(0) == s_k * h * d) && + (params.key.stride(1) == d) && + (params.key.stride(2) == h * d) && + (params.key.stride(3) == 1); + const bool value_layout_ok = (params.value.stride(0) == s_v * h * d) && + (params.value.stride(1) == d) && + (params.value.stride(2) == h * d) && + (params.value.stride(3) == 1); + + const bool layout_ok = query_layout_ok && key_layout_ok && value_layout_ok; + + if (!packed_value_layout_ok && !layout_ok) { + if (debug) { + if (!packed_layout_ok) { + if (!packed_query_layout_ok) { + TORCH_WARN("Query tensor was not in cuDNN-supported packed QKV layout", params.query.strides()); + } + if (!packed_key_layout_ok) { + TORCH_WARN("Key tensor was not in cuDNN-supported packed QKV layout", params.key.strides()); + } + if (!packed_value_layout_ok) { + TORCH_WARN("Value tensor was not in cuDNN-supported packed QKV layout", params.value.strides()); + } + } + if (!layout_ok) { + if (!query_layout_ok) { + TORCH_WARN("Query tensor was not in cuDNN-supported unpacked QKV layout", params.query.strides()); + } + if (!key_layout_ok) { + TORCH_WARN("Key tensor was not in cuDNN-supported unpacked QKV layout", params.key.strides()); + } + if (!value_layout_ok) { + TORCH_WARN("Value tensor was not in cuDNN-supported unpacked QKV layout", params.value.strides()); + } + } + } + return false; + } + return true; +} + +bool check_cudnn_hardware_support(sdp_params const& params, bool debug) { + using sm80 = SMVersion<8, 0>; + using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (!check_sm_version(dprops)) { + if (debug) { + TORCH_WARN( + "cuDNN MHA only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ", + dprops->major, + ".", + dprops->minor, + " gpu."); + } + return false; + } + return true; +} + +bool check_for_nested_inputs(sdp_params const& params, bool debug) { + // Check that the input is nested + if (has_for_nested_inputs(params)) { + if (debug) { + TORCH_WARN("CuDNN currently does not support nested inputs."); + } + return false; + } + return true; +} + +bool check_dtypes_low_precision(sdp_params const& params, bool debug) { + auto dprop = at::cuda::getCurrentDeviceProperties(); + if (dprop->major >= 8) { + constexpr auto sm80_dtypes = + array_of(at::kHalf, at::kBFloat16); + return check_tensor_dtype(params, sm80_dtypes, debug); + } else { + constexpr auto default_dtypes = array_of(at::kHalf); + return check_tensor_dtype(params, default_dtypes, debug); + } +} + +bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { + // We check the global context to see if user has explicitly turned of cudnn + // sdp kernels + if (!at::globalContext().userEnabledCuDNNSDP()) { + if (debug) { + TORCH_WARN("CuDNN attention has been runtime disabled."); + } + return false; + } + return true; +} + +bool check_cudnn_deterministic(const sdp_params& params, bool debug) { + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (!ctx.deterministicAlgorithmsWarnOnly()) { + if (debug) { + TORCH_WARN("cuDNN SDPA is not deterministic."); + } + return false; + } + } + return true; +} + +} // namespace + +bool can_use_cudnn_attention(const sdp_params& params, bool debug) { +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) + TORCH_WARN_ONCE(!debug, "Torch was not compiled with cuDNN attention."); + return false; +#endif + // Define gate functions that determine if a flash kernel can be ran + // Replace with std::to_array when we migrate to c++20 + constexpr auto general_constraints = + array_of( + check_runtime_disabled_cudnn, + check_for_nested_inputs, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense*/>, + check_all_tensors_on_device, + check_tensor_shapes, + check_cudnn_tensor_shapes, + check_cudnn_deterministic, + // check_is_causal, + check_dtypes_low_precision, + check_for_attn_mask_cudnn, + check_cudnn_hardware_support + ); + for (auto& constraint : general_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + return true; +} + +bool can_use_flash_attention(sdp_params const& params, bool debug) { +#ifndef USE_FLASH_ATTENTION + TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention."); + return false; +#endif + + // Define gate functions that determine if a flash kernel can be ran + // Replace with std::to_array when we migrate to c++20 + constexpr auto general_constraints = array_of( + check_runtime_disabled_flash, + check_all_tensors_on_device, + check_tensor_shapes, + check_for_attn_mask, + check_head_dim_size_flash, + check_flash_attention_hardware_support, + check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89, + check_flash_causal_non_square_seqlens, + check_dtypes_low_precision); + for (auto& constraint : general_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + + if (has_for_nested_inputs(params)) { + constexpr auto nested_constraints = array_of( + check_batch_size_nested, + check_head_dim_size_flash_nested, + check_for_seq_len_0_nested_tensor); + for (auto& constraint : nested_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + if (has_only_dense_inputs(params)) { + constexpr auto dense_constraints = array_of( + check_batch_size_and_num_heads_dense, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense); + for (auto& constraint : dense_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + return true; +} + +bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { +#ifndef USE_MEM_EFF_ATTENTION + TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention."); + return false; +#endif + // Constraints specific to mem efficient attention + constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes = + array_of(at::kHalf, at::kFloat, at::kBFloat16); + constexpr auto less_than_sm80_mem_efficient_dtypes = + array_of(at::kHalf, at::kFloat); +#ifdef USE_ROCM + // constexpr auto aotriton_mem_efficient_dtypes = + // array_of(at::kHalf, at::kFloat, at::kBFloat16); +#endif + + // Define gate functions that determine if a mem efficient kernel can be ran + constexpr auto general_constraints = array_of( + check_runtime_disabled_mem_efficient, + check_all_tensors_on_device, + check_mem_efficient_hardware_support, + check_tensor_shapes, + check_head_dim_size_mem_efficient); + for (auto& constraint : general_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + + if (has_for_nested_inputs(params)) { +#ifdef USE_ROCM + TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention."); + return false; +#endif + constexpr auto nested_constraints = array_of( + check_requires_grad_and_nested, + check_batch_size_nested, + check_for_seq_len_0_nested_tensor); + for (auto& constraint : nested_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + if (has_only_dense_inputs(params)) { + constexpr auto dense_constraints = array_of( + check_batch_size_and_num_heads_dense, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense); + for (auto& constraint : dense_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + +// #ifdef USE_ROCM +// return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); +// #else + auto dprop = at::cuda::getCurrentDeviceProperties(); + if (dprop->major >= 8) { + return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); + } +//#endif + return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); +} + +SDPBackend select_sdp_backend(sdp_params const& kernel_params) { + // This function defines the priority order of the different sdp backends + // 1. Flash Attention + // 2. Mem Efficient Attention + // 3. Math fallback + auto& ctx = at::globalContext(); + if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && + !ctx.userEnabledMemEfficientSDP() && !ctx.userEnabledCuDNNSDP()) { + return SDPBackend::error; + } + // Get ideal kernel ordering + const auto ordering = priority_order(kernel_params); + + // Because TORCHCHECK checks if condition is true we negate debug so that + // The statements will be printed when debug is true + bool print_debug = false; + for (auto& backend : ordering) { + switch (backend) { + case SDPBackend::cudnn_attention: + if (sdp::can_use_cudnn_attention(kernel_params, print_debug)) { + return SDPBackend::cudnn_attention; + } + break; + case SDPBackend::flash_attention: + if (sdp::can_use_flash_attention(kernel_params, print_debug)) { + return SDPBackend::flash_attention; + } + break; + case SDPBackend::efficient_attention: + if (sdp::can_use_mem_efficient_attention(kernel_params, print_debug)) { + return SDPBackend::efficient_attention; + } + break; + case SDPBackend::math: + if (ctx.userEnabledMathSDP()) { + return SDPBackend::math; + } + break; + default: + TORCH_CHECK(false, "Invalid backend"); + } + } + // If we have gotten to this point then two things have happened: + // 1. use_flash_attention or use_mem_efficient did not satisfy the + // constraints to be ran + // 2. The user has explicitly disabled the math kernel + // We then re-run the kernel checks with debug enabled to print out the + // reason why the kernel was not selected + + print_debug = true; + TORCH_WARN("Memory efficient kernel not used because:"); + sdp::can_use_mem_efficient_attention(kernel_params, print_debug); + TORCH_WARN("Flash attention kernel not used because:"); + sdp::can_use_flash_attention(kernel_params, print_debug); + TORCH_WARN("CuDNN attention kernel not used because:"); + sdp::can_use_cudnn_attention(kernel_params, print_debug); + TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") + return SDPBackend::error; +} + +bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug) { + // When this function is called we are assured that the nt is dim==4 + if (!params.query.is_nested()) { + return true; + } + + const auto nt_q_tensor_impl = + at::native::get_nested_tensor_impl(params.query); + const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes(); + auto* sizes_ptr = sizes.data_ptr(); + const int64_t n_tensors = params.query.size(0); + const int64_t size_tensor_stride = sizes.stride(0); + + // This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + for (const auto i : c10::irange(n_tensors)) { + if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) { + if (debug) { + TORCH_WARN( + "Packed projection for fused kernels does not support sequence_length <= 1"); + } + return false; + } + } + + return true; +} + +} // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/sdp_utils.h b/aten/src/ATen/native/transformers/hip/sdp_utils.h new file mode 100644 index 0000000000000..0156f861b1f8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/sdp_utils.h @@ -0,0 +1,17 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#include +#include +#include + +namespace sdp { + +bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug); +SDPBackend select_sdp_backend(sdp_params const& kernel_params); +C10_EXPORT bool can_use_flash_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_mem_efficient_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_cudnn_attention(sdp_params const& params, bool debug); + +} // namespace sdp diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3b6119087aa26..4120d1b4e66a3 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -899,9 +899,9 @@ if(USE_ROCM) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA}) hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) - if(USE_FLASH_ATTENTION) - target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) - endif() +# if(USE_FLASH_ATTENTION) +# target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) +# endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake # TODO: Not totally sure if this is live or not diff --git a/test/test_linalg.py b/test/test_linalg.py index 207290f5a6a8b..aadb501b4ff21 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8036,6 +8036,22 @@ def test_preferred_blas_library(self): self.assertEqual(out1, out2) self.assertEqual(out_ref, out2.cpu()) + @skipCUDAIfNotRocm + @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device") + @setBlasBackendsToDefaultFinally + def test_ck_blas_library(self): + m1 = torch.randint(2, 5, (7168, 8192), device='cuda', dtype=torch.float) + m2 = torch.randint(2, 5, (1280, 8192), device='cuda', dtype=torch.float) + + torch.backends.cuda.preferred_blas_library('ck') + ck_out = torch.nn.functional.linear(m1, m2) + + cpu_out = torch.nn.functional.linear(m1.cpu(), m2.cpu()) + + self.assertEqual(ck_out, cpu_out) + + + def test_permute_matmul(self): a = torch.ones([2, 5, 24, 24]) b = torch.ones([3, 2, 5, 24, 24]) diff --git a/test/test_transformers.py b/test/test_transformers.py index 9755e7b3601f2..7d5d5da619ae8 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2891,7 +2891,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048]) + @parametrize("seq_len_q", [4, 8, 64, 128, 143, 256, 512, 1024, 2048]) @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048]) @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) @parametrize("is_causal", [True, False]) @@ -2907,6 +2907,10 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: torch.cuda.empty_cache() # Prevent memory fragmentation + if TEST_WITH_ROCM and dropout_p != 0: + self.skipTest("CK does not support tensor format dropout masks") + if TEST_WITH_ROCM and head_dim > 128: + self.skipTest("CK does not support head dims over 128") scale = scale if scale is None else (1 / head_dim) n_heads = 4 diff --git a/third_party/composable_kernel b/third_party/composable_kernel new file mode 160000 index 0000000000000..a1c07e8d913cd --- /dev/null +++ b/third_party/composable_kernel @@ -0,0 +1 @@ +Subproject commit a1c07e8d913cd03011f4ea3d45033ab4e765e9f1 diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3ff54b32a425b..4b19e9d868f61 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1272,6 +1272,7 @@ def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ... class _BlasBackend: Cublas: _BlasBackend Cublaslt: _BlasBackend + Ck: _BlasBackend class ConvBackend(Enum): ... diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 00f511a544e6c..9c88247b601b9 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -213,6 +213,7 @@ def preferred_linalg_library( "cublas": torch._C._BlasBackend.Cublas, "cublaslt": torch._C._BlasBackend.Cublaslt, "hipblaslt": torch._C._BlasBackend.Cublaslt, # alias + "ck": torch._C._BlasBackend.Ck, } _BlasBackends_str = ", ".join(_BlasBackends.keys()) @@ -221,16 +222,17 @@ def preferred_blas_library( backend: Union[None, str, torch._C._BlasBackend] = None ) -> torch._C._BlasBackend: r""" - Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt. + Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. .. warning:: This flag is experimental and subject to change. When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. - For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance. + For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance. This flag (a :class:`str`) allows overriding which BLAS library to use. * If `"cublas"` is set then cuBLAS will be used wherever possible. * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. * When no input is given, this function returns the currently preferred library. * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt globally. diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index deb0e0943452b..2f7133d0f1ad6 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1974,7 +1974,8 @@ Call this whenever a new thread is created in order to propagate values from py::enum_(py_module, "_BlasBackend") .value("Cublas", at::BlasBackend::Cublas) - .value("Cublaslt", at::BlasBackend::Cublaslt); + .value("Cublaslt", at::BlasBackend::Cublaslt) + .value("Ck", at::BlasBackend::Ck); py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) { at::globalContext().setBlasPreferredBackend(b);