Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PagedAttention V2 #1348

Merged
merged 33 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7d057f9
PagedAttention V1
WoosukKwon Oct 12, 2023
2cc7bff
Mid
WoosukKwon Oct 12, 2023
8946093
PagedAttention V1
WoosukKwon Oct 12, 2023
f5b05fc
Undef DIVIDE_ROUND_UP
WoosukKwon Oct 12, 2023
235f273
Add empty PagedAttention V2
WoosukKwon Oct 12, 2023
472ee66
Minor
WoosukKwon Oct 12, 2023
3827e24
Minor
WoosukKwon Oct 12, 2023
2605c6e
Implement PagedAttention V2
WoosukKwon Oct 12, 2023
877a3f5
Add comment
WoosukKwon Oct 12, 2023
634f961
Fix performance bug
WoosukKwon Oct 12, 2023
7585101
Fix attention test
WoosukKwon Oct 13, 2023
3ea3891
Add heuristic
WoosukKwon Oct 13, 2023
ab89848
Minor optimization
WoosukKwon Oct 13, 2023
d83ce92
Add benchmark
WoosukKwon Oct 13, 2023
760e7a2
Minor
WoosukKwon Oct 13, 2023
e6d8a15
yapf
WoosukKwon Oct 15, 2023
4313691
Minor fix on comments
WoosukKwon Oct 15, 2023
c0021c1
Add comment on heuristic
WoosukKwon Oct 15, 2023
8ddb426
Fix test_attention
WoosukKwon Oct 15, 2023
ae14bba
Merge branch 'main' into pa-v2
WoosukKwon Oct 15, 2023
08e92c3
yapf
WoosukKwon Oct 15, 2023
dac5e24
Minor
WoosukKwon Oct 15, 2023
d674616
Minor
WoosukKwon Oct 15, 2023
612236b
Reimplement
WoosukKwon Oct 15, 2023
3d2eff1
Rename
WoosukKwon Oct 15, 2023
57b3071
Minor
WoosukKwon Oct 15, 2023
cb3af6d
yapf
WoosukKwon Oct 15, 2023
000abdf
Remove unnecessary fns
WoosukKwon Oct 15, 2023
f80f49f
Address comments
WoosukKwon Oct 16, 2023
5b0a536
Minor fix
WoosukKwon Oct 16, 2023
f3c8cb0
Support attention with ALiBi
WoosukKwon Oct 16, 2023
bfa8569
yapf
WoosukKwon Oct 16, 2023
9451b2d
yapf
WoosukKwon Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import argparse
import random
import time

import torch

from vllm import attention_ops

NUM_BLOCKS = 1024
PARTITION_SIZE = 512


@torch.inference_mode()
def main(
version: str,
num_seqs: int,
context_len: int,
num_query_heads: int,
num_kv_heads: int,
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
seed: int,
do_profile: bool,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")

context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")

# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
key_cache.uniform_(-scale, scale)
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)

# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_query_heads, num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)

def run_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

for _ in range(num_iters):
if version == "v1":
attention_ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
elif version == "v2":
attention_ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()

end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark(num_iters=3, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=100, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()
print(args)

if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main(
version=args.version,
num_seqs=args.batch_size,
context_len=args.context_len,
num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads,
head_size=args.head_size,
block_size=args.block_size,
use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype],
seed=args.seed,
do_profile=args.profile,
)
28 changes: 24 additions & 4 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <torch/extension.h>
#include <c10/util/Optional.h>

void single_query_cached_kv_attention(
void paged_attention_v1(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
Expand All @@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void paged_attention_v2(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
m.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
}
Loading
Loading