Skip to content

Commit

Permalink
MemEff: Raise if wrong bias
Browse files Browse the repository at this point in the history
ghstack-source-id: 657076325d46712005afa38e8cce40c6050723a3
Pull Request resolved: #510
  • Loading branch information
danthe3rd committed Nov 10, 2022
1 parent 034464a commit 3cae847
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xformers/ops/memory_efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def forward_no_grad(
attn_bias: Optional[Union[torch.Tensor, AttentionMask]],
p: float,
) -> torch.Tensor:
if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask):
raise NotImplementedError("Unsupported attn_bias type")
return cls.FORWARD_OPERATOR(
query=query,
key=key,
Expand All @@ -283,6 +285,8 @@ def forward_no_grad(

@classmethod
def forward(cls, ctx, query, key, value, attn_bias, p):
if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask):
raise NotImplementedError("Unsupported attn_bias type")
causal = isinstance(attn_bias, LowerTriangularMask)
out, lse = cls.FORWARD_OPERATOR(
query=query,
Expand Down Expand Up @@ -438,6 +442,8 @@ def prepare_inputs(

@classmethod
def forward(cls, ctx, query, key, value, attn_bias, p):
if attn_bias is not None and not isinstance(attn_bias, LowerTriangularMask):
raise NotImplementedError("Unsupported attn_bias type")
causal = isinstance(attn_bias, LowerTriangularMask)
return_softmax = False
ctx_flash = ctx if ctx is not None else SimpleNamespace()
Expand Down

0 comments on commit 3cae847

Please sign in to comment.