Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4 from kircle888/smask_switch
Browse files Browse the repository at this point in the history
remove cudaMallocAsync
  • Loading branch information
kircle888 authored Jun 18, 2024
2 parents 1e82ffc + 13d73df commit f20f94d
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 71 deletions.
16 changes: 15 additions & 1 deletion csrc/capi/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void set_params_fprop_strided(Flash_fwd_params &params,
void * attn_mask = nullptr,
void * attn_mask_start_row_indices = nullptr,
void * attn_mask_end_row_indices = nullptr,
void * flashmask_maxmin_ptr = nullptr,
const int attn_mask_start_row = 0,
int mask_head_mod_size = 0,
int mask_seq_q_mod_size = 0) {
Expand Down Expand Up @@ -188,10 +189,15 @@ void set_params_fprop_strided(Flash_fwd_params &params,
// sparse mask row index
params.attn_mask_start_row_indices_ptr = attn_mask_start_row_indices;
params.attn_mask_end_row_indices_ptr = attn_mask_end_row_indices;
params.flashmask_maxmin_ptr = static_cast<int*>(flashmask_maxmin_ptr);
params.attn_mask_start_row = attn_mask_start_row;
params.enable_mask_bypass = seqlen_q >= 1024;
if(attn_mask_start_row_indices!=nullptr||attn_mask_end_row_indices!=nullptr) {
params.h_sparsemask = mask_head_mod_size;
params.h_h_sparsemask_ratio = h / mask_head_mod_size;
if (params.enable_mask_bypass){
ASSERT_CHECK(params.flashmask_maxmin_ptr != nullptr);
}
}

// Set the different scale values.
Expand Down Expand Up @@ -466,6 +472,7 @@ void set_params_dgrad_strided(Flash_bwd_params &params,
void * attn_mask = nullptr,
void * attn_mask_start_row_indices = nullptr,
void * attn_mask_end_row_indices = nullptr,
void * flashmask_maxmin_ptr = nullptr,
const int attn_mask_start_row = 0,
int mask_head_mod_size = 0,
int mask_seq_q_mod_size = 0) {
Expand All @@ -490,6 +497,7 @@ void set_params_dgrad_strided(Flash_bwd_params &params,
attn_mask,
attn_mask_start_row_indices,
attn_mask_end_row_indices,
flashmask_maxmin_ptr,
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down Expand Up @@ -565,6 +573,7 @@ bool flash_attn_fwd(const void * const q,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const void * const flashmask_maxmin_ptr,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -621,6 +630,7 @@ bool flash_attn_fwd(const void * const q,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
const_cast<void *>(attn_mask_end_row_indices),
const_cast<void *>(flashmask_maxmin_ptr),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down Expand Up @@ -725,10 +735,11 @@ bool flash_attn_varlen_fwd(const void * const q,
const_cast<void *>(attn_mask),
nullptr,
nullptr,
nullptr,
-1,
mask_head_mod_size,
mask_seq_q_mod_size
);
);

params.rng_state = static_cast<uint64_t*>(rng_state);

Expand Down Expand Up @@ -800,6 +811,7 @@ bool flash_attn_bwd(const void * const dout,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const void * const flashmask_maxmin_ptr,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -893,6 +905,7 @@ bool flash_attn_bwd(const void * const dout,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
const_cast<void *>(attn_mask_end_row_indices),
const_cast<void *>(flashmask_maxmin_ptr),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down Expand Up @@ -1040,6 +1053,7 @@ bool flash_attn_varlen_bwd(const void * const dout,
const_cast<void *>(attn_mask),
nullptr,
nullptr,
nullptr,
-1,
mask_head_mod_size,
mask_seq_q_mod_size);
Expand Down
2 changes: 2 additions & 0 deletions csrc/capi/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const void * const flashmask_maxmin_ptr,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down Expand Up @@ -130,6 +131,7 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const void * const attn_mask_end_row_indices,
const void * const flashmask_maxmin_ptr,
const int attn_mask_start_row,
const int q_row_stride,
const int k_row_stride,
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ struct Flash_fwd_params : public Qkv_params {
bool varlen_padded_input = false;
void * __restrict__ attn_mask_start_row_indices_ptr;
void * __restrict__ attn_mask_end_row_indices_ptr;
int *__restrict__ flashmask_maxmin_ptr = nullptr;
int *__restrict__ attn_sparsemask_up_nblockmax = nullptr;
int *__restrict__ attn_sparsemask_up_nblockmin = nullptr;
int *__restrict__ attn_sparsemask_down_nblockmax = nullptr;
int *__restrict__ attn_sparsemask_down_nblockmin = nullptr;
int attn_mask_start_row;
int h_sparsemask;
int h_h_sparsemask_ratio;
bool enable_mask_bypass;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 4 additions & 2 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
int attn_mask_end_row = 0;
if (Is_sparse_attn_mask) {
const bool enable_mask_bypass = params.enable_mask_bypass;

if (Is_sparse_attn_mask && enable_mask_bypass) {
m_block_max = min(m_block_max,
cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM));
attn_mask_start_row = gSparseMaskDownMin[n_block];
Expand Down Expand Up @@ -714,7 +716,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
if (!Is_causal && Is_sparse_attn_mask) {
if (!Is_causal && Is_sparse_attn_mask && enable_mask_bypass) {
m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM);
}

Expand Down
5 changes: 1 addition & 4 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_deterministic = params.num_splits == 1;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
params.attn_mask_start_row = (int)(params.attn_mask_start_row / Kernel_traits::kBlockM) * Kernel_traits::kBlockM;
int *nblock_sparsemask = prepare_sparsemask<Kernel_traits>(params, stream);
prepare_sparsemask<Kernel_traits>(params, stream);
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
Expand All @@ -84,9 +84,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
});
});
});
if (nblock_sparsemask != nullptr) {
cudaFreeAsync(nblock_sparsemask, stream);
}
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
Expand Down
Loading

0 comments on commit f20f94d

Please sign in to comment.