Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchaudio.pipelines.WAVLM encounted error: Mask shape should match input #3219

Closed
anonymoussky opened this issue Mar 30, 2023 · 1 comment
Closed

Comments

@anonymoussky
Copy link

anonymoussky commented Mar 30, 2023

🐛 Describe the bug

Bug: I tried to use "torchaudio.pipelines.WAVLM_BASE_PLUS", however, it suffered from an error: "RuntimeError: Mask shape should match input. mask: [384, 199, 199] input: [32, 12, 199, 199]". Which only happens when using "with torch.no_grad():", it seems it enables the fast path and then crashes.

Sample Code:

import torchaudio
import torch
bundle = torchaudio.pipelines.WAVLM_BASE_PLUS
ssl_model = bundle.get_model()
ssl_model = ssl_model.cuda()
with torch.no_grad(): #without this, it works fine
    waveform = torch.rand(32, 4*16000)
    print(waveform.size())
    features, _ = ssl_model.extract_features(waveform.cuda())
    print(len(features))
    print(features[0].size()) 

Error:
return torch._native_multi_head_attention(
Traceback (most recent call last):
File "/data1/test_wavlm.py", line 9, in
features, _ = ssl_model.extract_features(waveform.cuda())
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/model.py", line 84, in extract_features
x = self.encoder.extract_features(x, lengths, num_layers)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 525, in extract_features
return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 474, in get_intermediate_outputs
x, _ = layer(x, attention_mask) # Ignore position_bias
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 405, in forward
x, position_bias = self.attention(
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py", line 185, in forward
attn_output, _ = self.attention(
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1144, in forward
return torch._native_multi_head_attention(
RuntimeError: Mask shape should match input. mask: [384, 199, 199] input: [32, 12, 199, 199]

Versions

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+e650d3708b
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchvision==0.15.1+cu118
[pip3] triton==2.0.0

@nateanl
Copy link
Member

nateanl commented Mar 30, 2023

The issue could be related with torch.nn.MultiheadAttention in PyTorch core, according to pytorch/pytorch#97409

nateanl added a commit to nateanl/audio that referenced this issue Apr 11, 2023
Summary:
Fix pytorch#3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: pytorch#3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl added a commit to nateanl/audio that referenced this issue Apr 12, 2023
Summary:
Fix pytorch#3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: pytorch#3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl added a commit that referenced this issue Apr 12, 2023
, #3265) (#3264)

* Use scaled_dot_product_attention in WavLM attention (#3252)

Summary:
Fix #3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: #3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665

* Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265)

Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.

Pull Request resolved: #3265

Reviewed By: hwangjeff

Differential Revision: D44901093

Pulled By: nateanl

fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants