Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

does vllm use Flash-Decoding? #1362

Closed
0x1997 opened this issue Oct 16, 2023 · 5 comments
Closed

does vllm use Flash-Decoding? #1362

0x1997 opened this issue Oct 16, 2023 · 5 comments

Comments

@0x1997
Copy link

0x1997 commented Oct 16, 2023

As vllm depends on xformers, is vllm already using this Flash-Decoding algorithm?

@WoosukKwon
Copy link
Collaborator

Hi @0x1997, yes. PagedAttention V2 (#1348) implements a similar idea to boost the performance when the batch size or the number of attention heads per GPU is small. We will announce more once all planned optimizations are merged.

@leocnj
Copy link

leocnj commented Oct 17, 2023

It looks that PR #1348 has been merged into 0.2.1 release. To use V2 version, do users need do anything when calling vllm? Thanks

@WoosukKwon
Copy link
Collaborator

@leocnj Nothing is required for users. vLLM uses both V1 and V2 based on a simple heuristic:

use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512

In a nutshell, currently we use V2 only when the batch size is small. Once we further optimize the performance, we will use V2 in more cases.

@hongqing1986
Copy link

@leocnj Nothing is required for users. vLLM uses both V1 and V2 based on a simple heuristic:

use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512

In a nutshell, currently we use V2 only when the batch size is small. Once we further optimize the performance, we will use V2 in more cases.

In the above condition, if either of the following conditions is met, the code will follow the "v1" version. You mentioned that it only applies when the batch size is small. However, in practice, for a small batch size, the other condition as well:

max_num_partitions = (
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)

According to the calculation logic above, essentially, in this batch, most sequences with a maximum length of less than or equal to 512 will also follow the "v1" version. As a result, the impact will only be noticeable when the batch size is small and there is at least one sequence with a length greater than 512.

Is my understanding correct?

@WoosukKwon
Copy link
Collaborator

@hongqing1986 Yes, your analysis is correct. V2 is used when the batch size is small and at least one sequence has a context length over 512.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants