forked from NVIDIA/kvpress
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add_AdaKV_initial_version AdaKV Signed-off-by: FFY0 <ffyfengyuan@gmail.com>
- Loading branch information
Showing
15 changed files
with
1,128 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
|
||
# Copyright (c) 2024 YuanFeng | ||
# | ||
# This file is part of the YuanFeng project and is licensed under the MIT License. | ||
# SPDX-License-Identifier: MIT | ||
|
||
from attr import dataclass | ||
from transformers.utils import is_flash_attn_greater_or_equal_2_10 | ||
from transformers.models.llama.modeling_llama import LlamaAttention | ||
from transformers.models.llama.modeling_llama import ( | ||
apply_rotary_pos_emb, | ||
) | ||
|
||
|
||
import logging | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from transformers import Cache | ||
|
||
from kvpress.ada_cache import DynamicCacheSplitHeadFlatten | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
from transformers.utils import ( | ||
logging, | ||
is_flash_attn_2_available, | ||
) | ||
if is_flash_attn_2_available(): | ||
from flash_attn import flash_attn_varlen_func | ||
|
||
|
||
@dataclass | ||
class MetaData: | ||
decoding_cu_seqlens_q: torch.Tensor | ||
cu_seqlens_k: torch.Tensor | ||
max_seqlen_k: int | ||
cu_offset: torch.Tensor | ||
cu_head_offset: torch.Tensor | ||
head_lens: torch.Tensor | ||
bsz: int | ||
num_key_value_heads: int | ||
seen_tokens: int | ||
|
||
|
||
|
||
|
||
class AdaLlamaFlashAttention(LlamaAttention): | ||
|
||
""" | ||
Llama flash attention module for AdaKV. This module inherits from `LlamaAttention` as the weights of the module stays untouched. | ||
Utilizing the flash_attn_varlen_func from the flash_attn library to perform the attention operation with flattened KV Cache layout. | ||
""" | ||
|
||
# update the metadata for the flatten cache during the decoding phase | ||
def _update_metadata_while_compressing(self, head_lens, cu_seqlens_k,max_seqlen_k): | ||
self.metadata.head_lens = head_lens | ||
self.metadata.cu_seqlens_k = cu_seqlens_k | ||
self.metadata.max_seqlen_k = max_seqlen_k | ||
|
||
|
||
def _update_metadata(self, key_states): | ||
bs, head, seqlen, dim = key_states.shape | ||
|
||
self.metadata.max_seqlen_k += seqlen | ||
self.metadata.cu_seqlens_k += self.metadata.cu_offset * seqlen | ||
self.metadata.head_lens += seqlen | ||
self.metadata.seen_tokens += seqlen | ||
|
||
# init the metadata for the flattened cache during the prefilling phase | ||
def _init_metadata(self, key_states): | ||
|
||
""" | ||
this method is used to initialize metadata for the flatten cache, | ||
input key_states is a regular key states with shape [bsz, num_key_value_heads, seqlen, head_dim] | ||
""" | ||
|
||
bsz, num_key_value_heads, k_len, head_dim = key_states.shape | ||
k_lens = bsz * num_key_value_heads * [k_len] | ||
_device = key_states.device | ||
max_seqlen_k = max(k_lens) | ||
|
||
head_seqlens_k = torch.tensor(k_lens, dtype=torch.int32, device=_device) | ||
cu_seqlens = torch.cumsum(head_seqlens_k, dim=0, dtype=torch.int32) | ||
cu_seqlens_k = torch.cat( | ||
[torch.tensor([0], dtype=torch.int32, device=_device), cu_seqlens], dim=0) | ||
|
||
|
||
decoding_q_lens = bsz * num_key_value_heads * [1] | ||
decoding_head_seqlens_q = torch.tensor(decoding_q_lens, dtype=torch.int32,device=_device) | ||
decoding_cu_seqlens_q = torch.cumsum(decoding_head_seqlens_q, dim=0, dtype=torch.int32) | ||
decoding_cu_seqlens_q = torch.cat( | ||
[ torch.tensor([0], dtype=torch.int32, device=_device), decoding_cu_seqlens_q], dim=0) | ||
|
||
|
||
cu_offset = torch.arange(0, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device) | ||
cu_head_offset = torch.arange(1, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device) | ||
|
||
self.metadata = MetaData( | ||
decoding_cu_seqlens_q = decoding_cu_seqlens_q, | ||
cu_seqlens_k = cu_seqlens_k, | ||
max_seqlen_k = max_seqlen_k, | ||
cu_offset = cu_offset, | ||
cu_head_offset = cu_head_offset, | ||
head_lens = head_seqlens_k, | ||
bsz = bsz, | ||
num_key_value_heads = num_key_value_heads, | ||
seen_tokens= k_len | ||
) | ||
|
||
|
||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. | ||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. | ||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | ||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() | ||
|
||
# used to store the metadata for the flatten cache | ||
self.metadata = None | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.LongTensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Cache] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
if not isinstance(past_key_value, DynamicCacheSplitHeadFlatten): | ||
raise ValueError( | ||
"current implementation of `AdaKV` only supports `DynamicCacheSplitHeadFlatten` as the cache type." | ||
) | ||
output_attentions = False | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
|
||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
# Flash attention requires the input to have the shape | ||
# batch_size x seq_length x head_dim x hidden_dim | ||
# therefore we just need to keep the original shape | ||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
||
if position_embeddings is None: | ||
logger.warning_once( | ||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " | ||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " | ||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " | ||
"removed and `position_embeddings` will be mandatory." | ||
) | ||
cos, sin = self.rotary_emb(value_states, position_ids) | ||
else: | ||
cos, sin = position_embeddings | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
||
if past_key_value is not None: | ||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "attn": self} | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache | ||
# to be able to avoid many of these transpose/reshape/view. | ||
# query_states = query_states.transpose(1, 2) | ||
# key_states = key_states.transpose(1, 2) | ||
# value_states = value_states.transpose(1, 2) | ||
|
||
# dropout_rate = self.attention_dropout if self.training else 0.0 | ||
|
||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
# therefore the input hidden states gets silently casted in float32. Hence, we need | ||
# cast them back in the correct dtype just to be sure everything works as expected. | ||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms | ||
# in fp32. (LlamaRMSNorm handles it correctly) | ||
|
||
input_dtype = query_states.dtype | ||
if input_dtype == torch.float32: | ||
if torch.is_autocast_enabled(): | ||
target_dtype = torch.get_autocast_gpu_dtype() | ||
# Handle the case where the model is quantized | ||
elif hasattr(self.config, "_pre_quantization_dtype"): | ||
target_dtype = self.config._pre_quantization_dtype | ||
else: | ||
target_dtype = self.q_proj.weight.dtype | ||
|
||
logger.warning_once( | ||
f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
f" {target_dtype}." | ||
) | ||
|
||
query_states = query_states.to(target_dtype) | ||
key_states = key_states.to(target_dtype) | ||
value_states = value_states.to(target_dtype) | ||
|
||
|
||
|
||
query_states = query_states.view(bsz,-1, self.num_key_value_groups,q_len ,self.head_dim) | ||
|
||
query_states = query_states.transpose(2, 3) | ||
query_states = query_states.reshape(-1,self.num_key_value_groups,self.head_dim) | ||
|
||
|
||
key_states = key_states.view(-1,1,self.head_dim) | ||
value_states = value_states.view(-1,1,self.head_dim) | ||
|
||
|
||
if q_len == 1: | ||
# init metadata for flatten query states during prefilling phase | ||
cu_seqlens_q = self.metadata.decoding_cu_seqlens_q | ||
max_seqlen_q = 1 | ||
else: | ||
# init metadata for flatten query states during prefilling phase | ||
prefill_q_lens = bsz * self.num_heads//self.num_key_value_groups * [q_len] | ||
head_seqlens_q = torch.tensor(prefill_q_lens, dtype=torch.int32, device=query_states.device) | ||
cu_seqlens_q = torch.cumsum(head_seqlens_q, dim=0, dtype=torch.int32) | ||
cu_seqlens_q = torch.cat( | ||
[torch.tensor([0], dtype=torch.int32, device=query_states.device), cu_seqlens_q], dim=0) | ||
max_seqlen_q = q_len | ||
|
||
cu_seqlens_k = self.metadata.cu_seqlens_k | ||
max_seqlen_k = self.metadata.max_seqlen_k | ||
|
||
|
||
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q, | ||
cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True) | ||
# TODO: support batch size > 1 | ||
assert bsz == 1 | ||
|
||
attn_output = attn_output.reshape(bsz, self.num_key_value_heads, q_len, self.num_key_value_groups, self.head_dim) | ||
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) | ||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) 2024 YuanFeng | ||
# | ||
# This file is part of the YuanFeng project and is licensed under the MIT License. | ||
# SPDX-License-Identifier: MIT | ||
|
||
from transformers.cache_utils import Cache | ||
from typing import List, Optional, Tuple | ||
import torch | ||
|
||
|
||
class DynamicCacheSplitHeadFlatten(Cache): | ||
|
||
""" | ||
Flattened KV Cache Layout with a costomized update kernel | ||
""" | ||
|
||
def __init__(self) ->None: | ||
super().__init__() | ||
self.key_cache: List[List[torch.Tensor]] = [] | ||
self.value_cache: List[List[torch.Tensor]] = [] | ||
self._seen_tokens = 0 | ||
|
||
def __len__(self): | ||
return len(self.key_cache) | ||
|
||
def __iter__(self): | ||
for layer_idx in range(len(self)): | ||
yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) | ||
|
||
def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]: | ||
if layer_idx < len(self): | ||
return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) | ||
else: | ||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | ||
|
||
def update(self, key_states, value_states, layer_idx, cache_kwargs=None): | ||
# NOTE: k, v = [head_num]( bs, 1, seqlen, dim) | ||
# each layer is a flatten layout like: | ||
# [bsz * (head_0_len + head_1_len + ...+ head_n_len) , dim] | ||
attn = cache_kwargs.get("attn", None) | ||
if len(self.key_cache) <= layer_idx: | ||
# prefilling | ||
# flatten key and value | ||
bs, head_num, seqlen, head_dim = key_states.shape | ||
flatten_key_cachee = key_states.reshape(bs* head_num* seqlen, head_dim) | ||
flatten_value_cache = value_states.reshape(bs* head_num* seqlen, head_dim) | ||
self.key_cache.append(flatten_key_cachee) | ||
self.value_cache.append(flatten_value_cache) | ||
|
||
# init metadata for flatten key states | ||
attn._init_metadata(key_states) | ||
self._seen_tokens = attn.metadata.seen_tokens | ||
else: | ||
# decoding | ||
assert self.key_cache[layer_idx].dim() == 2 | ||
bs, head, seqlen, head_dim = key_states.shape | ||
|
||
# TODO: Currently only support bs == 1 | ||
assert bs == 1 , f"bs: {bs}" | ||
# NOTE: phase 2. we got [bs, head, seqlen, dim] as k, v input | ||
head_lens = attn.metadata.head_lens | ||
cu_seqlens_k = attn.metadata.cu_seqlens_k | ||
|
||
# TODO: wrap as a python interface | ||
from tiny_api_cuda import update_flatten_klenN_view | ||
new_key_cache = update_flatten_klenN_view(self.key_cache[layer_idx].view(-1, head_dim), key_states, head_lens, cu_seqlens_k) | ||
new_value_cache = update_flatten_klenN_view(self.value_cache[layer_idx].view(-1, head_dim), value_states, head_lens, cu_seqlens_k) | ||
|
||
self.key_cache[layer_idx] = new_key_cache | ||
self.value_cache[layer_idx] = new_value_cache | ||
|
||
# update metadata | ||
attn._update_metadata(key_states) | ||
self._seen_tokens = attn.metadata.seen_tokens | ||
|
||
return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
if len(self.key_cache) <= layer_idx: | ||
return 0 | ||
|
||
# TODO: return 1 to means has content for now | ||
return 1 | ||
|
||
def get_max_length(self) -> Optional[int]: | ||
return None | ||
|
||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | ||
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" | ||
legacy_cache = () | ||
for layer_idx in range(len(self)): | ||
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx],),) | ||
return legacy_cache | ||
|
||
@classmethod | ||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead": | ||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" | ||
cache = cls() | ||
print(f"from_legacy_cache past_key_values") | ||
if past_key_values is not None: | ||
for layer_idx in range(len(past_key_values)): | ||
key_states, value_states = past_key_values[layer_idx] | ||
cache.key_cache.append(key_states) | ||
cache.value_cache.append(value_states) | ||
|
||
# TODO seen tokens should be updated | ||
cache._seen_tokens = None | ||
return cache |
Binary file not shown.
Oops, something went wrong.