Skip to content

Commit

Permalink
RFC: Ops dispatch (#356)
Browse files Browse the repository at this point in the history
* Ops dispatch

* CI: Fix doc build

* memory_efficient_attention raises when no implementation is available

* type: ignore

* Fix torch.device/str comparison

* Make mypy happy

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>
  • Loading branch information
2 people authored and fmassa committed Aug 10, 2022
1 parent a04fae4 commit 2568a84
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 29 deletions.
52 changes: 34 additions & 18 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
import math
from functools import partial
from typing import cast

import torch
from torch.utils import benchmark
Expand Down Expand Up @@ -40,8 +41,9 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0):


p = 0.0
op = xformers.ops.MemoryEfficientAttentionOp
# op = xformers.ops.MemoryEfficientAttentionGenericForwardOp
FORCE_OP = None
# FORCE_OP = xformers.ops.MemoryEfficientAttentionOp
# FORCE_OP = xformers.ops.MemoryEfficientAttentionGenericForwardOp


def product_dict(**kwargs):
Expand All @@ -63,23 +65,28 @@ def product_dict(**kwargs):

def benchmark_forward(shape, num_threads: int, use_attn_bias: bool, dtype):
B, M, K = shape
if (
K > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or (dtype not in op.SUPPORTED_DTYPES)
):
return
q = torch.rand(shape, device=device, dtype=dtype)
attn_bias = None
if use_attn_bias:
attn_bias = torch.rand(
shape[0], 1, shape[1], device=device, dtype=dtype
).expand(shape[0], shape[1], shape[1])

