diff --git a/xformers/ops/memory_efficient_attention.py b/xformers/ops/memory_efficient_attention.py index 6694c89de4..5476e858cd 100644 --- a/xformers/ops/memory_efficient_attention.py +++ b/xformers/ops/memory_efficient_attention.py @@ -270,6 +270,8 @@ def forward_no_grad( attn_bias: Optional[Union[torch.Tensor, AttentionMask]], p: float, ) -> torch.Tensor: + if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask): + raise NotImplementedError("Unsupported attn_bias type") return cls.FORWARD_OPERATOR( query=query, key=key, @@ -283,6 +285,8 @@ def forward_no_grad( @classmethod def forward(cls, ctx, query, key, value, attn_bias, p): + if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask): + raise NotImplementedError("Unsupported attn_bias type") causal = isinstance(attn_bias, LowerTriangularMask) out, lse = cls.FORWARD_OPERATOR( query=query, @@ -438,6 +442,8 @@ def prepare_inputs( @classmethod def forward(cls, ctx, query, key, value, attn_bias, p): + if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask): + raise NotImplementedError("Unsupported attn_bias type") causal = isinstance(attn_bias, LowerTriangularMask) return_softmax = False ctx_flash = ctx if ctx is not None else SimpleNamespace()