Skip to content

Commit

Permalink
[Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252
Browse files Browse the repository at this point in the history
, #3265) (#3264)

* Use scaled_dot_product_attention in WavLM attention (#3252)

Summary:
Fix #3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: #3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665

* Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265)

Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.

Pull Request resolved: #3265

Reviewed By: hwangjeff

Differential Revision: D44901093

Pulled By: nateanl

fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
  • Loading branch information
nateanl authored Apr 12, 2023
1 parent e99de15 commit 54f6c1f
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions torchaudio/models/wav2vec2/wavlm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

self.dropout = dropout
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)

self.gru_rel_pos = gru_rel_pos
Expand Down Expand Up @@ -165,7 +166,7 @@ def forward(

if self.rel_attn_embed is not None and position_bias is None:
position_bias = self.compute_bias(seq_len, seq_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, seq_len, seq_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)

attn_mask_rel_pos: Optional[Tensor] = None
if position_bias is not None:
Expand All @@ -178,11 +179,36 @@ def forward(
self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias

attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len))

attn_output, _ = self.attention(
query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False
attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias

attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))

if attn_mask_rel_pos is not None and key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
key_padding_mask = torch.nn.functional._canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
other_name="",
target_type=query.dtype,
)
if attn_mask_rel_pos is not None and key_padding_mask is not None:
attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
query, key, value = query_projected.chunk(3, -1)
shape = (bsz, seq_len, self.num_heads, self.head_dim)
query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
dropout = self.dropout if self.training else 0.0
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask_rel_pos,
dropout_p=dropout,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
attn_output = self.attention.out_proj(attn_output)
return attn_output, position_bias

0 comments on commit 54f6c1f

Please sign in to comment.