From ea82692a5265ffb5ea66b3dc3d09b5e30a1a1096 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 9 Dec 2024 12:41:10 +0000 Subject: [PATCH 01/11] add simlayerkvpress --- README.md | 3 +- kvpress/__init__.py | 4 +- kvpress/presses/simlayerkv_press.py | 79 +++++++++++++++++++++++++++++ kvpress/presses/snapkv_press.py | 19 ++++--- kvpress/presses/tova_press.py | 6 +-- tests/presses/test_presses.py | 14 ++++- 6 files changed, 110 insertions(+), 15 deletions(-) create mode 100644 kvpress/presses/simlayerkv_press.py diff --git a/README.md b/README.md index 4ebf50b..634353a 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,8 @@ All current presses are training free. We provide the following presses associat - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) - `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469)) - `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb)) -- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453)) +- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453)) +- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio. - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above. diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 5f7fc8c..81d959c 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -9,10 +9,11 @@ from kvpress.presses.knorm_press import KnormPress from kvpress.presses.observed_attention_press import ObservedAttentionPress from kvpress.presses.random_press import RandomPress +from kvpress.presses.simlayerkv_press import SimLayerKVPress from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress -from kvpress.presses.tova_press import TOVAPress from kvpress.presses.think_press import ThinKPress +from kvpress.presses.tova_press import TOVAPress __all__ = [ "BasePress", @@ -20,6 +21,7 @@ "KnormPress", "ObservedAttentionPress", "RandomPress", + "SimLayerKVPress", "SnapKVPress", "StreamingLLMPress", "ThinKPress", diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py new file mode 100644 index 0000000..da4826b --- /dev/null +++ b/kvpress/presses/simlayerkv_press.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass + +import torch +from torch import nn + +from kvpress.presses.base_press import BasePress +from kvpress.presses.snapkv_press import SnapKVPress + + +@dataclass +class SimLayerKVPress(BasePress): + """ + SimLayerKV (https://arxiv.org/abs/2410.13846) uses a layer-wise approach to compression: + - layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept) + - other layers use the full KV cache + + To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens + over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy. + + Official implementation: https://github.com/sail-sg/SimLayerKV. We use n_last=1 to match SKLV-decode + """ + + lazy_threshold: float + n_last: int = 1 + n_recent: int = 1024 + n_initial: int = 4 + + def is_lazy( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + ) -> bool: + """ + Compute the average attention weights of the last tokens over the initial and recent tokens. + A slight difference with the original implementation is that we + """ + + attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last) + attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size + score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum() + return score.item() > self.lazy_threshold + + @property + def compression_ratio(self): + if hasattr(self, "compression_ratios"): + return sum(self.compression_ratios) / len(self.compression_ratios) + else: + raise ValueError("Forward pass must be run to compute the compression ratio") + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + + cache = output[-1] + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Don't compress if this is not pre-filling or if there are not enough tokens + if (cache.seen_tokens > q_len) or (cache.seen_tokens < self.n_initial + self.n_recent): + return output + + # Re-initialize the compression_ratios list + if module.layer_idx == 0: + self.compression_ratios = [] + assert hidden_states.shape[1] > self.n_last, "Query length should be greater than the window size" + + keys, values = cache.key_cache[module.layer_idx], cache.value_cache[module.layer_idx] + if self.is_lazy(module, hidden_states, keys): + # If layer is lazy, only keep the initial and recent KV pairs + cache.key_cache[module.layer_idx] = torch.cat( + [keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2 + ) + cache.value_cache[module.layer_idx] = torch.cat( + [values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2 + ) + self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len) + else: + self.compression_ratios.append(0) + + return output diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 05753a1..89760ea 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -25,7 +25,10 @@ class SnapKVPress(BasePress): window_size: int = 64 kernel_size: int = 5 - def compute_window_attention(self, module, hidden_states, keys): + @staticmethod + def compute_window_attention( + module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int + ) -> torch.Tensor: """ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ @@ -34,17 +37,17 @@ def compute_window_attention(self, module, hidden_states, keys): # Get last window_size queries if hasattr(module, "q_proj"): - query_states = module.q_proj(hidden_states[:, -self.window_size :]) + query_states = module.q_proj(hidden_states[:, -window_size:]) elif hasattr(module, "qkv_proj"): - qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) + qkv = module.qkv_proj(hidden_states[:, -window_size:]) query_states = qkv[..., : module.num_heads * module.head_dim] else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2) # Apply RoPE - position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) + position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device) cos, sin = module.rotary_emb(query_states, position_ids) query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) @@ -52,10 +55,10 @@ def compute_window_attention(self, module, hidden_states, keys): key_states = repeat_kv(keys, module.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) attention_mask = torch.ones_like(attn_weights) * float("-inf") - attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1) + attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) attn_weights += attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = attn_weights[..., : -self.window_size] + attn_weights = attn_weights[..., :-window_size] return attn_weights @@ -76,7 +79,7 @@ def score( if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] else: - attn_weights = self.compute_window_attention(module, hidden_states, keys) + attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size) scores = attn_weights.mean(dim=-2) scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) diff --git a/kvpress/presses/tova_press.py b/kvpress/presses/tova_press.py index 0addcd9..14d8aa1 100644 --- a/kvpress/presses/tova_press.py +++ b/kvpress/presses/tova_press.py @@ -7,11 +7,12 @@ from torch import nn import torch.nn.functional as F +from kvpress.presses.base_press import BasePress from kvpress.presses.snapkv_press import SnapKVPress @dataclass -class TOVAPress(SnapKVPress): +class TOVAPress(BasePress): """ TOVA (https://arxiv.org/abs/2401.06104) use the attention of the last token averaged across heads to estimate the importance of the previous KV pairs. This press was reviewed by Michael Hassid, @@ -21,7 +22,6 @@ class TOVAPress(SnapKVPress): """ compression_ratio: float = 0.0 - window_size: int = 1 # re-use the attention weight computation from SnapKVPress for last token def score( self, @@ -36,7 +36,7 @@ def score( if attentions is not None: attn_weights = attentions[..., -1:, :-1] else: - attn_weights = self.compute_window_attention(module, hidden_states, keys) + attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, 1) # Average across heads and repeat num_key_value_head times scores = attn_weights.mean(1) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 4c2d361..a835e59 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -13,6 +13,7 @@ ObservedAttentionPress, RandomPress, SnapKVPress, + SimLayerKVPress, StreamingLLMPress, TOVAPress, ThinKPress, @@ -29,9 +30,18 @@ def test_think_inner_press(unit_test_model): # noqa: F811 def test_presses_run(unit_test_model): # noqa: F811 - for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]: + for cls in [ + KnormPress, + ExpectedAttentionPress, + RandomPress, + SimLayerKVPress, + StreamingLLMPress, + SnapKVPress, + TOVAPress, + ThinKPress, + ]: for compression_ratio in [0.2, 0.4, 0.6, 0.8]: - press = cls(compression_ratio=compression_ratio) + press = cls(compression_ratio) if cls in [SnapKVPress, ThinKPress]: press.window_size = 2 with press(unit_test_model): From 190a1dbbf680ef1902a5ce99c813751446210d75 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 9 Dec 2024 15:58:32 +0000 Subject: [PATCH 02/11] add quantization support --- kvpress/presses/simlayerkv_press.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index da4826b..968d932 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -2,6 +2,7 @@ import torch from torch import nn +from transformers import QuantizedCache from kvpress.presses.base_press import BasePress from kvpress.presses.snapkv_press import SnapKVPress @@ -63,17 +64,26 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic self.compression_ratios = [] assert hidden_states.shape[1] > self.n_last, "Query length should be greater than the window size" - keys, values = cache.key_cache[module.layer_idx], cache.value_cache[module.layer_idx] + if isinstance(cache, QuantizedCache): + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) + values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + else: + keys = cache.key_cache[module.layer_idx] + values = cache.value_cache[module.layer_idx] + if self.is_lazy(module, hidden_states, keys): # If layer is lazy, only keep the initial and recent KV pairs - cache.key_cache[module.layer_idx] = torch.cat( - [keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2 - ) - cache.value_cache[module.layer_idx] = torch.cat( - [values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2 - ) + keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) + values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len) else: self.compression_ratios.append(0) + if isinstance(cache, QuantizedCache): + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) + cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + else: + cache.key_cache[module.layer_idx] = keys + cache.value_cache[module.layer_idx] = values + return output From 445b62b3fd4bb7ceb8252f394e2811c1c562f5a2 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 08:52:55 +0000 Subject: [PATCH 03/11] Adress PR feedback --- kvpress/presses/simlayerkv_press.py | 13 ++++++++++--- tests/presses/test_presses.py | 8 ++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 968d932..55e28a6 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -26,6 +26,9 @@ class SimLayerKVPress(BasePress): n_recent: int = 1024 n_initial: int = 4 + def __post_init__(self): + self.compression_ratios = None + def is_lazy( self, module: nn.Module, @@ -33,8 +36,8 @@ def is_lazy( keys: torch.Tensor, ) -> bool: """ - Compute the average attention weights of the last tokens over the initial and recent tokens. - A slight difference with the original implementation is that we + Compute the attention weights of the last tokens over the initial and recent tokens. + The layer is considered lazy if the sum of these attention weights is above the lazy_threshold. """ attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last) @@ -44,11 +47,15 @@ def is_lazy( @property def compression_ratio(self): - if hasattr(self, "compression_ratios"): + if self.compression_ratios is not None: return sum(self.compression_ratios) / len(self.compression_ratios) else: raise ValueError("Forward pass must be run to compute the compression ratio") + @compression_ratio.setter + def compression_ratio(self, value): + raise AttributeError("Cannot set the compression ratio") + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): cache = output[-1] diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index a835e59..568482d 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -40,8 +40,12 @@ def test_presses_run(unit_test_model): # noqa: F811 TOVAPress, ThinKPress, ]: - for compression_ratio in [0.2, 0.4, 0.6, 0.8]: - press = cls(compression_ratio) + for value in [0.2, 0.4, 0.6, 0.8]: + if cls == SimLayerKVPress: + press = cls(lazy_threshold=value) + else: + press = cls(compression_ratio=value) + if cls in [SnapKVPress, ThinKPress]: press.window_size = 2 with press(unit_test_model): From 5e5a6bcfc91e0e194e371994dc0e063c9398207d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 16:57:35 +0000 Subject: [PATCH 04/11] Add lazy_threshold recommendations in docstring --- kvpress/presses/simlayerkv_press.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 76476ae..2523e30 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -2,7 +2,6 @@ import torch from torch import nn -from transformers import QuantizedCache from kvpress.presses.base_press import BasePress from kvpress.presses.snapkv_press import SnapKVPress @@ -17,12 +16,17 @@ class SimLayerKVPress(BasePress): To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy. - - Official implementation: https://github.com/sail-sg/SimLayerKV. We use n_last=1 to match SKLV-decode + + Recommended values for lazy_threshold from the official repository: + - llama3: 0.9 + - llama2: 0.65 + - mistral: 0.8 + - qwen: 0.85 + (Source: https://github.com/sail-sg/SimLayerKV/blob/main/LongBench/pred.py#L167) """ lazy_threshold: float - n_last: int = 1 + n_last: int = 1 # n_last=1 to match SKLV-decode n_recent: int = 1024 n_initial: int = 4 From ea5708510b1d0a99c9cb0f7c48f284823a3bd8e8 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 10 Dec 2024 17:01:18 +0000 Subject: [PATCH 05/11] Fix style --- kvpress/presses/simlayerkv_press.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 2523e30..2ccff43 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -16,7 +16,7 @@ class SimLayerKVPress(BasePress): To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy. - + Recommended values for lazy_threshold from the official repository: - llama3: 0.9 - llama2: 0.65 From 73cde6bcc944866ecd20a5f847216bf4ef0c678a Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 13:34:11 +0000 Subject: [PATCH 06/11] Adress PR feedback --- kvpress/presses/simlayerkv_press.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 2ccff43..4f0984d 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass import torch @@ -6,6 +7,8 @@ from kvpress.presses.base_press import BasePress from kvpress.presses.snapkv_press import SnapKVPress +logger = logging.getLogger(__name__) + @dataclass class SimLayerKVPress(BasePress): @@ -22,10 +25,11 @@ class SimLayerKVPress(BasePress): - llama2: 0.65 - mistral: 0.8 - qwen: 0.85 + By default, lazy_threshold is set to 1.0 (no compression) (Source: https://github.com/sail-sg/SimLayerKV/blob/main/LongBench/pred.py#L167) """ - lazy_threshold: float + lazy_threshold: float = 1.0 # no compression n_last: int = 1 # n_last=1 to match SKLV-decode n_recent: int = 1024 n_initial: int = 4 @@ -72,7 +76,11 @@ def compress( # Don't compress if the query length is less than the initial and recent tokens q_len = hidden_states.shape[1] - if q_len < self.n_initial + self.n_recent: + if (q_len < self.n_initial + self.n_recent) or (self.lazy_threshold == 1): + if self.lazy_threshold != 1: + logger.warning( + f"Query length {q_len} is less than the initial and recent tokens {self.n_initial} + {self.n_recent}. No compression will be applied." # noqa + ) self.compression_ratios = [0.0] return keys, values From befdbbf4a4d714fd31c86050fa025a1bc6723436 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 13:52:44 +0000 Subject: [PATCH 07/11] Update compression_ratio management --- kvpress/presses/simlayerkv_press.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 4f0984d..ab12277 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -29,13 +29,17 @@ class SimLayerKVPress(BasePress): (Source: https://github.com/sail-sg/SimLayerKV/blob/main/LongBench/pred.py#L167) """ - lazy_threshold: float = 1.0 # no compression + lazy_threshold: float = 1.0 n_last: int = 1 # n_last=1 to match SKLV-decode n_recent: int = 1024 n_initial: int = 4 def __post_init__(self): - self.compression_ratios = None + assert 0.0 <= self.lazy_threshold <= 1.0, "lazy_threshold should be in [0, 1]" + if self.lazy_threshold == 1.0: + self.compression_ratios = [0.0] + else: + self.compression_ratios = [] def is_lazy( self, @@ -55,7 +59,7 @@ def is_lazy( @property def compression_ratio(self): - if self.compression_ratios is not None: + if len(self.compression_ratios) > 0: return sum(self.compression_ratios) / len(self.compression_ratios) else: raise ValueError("Forward pass must be run to compute the compression ratio") @@ -74,20 +78,15 @@ def compress( kwargs: dict, ) -> tuple[torch.Tensor, torch.Tensor]: - # Don't compress if the query length is less than the initial and recent tokens - q_len = hidden_states.shape[1] - if (q_len < self.n_initial + self.n_recent) or (self.lazy_threshold == 1): - if self.lazy_threshold != 1: - logger.warning( - f"Query length {q_len} is less than the initial and recent tokens {self.n_initial} + {self.n_recent}. No compression will be applied." # noqa - ) - self.compression_ratios = [0.0] + if self.lazy_threshold == 1.0: return keys, values - # If first layer, initialize compression_ratios + # Sanity check and compression_ratios initialization + q_len = hidden_states.shape[1] if module.layer_idx == 0: self.compression_ratios = [] - assert hidden_states.shape[1] > self.n_last, "Query length should be greater than the window size" + min_length = self.n_initial + self.n_recent + self.n_last + assert q_len >= min_length, f"Query length should be greater than the window size ({min_length})" if self.is_lazy(module, hidden_states, keys): # If layer is lazy, only keep the initial and recent KV pairs From 7776670addaf4966482bd1ffcd48eeabac05a533 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 13:55:44 +0000 Subject: [PATCH 08/11] Update cr --- kvpress/presses/simlayerkv_press.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index ab12277..33775b8 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -36,10 +36,7 @@ class SimLayerKVPress(BasePress): def __post_init__(self): assert 0.0 <= self.lazy_threshold <= 1.0, "lazy_threshold should be in [0, 1]" - if self.lazy_threshold == 1.0: - self.compression_ratios = [0.0] - else: - self.compression_ratios = [] + self.compression_ratios = [] def is_lazy( self, @@ -78,9 +75,6 @@ def compress( kwargs: dict, ) -> tuple[torch.Tensor, torch.Tensor]: - if self.lazy_threshold == 1.0: - return keys, values - # Sanity check and compression_ratios initialization q_len = hidden_states.shape[1] if module.layer_idx == 0: @@ -88,12 +82,16 @@ def compress( min_length = self.n_initial + self.n_recent + self.n_last assert q_len >= min_length, f"Query length should be greater than the window size ({min_length})" + if self.lazy_threshold == 1.0: + self.compression_ratios.append(0.0) + return keys, values + if self.is_lazy(module, hidden_states, keys): # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) self.compression_ratios.append((q_len - self.n_initial - self.n_recent + 1) / q_len) else: - self.compression_ratios.append(0) + self.compression_ratios.append(0.0) return keys, values From e5d69e4075d3ee763a11576054fa0780bd3d1a07 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 13:55:54 +0000 Subject: [PATCH 09/11] Update cr --- kvpress/presses/simlayerkv_press.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 33775b8..70277e7 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -82,6 +82,7 @@ def compress( min_length = self.n_initial + self.n_recent + self.n_last assert q_len >= min_length, f"Query length should be greater than the window size ({min_length})" + # Do not compress if lazy_threshold is 1.0 if self.lazy_threshold == 1.0: self.compression_ratios.append(0.0) return keys, values From d85727005dd60b50b6f49a2c44a971b8c45072cb Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 13:56:45 +0000 Subject: [PATCH 10/11] Fix tests --- tests/presses/test_presses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 9b461b6..7950617 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -48,7 +48,7 @@ def test_presses_run(unit_test_model): # noqa: F811 if cls == ThinKPress: press = cls(key_channel_compression_ratio=value, window_size=2) elif cls == SimLayerKVPress: - press = cls(lazy_threshold=value) + press = cls(lazy_threshold=value, n_recent=2) else: press = cls(compression_ratio=value) if cls == SnapKVPress: From fbf225e736e1b1118763e1dfd8009b87deeda7e8 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 11 Dec 2024 14:04:05 +0000 Subject: [PATCH 11/11] Remove assert --- kvpress/presses/simlayerkv_press.py | 16 ++++++++++------ tests/presses/test_presses.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 70277e7..8693015 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -75,18 +75,22 @@ def compress( kwargs: dict, ) -> tuple[torch.Tensor, torch.Tensor]: - # Sanity check and compression_ratios initialization - q_len = hidden_states.shape[1] + # Initialize the compression ratios if module.layer_idx == 0: self.compression_ratios = [] - min_length = self.n_initial + self.n_recent + self.n_last - assert q_len >= min_length, f"Query length should be greater than the window size ({min_length})" - # Do not compress if lazy_threshold is 1.0 - if self.lazy_threshold == 1.0: + # Check if compression is needed + q_len = hidden_states.shape[1] + min_length = self.n_initial + self.n_recent + self.n_last + + if q_len <= min_length: + logger.warning(f"Sequence length is shorter than {min_length}: no compression applied") + + if (self.lazy_threshold == 1.0) or (q_len <= min_length): self.compression_ratios.append(0.0) return keys, values + # Compression if self.is_lazy(module, hidden_states, keys): # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 7950617..86dd35a 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -48,7 +48,7 @@ def test_presses_run(unit_test_model): # noqa: F811 if cls == ThinKPress: press = cls(key_channel_compression_ratio=value, window_size=2) elif cls == SimLayerKVPress: - press = cls(lazy_threshold=value, n_recent=2) + press = cls(lazy_threshold=value, n_initial=1, n_recent=1, n_last=1) else: press = cls(compression_ratio=value) if cls == SnapKVPress: