Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ComposedPress and compress method #29

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ComposedPress
SimJeg committed Dec 9, 2024
commit aaa1b18a1e7d39cf084766632c1074005307bff3
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -60,7 +60,8 @@ All current presses are training free. We provide the following presses associat
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,11 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress

__all__ = [
"BasePress",
"ComposedPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
18 changes: 18 additions & 0 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from kvpress.presses.base_press import BasePress


@dataclass
class ComposedPress(BasePress):
"""
Chain multiple presses together to create a composed press
"""

presses: list[BasePress]

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
for press in self.presses:
output = press.forward_hook(module, input, kwargs, output)
self.compression_ratio *= press.compression_ratio
return output
12 changes: 3 additions & 9 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@


from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn
@@ -18,7 +17,7 @@ class ThinKPress(BasePress):
"""
ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length.
Hence it can be combined with any other press that compresses the sequence length, e.g.
press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5))
press = ComposedPress([SnapKVPress(0.5), ThinKPress(0.5)])

Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
@@ -28,7 +27,6 @@ class ThinKPress(BasePress):
"""

compression_ratio: float = 0.0
inner_press: Optional[BasePress] = None
window_size: int = 32

def compute_window_queries(self, module, hidden_states):
@@ -58,14 +56,10 @@ def compute_window_queries(self, module, hidden_states):

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
We first apply the inner press, then we prune the key dimensions. If other similar presses are requested,
we will create a dedicated DimensionBasePress class to avoid code duplication.
If other similar presses are requested, we might create a generic forward_hook for dimension pruning
to avoid code duplication.
"""

# Apply the forward hook of the inner press
if self.inner_press is not None:
output = self.inner_press.forward_hook(module, input, kwargs, output)

# Don't compress if the compression ratio is 0 or this is not pre-filling
cache = output[-1]
hidden_states = kwargs["hidden_states"]
9 changes: 6 additions & 3 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

from kvpress import (
BasePress,
ComposedPress,
ExpectedAttentionPress,
KnormPress,
ObservedAttentionPress,
@@ -21,9 +22,11 @@
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


def test_think_inner_press(unit_test_model): # noqa: F811
press = ThinKPress(compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5))
with press(unit_test_model):
def test_composed_press(unit_test_model): # noqa: F811
press1 = KnormPress(0.5)
press2 = ThinKPress(compression_ratio=0.5, window_size=2)
composed_press = ComposedPress([press1, press2])
with composed_press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values