Skip to content

Commit

Permalink
Refactor5: Generic way of registring operators for xformers.info
Browse files Browse the repository at this point in the history
ghstack-source-id: d2cdf875d1755166cd441fa07ab9a76d6f8ed36e
Pull Request resolved: #561
  • Loading branch information
danthe3rd committed Dec 8, 2022
1 parent 43894f7 commit ece0801
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 22 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ install_repo: &install_repo
# Test import.
$CONDA_PYTHON -c 'import sys; sys.path = sys.path[1:]; import xformers'
$CONDA_PYTHON -m xformers.info
install_experimental_repo: &install_experimental_repo
- run:
Expand Down
16 changes: 5 additions & 11 deletions xformers/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@
# LICENSE file in the root directory of this source tree.


from typing import Dict, List, Type
from typing import Dict

import torch

from . import __version__, _is_functorch_available, _is_triton_available, ops
from .ops.common import OPERATORS_REGISTRY


def get_features_status() -> Dict[str, str]:
ALL_OPS: List[Type[ops.AttentionOpBase]] = [
ops.MemoryEfficientAttentionFlashAttentionOp,
ops.MemoryEfficientAttentionCutlassOp,
ops.MemoryEfficientAttentionOp,
ops.TritonFlashAttentionOp,
ops.MemoryEfficientAttentionTritonFwdFlashBwOp,
]
features = {}
for op in ALL_OPS:
features[f"memory_efficient_attention.{op.NAME}"] = op.info()
for op in OPERATORS_REGISTRY:
features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = op.info()
for k, v in ops.swiglu_op._info().items():
features[f"swiglu.{k}"] = v
features["is_triton_available"] = str(_is_triton_available())
Expand All @@ -42,7 +36,7 @@ def print_info():
else:
features["pytorch.cuda"] = "not available"
for name, status in features.items():
print("{:<40} {}".format(f"{name}:", status))
print("{:<50} {}".format(f"{name}:", status))


if __name__ == "__main__":
Expand Down
13 changes: 13 additions & 0 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, TypeVar

import torch


Expand All @@ -16,3 +18,14 @@ def no_such_operator(*args, **kwargs):
return getattr(torch.ops.xformers, name)
except (RuntimeError, AttributeError):
return no_such_operator


OPERATORS_REGISTRY: List[Any] = []

ClsT = TypeVar("ClsT")


def register_operator(cls: ClsT) -> ClsT:
global OPERATORS_REGISTRY
OPERATORS_REGISTRY.append(cls)
return cls
5 changes: 3 additions & 2 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ class AttentionOpBase:
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"

_TEST_BATCH_SIZES: List[int] = [1, 300]
_TEST_K: List[int] = [32, 128]

@classmethod
def info(cls):
if cls.OPERATOR.__name__ == "no_such_operator":
return "not built"
if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator":
return "unavailable"
return "available"

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from ..common import get_xformers_operator
from ..common import get_xformers_operator, register_operator
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Expand Down Expand Up @@ -43,6 +43,7 @@ def _minimum_gemm_alignment(inp: Inputs) -> int:
return matmul_alignment_mn


@register_operator
class FwOp(AttentionFwOpBase):
"""xFormers' MHA kernel based on CUTLASS.
Supports a large number of settings (including without TensorCores, f32 ...)
Expand Down Expand Up @@ -102,6 +103,7 @@ def supports(cls, d: Inputs) -> bool:
return True


@register_operator
class BwOp(AttentionBwOpBase):
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
Expand Down
3 changes: 3 additions & 0 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch

from ..common import register_operator
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Expand Down Expand Up @@ -69,6 +70,7 @@ def _convert_input_format(
return new_inp, softmax_scale, cu_seqlens_q, seqlen_q, cu_seqlens_k, seqlen_kv


@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
Expand Down Expand Up @@ -154,6 +156,7 @@ def apply(
return out, Context(out=out, lse=softmax_lse)


@register_operator
class BwOp(AttentionBwOpBase):
OPERATOR = _C_flashattention_bwd
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
Expand Down
6 changes: 4 additions & 2 deletions xformers/ops/fmha/small_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from ..common import get_xformers_operator
from ..common import get_xformers_operator, register_operator
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Expand All @@ -27,6 +27,7 @@ def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
)


@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Expand Down Expand Up @@ -85,11 +86,12 @@ def apply(
p=inp.p,
)
out = bmk2bmhk(out, num_heads)
lse = lse.reshape([lse.shape[0], 1, lse.shape[1]])
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
ctx = Context(out=out, lse=lse) if needs_gradient else None
return out, ctx


@register_operator
class BwOp(AttentionBwOpBase):
OPERATOR = get_xformers_operator("efficient_attention_backward")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
Expand Down
9 changes: 3 additions & 6 deletions xformers/ops/fmha/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from ... import _is_triton_available
from ..common import register_operator

if TYPE_CHECKING or _is_triton_available():
from ..._flash_attn.flash_attn_triton import (
Expand Down Expand Up @@ -48,6 +49,7 @@ def _prepare_inputs(inp: Inputs) -> Inputs:
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)


@register_operator
class FwOp(AttentionFwOpBase):
OPERATOR = triton_flash_forward
SUPPORTED_DEVICES = {"cuda"}
Expand All @@ -63,12 +65,6 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_CUSTOM_SCALE = True
NAME = "tritonflashattF"

@classmethod
def info(cls):
if cls.OPERATOR is None:
return "not built"
return "available"

@classmethod
def supports(cls, d: "Inputs") -> bool:
if cls.OPERATOR is None:
Expand All @@ -95,6 +91,7 @@ def apply(
return out, Context(lse=lse, out=out)


@register_operator
class BwOp(AttentionBwOpBase):
OPERATOR = triton_flash_backward
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
Expand Down

0 comments on commit ece0801

Please sign in to comment.