diff --git a/csrc/capi/flash_attn.cu b/csrc/capi/flash_attn.cu index 9cafd411b..997c24b01 100644 --- a/csrc/capi/flash_attn.cu +++ b/csrc/capi/flash_attn.cu @@ -87,12 +87,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, float p_dropout, float softmax_scale, bool is_causal, - bool is_bf16) { + bool is_bf16) { // Reset the parameters memset(¶ms, 0, sizeof(params)); params.is_bf16 = is_bf16; - // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -185,7 +184,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float p_dropout, float softmax_scale, bool is_causal, - bool is_bf16) { + bool is_bf16, + const int num_splits=0) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -196,7 +196,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, softmax_lse_d, p_dropout, softmax_scale, - is_causal, is_bf16); + is_causal, + is_bf16); // Set the pointers and strides. params.do_ptr = dout; @@ -225,6 +226,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; + params.num_splits = num_splits; } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { @@ -431,6 +433,7 @@ bool flash_attn_bwd(const void * const dout, const float softmax_scale, const bool is_causal, const bool is_bf16, + const int num_splits, cudaStream_t stream, uint64_t seed, uint64_t offset) { @@ -465,11 +468,13 @@ bool flash_attn_bwd(const void * const dout, num_heads, num_heads_k, head_size, head_size_rounded, const_cast(q), - const_cast(k), - const_cast(v), - const_cast(out), + const_cast(k), + const_cast(v), + const_cast(out), const_cast(dout), - dq, dk, dv, + dq, + dk, + dv, nullptr, nullptr, loop ? dq_accum : nullptr, @@ -480,7 +485,8 @@ bool flash_attn_bwd(const void * const dout, p_dropout, softmax_scale, is_causal, - is_bf16); + is_bf16, + num_splits); auto launch = &run_mha_bwd; @@ -527,6 +533,7 @@ bool flash_attn_varlen_bwd(const void * const dout, const float softmax_scale, const bool is_causal, const bool is_bf16, + const int num_splits, cudaStream_t stream, uint64_t seed, uint64_t offset) { @@ -562,7 +569,9 @@ bool flash_attn_varlen_bwd(const void * const dout, const_cast(v), const_cast(out), const_cast(dout), - dq, dk, dv, + dq, + dk, + dv, const_cast(cu_seqlens_q), const_cast(cu_seqlens_k), loop ? dq_accum : nullptr, @@ -573,7 +582,8 @@ bool flash_attn_varlen_bwd(const void * const dout, p_dropout, softmax_scale, is_causal, - is_bf16); + is_bf16, + num_splits); auto launch = &run_mha_bwd; diff --git a/csrc/capi/flash_attn.h b/csrc/capi/flash_attn.h index ea54226d9..2fdf22d37 100644 --- a/csrc/capi/flash_attn.h +++ b/csrc/capi/flash_attn.h @@ -85,6 +85,7 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea const float softmax_scale, const bool is_causal, const bool is_bf16, + const int num_splits, cudaStream_t stream, uint64_t seed, uint64_t offset); @@ -116,6 +117,7 @@ bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x h const float softmax_scale, const bool is_causal, const bool is_bf16, + const int num_splits, cudaStream_t stream, uint64_t seed, uint64_t offset); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 5bffda6c3..2fd7b1612 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -140,6 +140,7 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + int num_splits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 7c9638b54..42872c60f 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -1508,8 +1508,16 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; - - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + constexpr int kBlockN = Kernel_traits::kBlockN; + if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0; + int loop_step_x = 0; + for(int i = 0; i < params.seqlen_k; i+= kBlockN) { + compute_dq_dk_dv_1colblock(params, bidb, bidh, loop_step_x); + loop_step_x += 1; + } + } else { + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 15aa2ad8d..92c41e308 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -50,7 +50,7 @@ template void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); - const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + const int num_n_block = params.num_splits == 1 ? params.num_splits : (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; dim3 grid_n(num_n_block, params.b, params.h); flash_bwd_dot_do_o_kernel<<>>(params);