Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sdpa cmake #1606

Draft
wants to merge 49 commits into
base: rocm_gemm_ck
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b7a21fb
add ck blas backend selector
jeffdaily Jun 26, 2024
9534797
add composable_kernel submodule
jeffdaily Jun 27, 2024
851d4b5
copy bfloat16 gemm from fbgemm
jeffdaily Jun 27, 2024
9cbbb40
CK gemm header (#1445)
alugorey Jun 28, 2024
ae6a64b
use BLAS arg types for ck gemm kernel
jeffdaily Jul 1, 2024
9420e57
swap bf16 for float
jeffdaily Jul 2, 2024
6cdf163
Ck template (#1447)
alugorey Jul 2, 2024
442ccd7
Float CK GEMM backend - Initial prototype
alugorey Jul 17, 2024
de05526
bfloat16 and half prototype added
jeffdaily Jul 18, 2024
dc29f54
lint
jeffdaily Jul 18, 2024
6dddb73
split ck_gemm.hip into 3 files for faster compilation
jeffdaily Jul 18, 2024
248fb31
support all layouts
alugorey Aug 1, 2024
fb7957d
Only run CK tests on ROCm (#1527)
alugorey Aug 16, 2024
9c2a52c
Fix layouts (#1548)
alugorey Aug 20, 2024
11af46c
Initial commit
groenenboomj Jul 22, 2024
67d6d87
Remove AOT link for now
groenenboomj Jul 22, 2024
96f3c03
Also remove mem eff for now
groenenboomj Jul 22, 2024
0101f43
Remove extension implementation
groenenboomj Jul 22, 2024
e60c536
Move over API for flash v2
groenenboomj Jul 22, 2024
fb9ee10
Didn't need to do that, squash
groenenboomj Jul 22, 2024
cbc7d30
CK only version
groenenboomj Jul 22, 2024
3e19185
Massage and add headers to fix the build
groenenboomj Jul 23, 2024
54fcb02
type updates
groenenboomj Jul 24, 2024
4833464
Method porting is good now
groenenboomj Jul 24, 2024
69fbbcd
types done
groenenboomj Jul 24, 2024
b36e17e
Added a few temp files, namespaces
groenenboomj Jul 24, 2024
dea5266
REVERTME: Generate these kernels at build
groenenboomj Jul 25, 2024
03d9697
Also remove these mods
groenenboomj Jul 25, 2024
939424e
Fix forward
groenenboomj Aug 2, 2024
0924a33
Match signatures, fix dropout.
groenenboomj Aug 2, 2024
64611b2
Working for new APIs
groenenboomj Aug 23, 2024
2a8a6db
Testing hacks
groenenboomj Aug 29, 2024
d914581
Update CK
alugorey Sep 6, 2024
72f66e2
removing files
alugorey Sep 12, 2024
1bddcc7
Can generate files
alugorey Sep 12, 2024
f9e8937
Remove files that are supposed to be generated. causes issues because…
alugorey Sep 26, 2024
8622013
Update fwd_args for newer CK
alugorey Sep 26, 2024
a35a724
hipified new version of mha_varlen_bwd.cpp
alugorey Sep 26, 2024
c48861c
hipified new version of mha_bwd.cpp
alugorey Sep 26, 2024
1da210c
Add KBatch arg to ck_gemm_template
alugorey Sep 27, 2024
413755f
update CK
alugorey Oct 1, 2024
1ec4b85
change cmakelists
alugorey Oct 1, 2024
c0baa45
add all cmake stuff from before
alugorey Oct 1, 2024
db84054
Fix blob_list relative path
alugorey Oct 1, 2024
962a1e3
Add receipt flag
alugorey Oct 1, 2024
4d9297c
change receipt to 3
alugorey Oct 1, 2024
d80889e
attempt to add no-func-template
alugorey Oct 1, 2024
ef5fc2d
move property setting in cmakelists
alugorey Oct 1, 2024
5ce26a9
declare source file properties
alugorey Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,7 @@
path = third_party/cpp-httplib
url = https://github.com/yhirose/cpp-httplib.git
branch = v0.15.3
[submodule "third_party/composable_kernel"]
path = third_party/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = develop
97 changes: 95 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ cmake_policy(SET CMP0069 NEW)
# and it's possible on our Windows configs.
cmake_policy(SET CMP0092 NEW)

include(CMakePrintHelpers)



# Prohibit in-source builds
if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR})
message(FATAL_ERROR "In-source build are not supported")
Expand Down Expand Up @@ -773,7 +777,7 @@ set(CAFFE2_ALLOWLIST
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Build type not set - defaulting to Release")
set(CMAKE_BUILD_TYPE
"Release"
"Debug"
CACHE
STRING
"Choose the type of build from: Debug Release RelWithDebInfo MinSizeRel Coverage."
Expand Down Expand Up @@ -851,6 +855,8 @@ endif()
# aotriton build decision later.

include(cmake/Dependencies.cmake)
message("BEFORE USE_FLASH ATTENTION IS SUPPOSEDLY CREATED")
cmake_print_variables(USE_FLASH_ATTENTION)

cmake_dependent_option(
USE_FLASH_ATTENTION
Expand All @@ -860,6 +866,93 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

message("AFTER USE_FLASH_ATTENTION IS CREATED")


if(USE_FLASH_ATTENTION)
message("MADE IT HERE")
cmake_print_variables(Python3_EXECUTABLE)
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/fwd_blob_list.txt --receipt 3
)
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --list_blobs ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/bwd_blob_list.txt --receipt 3
)
message("MADE IT PAST EXECUTE_PROCESS")
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)

#add_custom_command(
# OUTPUT ${FMHA_FWD_GEN_BLOBS}
# COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
# --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_LIST_DIR}
#)

#add_custom_command(
# OUTPUT ${FMHA_BWD_GEN_BLOBS}
# COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
# --api bwd --output_dir ${CMAKE_CURRENT_LIST_DIR}
#)
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/ --receipt 3
)
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --output_dir ${CMAKE_CURRENT_LIST_DIR}/aten/src/ATen/native/transformers/hip/flash_attn/ --receipt 3
)

execute_process(COMMAND ls ${CMAKE_CURRENT_LIST_DIR})


set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/)
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})

set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/third_party/composable_kernel/example/ck_tile/)
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})

# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
endif()

set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)

# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()

# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)

target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

message("ANDY! WE MADE IT TO THE END OF OUR BLURB!")
endif()


# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
Expand All @@ -871,7 +964,7 @@ cmake_dependent_option(
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM" OFF)
"USE_CUDA" OFF)

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/BlasBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

namespace at {

enum class BlasBackend : int8_t { Cublas, Cublaslt };
enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck };

inline std::string BlasBackendToString(at::BlasBackend backend) {
switch (backend) {
case BlasBackend::Cublas:
return "at::BlasBackend::Cublas";
case BlasBackend::Cublaslt:
return "at::BlasBackend::Cublaslt";
case BlasBackend::Ck:
return "at::BlasBackend::Ck";
default:
TORCH_CHECK(false, "Unknown blas backend");
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# flash_attention sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
file(GLOB flash_attention_hip_cpp "native/transformers/hip/flash_attn/*.cpp")

#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
Expand All @@ -184,6 +185,14 @@ if(USE_FLASH_ATTENTION)

list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
add_subdirectory(native/transformers/hip/flash_attn)
set_source_files_properties(
${flash_attention_hip_cpp}
DIRECTORY "native/transformers/hip/flash_attn/"
PROPERTIES
COMPILE_FLAGS "-Wno-undefined-func-template"
)
list(APPEND native_transformers_hip_cpp ${flash_attention_hip_cpp})
endif()

if(USE_MEM_EFF_ATTENTION)
Expand Down Expand Up @@ -309,6 +318,7 @@ if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha/)
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
if (b != at::BlasBackend::Cublas) {
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class TORCH_API Context {
static bool hasCuBLASLt() {
return detail::getCUDAHooks().hasCuBLASLt();
}
static bool hasROCM() {
return detail::getCUDAHooks().hasROCM();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
#include <ATen/native/hip/ck_gemm.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// needed to work around calling rocblas API instead of hipblas API
Expand Down Expand Up @@ -792,6 +793,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::gemm_internal_cublas: not implemented for ", typeid(Dtype).name());
}


template <>
void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
Expand Down Expand Up @@ -1000,6 +1002,11 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
#endif
else {
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
}
Expand All @@ -1011,6 +1018,11 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
#endif
else {
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
}
Expand Down Expand Up @@ -1054,6 +1066,11 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#endif
else {
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
Expand All @@ -1065,6 +1082,11 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#endif
else {
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}

IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/hip/ck_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <ATen/OpMathType.h>
#include <ATen/hip/HIPBlas.h>
namespace at::native {


template <typename Dtype>
inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
}

template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float));
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));



} // namespace at::native
Loading