From 7d057f90463b065b814e91acf4a359be36711363 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 03:26:38 +0000 Subject: [PATCH 01/32] PagedAttention V1 --- csrc/attention/attention_kernels.cu | 45 ++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd78..575ee7af4ec82 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -65,14 +65,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// Grid: (num_heads, num_seqs). +// TODO(woosuk): Flatten the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, num_partitions). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> -__global__ void single_query_cached_kv_attention_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -98,6 +102,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -338,13 +343,39 @@ __global__ void single_query_cached_kv_attention_kernel( } } +// Grid: (num_heads, num_seqs, 1). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + } // namespace vllm #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ + vllm::paged_attention_v1, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::single_query_cached_kv_attention_kernel \ + vllm::paged_attention_v1 \ <<>>( \ out_ptr, \ query_ptr, \ @@ -408,7 +439,7 @@ void single_query_cached_kv_attention_launcher( // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs); + dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { From 2cc7bff861268cadb60947d4180b7e2194219072 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 04:01:01 +0000 Subject: [PATCH 02/32] Mid --- csrc/attention/attention_kernels.cu | 35 ++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 575ee7af4ec82..a36ad71c31bfa 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,6 +26,7 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -89,10 +90,29 @@ __device__ void paged_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { + // FIXME(woosuk): Optimize. + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + int num_blocks; + if (USE_PARTITIONING) { + num_blocks = MIN(num_context_blocks - start_block_idx, num_blocks_per_partition); + } else { + num_blocks = num_context_blocks; + } + if (num_blocks <= 0) { + return; + } + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -101,8 +121,6 @@ __device__ void paged_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -147,15 +165,12 @@ __device__ void paged_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -242,7 +257,7 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -432,7 +447,7 @@ void single_query_cached_kv_attention_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len From 8946093d99d45b49c196ccd7e52d42ab02fd6277 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 04:17:21 +0000 Subject: [PATCH 03/32] PagedAttention V1 --- csrc/attention.cpp | 8 +++--- csrc/attention/attention_kernels.cu | 36 ++++++++++++------------- vllm/model_executor/layers/attention.py | 2 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index 6be8a6d25ae49..ae1821b25e3a7 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -1,7 +1,7 @@ #include #include -void single_query_cached_kv_attention( +void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -16,7 +16,7 @@ void single_query_cached_kv_attention( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( - "single_query_cached_kv_attention", - &single_query_cached_kv_attention, - "Compute the attention between an input query and the cached key/value tensors"); + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index a36ad71c31bfa..574001513b41a 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -66,7 +66,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// TODO(woosuk): Flatten the last two dimensions of the grid. +// TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, num_partitions). template< typename scalar_t, @@ -364,7 +364,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS> -__global__ void paged_attention_v1( +__global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] @@ -386,11 +386,11 @@ __global__ void paged_attention_v1( } // namespace vllm -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +#define LAUNCH_PAGED_ATTENTION_V1(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1 \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -411,7 +411,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void single_query_cached_kv_attention_launcher( +void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -461,31 +461,31 @@ void single_query_cached_kv_attention_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted head sizes // 32, 160, 192. // case 32: - // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // LAUNCH_PAGED_ATTENTION_V1(T, 32, BLOCK_SIZE, NUM_THREADS); // break; case 64: - LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 64, BLOCK_SIZE, NUM_THREADS); break; case 80: - LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 80, BLOCK_SIZE, NUM_THREADS); break; case 96: - LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 96, BLOCK_SIZE, NUM_THREADS); break; case 112: - LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 112, BLOCK_SIZE, NUM_THREADS); break; case 128: - LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 128, BLOCK_SIZE, NUM_THREADS); break; // case 160: - // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // LAUNCH_PAGED_ATTENTION_V1(T, 160, BLOCK_SIZE, NUM_THREADS); // break; // case 192: - // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // LAUNCH_PAGED_ATTENTION_V1(T, 192, BLOCK_SIZE, NUM_THREADS); // break; case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(T, 256, BLOCK_SIZE, NUM_THREADS); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -494,7 +494,7 @@ void single_query_cached_kv_attention_launcher( } #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ + paged_attention_v1_launcher( \ out, \ query, \ key_cache, \ @@ -542,7 +542,7 @@ void single_query_cached_kv_attention_launcher( break; \ } -void single_query_cached_kv_attention( +void paged_attention_v1( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b1d0588d97f7e..4d32a63214953 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -150,7 +150,7 @@ def single_query_cached_kv_attention( input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( + attention_ops.paged_attention_v1( output, query, key_cache, From f5b05fc1869c8a16d375de1422717592e86d4300 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 04:20:01 +0000 Subject: [PATCH 04/32] Undef DIVIDE_ROUND_UP --- csrc/attention/attention_kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 574001513b41a..fad348ef1dfda 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -568,3 +568,4 @@ void paged_attention_v1( #undef WARP_SIZE #undef MAX #undef MIN +#undef DIVIDE_ROUND_UP From 235f273d5613970ba5b59a9e003763dd2687a4c5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 04:55:19 +0000 Subject: [PATCH 05/32] Add empty PagedAttention V2 --- csrc/attention.cpp | 20 ++++++++ csrc/attention/attention_kernels.cu | 65 ++++++++++++++++++++++++- vllm/model_executor/layers/attention.py | 62 ++++++++++++++++++----- 3 files changed, 132 insertions(+), 15 deletions(-) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index ae1821b25e3a7..bd93fd71b733d 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -14,9 +14,29 @@ void paged_attention_v1( int max_context_len, const c10::optional& alibi_slopes); +void paged_attention_v2( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "paged_attention_v1", &paged_attention_v1, "Compute the attention between an input query and the cached keys/values using PagedAttention."); + m.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index fad348ef1dfda..b56d52e026e5a 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -241,9 +241,21 @@ __device__ void paged_attention_kernel( exp_sum += val; } exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * num_partitions + + head_idx * num_partitions + + partition_idx; + *max_logits_ptr = inv_sum; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * num_partitions + + head_idx * num_partitions + + partition_idx; + *exp_sums_ptr = exp_sum; + } // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); for (int i = thread_idx; i < context_len; i += NUM_THREADS) { logits[i] *= inv_sum; } @@ -347,7 +359,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * num_partitions * HEAD_SIZE + + head_idx * num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -384,6 +398,35 @@ __global__ void paged_attention_v1_kernel( max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } +// Grid: (num_heads, num_seqs, num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + } // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ @@ -565,6 +608,24 @@ void paged_attention_v1( } } +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + // TODO +} + #undef WARP_SIZE #undef MAX #undef MIN diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 4d32a63214953..43f17d9bb8872 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,6 +15,7 @@ RotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +_PAGED_ATTENTION_PARTITION_SIZE = 512 class PagedAttention(nn.Module): @@ -149,20 +150,55 @@ def single_query_cached_kv_attention( block_size] input_metadata: metadata for paged attention. """ + num_partitions = (input_metadata.max_context_len + + _PAGED_ATTENTION_PARTITION_SIZE - + 1) // _PAGED_ATTENTION_PARTITION_SIZE block_size = value_cache.shape[3] - attention_ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - None, # alibi_slopes - ) + if num_partitions == 1: + # Short context. Run PagedAttention V1. + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + ) + else: + # Long context. Run PagedAttention V2. + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + ) def forward( self, From 472ee66393b98d92e9b827346938497acd1c2fc5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 05:02:51 +0000 Subject: [PATCH 06/32] Minor --- vllm/model_executor/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 43f17d9bb8872..dd4fa2971d58b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -171,6 +171,7 @@ def single_query_cached_kv_attention( ) else: # Long context. Run PagedAttention V2. + assert _PAGED_ATTENTION_PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( size=(num_seqs, num_heads, num_partitions, head_size), From 3827e241ec8c776a0b5bdfe2bba6013bcd6c3a27 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 09:05:30 +0000 Subject: [PATCH 07/32] Minor --- vllm/model_executor/layers/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index dd4fa2971d58b..f95ff9b0dd667 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,7 +15,7 @@ RotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] -_PAGED_ATTENTION_PARTITION_SIZE = 512 +_PARTITION_SIZE = 512 class PagedAttention(nn.Module): @@ -150,11 +150,11 @@ def single_query_cached_kv_attention( block_size] input_metadata: metadata for paged attention. """ - num_partitions = (input_metadata.max_context_len + - _PAGED_ATTENTION_PARTITION_SIZE - - 1) // _PAGED_ATTENTION_PARTITION_SIZE + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) block_size = value_cache.shape[3] - if num_partitions == 1: + if max_num_partitions == 1: # Short context. Run PagedAttention V1. attention_ops.paged_attention_v1( output, @@ -171,15 +171,15 @@ def single_query_cached_kv_attention( ) else: # Long context. Run PagedAttention V2. - assert _PAGED_ATTENTION_PARTITION_SIZE % block_size == 0 + assert _PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), + size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, device=output.device, ) exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), + size=(num_seqs, num_heads, max_num_partitions), dtype=torch.float32, device=output.device, ) From 2605c6ebefeaee38f5ebc816de3b400ded581906 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 09:05:45 +0000 Subject: [PATCH 08/32] Implement PagedAttention V2 --- csrc/attention/attention_kernels.cu | 370 +++++++++++++++++++++++----- csrc/attention/dtype_bfloat16.cuh | 5 + 2 files changed, 315 insertions(+), 60 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b56d52e026e5a..40b4d0f13b17d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -67,7 +67,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { } // TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, num_partitions). +// Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, int HEAD_SIZE, @@ -75,9 +75,9 @@ template< int NUM_THREADS, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, num_partitions, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -93,22 +93,26 @@ __device__ void paged_attention_kernel( // FIXME(woosuk): Optimize. const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; - const int num_partitions = gridDim.z; + const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; const int context_len = context_lens[seq_idx]; const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - int num_blocks; - if (USE_PARTITIONING) { - num_blocks = MIN(num_context_blocks - start_block_idx, num_blocks_per_partition); - } else { - num_blocks = num_context_blocks; - } + int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + int num_blocks = end_block_idx - start_block_idx; if (num_blocks <= 0) { + // No work to do. Terminate the thread block. return; } + // [start_token_idx, end_token_idx) is the range of tokens to process. + int start_token_idx = start_block_idx * BLOCK_SIZE; + int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + int num_tokens = end_token_idx - start_token_idx; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); @@ -170,7 +174,7 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) { + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -204,7 +208,7 @@ __device__ void paged_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx] = mask ? 0.f : qk; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -235,28 +239,28 @@ __device__ void paged_attention_kernel( // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; } exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * num_partitions - + head_idx * num_partitions + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; - *max_logits_ptr = inv_sum; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * num_partitions - + head_idx * num_partitions + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } // Compute softmax. - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); @@ -280,12 +284,12 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -295,7 +299,7 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_blocks - 1) { + if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -359,8 +363,8 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * num_partitions * HEAD_SIZE - + head_idx * num_partitions * HEAD_SIZE + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { @@ -398,7 +402,7 @@ __global__ void paged_attention_v1_kernel( max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } -// Grid: (num_heads, num_seqs, num_partitions). +// Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, int HEAD_SIZE, @@ -406,9 +410,9 @@ template< int NUM_THREADS, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, num_partitions, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -427,6 +431,105 @@ __global__ void paged_attention_v2_kernel( q_stride, kv_block_stride, kv_head_stride); } +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + } // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ @@ -536,7 +639,7 @@ void paged_attention_v1_launcher( } } -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_launcher( \ out, \ query, \ @@ -551,35 +654,17 @@ void paged_attention_v1_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ - /* case 1: */ \ - /* CALL_KERNEL_LAUNCHER(T, 1); */ \ - /* break; */ \ - /* case 2: */ \ - /* CALL_KERNEL_LAUNCHER(T, 2); */ \ - /* break; */ \ - /* case 4: */ \ - /* CALL_KERNEL_LAUNCHER(T, 4); */ \ - /* break; */ \ case 8: \ - CALL_KERNEL_LAUNCHER(T, 8); \ + CALL_V1_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_KERNEL_LAUNCHER(T, 16); \ + CALL_V1_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_KERNEL_LAUNCHER(T, 32); \ + CALL_V1_LAUNCHER(T, 32); \ break; \ - /* case 64: */ \ - /* CALL_KERNEL_LAUNCHER(T, 64); */ \ - /* break; */ \ - /* case 128: */ \ - /* CALL_KERNEL_LAUNCHER(T, 128); */ \ - /* break; */ \ - /* case 256: */ \ - /* CALL_KERNEL_LAUNCHER(T, 256); */ \ - /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ @@ -598,21 +683,178 @@ void paged_attention_v1( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + CALL_V1_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } +#define LAUNCH_PAGED_ATTENTION_V2(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 grid2(num_heads, num_seqs); + int shared_mem_size2 = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_PAGED_ATTENTION_V2(T, 32, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + // break; + case 64: + LAUNCH_PAGED_ATTENTION_V2(T, 64, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(T, 80, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(T, 96, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(T, 112, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(T, 128, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + // case 160: + // LAUNCH_PAGED_ATTENTION_V2(T, 160, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + // break; + // case 192: + // LAUNCH_PAGED_ATTENTION_V2(T, 192, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + // break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(T, 256, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + void paged_attention_v2( torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, num_partitions, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -623,7 +865,15 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes) { - // TODO + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } #undef WARP_SIZE diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 2154bfcf8631a..5786f77f7bca6 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 From 877a3f5df60fd37e63fadf170262adab4eb59eac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 09:09:18 +0000 Subject: [PATCH 09/32] Add comment --- vllm/model_executor/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f95ff9b0dd667..826522bd7e5ab 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,6 +15,7 @@ RotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 From 634f9618d38270ab47437dd3489c5343afa59c58 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 12 Oct 2023 23:52:04 +0000 Subject: [PATCH 10/32] Fix performance bug --- csrc/attention/attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 40b4d0f13b17d..7764ebd3bcbb3 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -247,7 +247,7 @@ __device__ void paged_attention_kernel( exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING) { + if (USE_PARTITIONING && thread_idx == 0) { float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; From 758510189c894dd1a2815260c51b1461f0f76824 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 03:24:58 +0000 Subject: [PATCH 11/32] Fix attention test --- tests/kernels/test_attention.py | 59 +++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 59d8b0a59ce6f..d55ca1e99fa80 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -162,19 +162,52 @@ def test_single_query_cached_kv_attention( # Call the paged attention kernel. output = torch.empty_like(query) - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) + PARTITION_SIZE = 512 + num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE + if num_partitions == 1: + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) # Run the reference implementation. ref_output = torch.empty_like(query) From 3ea3891472ae1822732a372a68dd680eae1a1324 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 06:40:53 +0000 Subject: [PATCH 12/32] Add heuristic --- vllm/model_executor/layers/attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 826522bd7e5ab..1cb5f552e7bac 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -151,11 +151,13 @@ def single_query_cached_kv_attention( block_size] input_metadata: metadata for paged attention. """ + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape max_num_partitions = ( (input_metadata.max_context_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) - block_size = value_cache.shape[3] - if max_num_partitions == 1: + use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + if use_v1: # Short context. Run PagedAttention V1. attention_ops.paged_attention_v1( output, @@ -173,7 +175,6 @@ def single_query_cached_kv_attention( else: # Long context. Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, From ab89848678f946f5028d257eec1d621cb61ab5db Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 06:41:13 +0000 Subject: [PATCH 13/32] Minor optimization --- csrc/attention/attention_kernels.cu | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 7764ebd3bcbb3..b046b6cdb55bf 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -90,12 +90,16 @@ __device__ void paged_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - // FIXME(woosuk): Optimize. const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; @@ -103,10 +107,6 @@ __device__ void paged_attention_kernel( const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); int num_blocks = end_block_idx - start_block_idx; - if (num_blocks <= 0) { - // No work to do. Terminate the thread block. - return; - } // [start_token_idx, end_token_idx) is the range of tokens to process. int start_token_idx = start_block_idx * BLOCK_SIZE; @@ -246,6 +246,13 @@ __device__ void paged_attention_kernel( } exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions @@ -258,13 +265,6 @@ __device__ void paged_attention_kernel( *exp_sums_ptr = exp_sum; } - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; From d83ce92f90cd694467cb258acd0e761b6053f182 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 06:41:24 +0000 Subject: [PATCH 14/32] Add benchmark --- benchmarks/benchmark_attention.py | 200 ++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 benchmarks/benchmark_attention.py diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py new file mode 100644 index 0000000000000..d42a777a402d8 --- /dev/null +++ b/benchmarks/benchmark_attention.py @@ -0,0 +1,200 @@ +import random +from typing import List, Optional, Tuple + +import time +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + +from vllm import attention_ops +from vllm.utils import get_max_shared_memory_bytes + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +NUM_BLOCKS = 128 # Arbitrary values for testing + +DTYPES = [torch.half] +NUM_GEN_SEQS = [1, 4, 16, 64, 128] # Arbitrary values for testing +CONTEXT_LENS = [1024, 2048, 4096, 8192, 16384] # Arbitrary values for testing +NUM_HEADS = [(40, 40)] # Arbitrary values for testing +HEAD_SIZES = [128] +BLOCK_SIZES = [16] +USE_ALIBI = [False] +SEEDS = [0] + +def create_kv_caches( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + seed: int, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device='cuda') + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device='cuda') + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches + + +@torch.inference_mode() +def test_single_query_cached_kv_attention( + num_seqs: int, + context_len: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + version: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [random.randint(1, context_len) for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + key_caches, value_caches = create_kv_caches(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, dtype, + seed) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Call the paged attention kernel. + output = torch.empty_like(query) + PARTITION_SIZE = 512 + num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE + + def f(): + num_seqs, num_heads, head_size = output.shape + use_v1 = num_partitions == 1 or num_seqs * num_heads > 512 + if version == 1 or use_v1: + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + assert PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + for _ in range(3): + f() + torch.cuda.synchronize() + + start = time.time() + for _ in range(100): + f() + torch.cuda.synchronize() + end = time.time() + print(f"Time: {(end - start) / 100 * 1000:.3f} ms") + + +if __name__ == "__main__": + for context_len in CONTEXT_LENS: + for num_seqs in NUM_GEN_SEQS: + for num_heads in NUM_HEADS: + for dtype in DTYPES: + for version in [1, 2]: + print( + f"Testing: V{version} {num_seqs}, {context_len}" + ) + test_single_query_cached_kv_attention( + num_seqs, + context_len, + num_heads, + 128, + False, + 16, + dtype, + 0, + version, + ) From 760e7a29793175bed629b2c6d1638e23daac660e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 13 Oct 2023 16:29:19 +0000 Subject: [PATCH 15/32] Minor --- benchmarks/benchmark_attention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py index d42a777a402d8..ddbad1b1243a3 100644 --- a/benchmarks/benchmark_attention.py +++ b/benchmarks/benchmark_attention.py @@ -175,7 +175,9 @@ def f(): f() torch.cuda.synchronize() end = time.time() - print(f"Time: {(end - start) / 100 * 1000:.3f} ms") + t = (end - start) / 100 * 1000 + print(f"Time: {t:.3f} ms") + return t if __name__ == "__main__": @@ -183,11 +185,12 @@ def f(): for num_seqs in NUM_GEN_SEQS: for num_heads in NUM_HEADS: for dtype in DTYPES: + ts = [] for version in [1, 2]: print( f"Testing: V{version} {num_seqs}, {context_len}" ) - test_single_query_cached_kv_attention( + t = test_single_query_cached_kv_attention( num_seqs, context_len, num_heads, @@ -198,3 +201,5 @@ def f(): 0, version, ) + ts.append(t) + print(f"Speedup: {ts[0] / ts[1]:.3f}") From e6d8a15c62d6b8e90e94015d9c09007d63991d9f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 06:35:23 +0000 Subject: [PATCH 16/32] yapf --- benchmarks/benchmark_attention.py | 37 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py index ddbad1b1243a3..06ba01aa469ae 100644 --- a/benchmarks/benchmark_attention.py +++ b/benchmarks/benchmark_attention.py @@ -23,6 +23,7 @@ USE_ALIBI = [False] SEEDS = [0] + def create_kv_caches( num_blocks: int, block_size: int, @@ -184,22 +185,20 @@ def f(): for context_len in CONTEXT_LENS: for num_seqs in NUM_GEN_SEQS: for num_heads in NUM_HEADS: - for dtype in DTYPES: - ts = [] - for version in [1, 2]: - print( - f"Testing: V{version} {num_seqs}, {context_len}" - ) - t = test_single_query_cached_kv_attention( - num_seqs, - context_len, - num_heads, - 128, - False, - 16, - dtype, - 0, - version, - ) - ts.append(t) - print(f"Speedup: {ts[0] / ts[1]:.3f}") + for dtype in DTYPES: + ts = [] + for version in [1, 2]: + print(f"Testing: V{version} {num_seqs}, {context_len}") + t = test_single_query_cached_kv_attention( + num_seqs, + context_len, + num_heads, + 128, + False, + 16, + dtype, + 0, + version, + ) + ts.append(t) + print(f"Speedup: {ts[0] / ts[1]:.3f}") From 4313691d02ed8453c8b368ddbf7057ae0d6d65a5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 06:46:54 +0000 Subject: [PATCH 17/32] Minor fix on comments --- csrc/attention/attention_kernels.cu | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b046b6cdb55bf..3dccfad96a2c6 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -604,11 +604,9 @@ void paged_attention_v1_launcher( dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192. - // case 32: - // LAUNCH_PAGED_ATTENTION_V1(T, 32, BLOCK_SIZE, NUM_THREADS); - // break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. case 64: LAUNCH_PAGED_ATTENTION_V1(T, 64, BLOCK_SIZE, NUM_THREADS); break; @@ -624,12 +622,6 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(T, 128, BLOCK_SIZE, NUM_THREADS); break; - // case 160: - // LAUNCH_PAGED_ATTENTION_V1(T, 160, BLOCK_SIZE, NUM_THREADS); - // break; - // case 192: - // LAUNCH_PAGED_ATTENTION_V1(T, 192, BLOCK_SIZE, NUM_THREADS); - // break; case 256: LAUNCH_PAGED_ATTENTION_V1(T, 256, BLOCK_SIZE, NUM_THREADS); break; @@ -781,11 +773,9 @@ void paged_attention_v2_launcher( dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192. - // case 32: - // LAUNCH_PAGED_ATTENTION_V2(T, 32, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); - // break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. case 64: LAUNCH_PAGED_ATTENTION_V2(T, 64, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); break; @@ -801,12 +791,6 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(T, 128, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); break; - // case 160: - // LAUNCH_PAGED_ATTENTION_V2(T, 160, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); - // break; - // case 192: - // LAUNCH_PAGED_ATTENTION_V2(T, 192, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); - // break; case 256: LAUNCH_PAGED_ATTENTION_V2(T, 256, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); break; From c0021c16cf8a10ec689aba5e506c63af9f7eaa15 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 06:52:10 +0000 Subject: [PATCH 18/32] Add comment on heuristic --- vllm/model_executor/layers/attention.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 1cb5f552e7bac..1a2e2bb895d1d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,6 +156,12 @@ def single_query_cached_kv_attention( max_num_partitions = ( (input_metadata.max_context_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 if use_v1: # Short context. Run PagedAttention V1. From 8ddb4268e9b9da8d2041667be223db10719aad6b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:01:23 +0000 Subject: [PATCH 19/32] Fix test_attention --- tests/kernels/test_attention.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index d55ca1e99fa80..a57da65e141ba 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -14,13 +14,14 @@ # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_GEN_SEQS = [7] # Arbitrary values for testing -NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] +BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] @@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) +@pytest.mark.parametrize("use_v2", [False, True]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_single_query_cached_kv_attention( +def test_paged_attention( kv_cache_factory, + use_v2: bool, num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -162,9 +164,7 @@ def test_single_query_cached_kv_attention( # Call the paged attention kernel. output = torch.empty_like(query) - PARTITION_SIZE = 512 - num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE - if num_partitions == 1: + if not use_v2: attention_ops.paged_attention_v1( output, query, @@ -179,6 +179,7 @@ def test_single_query_cached_kv_attention( alibi_slopes, ) else: + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( From 08e92c36928c76c8cc04066e6321de2df0410b73 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:03:27 +0000 Subject: [PATCH 20/32] yapf --- tests/kernels/test_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index a57da65e141ba..96627a13c343a 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -179,7 +179,8 @@ def test_paged_attention( alibi_slopes, ) else: - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( From dac5e24fe65c80a548e51b608f08557bcf6f1c07 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:04:18 +0000 Subject: [PATCH 21/32] Minor --- tests/kernels/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 96627a13c343a..4070e6a93733d 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -97,7 +97,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("use_v2", [False, True]) +@pytest.mark.parametrize("use_v2", [True, False]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From d674616ba5830cad2c8bcde955eccdc9093fceb1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:05:38 +0000 Subject: [PATCH 22/32] Minor --- vllm/model_executor/layers/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 1a2e2bb895d1d..d26b5bd02012c 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -164,7 +164,7 @@ def single_query_cached_kv_attention( # TODO(woosuk): Tune this heuristic. use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 if use_v1: - # Short context. Run PagedAttention V1. + # Run PagedAttention V1. attention_ops.paged_attention_v1( output, query, @@ -179,7 +179,7 @@ def single_query_cached_kv_attention( None, # alibi_slopes ) else: - # Long context. Run PagedAttention V2. + # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), From 612236b1b4e059bebe5ced44c2e38b2f347e3bd5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:43:46 +0000 Subject: [PATCH 23/32] Reimplement --- benchmarks/benchmark_attention.py | 204 ------------------- benchmarks/kernels/benchmark_paged_attn.py | 218 +++++++++++++++++++++ 2 files changed, 218 insertions(+), 204 deletions(-) delete mode 100644 benchmarks/benchmark_attention.py create mode 100644 benchmarks/kernels/benchmark_paged_attn.py diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py deleted file mode 100644 index 06ba01aa469ae..0000000000000 --- a/benchmarks/benchmark_attention.py +++ /dev/null @@ -1,204 +0,0 @@ -import random -from typing import List, Optional, Tuple - -import time -import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - -from vllm import attention_ops -from vllm.utils import get_max_shared_memory_bytes - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -NUM_BLOCKS = 128 # Arbitrary values for testing - -DTYPES = [torch.half] -NUM_GEN_SEQS = [1, 4, 16, 64, 128] # Arbitrary values for testing -CONTEXT_LENS = [1024, 2048, 4096, 8192, 16384] # Arbitrary values for testing -NUM_HEADS = [(40, 40)] # Arbitrary values for testing -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -USE_ALIBI = [False] -SEEDS = [0] - - -def create_kv_caches( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - seed: int, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device='cuda') - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device='cuda') - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches - - -@torch.inference_mode() -def test_single_query_cached_kv_attention( - num_seqs: int, - context_len: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - seed: int, - version: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device="cuda") - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device="cuda") - - context_lens = [random.randint(1, context_len) for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - - # Create the KV caches. - key_caches, value_caches = create_kv_caches(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, dtype, - seed) - key_cache, value_cache = key_caches[0], value_caches[0] - - # Call the paged attention kernel. - output = torch.empty_like(query) - PARTITION_SIZE = 512 - num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE - - def f(): - num_seqs, num_heads, head_size = output.shape - use_v1 = num_partitions == 1 or num_seqs * num_heads > 512 - if version == 1 or use_v1: - attention_ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) - else: - assert PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) - - for _ in range(3): - f() - torch.cuda.synchronize() - - start = time.time() - for _ in range(100): - f() - torch.cuda.synchronize() - end = time.time() - t = (end - start) / 100 * 1000 - print(f"Time: {t:.3f} ms") - return t - - -if __name__ == "__main__": - for context_len in CONTEXT_LENS: - for num_seqs in NUM_GEN_SEQS: - for num_heads in NUM_HEADS: - for dtype in DTYPES: - ts = [] - for version in [1, 2]: - print(f"Testing: V{version} {num_seqs}, {context_len}") - t = test_single_query_cached_kv_attention( - num_seqs, - context_len, - num_heads, - 128, - False, - 16, - dtype, - 0, - version, - ) - ts.append(t) - print(f"Speedup: {ts[0] / ts[1]:.3f}") diff --git a/benchmarks/kernels/benchmark_paged_attn.py b/benchmarks/kernels/benchmark_paged_attn.py new file mode 100644 index 0000000000000..52fc1f390d174 --- /dev/null +++ b/benchmarks/kernels/benchmark_paged_attn.py @@ -0,0 +1,218 @@ +import argparse +import random +from typing import List, Tuple + +import time +import torch + +from vllm import attention_ops + +NUM_BLOCKS = 1024 +PARTITION_SIZE = 512 + + +def create_kv_caches( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device="cuda") + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches + + +@torch.inference_mode() +def main( + version: int, + num_seqs: int, + context_len: int, + num_query_heads: int, + num_kv_heads: int, + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + do_profile: bool, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [context_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV cache. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device="cuda") + key_cache.uniform_(-scale, scale) + value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) + value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + + # Prepare for the paged attention kernel. + output = torch.empty_like(query) + if version == 2: + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + tmp_output = torch.empty( + size=(num_seqs, num_query_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_query_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + def run_benchmark(num_iters: int, profile: bool = False) -> float: + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if version == 1: + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark(num_iters=3, profile=False) + torch.cuda.synchronize() + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=100, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the paged attention kernel.") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", type=int, choices=[64, 80, 96, 112, 128, 256], default=128) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--version", type=int, choices=[1, 2], default=2) + parser.add_argument("--use-alibi", action="store_true") + parser.add_argument("--dtype", type=str, choices=["half", "bfloat16", "float"], default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + print(args) + + if args.num_query_heads % args.num_kv_heads != 0: + raise ValueError("num_query_heads must be divisible by num_kv_heads") + dtype_to_torch_dtype = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + } + main( + version=args.version, + num_seqs=args.batch_size, + context_len=args.context_len, + num_query_heads=args.num_query_heads, + num_kv_heads=args.num_kv_heads, + head_size=args.head_size, + block_size=args.block_size, + use_alibi=args.use_alibi, + dtype=dtype_to_torch_dtype[args.dtype], + seed=args.seed, + do_profile=args.profile, + ) From 3d2eff17d9fcb3ce01d38dfe60cf564203d69ad1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:44:14 +0000 Subject: [PATCH 24/32] Rename --- .../{benchmark_paged_attn.py => benchmark_paged_attention.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmarks/kernels/{benchmark_paged_attn.py => benchmark_paged_attention.py} (100%) diff --git a/benchmarks/kernels/benchmark_paged_attn.py b/benchmarks/kernels/benchmark_paged_attention.py similarity index 100% rename from benchmarks/kernels/benchmark_paged_attn.py rename to benchmarks/kernels/benchmark_paged_attention.py From 57b30711fea5d35b1467ee31dc0678a2b31b6cfb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:46:07 +0000 Subject: [PATCH 25/32] Minor --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 52fc1f390d174..80a520783b6e4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -182,13 +182,13 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: if __name__ == '__main__': parser = argparse.ArgumentParser( description="Benchmark the paged attention kernel.") + parser.add_argument("--version", type=int, choices=[1, 2], default=2) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, choices=[64, 80, 96, 112, 128, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) - parser.add_argument("--version", type=int, choices=[1, 2], default=2) parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--dtype", type=str, choices=["half", "bfloat16", "float"], default="half") parser.add_argument("--seed", type=int, default=0) From cb3af6d3141e3ec15a8aa8417fc30eb887be298e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:46:24 +0000 Subject: [PATCH 26/32] yapf --- .../kernels/benchmark_paged_attention.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 80a520783b6e4..1e05933986b10 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -96,20 +96,19 @@ def main( # Create the KV cache. x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) - key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device="cuda") + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") key_cache.uniform_(-scale, scale) value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device="cuda") + dtype=dtype, + device="cuda") value_cache.uniform_(-scale, scale) # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == 2: - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -187,10 +186,16 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) - parser.add_argument("--head-size", type=int, choices=[64, 80, 96, 112, 128, 256], default=128) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") - parser.add_argument("--dtype", type=str, choices=["half", "bfloat16", "float"], default="half") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") args = parser.parse_args() From 000abdf3f6f09e2f27f9605c384f8b54d06ef06a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 15 Oct 2023 07:47:44 +0000 Subject: [PATCH 27/32] Remove unnecessary fns --- .../kernels/benchmark_paged_attention.py | 33 +------------------ 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 1e05933986b10..a6c0873f707b8 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,8 +1,7 @@ import argparse import random -from typing import List, Tuple - import time + import torch from vllm import attention_ops @@ -11,36 +10,6 @@ PARTITION_SIZE = 512 -def create_kv_caches( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device="cuda") - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches - - @torch.inference_mode() def main( version: int, From f80f49f77c3f52001e0c4b0fdf878411f67e8ff1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 06:29:14 +0000 Subject: [PATCH 28/32] Address comments --- .../kernels/benchmark_paged_attention.py | 12 ++--- csrc/attention/attention_kernels.cu | 44 +++++++++---------- tests/kernels/test_attention.py | 10 +++-- 3 files changed, 35 insertions(+), 31 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a6c0873f707b8..32218395119d2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -12,7 +12,7 @@ @torch.inference_mode() def main( - version: int, + version: str, num_seqs: int, context_len: int, num_query_heads: int, @@ -91,12 +91,13 @@ def main( max_logits = torch.empty_like(exp_sums) def run_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStart() start_time = time.perf_counter() for _ in range(num_iters): - if version == 1: + if version == "v1": attention_ops.paged_attention_v1( output, query, @@ -110,7 +111,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: max_context_len, alibi_slopes, ) - else: + elif version == "v2": attention_ops.paged_attention_v2( output, exp_sums, @@ -127,6 +128,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: max_context_len, alibi_slopes, ) + else: + raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() end_time = time.perf_counter() @@ -137,7 +140,6 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark(num_iters=3, profile=False) - torch.cuda.synchronize() # Benchmark. if do_profile: @@ -150,7 +152,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: if __name__ == '__main__': parser = argparse.ArgumentParser( description="Benchmark the paged attention kernel.") - parser.add_argument("--version", type=int, choices=[1, 2], default=2) + parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3dccfad96a2c6..ee6b715adaef0 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -105,13 +105,13 @@ __device__ void paged_attention_kernel( // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - int num_blocks = end_block_idx - start_block_idx; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. - int start_token_idx = start_block_idx * BLOCK_SIZE; - int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - int num_tokens = end_token_idx - start_token_idx; + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS @@ -532,7 +532,7 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#define LAUNCH_PAGED_ATTENTION_V1(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ @@ -608,22 +608,22 @@ void paged_attention_v1_launcher( // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. case 64: - LAUNCH_PAGED_ATTENTION_V1(T, 64, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(64); break; case 80: - LAUNCH_PAGED_ATTENTION_V1(T, 80, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(80); break; case 96: - LAUNCH_PAGED_ATTENTION_V1(T, 96, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(96); break; case 112: - LAUNCH_PAGED_ATTENTION_V1(T, 112, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(112); break; case 128: - LAUNCH_PAGED_ATTENTION_V1(T, 128, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(128); break; case 256: - LAUNCH_PAGED_ATTENTION_V1(T, 256, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V1(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -685,7 +685,7 @@ void paged_attention_v1( } } -#define LAUNCH_PAGED_ATTENTION_V2(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE) \ +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ <<>>( \ exp_sums_ptr, \ @@ -704,7 +704,7 @@ void paged_attention_v1( kv_block_stride, \ kv_head_stride); \ vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ + <<>>( \ out_ptr, \ exp_sums_ptr, \ max_logits_ptr, \ @@ -767,8 +767,8 @@ void paged_attention_v2_launcher( dim3 grid(num_heads, num_seqs, max_num_partitions); int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. - dim3 grid2(num_heads, num_seqs); - int shared_mem_size2 = 2 * max_num_partitions * sizeof(float); + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -777,22 +777,22 @@ void paged_attention_v2_launcher( // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. case 64: - LAUNCH_PAGED_ATTENTION_V2(T, 64, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(64); break; case 80: - LAUNCH_PAGED_ATTENTION_V2(T, 80, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(80); break; case 96: - LAUNCH_PAGED_ATTENTION_V2(T, 96, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(96); break; case 112: - LAUNCH_PAGED_ATTENTION_V2(T, 112, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(112); break; case 128: - LAUNCH_PAGED_ATTENTION_V2(T, 128, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(128); break; case 256: - LAUNCH_PAGED_ATTENTION_V2(T, 256, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE); + LAUNCH_PAGED_ATTENTION_V2(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4070e6a93733d..31d78dd1bcf90 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -97,7 +97,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("use_v2", [True, False]) +@pytest.mark.parametrize("version", ["v1", "v2"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -107,7 +107,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("seed", SEEDS) def test_paged_attention( kv_cache_factory, - use_v2: bool, + version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -164,7 +164,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - if not use_v2: + if version == "v1": attention_ops.paged_attention_v1( output, query, @@ -178,7 +178,7 @@ def test_paged_attention( max_context_len, alibi_slopes, ) - else: + elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 @@ -210,6 +210,8 @@ def test_paged_attention( max_context_len, alibi_slopes, ) + else: + assert False, f"Unknown version: {version}" # Run the reference implementation. ref_output = torch.empty_like(query) From 5b0a536e774d76a619c953548396010063fad07b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 06:39:10 +0000 Subject: [PATCH 29/32] Minor fix --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 32218395119d2..3ba7d5a504925 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -75,7 +75,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) - if version == 2: + if version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( From f3c8cb0c4ee17978634fa2dc8b8efd583e0113ca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 06:59:55 +0000 Subject: [PATCH 30/32] Support attention with ALiBi --- vllm/model_executor/layers/attention.py | 52 ++++++++----------------- 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index d26b5bd02012c..3aa0b00dbacd8 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -132,6 +132,14 @@ def multi_query_kv_attention( output.copy_(out.squeeze(0)) return output + def get_alibi_slopes(self) -> Optional[torch.Tensor]: + """Returns the slopes for the alibi attention bias. + + Returns: + slopes: shape = [num_heads] + """ + return None + def single_query_cached_kv_attention( self, output: torch.Tensor, @@ -139,6 +147,7 @@ def single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, + alibi_slopes: Optional[torch.Tensor], ) -> None: """PagedAttention for the generation tokens. @@ -150,6 +159,7 @@ def single_query_cached_kv_attention( value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for paged attention. + alibi_slopes: shape = [num_heads] """ block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape @@ -176,7 +186,7 @@ def single_query_cached_kv_attention( input_metadata.context_lens, block_size, input_metadata.max_context_len, - None, # alibi_slopes + alibi_slopes, ) else: # Run PagedAttention V2. @@ -206,7 +216,7 @@ def single_query_cached_kv_attention( input_metadata.context_lens, block_size, input_metadata.max_context_len, - None, # alibi_slopes + alibi_slopes, ) def forward( @@ -294,11 +304,12 @@ def forward( assert key_cache is not None and value_cache is not None, ( "key_cache and value_cache must be provided when " "generating tokens.") + alibi_slopes = self.get_alibi_slopes() # Compute the attention op for generation tokens. self.single_query_cached_kv_attention( output[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache, - value_cache, input_metadata) + value_cache, input_metadata, alibi_slopes) # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. @@ -476,36 +487,5 @@ def multi_query_kv_attention( start += prompt_len return output - def single_query_cached_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - ) -> None: - """PagedAttention with ALiBi bias for the generation tokens. - - Args: - output: shape = [num_generation_tokens, num_heads, head_size] - query: shape = [num_generation_tokens, num_heads, head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for paged attention. - """ - block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - self.alibi_slopes, - ) + def get_alibi_slopes(self) -> Optional[torch.Tensor]: + return self.alibi_slopes From bfa856950dfe3ac74ccf17a941a227eca4a5e29b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 07:02:00 +0000 Subject: [PATCH 31/32] yapf --- vllm/model_executor/layers/attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3aa0b00dbacd8..0677ebbae792d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -304,12 +304,11 @@ def forward( assert key_cache is not None and value_cache is not None, ( "key_cache and value_cache must be provided when " "generating tokens.") - alibi_slopes = self.get_alibi_slopes() # Compute the attention op for generation tokens. self.single_query_cached_kv_attention( output[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache, - value_cache, input_metadata, alibi_slopes) + value_cache, input_metadata, self.get_alibi_slopes()) # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. From 9451b2d8e3282d27797e9b02780448f703dc80d6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 07:03:28 +0000 Subject: [PATCH 32/32] yapf --- benchmarks/kernels/benchmark_paged_attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 3ba7d5a504925..0ef8030767677 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -152,7 +152,10 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: if __name__ == '__main__': parser = argparse.ArgumentParser( description="Benchmark the paged attention kernel.") - parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") + parser.add_argument("--version", + type=str, + choices=["v1", "v2"], + default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64)