diff --git a/setup.py b/setup.py index 14323400d6..19767569d2 100644 --- a/setup.py +++ b/setup.py @@ -276,6 +276,7 @@ def run(self): version=version, install_requires=fetch_requirements(), packages=setuptools.find_packages(exclude=("tests", "tests.*")), + dependency_links=["file:///./third_party/flash-attention#egg=flash-attention"], ext_modules=get_extensions(), cmdclass={ "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 9649437249..4f1f9f652e 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -79,6 +79,8 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): xformers.ops.MemoryEfficientAttentionCutlassOp, xformers.ops.MemoryEfficientAttentionFlashAttentionOp, xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp, + xformers.ops.TritonFlashAttentionOp, + xformers.ops.MemoryEfficientAttentionTritonFwdFlashBwOp, ] @@ -135,11 +137,14 @@ def assert_allclose( flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() max_pos = flatten_diff.argmax() max_diff = flatten_diff[max_pos] + num_different = torch.count_nonzero(flatten_diff > 0) + percentage = num_different / flatten_diff.numel() del flatten_diff assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( f"{msg}: " f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" f"/ atol={atol}, rtol={rtol}" + f"/ total failing elements: {num_different}, percentage={percentage}" ) diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index e17c1c5dac..72d7e8f5a4 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -101,6 +101,8 @@ def T(t): (1, 16384, 16, 40), # 1024x1024 # ParlAI model (256, 4096, 16, 64), + # Zetta B M H K + (8, 2048, 20, 128), *sorted( list(itertools.product([16, 64], [128, 512, 1024], [16], [16, 32, 64, 128])) ), @@ -112,6 +114,9 @@ def T(t): # FORCE_OP = xformers.ops.MemoryEfficientAttentionOp # FORCE_OP = xformers.ops.MemoryEfficientAttentionCutlassOp # FORCE_OP = xformers.ops.MemoryEfficientAttentionFlashAttentionOp +# FORCE_OP = xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp +# FORCE_OP = xformers.ops.TritonFlashAttentionOp +# FORCE_OP = xformers.ops.MemoryEfficientAttentionTritonFwdFlashBwOp def product_dict(**kwargs): diff --git a/xformers/info.py b/xformers/info.py index 5df206c972..7321a22aa0 100644 --- a/xformers/info.py +++ b/xformers/info.py @@ -16,6 +16,8 @@ def get_features_status() -> Dict[str, str]: ops.MemoryEfficientAttentionFlashAttentionOp, ops.MemoryEfficientAttentionCutlassOp, ops.MemoryEfficientAttentionOp, + ops.TritonFlashAttentionOp, + ops.MemoryEfficientAttentionTritonFwdFlashBwOp, ] features = {} for op in ALL_OPS: diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py index cf82f280f8..67cd9cc12e 100644 --- a/xformers/ops/__init__.py +++ b/xformers/ops/__init__.py @@ -15,6 +15,8 @@ MemoryEfficientAttentionCutlassOp, MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionOp, + MemoryEfficientAttentionTritonFwdFlashBwOp, + TritonFlashAttentionOp, memory_efficient_attention, ) from .swiglu_op import ( @@ -61,6 +63,7 @@ def masked_matmul(a, b, mask=None): "MemoryEfficientAttentionCutlassOp", "MemoryEfficientAttentionFlashAttentionOp", "MemoryEfficientAttentionOp", + "MemoryEfficientAttentionTritonFwdFlashBwOp", "memory_efficient_attention", "SwiGLU", "SwiGLUEagerOp", @@ -69,6 +72,7 @@ def masked_matmul(a, b, mask=None): "SwiGLUOpDispatch", "SwiGLUPackedFusedOp", "swiglu", + "TritonFlashAttentionOp", "unbind", "stack_or_none", "get_stack_strides", diff --git a/xformers/ops/memory_efficient_attention.py b/xformers/ops/memory_efficient_attention.py index 1a9e581669..b7a4b0cc65 100644 --- a/xformers/ops/memory_efficient_attention.py +++ b/xformers/ops/memory_efficient_attention.py @@ -20,6 +20,16 @@ except ImportError: has_flashattention = False +try: + from flash_attn.flash_attn_triton import ( + _flash_attn_backward as triton_flash_backward, + ) + from flash_attn.flash_attn_triton import _flash_attn_forward as triton_flash_forward + + has_triton_flashattention = True +except ImportError: + has_triton_flashattention = False + class AttentionMask: """Base class for custom masks that can be applied \ @@ -89,6 +99,10 @@ class AttentionOpBase(torch.autograd.Function): - :attr:`xformers.ops.MemoryEfficientAttentionFlashAttentionOp` - :attr:`xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp` + + - :attr:`xformers.ops.TritonFlashAttentionOp` + + - :attr:`xformers.ops.MemoryEfficientAttentionTritonFwdFlashBwOp` """ FORWARD_OPERATOR: Any @@ -735,6 +749,165 @@ def _flash_attn_backward( return dq, dk, dv, softmax_d +class TritonFlashAttentionOp(AttentionOpBase): + FORWARD_OPERATOR = None + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 128 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + LowerTriangularMask, + # TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now. + # torch.Tensor, + } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + NAME = "tritonflashatt" + + @classmethod + def info(cls): + if not has_triton_flashattention: + return "not built" + return "available" + + @classmethod + def supports(cls, d: "AttentionOpDispatch") -> bool: + if not has_triton_flashattention: + return False + device_capability = torch.cuda.get_device_capability(d.device) + is_sm80 = device_capability[0] >= 8 + if not is_sm80: + return False + return super(TritonFlashAttentionOp, cls).supports(d) + + @classmethod + def forward_no_grad( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionMask]], + p: float, + scale: Optional[float] = None, + ) -> torch.Tensor: + return cls.forward( + ctx=None, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + ) + + @classmethod + def forward(cls, ctx, query, key, value, attn_bias, p, scale): + softmax_scale = query.shape[-1] ** (-0.5) if scale is None else scale + causal = isinstance(attn_bias, LowerTriangularMask) + if not causal and attn_bias is not None and attn_bias.ndim == 3: + B = query.shape[0] + h = attn_bias.shape[0] // B + attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2]) + bias = None if causal else attn_bias + + # Make sure that the last dimension is contiguous + query, key, value = [ + x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value] + ] + + o, lse, softmax_scale = triton_flash_forward( + q=query, + k=key, + v=value, + bias=bias, + softmax_scale=softmax_scale, + causal=causal, + ) + + if ctx is not None: + ctx.save_for_backward(query, key, value, o, lse, bias) + ctx.causal = causal + ctx.softmax_scale = softmax_scale + return o + + @staticmethod + def backward(ctx, grad): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[ + 3 + ], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + triton_flash_backward( + grad, + q, + k, + v, + o, + lse, + dq, + dk, + dv, + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dk, dv, None, None, None + + +class MemoryEfficientAttentionTritonFwdFlashBwOp(TritonFlashAttentionOp): + """An operator that uses :attr:`xformers.ops.TritonFlashAttentionOp` for the forward pass \ + and :attr:`xformers.ops.MemoryEfficientAttentionFlashAttentionOp` for the backward. + """ + + FW_OP = TritonFlashAttentionOp + BW_OP = MemoryEfficientAttentionFlashAttentionOp + SUPPORTS_CUSTOM_SCALE = True + SUPPORTED_ATTN_BIAS_TYPES = BW_OP.SUPPORTED_ATTN_BIAS_TYPES.intersection( + FW_OP.SUPPORTED_ATTN_BIAS_TYPES + ) + SUPPORTED_DTYPES = BW_OP.SUPPORTED_DTYPES.intersection(FW_OP.SUPPORTED_DTYPES) + SUPPORTED_DEVICES = BW_OP.SUPPORTED_DEVICES.intersection(FW_OP.SUPPORTED_DEVICES) + + NAME = "ftriton_bflsh" + + @classmethod + def supports(cls, d: "AttentionOpDispatch") -> bool: + if d.requires_grad and not cls.BW_OP.supports(d): + return False + return cls.FW_OP.supports(replace(d, requires_grad=False)) + + @classmethod + def backward(cls, ctx, grad): + query, key, value, out, lse, bias = ctx.saved_tensors + ctx_flash = SimpleNamespace() + + ctx_flash.causal = ctx.causal + ctx_flash.dropout_p = 0.0 + query, key, value, cu_seqlens_k, cu_seqlens_q = cls.BW_OP.prepare_inputs( + ctx_flash, query, key, value + ) + ctx_flash.kernel_output_shape = (query.shape[0], query.shape[1], value.shape[2]) + ctx_flash.softmax_scale = ( + query.shape[-1] ** (-0.5) + if ctx.softmax_scale is None + else ctx.softmax_scale + ) + rng_state = None + + out = out.reshape(ctx_flash.kernel_output_shape) + grad = grad.reshape(ctx_flash.kernel_output_shape) + return cls.BW_OP._backward( + ctx_flash, + grad, + [query, key, value, out, lse, cu_seqlens_q, cu_seqlens_k, rng_state], + ) + + class MemoryEfficientAttentionCutlassFwdFlashBwOp(MemoryEfficientAttentionCutlassOp): """An operator that uses :attr:`xformers.ops.MemoryEfficientAttentionCutlassOp` for the forward pass \ and :attr:`xformers.ops.MemoryEfficientAttentionFlashAttentionOp` for the backward. @@ -811,6 +984,10 @@ def _is_cutlass_fwd_faster_than_flash(self) -> bool: # Large values of K return max(self.k, self.kv) == 128 + def _is_triton_fwd_faster_than_cutlass(self) -> bool: + # TODO: fill out + return False + @property def op(self) -> AttentionOp: """Computes the best operator @@ -823,11 +1000,16 @@ def op(self) -> AttentionOp: """ priority_list_ops: List[AttentionOp] = [ MemoryEfficientAttentionFlashAttentionOp, + # TODO: remove once triton_faster_than_cutlass method complete + MemoryEfficientAttentionTritonFwdFlashBwOp, MemoryEfficientAttentionCutlassOp, + TritonFlashAttentionOp, MemoryEfficientAttentionOp, ] if self.requires_grad and self._is_cutlass_fwd_faster_than_flash(): priority_list_ops.insert(0, MemoryEfficientAttentionCutlassFwdFlashBwOp) + if self.requires_grad and self._is_triton_fwd_faster_than_cutlass(): + priority_list_ops.insert(0, MemoryEfficientAttentionTritonFwdFlashBwOp) for op in priority_list_ops: if op.supports(self): return op