Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid rearranging all caches #1483

Merged
merged 5 commits into from
Jul 6, 2023

Conversation

wangchou
Copy link
Contributor

@wangchou wangchou commented Jun 30, 2023

Since kv_caches from cross attention block is the same for each beam, we can avoid rearranging or calculating it for multiple times.
I saw about 20% speed up of large model with beam_size = 5 on cpu backend.

@jongwook jongwook merged commit b91c907 into openai:main Jul 6, 2023
@wangchou wangchou deleted the avoid-rearranging-all-caches branch July 7, 2023 02:09
abyesilyurt pushed a commit to abyesilyurt/whisper that referenced this pull request Nov 13, 2023
* avoid rearranging all kv_caches

* avoid calculating the same kv_cache from cross attn

* Update decoding.py

* linter fix

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
@@ -721,8 +725,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
)
]

# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangchou I was wondering if you remove the repeat of audio_features, where you repeat the kv_cache for cross attention? Otherwise, during cross_attention, q@k seems with mismatch dims since tokens are repeated according the beam_size.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jongwook, would you mind checking this please? Thanks.

Copy link
Contributor Author

@wangchou wangchou Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang @ operator(matmul) should support broadcasting?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
q = torch.ones(70, 4, 16, 4)
k = torch.ones(7, 4, 4, 16)
k2 = torch.ones(70, 4, 4, 16)

context2 = q @ k2
print(context2.shape)

context1 = q @ k
print(context1.shape)

I ran with torch==2.0.1

RuntimeError: The size of tensor a (70) must match the size of tensor b (7) at non-singleton dimension 0

Copy link
Contributor Author

@wangchou wangchou Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang I ran whisper with beam_size=5. It works.

python -m whisper ../samples/thatBand2ch_short.wav --language ja --model small --beam_size=5

After adding print in qkv_attention() like

        qk = q @ k
        print("q.shape=",q.shape,", k.shape=", k.shape)

it outputs

...
q.shape= torch.Size([5, 12, 1, 64]) , k.shape= torch.Size([5, 12, 64, 6])
q.shape= torch.Size([5, 12, 1, 64]) , k.shape= torch.Size([1, 12, 64, 1500])

What arguments did you use to get k like (7, 4, 4, 16)? Where is that 7 comes from?

ps: I only test this on mac cpu backend. I guess that 7 is used by GPU related code?

Copy link

@yuekaizhang yuekaizhang Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangchou Did you try to inference with batch_size > 1? I met this issue when I tried with both batch_size, beam_size > 1. The snippet codes above using batch_size 7, beam_size 10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang I don't even know batch_size option. And I cannot find it with whisper --help. Sorry.

zuazo added a commit to zuazo-forks/whisper that referenced this pull request Jun 1, 2024
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* This issue was introduced in PR openai#1483.
zuazo added a commit to zuazo-forks/whisper that referenced this pull request Jun 1, 2024
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR openai#1483.
zuazo added a commit to zuazo-forks/whisper that referenced this pull request Jun 1, 2024
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR openai#1483.
zuazo added a commit to zuazo-forks/whisper that referenced this pull request Aug 31, 2024
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR openai#1483.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants