Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PagedAttention V2 #1348

Merged
merged 33 commits into from
Oct 16, 2023
Merged

Implement PagedAttention V2 #1348

merged 33 commits into from
Oct 16, 2023

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Oct 13, 2023

This PR implements the first part of the PagedAttention V2 kernel, which uses sequence-level parallelism for better work partitioning. Compared to V1, the V2 kernel achieves huge speedup when the batch size is small (e.g., <= 8). We will further optimize the kernel henceforth.

@WoosukKwon WoosukKwon changed the title [WIP] Paged Attention V2 Implement PagedAttention V2 Oct 15, 2023
@WoosukKwon WoosukKwon marked this pull request as ready for review October 15, 2023 07:51
@WoosukKwon WoosukKwon requested a review from zhuohan123 October 15, 2023 07:51
@WoosukKwon WoosukKwon mentioned this pull request Oct 14, 2023
3 tasks
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! In general LGTM. Left some style comments.

block_size,
input_metadata.max_context_len,
None, # alibi_slopes
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we modify the Alibi paged attention to let it use paged attention v2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

benchmarks/kernels/benchmark_paged_attention.py Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_paged_attention.py Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_paged_attention.py Outdated Show resolved Hide resolved
csrc/attention/attention_kernels.cu Outdated Show resolved Hide resolved
csrc/attention/attention_kernels.cu Outdated Show resolved Hide resolved
csrc/attention/attention_kernels.cu Outdated Show resolved Hide resolved
csrc/attention/attention_kernels.cu Outdated Show resolved Hide resolved
tests/kernels/test_attention.py Outdated Show resolved Hide resolved
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the threshold 512? Is this number related to the number of SMs a GPU has?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. As we discussed offline, this is a simple heuristic to make sure that the V1 kernel is used when num_seq * num_heads is roughly larger than 4 * SM count in A100 and H100 GPUs. Actually, this can be improved by considering the GPU's actual SM counts. For now, I leave this as future work.


#define LAUNCH_PAGED_ATTENTION_V2(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE) \
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not need to set cudaFuncAttributeMaxDynamicSharedMemorySize here like we do for v1?

Copy link
Collaborator Author

@WoosukKwon WoosukKwon Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not necessary because in V2 each thread block only handles PARTITION_SIZE (=512) tokens. So, if we actually use V2 in all cases, we can remove the shared memory check and support (almost) arbitrary length in all GPUs.

Copy link
Collaborator

@Yard1 Yard1 Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for explaining! Should we then force v2 to be used if the check fails, in that case? It could be done in a followup PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 That's a good idea! Let's do it in a followup PR.

@WoosukKwon WoosukKwon requested a review from zhuohan123 October 16, 2023 07:04
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 I addressed your comments. PTAL.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the awesome work!

@WoosukKwon WoosukKwon merged commit 928de46 into main Oct 16, 2023
2 checks passed
@WoosukKwon WoosukKwon deleted the pa-v2 branch October 16, 2023 08:00
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants