Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ComposedPress and compress method #29

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install flash-attn --no-build-isolation

## Usage

This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that controls the amount of pruning. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:



Expand All @@ -41,7 +41,7 @@ answer = pipe(context, question=question, press=press)["answer"]
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example.

> [!IMPORTANT]
> We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.

> [!NOTE]
> To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention).
Expand All @@ -51,16 +51,23 @@ In the snippet above, the compression is only applied on the context tokens so t
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.

## Available presses
All current presses are training free. We provide the following presses associated with the following scores:

All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:

- `RandomPress`: random score
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

We also provide presses relying on a different logic:
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))

Finally we provide two special presses:
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress

__all__ = [
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
Expand Down
88 changes: 81 additions & 7 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,66 @@

import torch
from torch import nn
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM
from transformers import (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
)

logger = logging.getLogger(__name__)


@dataclass
class BasePress:
"""
Base class for all pruning methods.
The `forward_hook` method is called after the forward pass of an attention layer.
Any pruning/updating method should be implemented in the derived class.
Base class for all KV cache compression methods.
The `forward_hook` method is called after the forward pass of an attention layer to update the cache.
"""

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core logic of the compression method.

Parameters
----------
module :
Transformer layer, see `hook` method for more details
hidden_states :
Hidden states of the layer
keys :
Keys of the cache (unquantized)
values :
Values of the cache (unquantized)
attentions :
Attention weights of the layer
kwargs :
Keyword arguments, as given to the forward pass of the layer

Returns
-------
tuple[torch.Tensor, torch.Tensor]
Updated keys and values
"""

raise NotImplementedError("compress method must be implemented in subclass")

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""Cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
"""
Default forward hook called after the forward pass of an attention layer.
The hook calls the compress method to compress the KV cache while ensuring:
- compression is only applied only during the pre-filling phase
- KV cache quantization is handled correctly

Parameters
----------
Expand All @@ -40,8 +84,38 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Returns
-------
Modified output of the forward pass of the layer.

"""
raise NotImplementedError("forward_hook method must be implemented in the derived class")
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]

# Don't compress during pre-filling
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
if cache.seen_tokens > q_len:
return output

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output

@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
Expand Down
18 changes: 18 additions & 0 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from kvpress.presses.base_press import BasePress


@dataclass
class ComposedPress(BasePress):
"""
Chain multiple presses together to create a composed press
"""

presses: list[BasePress]

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1
SimJeg marked this conversation as resolved.
Show resolved Hide resolved
for press in self.presses:
output = press.forward_hook(module, input, kwargs, output)
self.compression_ratio *= press.compression_ratio # type: ignore
return output
99 changes: 24 additions & 75 deletions kvpress/presses/scorer_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
from torch import nn
from transformers import QuantizedCache

from kvpress.presses.base_press import BasePress

Expand All @@ -18,8 +17,9 @@
class ScorerPress(BasePress):
"""
Default press method for using a score method.
The `forward_hook` method is called after the forward pass of an attention layer.
and updates the cache with the pruned KV pairs.
Any ScorerPress subclass must implement the `score` method that computes a tensor of scores for each key-value pair
The KV pairs with the lowest scores will be pruned in the `compress` method.
The cache is uniformly pruned across all heads and layers using the compression_ratio parameter.
"""

compression_ratio: float = 0.0
Expand All @@ -36,87 +36,36 @@ def score(
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""Compute the scores for each KV pair in the layer.

Parameters
----------
module :
Transformer layer, see `hook` method for more details.
hidden_states :
Hidden states of the layer.
keys :
Keys of the cache. Note keys are after RoPE.
values :
Values of the cache.
attentions :
Attention weights of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.

Returns
-------
Scores for each KV pair in the layer, shape keys.shape[:-1].
"""
raise NotImplementedError

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
Default cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
This implementation allows to remove a constant number of KV pairs.

Parameters
----------
module :
Transformer attention layer.
input :
Input to the hook. This is the input to the forward pass of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.
output :
Output of the hook. This is the original output of the forward pass of the layer.

Returns
-------
Modified output of the forward pass of the layer.

Compute a tensor of scores with shape (bsz, num_key_value_heads, q_len)
The KV pairs with lowest scores will be pruned in the `compress` method.
"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]
raise NotImplementedError

# Don't compress if the compression ratio is 0 or this is not pre-filling
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]
if self.compression_ratio == 0:
return keys, values

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Prune KV pairs with the lowest scores
# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
# Prune keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output

return keys, values
Loading
Loading