dispatch = xformers.ops.AttentionOpDispatch(
dtype=dtype, device=device, k=K, has_attn_bias=use_attn_bias, has_dropout=False
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
except NotImplementedError:
return
if not op.supports(dispatch):
return

dtype_str = {
torch.half: "f16",
torch.float: "f32",
}[dtype]
sub_label = f"{dtype_str} B={B}, M={M}, K={K}"
sub_label = f"{dtype_str} {op.NAME} B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op).float()
Expand Down Expand Up @@ -128,25 +135,34 @@ def benchmark_backward(shape, num_threads: int, use_attn_bias: bool, dtype):
attn_bias = torch.rand(shape[0], 1, shape[1], device=device).expand(
shape[0], shape[1], shape[1]
)
sub_label = f"B={B}, M={M}, K={K}"

if (
K > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
# only fp32 is supported at the moment
or (dtype not in {torch.float})
):

dispatch = xformers.ops.AttentionOpDispatch(
dtype=dtype, device=device, k=K, has_attn_bias=use_attn_bias, has_dropout=False
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
except NotImplementedError:
return
if not op.supports(dispatch):
return

dtype_str = {
torch.half: "f16",
torch.float: "f32",
}[dtype]
sub_label = f"{dtype_str} {op.NAME} B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op)
r.backward(torch.ones_like(q))

grad = q.grad
grad = cast(torch.Tensor, q.grad)
q.grad = None

rr = ref_attention(q, q, q, attn_bias)
rr.backward(torch.ones_like(q))
atol = 2e-4 + 2e-6 * K * M * math.sqrt(B) * math.sqrt(M)
# type: ignore
assert (grad - q.grad).abs().max() < atol, f"{(grad - q.grad).abs().max()}"
q.grad = None
del r, rr, grad
Expand Down
127 changes: 116 additions & 11 deletions xformers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


import math
from typing import Optional
from dataclasses import dataclass
from typing import Any, List, Optional, Set, Type, Union

import torch

Expand Down Expand Up @@ -40,17 +41,47 @@ def no_such_operator(*args, **kwargs):

try:
return getattr(torch.ops.xformers, name)
except RuntimeError:
except (RuntimeError, AttributeError):
return no_such_operator


class MemoryEfficientAttentionOp(torch.autograd.Function):
FORWARD_OPERATOR = _get_xformers_operator("efficient_attention")
SUPPORTED_DEVICES = {"cuda", "cpu"}
SUPPORTED_DTYPES = {torch.float}
SUPPORTED_MAX_K: float = 32
SUPPORTS_ATTN_BIAS = True
SUPPORTS_DROPOUT = True
def _ref_attention(
query, key, value, compute_logsumexp: bool, attn_bias=None, p: float = 0.0
):
query = query * (1.0 / query.shape[-1] ** 0.5)
if attn_bias is None:
attn = query @ key.transpose(-2, -1)
else:
# equivalent to (query @ key.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
attn = torch.baddbmm(attn_bias, query, key.transpose(-2, -1))
dtype = attn.dtype
attn = attn.to(torch.float).softmax(-1).to(dtype)
if p > 0:
attn = torch.nn.functional.dropout(attn, p=p)
rng_seed = 0
rng_offset = 0
return (
attn @ value,
attn.logsumexp(-1) if compute_logsumexp else None,
rng_seed,
rng_offset,
)


class AttentionOpBase(torch.autograd.Function):
"""
Manually doing what our efficient kernels do with Pytorch.
Allows to support forward/backwards when not implemented otherwise
"""

FORWARD_OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTS_ATTN_BIAS: bool
SUPPORTS_DROPOUT: bool
NAME: str

@classmethod
def forward(cls, ctx, query, key, value, attn_bias, p):
Expand All @@ -68,6 +99,31 @@ def forward(cls, ctx, query, key, value, attn_bias, p):
ctx.rng_offset = rng_offset
return out

@classmethod
def supports(cls, d: "AttentionOpDispatch") -> bool:
device_type = d.device if isinstance(d.device, str) else d.device.type
if device_type not in cls.SUPPORTED_DEVICES:
return False
if d.dtype not in cls.SUPPORTED_DTYPES:
return False
if d.k > cls.SUPPORTED_MAX_K:
return False
if d.has_attn_bias and not cls.SUPPORTS_ATTN_BIAS:
return False
if d.has_dropout and not cls.SUPPORTS_DROPOUT:
return False
return True


class MemoryEfficientAttentionOp(AttentionOpBase):
FORWARD_OPERATOR = _get_xformers_operator("efficient_attention")
SUPPORTED_DEVICES = {"cuda", "cpu"}
SUPPORTED_DTYPES = {torch.float}
SUPPORTED_MAX_K: float = 32
SUPPORTS_ATTN_BIAS = True
SUPPORTS_DROPOUT = True
NAME = "small_k"

@staticmethod
def backward(ctx, grad):
query, key, value, lse, attn_bias, out = ctx.saved_tensors
Expand All @@ -80,13 +136,54 @@ def backward(ctx, grad):
return grad_q, grad_k, grad_v, None, None


class MemoryEfficientAttentionGenericForwardOp(MemoryEfficientAttentionOp):
class MemoryEfficientAttentionGenericForwardOp(AttentionOpBase):
FORWARD_OPERATOR = _get_xformers_operator("efficient_attention_forward_generic")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.float, torch.half}
SUPPORTED_MAX_K = math.inf
SUPPORTS_ATTN_BIAS = False
SUPPORTS_DROPOUT = False
NAME = "fwd_gen"

@classmethod
def backward(cls, ctx, grad):
query, key, value, lse, attn_bias, out = ctx.saved_tensors
p = ctx.p
rng_seed = ctx.rng_seed
rng_offset = ctx.rng_offset
grad_q, grad_k, grad_v = torch.ops.xformers.efficient_attention_backward(
grad.float(),
query.float(),
key.float(),
value.float(),
lse.float(),
out.float(),
attn_bias,
p,
rng_seed,
rng_offset,
)
return grad_q, grad_k, grad_v, None, None


@dataclass
class AttentionOpDispatch:
dtype: torch.dtype
device: Union[torch.device, str]
k: int
has_dropout: bool
has_attn_bias: bool

@property
def op(self) -> Type[AttentionOpBase]:
priority_list_ops: List[Type[AttentionOpBase]] = [
MemoryEfficientAttentionOp,
MemoryEfficientAttentionGenericForwardOp,
]
for op in priority_list_ops:
if op.supports(self):
return op
raise NotImplementedError(f"No operator found for this attention: {self}")


def memory_efficient_attention(
Expand All @@ -96,13 +193,21 @@ def memory_efficient_attention(
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
*,
op=MemoryEfficientAttentionOp,
op=None,
):
"""
Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
"""
if op is None:
op = AttentionOpDispatch(
dtype=query.dtype,
device=query.device,
k=query.shape[-1],
has_dropout=p > 0.0,
has_attn_bias=attn_bias is not None,
).op
# fast-path that doesn't require computing the logsumexp for backward computation
if all(x.requires_grad is False for x in [query, key, value]):
return op.FORWARD_OPERATOR(query, key, value, False, attn_bias, p)[0]
Expand Down

0 comments on commit 2568a84

Please sign in to comment.