Skip to content

Commit

Permalink
Add ComposedPress and compress method (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimJeg authored Dec 10, 2024
1 parent 9ecd556 commit 715f8a7
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 167 deletions.
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install flash-attn --no-build-isolation

## Usage

This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that controls the amount of pruning. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:



Expand All @@ -41,7 +41,7 @@ answer = pipe(context, question=question, press=press)["answer"]
In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example.

> [!IMPORTANT]
> We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.
> [!NOTE]
> To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention).
Expand All @@ -51,16 +51,23 @@ In the snippet above, the compression is only applied on the context tokens so t
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.

## Available presses
All current presses are training free. We provide the following presses associated with the following scores:

All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:

- `RandomPress`: random score
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

We also provide presses relying on a different logic:
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))

Finally we provide two special presses:
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress

__all__ = [
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
Expand Down
88 changes: 81 additions & 7 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,66 @@

import torch
from torch import nn
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM
from transformers import (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
)

logger = logging.getLogger(__name__)


@dataclass
class BasePress:
"""
Base class for all pruning methods.
The `forward_hook` method is called after the forward pass of an attention layer.
Any pruning/updating method should be implemented in the derived class.
Base class for all KV cache compression methods.
The `forward_hook` method is called after the forward pass of an attention layer to update the cache.
"""

def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core logic of the compression method.
Parameters
----------
module :
Transformer layer, see `hook` method for more details
hidden_states :
Hidden states of the layer
keys :
Keys of the cache (unquantized)
values :
Values of the cache (unquantized)
attentions :
Attention weights of the layer
kwargs :
Keyword arguments, as given to the forward pass of the layer
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Updated keys and values
"""

raise NotImplementedError("compress method must be implemented in subclass")

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""Cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
"""
Default forward hook called after the forward pass of an attention layer.
The hook calls the compress method to compress the KV cache while ensuring:
- compression is only applied only during the pre-filling phase
- KV cache quantization is handled correctly
Parameters
----------
Expand All @@ -40,8 +84,38 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Returns
-------
Modified output of the forward pass of the layer.
"""
raise NotImplementedError("forward_hook method must be implemented in the derived class")
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if cache.seen_tokens > q_len:
return output

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output

@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
Expand Down
21 changes: 21 additions & 0 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from kvpress.presses.base_press import BasePress


@dataclass
class ComposedPress(BasePress):
"""
Chain multiple presses together to create a composed press
"""

presses: list[BasePress]

def __post_init__(self):
self.compression_ratio = None

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1.0
for press in self.presses:
output = press.forward_hook(module, input, kwargs, output)
self.compression_ratio *= press.compression_ratio # type: ignore
return output
99 changes: 24 additions & 75 deletions kvpress/presses/scorer_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
from torch import nn
from transformers import QuantizedCache

from kvpress.presses.base_press import BasePress

Expand All @@ -18,8 +17,9 @@
class ScorerPress(BasePress):
"""
Default press method for using a score method.
The `forward_hook` method is called after the forward pass of an attention layer.
and updates the cache with the pruned KV pairs.
Any ScorerPress subclass must implement the `score` method that computes a tensor of scores for each key-value pair
The KV pairs with the lowest scores will be pruned in the `compress` method.
The cache is uniformly pruned across all heads and layers using the compression_ratio parameter.
"""

compression_ratio: float = 0.0
Expand All @@ -36,87 +36,36 @@ def score(
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""Compute the scores for each KV pair in the layer.
Parameters
----------
module :
Transformer layer, see `hook` method for more details.
hidden_states :
Hidden states of the layer.
keys :
Keys of the cache. Note keys are after RoPE.
values :
Values of the cache.
attentions :
Attention weights of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.
Returns
-------
Scores for each KV pair in the layer, shape keys.shape[:-1].
"""
raise NotImplementedError

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
Default cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
This implementation allows to remove a constant number of KV pairs.
Parameters
----------
module :
Transformer attention layer.
input :
Input to the hook. This is the input to the forward pass of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.
output :
Output of the hook. This is the original output of the forward pass of the layer.
Returns
-------
Modified output of the forward pass of the layer.
Compute a tensor of scores with shape (bsz, num_key_value_heads, q_len)
The KV pairs with lowest scores will be pruned in the `compress` method.
"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]
raise NotImplementedError

# Don't compress if the compression ratio is 0 or this is not pre-filling
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]
if self.compression_ratio == 0:
return keys, values

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Prune KV pairs with the lowest scores
# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
# Prune keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output

return keys, values
Loading

0 comments on commit 715f8a7

Please sign in to comment.