Skip to content

Commit

Permalink
init_adakv
Browse files Browse the repository at this point in the history
add_AdaKV_initial_version

AdaKV

Signed-off-by: FFY0 <ffyfengyuan@gmail.com>
  • Loading branch information
FFY0 committed Nov 30, 2024
1 parent 51f3877 commit 662f3f2
Show file tree
Hide file tree
Showing 15 changed files with 1,128 additions and 13 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,13 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Vscode
.vscode/

# bash
evaluation/modified_evaluate.py
evaluation/A100_evaluate.sh
evaluation/4090_evaluate_ada.sh
evaluation/logs/*
evaluation/results/*
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.ada_snapkv_press import AdaSnapKVPress

__all__ = [
"BasePress",
Expand All @@ -24,4 +25,5 @@
"TOVAPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
"AdaSnapKVPress"
]
247 changes: 247 additions & 0 deletions kvpress/ada_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@

# Copyright (c) 2024 YuanFeng
#
# This file is part of the YuanFeng project and is licensed under the MIT License.
# SPDX-License-Identifier: MIT

from attr import dataclass
from transformers.utils import is_flash_attn_greater_or_equal_2_10
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
)


import logging
from typing import Optional, Tuple

import torch
from transformers import Cache

from kvpress.ada_cache import DynamicCacheSplitHeadFlatten

logger = logging.getLogger(__name__)


from transformers.utils import (
logging,
is_flash_attn_2_available,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func


@dataclass
class MetaData:
decoding_cu_seqlens_q: torch.Tensor
cu_seqlens_k: torch.Tensor
max_seqlen_k: int
cu_offset: torch.Tensor
cu_head_offset: torch.Tensor
head_lens: torch.Tensor
bsz: int
num_key_value_heads: int
seen_tokens: int




class AdaLlamaFlashAttention(LlamaAttention):

"""
Llama flash attention module for AdaKV. This module inherits from `LlamaAttention` as the weights of the module stays untouched.
Utilizing the flash_attn_varlen_func from the flash_attn library to perform the attention operation with flattened KV Cache layout.
"""

# update the metadata for the flatten cache during the decoding phase
def _update_metadata_while_compressing(self, head_lens, cu_seqlens_k,max_seqlen_k):
self.metadata.head_lens = head_lens
self.metadata.cu_seqlens_k = cu_seqlens_k
self.metadata.max_seqlen_k = max_seqlen_k


def _update_metadata(self, key_states):
bs, head, seqlen, dim = key_states.shape

self.metadata.max_seqlen_k += seqlen
self.metadata.cu_seqlens_k += self.metadata.cu_offset * seqlen
self.metadata.head_lens += seqlen
self.metadata.seen_tokens += seqlen

# init the metadata for the flattened cache during the prefilling phase
def _init_metadata(self, key_states):

"""
this method is used to initialize metadata for the flatten cache,
input key_states is a regular key states with shape [bsz, num_key_value_heads, seqlen, head_dim]
"""

bsz, num_key_value_heads, k_len, head_dim = key_states.shape
k_lens = bsz * num_key_value_heads * [k_len]
_device = key_states.device
max_seqlen_k = max(k_lens)

head_seqlens_k = torch.tensor(k_lens, dtype=torch.int32, device=_device)
cu_seqlens = torch.cumsum(head_seqlens_k, dim=0, dtype=torch.int32)
cu_seqlens_k = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=_device), cu_seqlens], dim=0)


decoding_q_lens = bsz * num_key_value_heads * [1]
decoding_head_seqlens_q = torch.tensor(decoding_q_lens, dtype=torch.int32,device=_device)
decoding_cu_seqlens_q = torch.cumsum(decoding_head_seqlens_q, dim=0, dtype=torch.int32)
decoding_cu_seqlens_q = torch.cat(
[ torch.tensor([0], dtype=torch.int32, device=_device), decoding_cu_seqlens_q], dim=0)


cu_offset = torch.arange(0, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device)
cu_head_offset = torch.arange(1, bsz * num_key_value_heads + 1, dtype=torch.int32, device=_device)

self.metadata = MetaData(
decoding_cu_seqlens_q = decoding_cu_seqlens_q,
cu_seqlens_k = cu_seqlens_k,
max_seqlen_k = max_seqlen_k,
cu_offset = cu_offset,
cu_head_offset = cu_head_offset,
head_lens = head_seqlens_k,
bsz = bsz,
num_key_value_heads = num_key_value_heads,
seen_tokens= k_len
)



def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

# used to store the metadata for the flatten cache
self.metadata = None

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if not isinstance(past_key_value, DynamicCacheSplitHeadFlatten):
raise ValueError(
"current implementation of `AdaKV` only supports `DynamicCacheSplitHeadFlatten` as the cache type."
)
output_attentions = False

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "attn": self}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)

# dropout_rate = self.attention_dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)



query_states = query_states.view(bsz,-1, self.num_key_value_groups,q_len ,self.head_dim)

query_states = query_states.transpose(2, 3)
query_states = query_states.reshape(-1,self.num_key_value_groups,self.head_dim)


key_states = key_states.view(-1,1,self.head_dim)
value_states = value_states.view(-1,1,self.head_dim)


if q_len == 1:
# init metadata for flatten query states during prefilling phase
cu_seqlens_q = self.metadata.decoding_cu_seqlens_q
max_seqlen_q = 1
else:
# init metadata for flatten query states during prefilling phase
prefill_q_lens = bsz * self.num_heads//self.num_key_value_groups * [q_len]
head_seqlens_q = torch.tensor(prefill_q_lens, dtype=torch.int32, device=query_states.device)
cu_seqlens_q = torch.cumsum(head_seqlens_q, dim=0, dtype=torch.int32)
cu_seqlens_q = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=query_states.device), cu_seqlens_q], dim=0)
max_seqlen_q = q_len

cu_seqlens_k = self.metadata.cu_seqlens_k
max_seqlen_k = self.metadata.max_seqlen_k


attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q,
cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True)
# TODO: support batch size > 1
assert bsz == 1

attn_output = attn_output.reshape(bsz, self.num_key_value_heads, q_len, self.num_key_value_groups, self.head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

108 changes: 108 additions & 0 deletions kvpress/ada_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2024 YuanFeng
#
# This file is part of the YuanFeng project and is licensed under the MIT License.
# SPDX-License-Identifier: MIT

from transformers.cache_utils import Cache
from typing import List, Optional, Tuple
import torch


class DynamicCacheSplitHeadFlatten(Cache):

"""
Flattened KV Cache Layout with a costomized update kernel
"""

def __init__(self) ->None:
super().__init__()
self.key_cache: List[List[torch.Tensor]] = []
self.value_cache: List[List[torch.Tensor]] = []
self._seen_tokens = 0

def __len__(self):
return len(self.key_cache)

def __iter__(self):
for layer_idx in range(len(self)):
yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))

def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]:
if layer_idx < len(self):
return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
# NOTE: k, v = [head_num]( bs, 1, seqlen, dim)
# each layer is a flatten layout like:
# [bsz * (head_0_len + head_1_len + ...+ head_n_len) , dim]
attn = cache_kwargs.get("attn", None)
if len(self.key_cache) <= layer_idx:
# prefilling
# flatten key and value
bs, head_num, seqlen, head_dim = key_states.shape
flatten_key_cachee = key_states.reshape(bs* head_num* seqlen, head_dim)
flatten_value_cache = value_states.reshape(bs* head_num* seqlen, head_dim)
self.key_cache.append(flatten_key_cachee)
self.value_cache.append(flatten_value_cache)

# init metadata for flatten key states
attn._init_metadata(key_states)
self._seen_tokens = attn.metadata.seen_tokens
else:
# decoding
assert self.key_cache[layer_idx].dim() == 2
bs, head, seqlen, head_dim = key_states.shape

# TODO: Currently only support bs == 1
assert bs == 1 , f"bs: {bs}"
# NOTE: phase 2. we got [bs, head, seqlen, dim] as k, v input
head_lens = attn.metadata.head_lens
cu_seqlens_k = attn.metadata.cu_seqlens_k

# TODO: wrap as a python interface
from tiny_api_cuda import update_flatten_klenN_view
new_key_cache = update_flatten_klenN_view(self.key_cache[layer_idx].view(-1, head_dim), key_states, head_lens, cu_seqlens_k)
new_value_cache = update_flatten_klenN_view(self.value_cache[layer_idx].view(-1, head_dim), value_states, head_lens, cu_seqlens_k)

self.key_cache[layer_idx] = new_key_cache
self.value_cache[layer_idx] = new_value_cache

# update metadata
attn._update_metadata(key_states)
self._seen_tokens = attn.metadata.seen_tokens

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0

# TODO: return 1 to means has content for now
return 1

def get_max_length(self) -> Optional[int]:
return None

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx],),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
print(f"from_legacy_cache past_key_values")
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.key_cache.append(key_states)
cache.value_cache.append(value_states)

# TODO seen tokens should be updated
cache._seen_tokens = None
return cache
Binary file added kvpress/csrc.zip
Binary file not shown.
Loading

0 comments on commit 662f3f2

Please sign in to comment.