Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Head-Specific KV Cache Compression Feature (Ada-SnapKV, AdaKV) #25

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,14 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Vscode
.vscode/

# bash
evaluation/modified_evaluate.py
evaluation/a100_*.sh
evaluation/4090_*.sh
evaluation/logs/*
evaluation/results/*
evaluation/ruler/data/*
10 changes: 9 additions & 1 deletion evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import torch
from datasets import load_dataset
import os
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from kvpress.ada_attn import replace_var_flash_attn
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
Expand All @@ -23,6 +25,8 @@
RandomPress,
SnapKVPress,
StreamingLLMPress,
AdaSnapKVPress,
AdaScorerPress
)

logger = logging.getLogger(__name__)
Expand All @@ -48,6 +52,7 @@
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"ada_snapkv": AdaSnapKVPress()
}


Expand Down Expand Up @@ -124,10 +129,13 @@ def evaluate(
# Initialize pipeline with the correct attention implementation
if isinstance(press, ObservedAttentionPress):
model_kwargs = {"attn_implementation": "eager"}
# Support AdaKV
elif isinstance(press, AdaScorerPress):
replace_var_flash_attn(model=model)
model_kwargs = {"attn_implementation": "flash_attention_2"}
else:
try:
import flash_attn # noqa: F401

model_kwargs = {"attn_implementation": "flash_attention_2"}
except ImportError:
model_kwargs = {}
Expand Down
2 changes: 1 addition & 1 deletion evaluation/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dataset="ruler"
data_dir="4096"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
compression_ratios=(0.1 0.25 0.5)
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv")
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv" "ada_snapkv")

# Check if the number of press names is less than or equal to the number of available GPUs
num_gpus=$(nvidia-smi --list-gpus | wc -l)
Expand Down
7 changes: 6 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.ada_scorer_press import AdaScorerPress
from kvpress.presses.ada_snapkv_press import AdaSnapKVPress


__all__ = [
"BasePress",
"ComposedPress",
"ScorerPress",
"AdaScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
Expand All @@ -29,6 +34,6 @@
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"AdaSnapKVPress",
]

from kvpress.presses.tova_press import TOVAPress
Loading