forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from Xreki/update_repo
Update the latest flash-attention repo
- Loading branch information
Showing
20 changed files
with
785 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,168 @@ | ||
from functools import partial | ||
# Install the newest triton version with | ||
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" | ||
import pickle | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from einops import rearrange, repeat | ||
|
||
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined | ||
from flash_attn.bert_padding import unpad_input, pad_input | ||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func | ||
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward | ||
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined | ||
|
||
from flash_attn import flash_attn_qkvpacked_func | ||
|
||
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): | ||
try: | ||
from triton.ops.flash_attention import attention as attention_triton | ||
except ImportError: | ||
attention_triton = None | ||
|
||
try: | ||
import xformers.ops as xops | ||
except ImportError: | ||
xops = None | ||
|
||
|
||
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): | ||
assert mode in ["fwd", "bwd", "fwd_bwd"] | ||
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) | ||
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) | ||
|
||
def efficiency(flop, time): | ||
return (flop / time / 10**12) if not math.isnan(time) else 0.0 | ||
|
||
|
||
def attention_pytorch(qkv, dropout_p=0.0, causal=True): | ||
""" | ||
Arguments: | ||
qkv: (batch_size, seqlen, 3, nheads, head_dim) | ||
attn_mask: (batch_size, seqlen) | ||
dropout_p: float | ||
Output: | ||
output: (batch_size, seqlen, nheads, head_dim) | ||
attention: softmax after dropout | ||
""" | ||
q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2) | ||
seqlen = qkv.shape[1] | ||
d = qkv.shape[-1] | ||
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) | ||
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) | ||
batch_size, seqlen, _, nheads, d = qkv.shape | ||
q, k, v = qkv.unbind(dim=2) | ||
q = rearrange(q, 'b t h d -> (b h) t d') | ||
k = rearrange(k, 'b s h d -> (b h) d s') | ||
softmax_scale = 1.0 / math.sqrt(d) | ||
# Preallocate attn_weights for `baddbmm` | ||
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) | ||
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), | ||
'(b h) t s -> b h t s', h=nheads) | ||
if causal: | ||
causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) | ||
scores.masked_fill_(causal_mask, float('-inf')) | ||
# "triu_tril_cuda_template" not implemented for 'BFloat16' | ||
# So we have to construct the mask in float | ||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) | ||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) | ||
scores = scores + causal_mask.to(dtype=scores.dtype) | ||
attention = torch.softmax(scores, dim=-1) | ||
attention_drop = F.dropout(attention, dropout_p) | ||
output = torch.einsum('bhts,bshd->bthd', attention_drop , v) | ||
# return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) | ||
return output.to(dtype=qkv.dtype) | ||
|
||
|
||
torch.manual_seed(0) | ||
def time_fwd_bwd(func, *args, **kwargs): | ||
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) | ||
return time_f[1].mean, time_b[1].mean | ||
|
||
|
||
repeats = 30 | ||
batch_size = 64 | ||
nheads = 16 | ||
seqlen = 1024 | ||
n = 1024 | ||
d = n // nheads | ||
dropout_p = 0.1 | ||
causal = False | ||
dtype = torch.float16 | ||
device = 'cuda' | ||
dtype = torch.float16 | ||
|
||
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] | ||
causal_vals = [False, True] | ||
headdim_vals = [64, 128] | ||
dim = 2048 | ||
dropout_p = 0.0 | ||
|
||
methods = (["Flash2", "Pytorch"] | ||
+ (["Triton"] if attention_triton is not None else []) | ||
+ (["xformers"] if xops is not None else [])) | ||
|
||
time_f = {} | ||
time_b = {} | ||
time_f_b = {} | ||
speed_f = {} | ||
speed_b = {} | ||
speed_f_b = {} | ||
for causal in causal_vals: | ||
for headdim in headdim_vals: | ||
for batch_size, seqlen in bs_seqlen_vals: | ||
config = (causal, headdim, batch_size, seqlen) | ||
nheads = dim // headdim | ||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, | ||
requires_grad=True) | ||
f, b = time_fwd_bwd( | ||
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False | ||
) | ||
time_f[config, "Flash2"] = f | ||
time_b[config, "Flash2"] = b | ||
|
||
try: | ||
qkv = qkv.detach().requires_grad_(True) | ||
f, b = time_fwd_bwd( | ||
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False | ||
) | ||
except: # Skip if OOM | ||
f, b = float('nan'), float('nan') | ||
time_f[config, "Pytorch"] = f | ||
time_b[config, "Pytorch"] = b | ||
|
||
if attention_triton is not None: | ||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, | ||
requires_grad=True) for _ in range(3)] | ||
# Try both values of sequence_parallel and pick the faster one | ||
try: | ||
f, b = time_fwd_bwd( | ||
attention_triton, q, k, v, causal, headdim**(-0.5), | ||
False, repeats=repeats, verbose=False | ||
) | ||
except: | ||
f, b = float('nan'), float('inf') | ||
try: | ||
_, b0 = time_fwd_bwd( | ||
attention_triton, q, k, v, causal, headdim**(-0.5), | ||
True, repeats=repeats, verbose=False | ||
) | ||
except: | ||
b0 = float('inf') | ||
time_f[config, "Triton"] = f | ||
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') | ||
|
||
if xops is not None: | ||
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, | ||
requires_grad=True) for _ in range(3)] | ||
f, b = time_fwd_bwd( | ||
xops.memory_efficient_attention, q, k, v, | ||
attn_bias=xops.LowerTriangularMask() if causal else None, | ||
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) | ||
) | ||
time_f[config, "xformers"] = f | ||
time_b[config, "xformers"] = b | ||
|
||
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") | ||
for method in methods: | ||
time_f_b[config, method] = time_f[config, method] + time_b[config, method] | ||
speed_f[config, method] = efficiency( | ||
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), | ||
time_f[config, method] | ||
) | ||
speed_b[config, method] = efficiency( | ||
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), | ||
time_b[config, method] | ||
) | ||
speed_f_b[config, method] = efficiency( | ||
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), | ||
time_f_b[config, method] | ||
) | ||
print( | ||
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " | ||
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " | ||
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" | ||
) | ||
|
||
|
||
x = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True) | ||
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) | ||
|
||
lengths = torch.randint(seqlen - 20, seqlen, (batch_size, 1), device='cuda') | ||
attention_mask_bool = repeat(torch.arange(seqlen, device='cuda'), 's -> b s', b=batch_size) < lengths | ||
attention_mask = torch.zeros(batch_size, seqlen, device='cuda', dtype=dtype) | ||
attention_mask[~attention_mask_bool] = -10000.0 | ||
attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s') | ||
|
||
x_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, attention_mask_bool) | ||
qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3, | ||
h=nheads).detach().requires_grad_() | ||
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_() | ||
|
||
fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func( | ||
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal | ||
) | ||
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention') | ||
fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal) | ||
benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention') | ||
# with open('flash2_attn_time.plk', 'wb') as fp: | ||
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.