Skip to content

Commit

Permalink
Add strategy without exec
Browse files Browse the repository at this point in the history
  • Loading branch information
SimJeg committed Dec 18, 2024
1 parent d9e7a54 commit b6b16af
Show file tree
Hide file tree
Showing 9 changed files with 5,961 additions and 14 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ max-line-length = 120
per-file-ignores =
__init__.py:F401
evaluation/infinite_bench/create_huggingface_dataset.py:E501
exclude = kvpress/models/modeling_*
# E203, W503 - black-compatible config
extend-ignore = E203, W503
32 changes: 18 additions & 14 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import re
import inspect
import importlib

from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
Expand All @@ -22,16 +18,24 @@
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

# Hack to add query_states to the cache_kwargs of the attention classes for DynamicHeadCache
for name in ["llama", "mistral", "phi3", "qwen2"]:
module = importlib.import_module(f"transformers.models.{name}.modeling_{name}")
attention_classes = getattr(module, f"{name.upper()}_ATTENTION_CLASSES")
for key, cls in attention_classes.items():
updated_source_code = re.sub(
r"cache_kwargs = {(.*?)\}", r'cache_kwargs = {\1, "query_states": query_states}', inspect.getsource(cls)
)
exec(updated_source_code, module.__dict__) # security risk here
attention_classes[key] = module.__dict__[cls.__name__]

# Strategy 1: compact but use exec
# import re
# import inspect
# import importlib
# for name in ["llama", "mistral", "phi3", "qwen2"]:
# module = importlib.import_module(f"transformers.models.{name}.modeling_{name}")
# attention_classes = getattr(module, f"{name.upper()}_ATTENTION_CLASSES")
# for key, cls in attention_classes.items():
# updated_source_code = re.sub(
# r"cache_kwargs = {(.*?)\}", r'cache_kwargs = {\1, "query_states": query_states}', inspect.getsource(cls)
# )
# exec(updated_source_code, module.__dict__) # security risk here
# attention_classes[key] = module.__dict__[cls.__name__]

# Strategy 2: cleaner but less compact
from kvpress.models.utils import update_attn_implementations
update_attn_implementations()


__all__ = [
Expand Down
5 changes: 5 additions & 0 deletions kvpress/models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Registering new attention modules

The `modeling_{name}.py` files have been created running the `utils.rewrite_modeling_scripts` function. This function simply adds the `query_states` in the `cache_kwargs`.

The `utils.update_attn_implementations` function can then bue used to register the attention classes in the `{NAME}_ATTENTION_CLASSES` dictionary.
Empty file added kvpress/models/__init__.py
Empty file.
1,475 changes: 1,475 additions & 0 deletions kvpress/models/modeling_llama.py

Large diffs are not rendered by default.

1,395 changes: 1,395 additions & 0 deletions kvpress/models/modeling_mistral.py

Large diffs are not rendered by default.

1,523 changes: 1,523 additions & 0 deletions kvpress/models/modeling_phi3.py

Large diffs are not rendered by default.

1,497 changes: 1,497 additions & 0 deletions kvpress/models/modeling_qwen2.py

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions kvpress/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import re
import inspect
import importlib
import transformers
from pathlib import Path

MODEL_NAMES = ["llama", "mistral", "phi3", "qwen2"]


def rewrite_modeling_scripts():
"""
Rewrite the modeling_{name}.py files to include the query_states argument in the cache_kwargs
"""

for name in MODEL_NAMES:
module = importlib.import_module(f"transformers.models.{name}.modeling_{name}")
pattern = r"cache_kwargs = {(.*?)\}"
repl = r'cache_kwargs = {\1, "query_states": query_states}'
version = transformers.__version__
source_code = f"# 🚨🚨🚨 This code has been automatically generated using transformers {version} 🚨🚨🚨\n"
source_code += inspect.getsource(module)
source_code = re.sub(pattern, repl, source_code)
source_code = source_code.replace("from ...", "from transformers.")
source_code = source_code.replace("from .", f"from transformers.models.{name}.")
path = Path(__file__).resolve().parent / f"modeling_{name}.py"
path.write_text(source_code)


def update_attn_implementations():
"""
Register the kvpress attention classes in the {NAME}_ATTENTION_CLASSES dictionaries of the transformers models
"""

for name in MODEL_NAMES:
transformers_module = importlib.import_module(f"transformers.models.{name}.modeling_{name}")
transformers_attention_classes = getattr(transformers_module, f"{name.upper()}_ATTENTION_CLASSES")

kvpress_module = importlib.import_module(f"kvpress.models.modeling_{name}")
kvpress_attention_classes = getattr(kvpress_module, f"{name.upper()}_ATTENTION_CLASSES")

# Update transformers_attention_classes
for key in transformers_attention_classes:
transformers_attention_classes[key] = kvpress_attention_classes[key]


if __name__ == "__main__":
rewrite_modeling_scripts()

0 comments on commit b6b16af

Please sign in to comment.