Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimLayerKVPress #28

Merged
merged 12 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ All current presses are training free. Several of them inherit from `ScorerPress
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio.
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
Expand All @@ -23,6 +24,7 @@
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
Expand Down
99 changes: 99 additions & 0 deletions kvpress/presses/simlayerkv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress

logger = logging.getLogger(__name__)


@dataclass
class SimLayerKVPress(BasePress):
"""
SimLayerKV (https://arxiv.org/abs/2410.13846) uses a layer-wise approach to compression:
- layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept)
- other layers use the full KV cache

To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens
over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy.

Recommended values for lazy_threshold from the official repository:
- llama3: 0.9
- llama2: 0.65
- mistral: 0.8
- qwen: 0.85
By default, lazy_threshold is set to 1.0 (no compression)
(Source: https://github.com/sail-sg/SimLayerKV/blob/main/LongBench/pred.py#L167)
"""

lazy_threshold: float = 1.0
n_last: int = 1 # n_last=1 to match SKLV-decode
n_recent: int = 1024
n_initial: int = 4

def __post_init__(self):
assert 0.0 <= self.lazy_threshold <= 1.0, "lazy_threshold should be in [0, 1]"
if self.lazy_threshold == 1.0:
self.compression_ratios = [0.0]
else:
self.compression_ratios = []

def is_lazy(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
) -> bool:
"""
Compute the attention weights of the last tokens over the initial and recent tokens.
The layer is considered lazy if the sum of these attention weights is above the lazy_threshold.
"""

attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last)
attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size
score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum()
return score.item() > self.lazy_threshold

@property
def compression_ratio(self):
if len(self.compression_ratios) > 0:
return sum(self.compression_ratios) / len(self.compression_ratios)
else:
raise ValueError("Forward pass must be run to compute the compression ratio")

@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:

if self.lazy_threshold == 1.0:
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
return keys, values

# Sanity check and compression_ratios initialization
q_len = hidden_states.shape[1]
if module.layer_idx == 0:
self.compression_ratios = []
min_length = self.n_initial + self.n_recent + self.n_last
assert q_len >= min_length, f"Query length should be greater than the window size ({min_length})"

if self.is_lazy(module, hidden_states, keys):
# If layer is lazy, only keep the initial and recent KV pairs
keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2)
values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2)
self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len)
else:
self.compression_ratios.append(0)

return keys, values
19 changes: 11 additions & 8 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class SnapKVPress(ScorerPress):
window_size: int = 64
kernel_size: int = 5

def compute_window_attention(self, module, hidden_states, keys):
@staticmethod
def compute_window_attention(
module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int
) -> torch.Tensor:
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""
Expand All @@ -34,28 +37,28 @@ def compute_window_attention(self, module, hidden_states, keys):

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
query_states = module.q_proj(hidden_states[:, -window_size:])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1)
attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = attn_weights[..., : -self.window_size]
attn_weights = attn_weights[..., :-window_size]

return attn_weights

Expand All @@ -76,7 +79,7 @@ def score(
if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
else:
attn_weights = self.compute_window_attention(module, hidden_states, keys)
attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size)

scores = attn_weights.mean(dim=-2)
scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
Expand Down
8 changes: 4 additions & 4 deletions kvpress/presses/tova_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import torch.nn.functional as F
from torch import nn

from kvpress import SnapKVPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.snapkv_press import SnapKVPress


@dataclass
class TOVAPress(SnapKVPress):
class TOVAPress(ScorerPress):
"""
TOVA (https://arxiv.org/abs/2401.06104) use the attention of the last token averaged across heads
to estimate the importance of the previous KV pairs. This press was reviewed by Michael Hassid,
Expand All @@ -22,7 +23,6 @@ class TOVAPress(SnapKVPress):
"""

compression_ratio: float = 0.0
window_size: int = 1 # re-use the attention weight computation from SnapKVPress for last token

def score(
self,
Expand All @@ -37,7 +37,7 @@ def score(
if attentions is not None:
attn_weights = attentions[..., -1:, :-1]
else:
attn_weights = self.compute_window_attention(module, hidden_states, keys)
attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, 1)

# Average across heads and repeat num_key_value_head times
scores = attn_weights.mean(1)
Expand Down
26 changes: 21 additions & 5 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ObservedAttentionPress,
RandomPress,
SnapKVPress,
SimLayerKVPress,
StreamingLLMPress,
TOVAPress,
)
Expand All @@ -31,14 +32,29 @@ def test_composed_press(unit_test_model): # noqa: F811


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]:
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
for cls in [
KnormPress,
ExpectedAttentionPress,
RandomPress,
StreamingLLMPress,
SimLayerKVPress,
SnapKVPress,
TOVAPress,
ThinKPress,
]:
for value in [0.2, 0.4, 0.6, 0.8]:

# Load the press
if cls == ThinKPress:
press = cls(key_channel_compression_ratio=compression_ratio, window_size=2)
press = cls(key_channel_compression_ratio=value, window_size=2)
elif cls == SimLayerKVPress:
press = cls(lazy_threshold=value)
else:
press = cls(compression_ratio=compression_ratio)
if cls in [SnapKVPress]:
press = cls(compression_ratio=value)
if cls == SnapKVPress:
press.window_size = 2

# Run the press
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
Expand Down
Loading