Skip to content

Commit

Permalink
Merge pull request #15 from Xreki/update_repo
Browse files Browse the repository at this point in the history
Update the latest flash-attention repo
  • Loading branch information
Xreki authored Aug 4, 2023
2 parents ee74a8f + cba96bd commit 4d94946
Show file tree
Hide file tree
Showing 20 changed files with 785 additions and 280 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Return:
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Expand Down Expand Up @@ -131,7 +131,7 @@ These functions have been renamed:
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```python
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
Expand Down
186 changes: 142 additions & 44 deletions benchmarks/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,168 @@
from functools import partial
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_qkvpacked_func

def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
try:
from triton.ops.flash_attention import attention as attention_triton
except ImportError:
attention_triton = None

try:
import xformers.ops as xops
except ImportError:
xops = None


def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0


def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
attn_mask: (batch_size, seqlen)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2)
seqlen = qkv.shape[1]
d = qkv.shape[-1]
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
# return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
return output.to(dtype=qkv.dtype)


torch.manual_seed(0)
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean


repeats = 30
batch_size = 64
nheads = 16
seqlen = 1024
n = 1024
d = n // nheads
dropout_p = 0.1
causal = False
dtype = torch.float16
device = 'cuda'
dtype = torch.float16

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0

methods = (["Flash2", "Pytorch"]
+ (["Triton"] if attention_triton is not None else [])
+ (["xformers"] if xops is not None else []))

time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b

try:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
except: # Skip if OOM
f, b = float('nan'), float('nan')
time_f[config, "Pytorch"] = f
time_b[config, "Pytorch"] = b

if attention_triton is not None:
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
# Try both values of sequence_parallel and pick the faster one
try:
f, b = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
False, repeats=repeats, verbose=False
)
except:
f, b = float('nan'), float('inf')
try:
_, b0 = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
True, repeats=repeats, verbose=False
)
except:
b0 = float('inf')
time_f[config, "Triton"] = f
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')

if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
)
time_f[config, "xformers"] = f
time_b[config, "xformers"] = b

print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
)


x = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

lengths = torch.randint(seqlen - 20, seqlen, (batch_size, 1), device='cuda')
attention_mask_bool = repeat(torch.arange(seqlen, device='cuda'), 's -> b s', b=batch_size) < lengths
attention_mask = torch.zeros(batch_size, seqlen, device='cuda', dtype=dtype)
attention_mask[~attention_mask_bool] = -10000.0
attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s')

x_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, attention_mask_bool)
qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()

fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
)
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention')
# with open('flash2_attn_time.plk', 'wb') as fp:
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
50 changes: 36 additions & 14 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_scale,
is_causal);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
Expand All @@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_scale,
is_causal);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
Expand All @@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
Expand Down Expand Up @@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
c10::optional<at::Generator> gen_) {
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Expand Down Expand Up @@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;

if (is_dropout) {
if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}

launch(params, stream, /*configure=*/false);
Expand Down Expand Up @@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
c10::optional<at::Generator> gen_
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;

if (is_dropout) {
if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}

launch(params, stream, /*configure=*/false);
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ struct Flash_fwd_params : public Qkv_params {
// Random state.
at::PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_causal;
};
Expand Down
Loading

0 comments on commit 4d94946

Please sign in to comment.