diff --git a/.circleci/config.yml b/.circleci/config.yml index 24badafea7..a7e8f27bf6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -126,6 +126,7 @@ install_repo: &install_repo # Test import. $CONDA_PYTHON -c 'import sys; sys.path = sys.path[1:]; import xformers' + $CONDA_PYTHON -m xformers.info install_experimental_repo: &install_experimental_repo - run: diff --git a/xformers/info.py b/xformers/info.py index 7321a22aa0..c02354cdb3 100644 --- a/xformers/info.py +++ b/xformers/info.py @@ -4,24 +4,18 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Type +from typing import Dict import torch from . import __version__, _is_functorch_available, _is_triton_available, ops +from .ops.common import OPERATORS_REGISTRY def get_features_status() -> Dict[str, str]: - ALL_OPS: List[Type[ops.AttentionOpBase]] = [ - ops.MemoryEfficientAttentionFlashAttentionOp, - ops.MemoryEfficientAttentionCutlassOp, - ops.MemoryEfficientAttentionOp, - ops.TritonFlashAttentionOp, - ops.MemoryEfficientAttentionTritonFwdFlashBwOp, - ] features = {} - for op in ALL_OPS: - features[f"memory_efficient_attention.{op.NAME}"] = op.info() + for op in OPERATORS_REGISTRY: + features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = op.info() for k, v in ops.swiglu_op._info().items(): features[f"swiglu.{k}"] = v features["is_triton_available"] = str(_is_triton_available()) @@ -42,7 +36,7 @@ def print_info(): else: features["pytorch.cuda"] = "not available" for name, status in features.items(): - print("{:<40} {}".format(f"{name}:", status)) + print("{:<50} {}".format(f"{name}:", status)) if __name__ == "__main__": diff --git a/xformers/ops/common.py b/xformers/ops/common.py index 150540f246..dd35ce67f4 100644 --- a/xformers/ops/common.py +++ b/xformers/ops/common.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, List, TypeVar + import torch @@ -16,3 +18,14 @@ def no_such_operator(*args, **kwargs): return getattr(torch.ops.xformers, name) except (RuntimeError, AttributeError): return no_such_operator + + +OPERATORS_REGISTRY: List[Any] = [] + +ClsT = TypeVar("ClsT") + + +def register_operator(cls: ClsT) -> ClsT: + global OPERATORS_REGISTRY + OPERATORS_REGISTRY.append(cls) + return cls diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 5d2b3c5452..a5a52de02c 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -141,14 +141,15 @@ class AttentionOpBase: SUPPORTS_CUSTOM_SCALE: bool = False SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False NAME: str + OPERATOR_CATEGORY = "memory_efficient_attention" _TEST_BATCH_SIZES: List[int] = [1, 300] _TEST_K: List[int] = [32, 128] @classmethod def info(cls): - if cls.OPERATOR.__name__ == "no_such_operator": - return "not built" + if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator": + return "unavailable" return "available" @classmethod diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index f008edb147..6720ef04bb 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -9,7 +9,7 @@ import torch -from ..common import get_xformers_operator +from ..common import get_xformers_operator, register_operator from .common import ( AttentionBwOpBase, AttentionFwOpBase, @@ -43,6 +43,7 @@ def _minimum_gemm_alignment(inp: Inputs) -> int: return matmul_alignment_mn +@register_operator class FwOp(AttentionFwOpBase): """xFormers' MHA kernel based on CUTLASS. Supports a large number of settings (including without TensorCores, f32 ...) @@ -102,6 +103,7 @@ def supports(cls, d: Inputs) -> bool: return True +@register_operator class BwOp(AttentionBwOpBase): OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 0befce47bb..d24b649135 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -9,6 +9,7 @@ import torch +from ..common import register_operator from .common import ( AttentionBwOpBase, AttentionFwOpBase, @@ -69,6 +70,7 @@ def _convert_input_format( return new_inp, softmax_scale, cu_seqlens_q, seqlen_q, cu_seqlens_k, seqlen_kv +@register_operator class FwOp(AttentionFwOpBase): """Operator that computes memory-efficient attention using \ `Flash-Attention `_ \ @@ -154,6 +156,7 @@ def apply( return out, Context(out=out, lse=softmax_lse) +@register_operator class BwOp(AttentionBwOpBase): OPERATOR = _C_flashattention_bwd SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES diff --git a/xformers/ops/fmha/small_k.py b/xformers/ops/fmha/small_k.py index 91c6a5f84e..0146da10a1 100644 --- a/xformers/ops/fmha/small_k.py +++ b/xformers/ops/fmha/small_k.py @@ -7,7 +7,7 @@ import torch -from ..common import get_xformers_operator +from ..common import get_xformers_operator, register_operator from .common import ( AttentionBwOpBase, AttentionFwOpBase, @@ -27,6 +27,7 @@ def _bmhk2bmk_contiguous(tensor) -> torch.Tensor: ) +@register_operator class FwOp(AttentionFwOpBase): """An operator optimized for very small values of K (``K <= 32``) \ and f32 pre-Ampere as it does not use TensorCores. @@ -85,11 +86,12 @@ def apply( p=inp.p, ) out = bmk2bmhk(out, num_heads) - lse = lse.reshape([lse.shape[0], 1, lse.shape[1]]) + lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]]) ctx = Context(out=out, lse=lse) if needs_gradient else None return out, ctx +@register_operator class BwOp(AttentionBwOpBase): OPERATOR = get_xformers_operator("efficient_attention_backward") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES diff --git a/xformers/ops/fmha/triton.py b/xformers/ops/fmha/triton.py index 0fab52c587..bb89b3947f 100644 --- a/xformers/ops/fmha/triton.py +++ b/xformers/ops/fmha/triton.py @@ -10,6 +10,7 @@ import torch from ... import _is_triton_available +from ..common import register_operator if TYPE_CHECKING or _is_triton_available(): from ..._flash_attn.flash_attn_triton import ( @@ -48,6 +49,7 @@ def _prepare_inputs(inp: Inputs) -> Inputs: return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value) +@register_operator class FwOp(AttentionFwOpBase): OPERATOR = triton_flash_forward SUPPORTED_DEVICES = {"cuda"} @@ -63,12 +65,6 @@ class FwOp(AttentionFwOpBase): SUPPORTS_CUSTOM_SCALE = True NAME = "tritonflashattF" - @classmethod - def info(cls): - if cls.OPERATOR is None: - return "not built" - return "available" - @classmethod def supports(cls, d: "Inputs") -> bool: if cls.OPERATOR is None: @@ -95,6 +91,7 @@ def apply( return out, Context(lse=lse, out=out) +@register_operator class BwOp(AttentionBwOpBase): OPERATOR = triton_flash_backward SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES