Skip to content

Commit

Permalink
Refactor press
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjeblick authored Dec 10, 2024
1 parent ac2445e commit 9ecd556
Show file tree
Hide file tree
Showing 22 changed files with 286 additions and 2,186 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ However, the `generate` method does not allow to exclude the question from the c
### How to create a new press ?
</summary>

All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `BasePress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.
All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.

Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py).

Expand All @@ -181,9 +181,9 @@ Before opening a pull request with a new press, make sure to register it in the

We provide an experimental feature, which only works with flash attention:
```python
from kvpress import apply_per_layer_compression
from kvpress import PerLayerCompressionPress
# compression_ratios should have the same length as the number of layers
press = apply_per_layer_compression(press, compression_ratios=[...])
press = PerLayerCompressionPress(press, compression_ratios=[...])
```

Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
Expand Down
9 changes: 6 additions & 3 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
# SPDX-License-Identifier: Apache-2.0


from kvpress.per_layer_compression_wrapper import apply_per_layer_compression
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.base_press import BasePress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress

__all__ = [
"BasePress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
Expand All @@ -25,5 +26,7 @@
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
"PerLayerCompressionPress",
]

from kvpress.presses.tova_press import TOVAPress
48 changes: 0 additions & 48 deletions kvpress/per_layer_compression_wrapper.py

This file was deleted.

2 changes: 1 addition & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

Expand Down
103 changes: 10 additions & 93 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,27 @@

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Generator

import torch
from torch import nn
from transformers import (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
Qwen2ForCausalLM,
QuantizedCache,
)
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM

logger = logging.getLogger(__name__)


@dataclass
class BasePress:
"""Base class for pruning methods.
Each pruning method should implement a `score` method that computes the scores for each KV pair in a layer.
This score is used to prune the KV pairs with the lowest scores in the `hook` method
The `hook` method is called after the forward pass of a layer and updates the cache with the pruned KV pairs.
The press can be applied to a model by calling it with the model as an argument.
"""

def __init__(self, compression_ratio: float = 0.0):
self.compression_ratio = compression_ratio
assert 0 <= compression_ratio < 1, "Compression ratio must be between 0 and 1"

def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""Compute the scores for each KV pair in the layer.
Parameters
----------
module :
Transformer layer, see `hook` method for more details.
hidden_states :
Hidden states of the layer.
keys :
Keys of the cache. Note keys are after RoPE.
values :
Values of the cache.
attentions :
Attention weights of the layer.
kwargs :
Keyword arguments, as given to the forward pass of the layer.
Returns
-------
Scores for each KV pair in the layer, shape keys.shape[:-1].
"""
raise NotImplementedError
Base class for all pruning methods.
The `forward_hook` method is called after the forward pass of an attention layer.
Any pruning/updating method should be implemented in the derived class.
"""

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""Cache compression hook called after the forward pass of a decoder layer.
"""Cache compression hook called after the forward pass of an attention layer.
The hook is applied only during the pre-filling phase if there is some pruning ratio.
The current implementation only allows to remove a constant number of KV pairs.
Parameters
----------
Expand All @@ -84,47 +40,8 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Returns
-------
Modified output of the forward pass of the layer.
"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]

# Don't compress if the compression ratio is 0 or this is not pre-filling
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output

if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Prune KV pairs with the lowest scores
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values

return output
raise NotImplementedError("forward_hook method must be implemented in the derived class")

@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
Expand All @@ -141,8 +58,8 @@ def __call__(self, model: PreTrainedModel) -> Generator:
if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)):
logger.warning(f"Model {type(model)} not tested")

hooks = []
try:
hooks = []
for layer in model.model.layers:
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))

Expand Down
6 changes: 3 additions & 3 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class ExpectedAttentionPress(BasePress):
class ExpectedAttentionPress(ScorerPress):
"""
Compute scores based on the expected attention on next positions. To do so
1. Compute the mean and covariance matrix of the queries before RoPE.
Expand Down
7 changes: 5 additions & 2 deletions kvpress/presses/knorm_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


class KnormPress(BasePress):
@dataclass
class KnormPress(ScorerPress):
"""Prune KV pairs with highest L2 norm of keys (https://arxiv.org/pdf/2406.11430)"""

def score(
Expand Down
24 changes: 14 additions & 10 deletions kvpress/presses/observed_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,36 @@
# SPDX-License-Identifier: Apache-2.0


import logging
from dataclasses import dataclass

import torch
from torch import nn
from transformers.utils import logging

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress

logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)


@dataclass
class ObservedAttentionPress(BasePress):
"""The observed attention score is defined as the average attention weight over all prompt tokens
class ObservedAttentionPress(ScorerPress):
"""
The observed attention score is defined as the average attention weight over all prompt tokens
Requires output_attentions=True and attn_implementation="eager" to have access to attentions
This approach is related to H2O (https://arxiv.org/abs/2306.14048).
"""

compression_ratio: float = 0.0
output_attentions: bool = False

def __post_init__(self):
if not self.output_attentions:
logger.warning(
"Model will not return attentions in its output to save memory. Please use DefaultPruner if"
" attentions are needed in the output."
)
super().__post_init__()

def score(
self,
module: nn.Module,
Expand All @@ -42,14 +51,9 @@ def score(

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
output = super().forward_hook(module, input, kwargs, output)

# attentions are needed as input for the hook, but unless the user wants to return them in the output,
# we can remove them to save memory
if not self.output_attentions:
logger.warning_once(
"Model will not return attentions in its output to save memory. "
"Set output_attentions=True in ObservedAttentionPress to return attentions."
)
output = list(output)
output[-2] = None
output = tuple(output)
Expand Down
49 changes: 49 additions & 0 deletions kvpress/presses/per_layer_compression_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import inspect
import logging
from dataclasses import dataclass
from typing import List

import torch
from torch import nn

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress

logger = logging.getLogger(__name__)


@dataclass
class PerLayerCompressionPress(BasePress):
press: ScorerPress
compression_ratios: List[float]

def __post_init__(self):
logger.warning(
"Per layer compression wrapper is an experimental feature and only works with flash attention. "
"Please make sure that the model uses flash attention."
)
assert (
"compression_ratio"
in inspect.signature(
self.press.__init__ # type:ignore[misc]
).parameters
), f"compression_ratio can't be set in the provided press: {self.press.__class__}"

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined]
self.press.compression_ratio = self.compression_ratios[module.layer_idx] # type:ignore[attr-defined]
output = self.press.forward_hook(module, input, kwargs, output)
self.press.compression_ratio = original_compression_ratio # type:ignore[attr-defined]
return output

@property
def compression_ratio(self):
return sum(self.compression_ratios) / len(self.compression_ratios)

@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
Loading

0 comments on commit 9ecd556

Please sign in to comment.