Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimLayerKVPress #28

Merged
merged 12 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ All current presses are training free. We provide the following presses associat
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio.
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.

Expand Down
4 changes: 3 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

__all__ = [
"BasePress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
Expand Down
89 changes: 89 additions & 0 deletions kvpress/presses/simlayerkv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass

import torch
from torch import nn
from transformers import QuantizedCache

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress


@dataclass
class SimLayerKVPress(BasePress):
"""
SimLayerKV (https://arxiv.org/abs/2410.13846) uses a layer-wise approach to compression:
- layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept)
- other layers use the full KV cache

To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens
over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy.

Official implementation: https://github.com/sail-sg/SimLayerKV. We use n_last=1 to match SKLV-decode
"""

lazy_threshold: float
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
n_last: int = 1
n_recent: int = 1024
n_initial: int = 4

def is_lazy(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
) -> bool:
"""
Compute the average attention weights of the last tokens over the initial and recent tokens.
A slight difference with the original implementation is that we
maxjeblick marked this conversation as resolved.
Show resolved Hide resolved
"""

attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last)
attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size
score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum()
return score.item() > self.lazy_threshold

@property
def compression_ratio(self):
if hasattr(self, "compression_ratios"):
maxjeblick marked this conversation as resolved.
Show resolved Hide resolved
return sum(self.compression_ratios) / len(self.compression_ratios)
else:
raise ValueError("Forward pass must be run to compute the compression ratio")

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):

cache = output[-1]
hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]

# Don't compress if this is not pre-filling or if there are not enough tokens
if (cache.seen_tokens > q_len) or (cache.seen_tokens < self.n_initial + self.n_recent):
return output

# Re-initialize the compression_ratios list
if module.layer_idx == 0:
self.compression_ratios = []
assert hidden_states.shape[1] > self.n_last, "Query length should be greater than the window size"

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

if self.is_lazy(module, hidden_states, keys):
# If layer is lazy, only keep the initial and recent KV pairs
keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2)
values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2)
self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len)
else:
self.compression_ratios.append(0)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output
19 changes: 11 additions & 8 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class SnapKVPress(BasePress):
window_size: int = 64
kernel_size: int = 5

def compute_window_attention(self, module, hidden_states, keys):
@staticmethod
def compute_window_attention(
module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int
) -> torch.Tensor:
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""
Expand All @@ -34,28 +37,28 @@ def compute_window_attention(self, module, hidden_states, keys):

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
query_states = module.q_proj(hidden_states[:, -window_size:])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1)
attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = attn_weights[..., : -self.window_size]
attn_weights = attn_weights[..., :-window_size]

return attn_weights

Expand All @@ -76,7 +79,7 @@ def score(
if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
else:
attn_weights = self.compute_window_attention(module, hidden_states, keys)
attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size)

scores = attn_weights.mean(dim=-2)
scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
Expand Down
6 changes: 3 additions & 3 deletions kvpress/presses/tova_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from torch import nn
import torch.nn.functional as F

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress


@dataclass
class TOVAPress(SnapKVPress):
class TOVAPress(BasePress):
"""
TOVA (https://arxiv.org/abs/2401.06104) use the attention of the last token averaged across heads
to estimate the importance of the previous KV pairs. This press was reviewed by Michael Hassid,
Expand All @@ -21,7 +22,6 @@ class TOVAPress(SnapKVPress):
"""

compression_ratio: float = 0.0
window_size: int = 1 # re-use the attention weight computation from SnapKVPress for last token

def score(
self,
Expand All @@ -36,7 +36,7 @@ def score(
if attentions is not None:
attn_weights = attentions[..., -1:, :-1]
else:
attn_weights = self.compute_window_attention(module, hidden_states, keys)
attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, 1)

# Average across heads and repeat num_key_value_head times
scores = attn_weights.mean(1)
Expand Down
14 changes: 12 additions & 2 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ObservedAttentionPress,
RandomPress,
SnapKVPress,
SimLayerKVPress,
StreamingLLMPress,
TOVAPress,
ThinKPress,
Expand All @@ -29,9 +30,18 @@ def test_think_inner_press(unit_test_model): # noqa: F811


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]:
for cls in [
KnormPress,
ExpectedAttentionPress,
RandomPress,
SimLayerKVPress,
StreamingLLMPress,
SnapKVPress,
TOVAPress,
ThinKPress,
]:
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
press = cls(compression_ratio=compression_ratio)
press = cls(compression_ratio)
maxjeblick marked this conversation as resolved.
Show resolved Hide resolved
if cls in [SnapKVPress, ThinKPress]:
press.window_size = 2
with press(unit_test_model):
Expand Down
Loading