Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#2 from kircle888/smask_up
Browse files Browse the repository at this point in the history
fix sparsemask bug
  • Loading branch information
kircle888 authored May 27, 2024
2 parents c3212be + 12d3752 commit 43c517d
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,21 +521,25 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
}

#define SPARSE_MASKED_DOWN(N_BLOCK) \
((m_block * kBlockM) >= gSparseMaskDownMax[(N_BLOCK)])
#define SPARSE_MASKED_UP(N_BLOCK) \
(!Is_causal && (m_block + 1) * kBlockM < gSparseMaskUpMin[(N_BLOCK)])
#define SPARSE_MASKED(N_BLOCK) \
(SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK))
// These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) {
if (Is_sparse_attn_mask) {
if ((m_block * kBlockM >= gSparseMaskDownMax[n_block]) ||
(!Is_causal &&
(m_block + 1) * kBlockM < gSparseMaskUpMin[n_block])) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
if (!Is_causal)
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
if(Return_softmax)
tPgP.data() = tPgP.data() + (-kBlockN);
continue;
if (Is_sparse_attn_mask && SPARSE_MASKED(n_block)) {
if (n_block == 0) {
flash::cp_async_wait<0>();
__syncthreads();
}
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
if (!Is_causal)
gSparseMaskUp.data() = gSparseMaskUp.data() + (-kBlockN);
if (Return_softmax) tPgP.data() = tPgP.data() + (-kBlockN);
continue;
}
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
Expand All @@ -556,6 +560,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
if (Is_sparse_attn_mask) {
auto in_block = n_block - 1;
for (; in_block > 0 && SPARSE_MASKED(in_block); --in_block) {
tKgK.data() =
tKgK.data() + (-int(kBlockN * params.k_row_stride));
}
__syncwarp();
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
Expand Down

0 comments on commit 43c517d

Please sign in to comment.