Skip to content

Commit

Permalink
[feat] Add causal masking option to Nystrom. (facebookresearch#85)
Browse files Browse the repository at this point in the history
* add causal masking option
* minor, caching the causal masks and moving them to device

Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@gmail.com>
  • Loading branch information
dianaml0 and blefaudeux authored May 3, 2021
1 parent a4e2355 commit 84d8fcd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party =matplotlib,pandas,pytest,seaborn,setuptools,sklearn,torch,tqdm
known_third_party =matplotlib,numpy,pandas,pytest,seaborn,setuptools,sklearn,torch,tqdm
40 changes: 35 additions & 5 deletions xformers/components/attention/nystrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NystromSelfAttentionConfig(AttentionConfig):
num_heads Number of heads.
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
causal Apply a causal mask, in that the attention cannot be applied to the future.
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
inverse, otherwise use standard torch inverse.
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
Expand All @@ -35,6 +36,7 @@ class NystromSelfAttentionConfig(AttentionConfig):

num_heads: int
num_landmarks: Optional[int]
causal: Optional[bool]
pinverse_original_init: Optional[bool]
inv_iterations: Optional[int]
v_skip_connection: Optional[nn.Module]
Expand All @@ -50,6 +52,7 @@ def __init__(
dropout: float,
num_heads: int,
num_landmarks: int = 64,
causal: bool = False,
use_razavi_pinverse: bool = True,
pinverse_original_init: bool = False,
inv_iterations: int = 6, # recommended default in paper was 6.
Expand Down Expand Up @@ -77,6 +80,7 @@ def __init__(
self.inv_iterations = inv_iterations
self.attn_drop = nn.Dropout(dropout)
self.skip_connection = v_skip_connection
self.causal = causal

if self.skip_connection is None and conv_kernel_size is not None:
self.skip_connection = nn.Conv2d(
Expand All @@ -88,16 +92,21 @@ def __init__(
groups=self.num_heads,
)

# Optional lower triangular masks for causal attention
self.causal_mask_1: Optional[torch.Tensor] = None
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
):

batched_dim = k.size(0)
head_dim = k.size(-1)
seq_len = k.size(-2)

Expand All @@ -106,7 +115,10 @@ def forward(
), "the sequence length needs to be divisible by the number of landmarks"

if self.num_landmarks == seq_len:
x = scaled_dot_product_attention(q, k, v, att_mask)
mask = None
if self.causal:
mask = self._tril_mask(batched_dim, seq_len, seq_len)
x = scaled_dot_product_attention(q, k, v, mask)

else:
q_landmarks = q.reshape(
Expand All @@ -122,9 +134,24 @@ def forward(
head_dim,
).mean(dim=-2)

kernel_1 = scaled_query_key_softmax(q, k_landmarks, None)
kernel_2 = scaled_query_key_softmax(q_landmarks, k_landmarks, None)
kernel_3 = scaled_dot_product_attention(q_landmarks, k, v, None)
if self.causal and self.causal_mask_1 is None:
self.causal_mask_1 = self._tril_mask(
batched_dim, seq_len, self.num_landmarks
).to(q.device)
self.causal_mask_2 = self._tril_mask(
batched_dim, self.num_landmarks, self.num_landmarks
).to(q.device)
self.causal_mask_3 = self._tril_mask(
batched_dim, self.num_landmarks, seq_len
).to(q.device)

kernel_1 = scaled_query_key_softmax(q, k_landmarks, self.causal_mask_1)
kernel_2 = scaled_query_key_softmax(
q_landmarks, k_landmarks, self.causal_mask_2
)
kernel_3 = scaled_dot_product_attention(
q_landmarks, k, v, self.causal_mask_3
)

kernel_2_inv = (
iterative_pinv(
Expand All @@ -151,6 +178,9 @@ def forward(
x = self.attn_drop(x)
return x

def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int):
return torch.tril(torch.ones(dim_1, dim_2, dim_3, dtype=torch.bool), diagonal=0)

@classmethod
def from_config(cls, config: AttentionConfig) -> "Attention":
return cls(**NystromSelfAttentionConfig.as_patchy_dict(config))

0 comments on commit 84d8fcd

Please sign in to comment.