Skip to content

Commit

Permalink
Refactor4: Split fw/bw operators
Browse files Browse the repository at this point in the history
ghstack-source-id: fdf6e6e9745377fdf6e25885f8cdcd56e30fb214
Pull Request resolved: #560
  • Loading branch information
danthe3rd committed Dec 7, 2022
1 parent e803ef0 commit 9befff7
Show file tree
Hide file tree
Showing 13 changed files with 1,034 additions and 925 deletions.
204 changes: 108 additions & 96 deletions tests/test_mem_eff_attention.py

Large diffs are not rendered by default.

63 changes: 29 additions & 34 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils import benchmark_main_helper

import xformers.ops
import xformers.ops.fmha as fmha

CHECK_CORRECTNESS = True
torch.backends.cuda.matmul.allow_tf32 = False
Expand Down Expand Up @@ -149,23 +150,15 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
B, M, H, K = shape
_, q, k, v = create_tensors(shape, dtype)

dispatch = xformers.ops.AttentionOpDispatch(
dtype=dtype,
device=device,
k=K,
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
q_len=M,
)
inp = fmha.Inputs(query=q, key=k, value=v)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
op = (fmha._dispatch_fw(inp), None) if FORCE_OP is None else FORCE_OP
except NotImplementedError:
return
if not op.supports(dispatch):
if not op[0].supports(inp):
return

attn_bias = create_attn_bias(
inp.attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=H,
Expand All @@ -174,6 +167,8 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
device=device,
dtype=dtype,
)
if not op[0].supports(inp):
return

dtype_str = {
torch.bfloat16: "b16",
Expand All @@ -183,12 +178,14 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
sub_label = f"{dtype_str} B={B}, M={M}, H={H}, K={K}"

try:
r = xformers.ops.memory_efficient_attention(q, k, v, attn_bias, op=op).float()
r = xformers.ops.memory_efficient_attention(
q, k, v, inp.attn_bias, op=op
).float()
rr = ref_attention(
q.float(),
k.float(),
v.float(),
attn_bias,
inp.attn_bias,
)
assert not CHECK_CORRECTNESS or (r - rr).abs().max() < 4e-3, (
(r - rr).abs().max()
Expand All @@ -203,12 +200,12 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
"q": q,
"k": k,
"v": v,
"attn_bias": attn_bias,
"attn_bias": inp.attn_bias,
"p": p,
"fn": partial(xformers.ops.memory_efficient_attention, op=op),
},
label=f"attention (attn_bias={attn_bias_type})",
description=op.NAME,
description=op[0].NAME,
sub_label=sub_label,
num_threads=num_threads,
)
Expand All @@ -218,7 +215,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
"q": q,
"k": k,
"v": v,
"attn_bias": attn_bias,
"attn_bias": inp.attn_bias,
"p": p,
"fn": ref_attention,
},
Expand All @@ -233,23 +230,19 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
B, M, H, K = shape
qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True)

dispatch = xformers.ops.AttentionOpDispatch(
dtype=dtype,
device=device,
k=K,
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
q_len=M,
)
inp = fmha.Inputs(query=q, key=k, value=v)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
op = (
(fmha._dispatch_fw(inp), fmha._dispatch_bw(inp))
if FORCE_OP is None
else FORCE_OP
)
except NotImplementedError:
return
if not op.supports(dispatch):
if not op[0].supports(inp) or not op[1].supports(inp):
return

attn_bias = create_attn_bias(
inp.attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=H,
Expand All @@ -258,6 +251,8 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
device=device,
dtype=dtype,
)
if not op[0].supports(inp) or not op[1].supports(inp):
return

dtype_str = {
torch.bfloat16: "b16",
Expand All @@ -266,7 +261,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
}[dtype]
sub_label = f"{dtype_str} B={B}, M={M}, H={H}, K={K}"

out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias, p, op=op)
out = xformers.ops.memory_efficient_attention(q, k, v, inp.attn_bias, p, op=op)
grad_benchmark = torch.ones_like(q)

yield benchmark.Timer(
Expand All @@ -276,21 +271,21 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
"grad": grad_benchmark,
},
label=f"attention backward (attn_bias={attn_bias_type})",
description=op.NAME,
description=op[1].NAME,
sub_label=sub_label,
num_threads=num_threads,
)
del out

try:
qkv.grad = None
r = xformers.ops.memory_efficient_attention(q, k, v, attn_bias, op=op)
r = xformers.ops.memory_efficient_attention(q, k, v, inp.attn_bias, op=op)
r.backward(torch.ones_like(q))

grad = cast(torch.Tensor, qkv.grad)
qkv.grad = None

rr = ref_attention(q, k, v, attn_bias)
rr = ref_attention(q, k, v, inp.attn_bias)
rr.backward(torch.ones_like(q))
atol = 2e-4 + 2e-6 * K * M * math.sqrt(B) * math.sqrt(M)
# type: ignore
Expand All @@ -303,7 +298,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": ref_attention(q, k, v, attn_bias),
"out": ref_attention(q, k, v, inp.attn_bias),
"grad": grad_benchmark,
},
label=f"attention backward (attn_bias={attn_bias_type})",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ mem_efficient_attention_backward_cutlass(
p.scale = float(1.0 / std::sqrt(float(p.head_dim)));
}

ASSIGN_CHECK_OVERFLOW(p.lse_strideM, logsumexp.stride(1));
ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0));
ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1));
ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ struct AttentionBackwardKernel {
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t lse_strideM;
int32_t num_batches;

int64_t gO_strideB;
Expand All @@ -226,16 +227,13 @@ struct AttentionBackwardKernel {
int64_t gV_strideH;

CUTLASS_DEVICE void advance_to_block() {
constexpr int32_t kAlignLSE = 32; // block size of backward
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;

int64_t batch_id = blockIdx.z;
int32_t head_id = blockIdx.y;

query_ptr += batch_id * q_strideB + head_id * q_strideH;
key_ptr += batch_id * k_strideB + head_id * k_strideH;
value_ptr += batch_id * v_strideB + head_id * v_strideH;
logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim;
logsumexp_ptr += (batch_id * num_heads + head_id) * lse_strideM;
output_ptr += batch_id * o_strideB + head_id * o_strideH;
grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
delta_ptr += (batch_id * num_heads + head_id) * num_queries;
Expand Down Expand Up @@ -843,6 +841,7 @@ struct AttentionBackwardKernel {
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
XFORMERS_CHECK(p.lse_strideM % 8 == 0, "LSE is not correctly aligned");
XFORMERS_CHECK(
p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned");
XFORMERS_CHECK(
Expand Down
6 changes: 6 additions & 0 deletions xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
MemoryEfficientAttentionTritonFwdFlashBwOp,
TritonFlashAttentionOp,
memory_efficient_attention,
memory_efficient_attention_backward,
memory_efficient_attention_forward,
memory_efficient_attention_forward_requires_grad,
)
from .swiglu_op import (
SwiGLU,
Expand Down Expand Up @@ -65,6 +68,9 @@ def masked_matmul(a, b, mask=None):
"MemoryEfficientAttentionOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"memory_efficient_attention",
"memory_efficient_attention_backward",
"memory_efficient_attention_forward",
"memory_efficient_attention_forward_requires_grad",
"SwiGLU",
"SwiGLUEagerOp",
"SwiGLUFusedOp",
Expand Down
Loading

0 comments on commit 9befff7

Please sign in to comment.