Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scaled_dot_product_attention in WavLM attention #3252

Closed
wants to merge 1 commit into from

Conversation

nateanl
Copy link
Member

@nateanl nateanl commented Apr 7, 2023

Fix #3219.

torch.nn.MultiheadAttention will throw an error if torch.no_grad() and mask are both given. The pull request fixes it by replacing the forward method with torch.nn.functional.scaled_dot_product_attention.

@nateanl nateanl changed the title Use scaled_dot_product_attention in WavLM attention [Cherry-pick] Use scaled_dot_product_attention in WavLM attention Apr 7, 2023
@nateanl nateanl mentioned this pull request Apr 7, 2023
@nateanl
Copy link
Member Author

nateanl commented Apr 7, 2023

Here are the benchmark results with new changes. The benchmark script can be found in https://gist.github.com/nateanl/97b2f9adb39c05a4e854fbd924de01f6.

MultiheadAttention
[--------------- WavLM benchmark ---------------]
                              |    CPU    |  CUDA
1 threads: --------------------------------------
      wavlm_base 5 seconds    |   1561.1  |  14.2
      wavlm_base 10 seconds   |   3325.5  |  22.1
      wavlm_base 15 seconds   |   5423.0  |  34.4
      wavlm_base 20 seconds   |   7722.4  |  50.7
      wavlm_large 5 seconds   |   4451.5  |  25.2
      wavlm_large 10 seconds  |   9465.5  |  48.3
      wavlm_large 15 seconds  |  15140.8  |  72.6
      wavlm_large 20 seconds  |  22030.9  |  97.9
Times are in milliseconds (ms).


scaled_dot_product_attention
[--------------- WavLM benchmark ---------------]
                              |    CPU    |  CUDA
1 threads: --------------------------------------
      wavlm_base 5 seconds    |   1562.8  |  13.2
      wavlm_base 10 seconds   |   3164.2  |  21.1
      wavlm_base 15 seconds   |   5175.0  |  33.1
      wavlm_base 20 seconds   |   7357.2  |  48.9
      wavlm_large 5 seconds   |   4415.9  |  23.3
      wavlm_large 10 seconds  |   9279.8  |  47.8
      wavlm_large 15 seconds  |  14930.8  |  71.1
      wavlm_large 20 seconds  |  21524.6  |  97.3
Times are in milliseconds (ms).

@nateanl nateanl requested a review from a team April 7, 2023 19:06
@nateanl nateanl changed the title [Cherry-pick] Use scaled_dot_product_attention in WavLM attention Use scaled_dot_product_attention in WavLM attention Apr 7, 2023
@facebook-github-bot
Copy link
Contributor

@nateanl has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@nateanl
Copy link
Member Author

nateanl commented Apr 8, 2023

Here is the script which shows the new scaled_dot_product_attention achieves identical outputs as original implementation.

@facebook-github-bot
Copy link
Contributor

@nateanl merged this pull request in adb0338.

@github-actions
Copy link

Hey @nateanl.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

nateanl added a commit to nateanl/audio that referenced this pull request Apr 11, 2023
Summary:
Fix pytorch#3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: pytorch#3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl added a commit to nateanl/audio that referenced this pull request Apr 12, 2023
Summary:
Fix pytorch#3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: pytorch#3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl added a commit that referenced this pull request Apr 12, 2023
, #3265) (#3264)

* Use scaled_dot_product_attention in WavLM attention (#3252)

Summary:
Fix #3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: #3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665

* Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265)

Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.

Pull Request resolved: #3265

Reviewed By: hwangjeff

Differential Revision: D44901093

Pulled By: nateanl

fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torchaudio.pipelines.WAVLM encounted error: Mask shape should match input
3 participants