Skip to content

Commit

Permalink
Fix flashattn
Browse files Browse the repository at this point in the history
  • Loading branch information
jdecourval committed Apr 30, 2024
1 parent 7eb14d5 commit 3e560c8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ add_library(ggml OBJECT
)

target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
target_compile_features (ggml PUBLIC c_std_11) # don't bump
target_compile_features (ggml PUBLIC cxx_std_17) # don't bump

target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})

Expand Down
9 changes: 2 additions & 7 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,11 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}

static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
GGML_UNUSED(a);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}

static __device__ __forceinline__ float warp_reduce_max(float x) {
Expand All @@ -399,8 +394,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {

#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL

#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA

// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
Expand Down
10 changes: 8 additions & 2 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
inline __device__ __half2 __hmax2(__half2 x, __half2 y) {
return __half2_raw{
{{__hmax(__half2_raw(x).x, __half2_raw(y).x),
__hmax(__half2_raw(x).y, __half2_raw(y).y)}}
};
}
#else
#include <mma.h>
namespace wmma = nvcuda::wmma;
Expand Down Expand Up @@ -339,7 +345,7 @@ static __global__ void flash_attn_ext_f16(
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(KQ_c[j], 0.0f);
wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f});
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
Expand Down Expand Up @@ -470,7 +476,7 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f});
}

#pragma unroll
Expand Down
7 changes: 0 additions & 7 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15473,13 +15473,6 @@ struct llama_context * llama_new_context_with_model(
cparams.flash_attn = false;
}

#ifdef GGML_USE_HIPBLAS
if (cparams.flash_attn) {
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
cparams.flash_attn = false;
}
#endif

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
Expand Down

0 comments on commit 3e560c8

Please sign in to comment.