From aaa1b18a1e7d39cf084766632c1074005307bff3 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 9 Dec 2024 15:27:03 +0000 Subject: [PATCH 1/7] Add ComposedPress --- README.md | 3 ++- kvpress/__init__.py | 2 ++ kvpress/presses/composed_press.py | 18 ++++++++++++++++++ kvpress/presses/think_press.py | 12 +++--------- tests/presses/test_presses.py | 9 ++++++--- 5 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 kvpress/presses/composed_press.py diff --git a/README.md b/README.md index 4ebf50b..db85813 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ All current presses are training free. We provide the following presses associat - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453)) - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) -- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above. +- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) +- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 5f7fc8c..77ee5e6 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -13,9 +13,11 @@ from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.tova_press import TOVAPress from kvpress.presses.think_press import ThinKPress +from kvpress.presses.composed_press import ComposedPress __all__ = [ "BasePress", + "ComposedPress", "ExpectedAttentionPress", "KnormPress", "ObservedAttentionPress", diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py new file mode 100644 index 0000000..9250818 --- /dev/null +++ b/kvpress/presses/composed_press.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from kvpress.presses.base_press import BasePress + + +@dataclass +class ComposedPress(BasePress): + """ + Chain multiple presses together to create a composed press + """ + + presses: list[BasePress] + + def forward_hook(self, module, input, kwargs, output): + self.compression_ratio = 1 + for press in self.presses: + output = press.forward_hook(module, input, kwargs, output) + self.compression_ratio *= press.compression_ratio + return output diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 7aae541..498b351 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -3,7 +3,6 @@ from dataclasses import dataclass -from typing import Optional import torch from torch import nn @@ -18,7 +17,7 @@ class ThinKPress(BasePress): """ ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length. Hence it can be combined with any other press that compresses the sequence length, e.g. - press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5)) + press = ComposedPress([SnapKVPress(0.5), ThinKPress(0.5)]) Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same). To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/), @@ -28,7 +27,6 @@ class ThinKPress(BasePress): """ compression_ratio: float = 0.0 - inner_press: Optional[BasePress] = None window_size: int = 32 def compute_window_queries(self, module, hidden_states): @@ -58,14 +56,10 @@ def compute_window_queries(self, module, hidden_states): def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): """ - We first apply the inner press, then we prune the key dimensions. If other similar presses are requested, - we will create a dedicated DimensionBasePress class to avoid code duplication. + If other similar presses are requested, we might create a generic forward_hook for dimension pruning + to avoid code duplication. """ - # Apply the forward hook of the inner press - if self.inner_press is not None: - output = self.inner_press.forward_hook(module, input, kwargs, output) - # Don't compress if the compression ratio is 0 or this is not pre-filling cache = output[-1] hidden_states = kwargs["hidden_states"] diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 4c2d361..c1d6888 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -8,6 +8,7 @@ from kvpress import ( BasePress, + ComposedPress, ExpectedAttentionPress, KnormPress, ObservedAttentionPress, @@ -21,9 +22,11 @@ from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 -def test_think_inner_press(unit_test_model): # noqa: F811 - press = ThinKPress(compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5)) - with press(unit_test_model): +def test_composed_press(unit_test_model): # noqa: F811 + press1 = KnormPress(0.5) + press2 = ThinKPress(compression_ratio=0.5, window_size=2) + composed_press = ComposedPress([press1, press2]) + with composed_press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values From bc39643728a7d7bc6ce17c15e8d8da3ce959a81b Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 14:04:25 +0000 Subject: [PATCH 2/7] Fix tests --- kvpress/presses/think_press.py | 9 ++------- tests/presses/test_presses.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 5c7f5fe..62186bf 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -100,13 +100,8 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic @property def compression_ratio(self): - compression_ratio = self.key_channel_compression_ratio / 2 - if self.inner_press is not None and hasattr(self.inner_press, "compression_ratio"): - compression_ratio += self.inner_press.compression_ratio - return compression_ratio + return self.key_channel_compression_ratio / 2 @compression_ratio.setter def compression_ratio(self, value): - raise AttributeError( - "Cannot set the compression ratio of ThinKPress directly. " "Set key_channel_compression_ratio instead." - ) + raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}") diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 3d3030e..b35a955 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -22,8 +22,8 @@ def test_composed_press(unit_test_model): # noqa: F811 - press1 = KnormPress(key_channel_compression_ratio=0.5) - press2 = ThinKPress(compression_ratio=0.5, window_size=2) + press1 = KnormPress(compression_ratio=0.5) + press2 = ThinKPress(key_channel_compression_ratio=0.5, window_size=2) composed_press = ComposedPress([press1, press2]) with composed_press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] From 7fc2d1d965f59479d33f7dc5238bb422ba88772f Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 14:46:29 +0000 Subject: [PATCH 3/7] Add compress method --- README.md | 14 +++-- kvpress/presses/base_press.py | 88 +++++++++++++++++++++++++--- kvpress/presses/composed_press.py | 2 +- kvpress/presses/scorer_press.py | 97 ++++++++----------------------- kvpress/presses/think_press.py | 40 +++++-------- notebooks/new_press.ipynb | 71 +++++++++++----------- 6 files changed, 165 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index c7606e0..899497c 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ pip install flash-attn --no-build-isolation ## Usage -This repository provides a set of "presses" that compress the KV cache by pruning the least important key-value pairs in each attention head. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that controls the amount of pruning. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: +This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: @@ -41,7 +41,7 @@ answer = pipe(context, question=question, press=press)["answer"] In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example. > [!IMPORTANT] -> We focus on pruning during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. +> We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. > [!NOTE] > To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention). @@ -51,16 +51,22 @@ In the snippet above, the compression is only applied on the context tokens so t We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. ## Available presses -All current presses are training free. We provide the following presses associated with the following scores: + +All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance: - `RandomPress`: random score - `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430)) -- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) - `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469)) - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) - `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453)) - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) +- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) + +We also provide presses relying on a different logic: - `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)) + +Finally we provide two special presses: +- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental) - `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 6d30ff4..16fc160 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -9,7 +9,14 @@ import torch from torch import nn -from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM +from transformers import ( + LlamaForCausalLM, + MistralForCausalLM, + Phi3ForCausalLM, + PreTrainedModel, + Qwen2ForCausalLM, + QuantizedCache, +) logger = logging.getLogger(__name__) @@ -17,14 +24,51 @@ @dataclass class BasePress: """ - Base class for all pruning methods. - The `forward_hook` method is called after the forward pass of an attention layer. - Any pruning/updating method should be implemented in the derived class. + Base class for all KV cache compression methods. + The `forward_hook` method is called after the forward pass of an attention layer to update the cache. """ + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core logic of the compression method. + + Parameters + ---------- + module : + Transformer layer, see `hook` method for more details + hidden_states : + Hidden states of the layer + keys : + Keys of the cache (unquantized) + values : + Values of the cache (unquantized) + attentions : + Attention weights of the layer + kwargs : + Keyword arguments, as given to the forward pass of the layer + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Updated keys and values + """ + + raise NotImplementedError("compress method must be implemented in subclass") + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): - """Cache compression hook called after the forward pass of an attention layer. - The hook is applied only during the pre-filling phase if there is some pruning ratio. + """ + Default forward hook called after the forward pass of an attention layer. + The hook calls the compress method to compress the KV cache while ensuring: + - compression is only applied only during the pre-filling phase + - KV cache quantization is handled correctly Parameters ---------- @@ -40,8 +84,38 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic Returns ------- Modified output of the forward pass of the layer. + """ - raise NotImplementedError("forward_hook method must be implemented in the derived class") + # See e.g. LlamaDecoderLayer.forward for the output structure + if len(output) == 3: + _, attentions, cache = output + else: + attentions, cache = None, output[-1] + + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Don't compress during pre-filling + if cache.seen_tokens > q_len: + return output + + if isinstance(cache, QuantizedCache): + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) + values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + else: + keys = cache.key_cache[module.layer_idx] + values = cache.value_cache[module.layer_idx] + + keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs) + + if isinstance(cache, QuantizedCache): + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) + cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + else: + cache.key_cache[module.layer_idx] = keys + cache.value_cache[module.layer_idx] = values + + return output @contextmanager def __call__(self, model: PreTrainedModel) -> Generator: diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 9250818..a068f3f 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -14,5 +14,5 @@ def forward_hook(self, module, input, kwargs, output): self.compression_ratio = 1 for press in self.presses: output = press.forward_hook(module, input, kwargs, output) - self.compression_ratio *= press.compression_ratio + self.compression_ratio *= press.compression_ratio # type: ignore return output diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index a696cc7..4fa6d68 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -7,7 +7,6 @@ import torch from torch import nn -from transformers import QuantizedCache from kvpress.presses.base_press import BasePress @@ -18,10 +17,11 @@ class ScorerPress(BasePress): """ Default press method for using a score method. - The `forward_hook` method is called after the forward pass of an attention layer. - and updates the cache with the pruned KV pairs. + Any ScorerPress subclass must implement the `score` method that computes a tensor of scores for each key-value pair + The KV pairs with the lowest scores will be pruned in the `compress` method. + The cache is uniformly pruned across all heads and layers using the compression_ratio parameter. """ - + compression_ratio: float = 0.0 def __post_init__(self): @@ -36,87 +36,36 @@ def score( attentions: torch.Tensor, kwargs, ) -> torch.Tensor: - """Compute the scores for each KV pair in the layer. - - Parameters - ---------- - module : - Transformer layer, see `hook` method for more details. - hidden_states : - Hidden states of the layer. - keys : - Keys of the cache. Note keys are after RoPE. - values : - Values of the cache. - attentions : - Attention weights of the layer. - kwargs : - Keyword arguments, as given to the forward pass of the layer. - - Returns - ------- - Scores for each KV pair in the layer, shape keys.shape[:-1]. """ - raise NotImplementedError - - def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + Compute a tensor of scores with shape (bsz, num_key_value_heads, q_len) + The KV pairs with lowest scores will be pruned in the `compress` method. """ - Default cache compression hook called after the forward pass of an attention layer. - The hook is applied only during the pre-filling phase if there is some pruning ratio. - This implementation allows to remove a constant number of KV pairs. + raise NotImplementedError - Parameters - ---------- - module : - Transformer attention layer. - input : - Input to the hook. This is the input to the forward pass of the layer. - kwargs : - Keyword arguments, as given to the forward pass of the layer. - output : - Output of the hook. This is the original output of the forward pass of the layer. + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: - Returns - ------- - Modified output of the forward pass of the layer. + if self.compression_ratio == 0: + return keys, values - """ - # See e.g. LlamaDecoderLayer.forward for the output structure - if len(output) == 3: - _, attentions, cache = output - else: - attentions, cache = None, output[-1] + # Compute scores + scores = self.score(module, hidden_states, keys, values, attentions, kwargs) - hidden_states = kwargs["hidden_states"] + # Get indices of KV pairs with the lowest scores q_len = hidden_states.shape[1] - - # Don't compress if the compression ratio is 0 or this is not pre-filling - if (self.compression_ratio == 0) or (cache.seen_tokens > q_len): - return output - - if isinstance(cache, QuantizedCache): - keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) - values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) - else: - keys = cache.key_cache[module.layer_idx] - values = cache.value_cache[module.layer_idx] - - with torch.no_grad(): - scores = self.score(module, hidden_states, keys, values, attentions, kwargs) - - # Prune KV pairs with the lowest scores n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) - # Update cache + # Prune keys and values keys = keys.gather(2, indices).contiguous() values = values.gather(2, indices).contiguous() - if isinstance(cache, QuantizedCache): - cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) - cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) - else: - cache.key_cache[module.layer_idx] = keys - cache.value_cache[module.layer_idx] = values - return output + return keys, values diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 62186bf..e8be0dc 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -53,32 +53,26 @@ def compute_window_queries(self, module, hidden_states): return query_states - def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - If other similar presses are requested, we might create a generic forward_hook for dimension pruning + If other similar presses are requested, we might create a generic compress method for dimension pruning to avoid code duplication. """ - # Don't compress if the compression ratio is 0 or this is not pre-filling - cache = output[-1] - hidden_states = kwargs["hidden_states"] - q_len = hidden_states.shape[1] - assert q_len > self.window_size, "Query length should be greater than the window size" + if self.key_channel_compression_ratio == 0: + return keys, values - if (self.key_channel_compression_ratio == 0) or (cache.seen_tokens > q_len): - return output - - # Get keys - if isinstance(cache, QuantizedCache): - keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) - else: - keys = cache.key_cache[module.layer_idx] + # Compute scores per dimension bsz, num_key_value_heads, q_len, head_dim = keys.shape - - # ThinK specific code queries = self.compute_window_queries(module, kwargs["hidden_states"]) - - # Compute scores per dimension queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2) keys_norm = torch.pow(keys, 2).mean(dim=2) @@ -90,13 +84,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1) keys = keys.scatter_(-1, indices, 0) - # Update cache - if isinstance(cache, QuantizedCache): - cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) - else: - cache.key_cache[module.layer_idx] = keys - - return output + return keys, values @property def compression_ratio(self): diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 9d2b0fe..47bba3c 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -10,10 +10,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", @@ -25,10 +25,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Load pipeline\n", "\n", @@ -39,10 +39,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Load data\n", "\n", @@ -62,10 +62,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "A press registers a forward hook to each attention layer during the pre-filling phase:\n", - "1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method\n", - "2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter" + "A press registers a forward hook to each attention layer during the pre-filling phase. Immediately after the forward pass, the hook is called, and it compresses the KV cache." ] }, { @@ -127,7 +124,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", "The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n", "\n", "The arguments of the `score` method are obtained from the forward hook:\n", @@ -140,10 +136,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "class MyKnormPress(ScorerPress):\n", " def score(\n", @@ -181,47 +177,42 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 2.2 Updating the `forward_hook` method " + "### 2.2 Updating the `compress` method " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The `forward_hook` method defined in the `BasePress` class roughly works as follows:\n", - "\n", - "1. Get the scores\n", - "2. Update the key-value pairs based on the scores and the `compression_ratio`\n", + "The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n", "\n", - "While we generally do not recommend to modify this method, the following example will show how it works. We will re-implement the `StreamingLLMPress` without using the `compression_ratio` parameter. In `StreamingLLM`, only the first `n_first` and last `n_last` key-value pairs are kept." + "The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way." ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "@dataclass\n", "class MyStreamingLLMPress(BasePress):\n", " n_first: int = 1\n", " n_last: int = 8\n", "\n", - " def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):\n", - "\n", - " # Get the cache (transformers.cache_utils.DynamicCache object)\n", - " cache = output[-1]\n", - " i = module.layer_idx\n", - " keys, values = cache.key_cache[i], cache.value_cache[i]\n", + " def compress(\n", + " self,\n", + " module: nn.Module,\n", + " hidden_states: torch.Tensor,\n", + " keys: torch.Tensor,\n", + " values: torch.Tensor,\n", + " attentions: torch.Tensor,\n", + " kwargs: dict,\n", + " ) -> tuple[torch.Tensor, torch.Tensor]:\n", "\n", - " # Update the cache to only keep the first and last tokens\n", " mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)\n", " mask[self.n_first : -self.n_last] = False\n", - " cache.key_cache[i] = keys[:, :, mask, :]\n", - " cache.value_cache[i] = values[:, :, mask, :]\n", - "\n", - " # Return the updated output (output[-1] has been modified in-place)\n", - " return output\n", + " return keys[:, :, mask, :], values[:, :, mask, :]\n", "\n", "\n", "for n_last in [2, 4, 8]:\n", @@ -231,6 +222,13 @@ " print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -242,7 +240,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py). We recommend not to update the `forward_hook` or `__call__` method unless necessary." + "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n", + "- register it in the `__init__.py` file of repository\n", + "- add a test [test_presses.py](tests/presses/test_presses.py)\n", + "- update the README" ] } ], @@ -262,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.3" } }, "nbformat": 4, From 15d075d143e718ed63d1592df7d9c4f6e2d7d9b0 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 14:52:02 +0000 Subject: [PATCH 4/7] Fix style --- kvpress/presses/scorer_press.py | 2 +- kvpress/presses/think_press.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 4fa6d68..ea97eab 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -21,7 +21,7 @@ class ScorerPress(BasePress): The KV pairs with the lowest scores will be pruned in the `compress` method. The cache is uniformly pruned across all heads and layers using the compression_ratio parameter. """ - + compression_ratio: float = 0.0 def __post_init__(self): diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index e8be0dc..2f882e5 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -6,7 +6,6 @@ import torch from torch import nn -from transformers.cache_utils import QuantizedCache from transformers.models.llama.modeling_llama import rotate_half from kvpress.presses.base_press import BasePress From 47b61af3607d743747259a916930bdfaf3dc0a39 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 14:55:42 +0000 Subject: [PATCH 5/7] Test compression ratio --- tests/presses/test_presses.py | 2 ++ ...mpression_wrapper.py => test_per_layer_compression_press.py} | 0 2 files changed, 2 insertions(+) rename tests/{test_per_layer_compression_wrapper.py => test_per_layer_compression_press.py} (100%) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index b35a955..07f4008 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -42,6 +42,8 @@ def test_presses_run(unit_test_model): # noqa: F811 with press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values + # Check that the press has a compression_ratio attribute + assert hasattr(press, "compression_ratio") def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811 diff --git a/tests/test_per_layer_compression_wrapper.py b/tests/test_per_layer_compression_press.py similarity index 100% rename from tests/test_per_layer_compression_wrapper.py rename to tests/test_per_layer_compression_press.py From f0333136aec2dc714f9287c31b27613883a574ed Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 16:13:06 +0000 Subject: [PATCH 6/7] Adress PR feedback --- kvpress/presses/base_press.py | 2 +- kvpress/presses/composed_press.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 16fc160..129d176 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -95,7 +95,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic hidden_states = kwargs["hidden_states"] q_len = hidden_states.shape[1] - # Don't compress during pre-filling + # Don't compress after pre-filling if cache.seen_tokens > q_len: return output diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index a068f3f..2872464 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -10,6 +10,9 @@ class ComposedPress(BasePress): presses: list[BasePress] + def __post_init__(self): + self.compression_ratio = 1.0 + def forward_hook(self, module, input, kwargs, output): self.compression_ratio = 1 for press in self.presses: From f5c98c9247d9c07395be61e0eab6e29fe8bb9091 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 16:19:21 +0000 Subject: [PATCH 7/7] Update init --- kvpress/presses/composed_press.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 2872464..3ebc5c1 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -11,10 +11,10 @@ class ComposedPress(BasePress): presses: list[BasePress] def __post_init__(self): - self.compression_ratio = 1.0 + self.compression_ratio = None def forward_hook(self, module, input, kwargs, output): - self.compression_ratio = 1 + self.compression_ratio = 1.0 for press in self.presses: output = press.forward_hook(module, input, kwargs, output) self.compression_ratio *= press.compression_ratio # type: ignore