-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SimLayerKVPress #28
Conversation
Thanks for adding this i was banging my head just trying to implement this Right now Im actually busy I'll try reviewing tommorow if that's ok 😔 |
@dame-cell no pb, I also ask for a review on the official repository (see sail-sg/SimLayerKV#4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR!
In general, the PR is in a good shape. I left some comments, they should be fast to fix.
Regrading the press itself, it looks good to me. I haven't studied the original work in detail, so it may make sense to also wait if the authors give some feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code + implementation LGTM, thanks for adding the press!
I haven't checked in detail if the press is equivalent to the original one, if there's no feedback in the next day(s), I will have a look at this, as well.
@SimJeg I have added a comment please check it out and tell me if this will work or not ? |
@dame-cell I don't see your comment, can you provide a link to it ? |
@SimJeg forgive me here is the comment if 'llama3' in out_path:
threshold = 0.9
elif 'llama2' in out_path:
threshold = 0.65
elif 'mistral' in out_path:
threshold = 0.8
elif 'qwen' in out_path:
threshold = 0.85
Adding something similar could make this even more versatile and model-aware. Just a thought—curious to hear your perspective! 😊 maybe something like this def get_lazy_threshold(model_name: str) -> float:
if 'llama3' in model_name:
return 0.9
elif 'llama2' in model_name:
return 0.65
elif 'mistral' in model_name:
return 0.8
elif 'qwen' in model_name:
return 0.85
else:
return 0.7 # Default threshold
# Example
module.config.get("model_name", "")
lazy_threshold = get_lazy_threshold(model_name) |
I prefered to let the user specify the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
I left two small comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reapproving
Hi, Thanks for incorporating our SimLayerKV! We are currently working on version 2 of our project, with a major update being its integration with flash attention to enhance efficiency. For your reference, here is the source code:
|
Add
SimLayerKVPress
(paper, official repository) following issue #19 and PR #22.SimLayerKV uses a layer-wise approach to compression:
- layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept)
- other layers use the full KV cache
To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens
over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy.
As for the wrapper for layer-wise compression ratio, this press only works using flash attention (more investigation to be done on why). A notable difference from other presses is that the input of
SimLayerKVPress
is not a compression ratio but thelazy_threshold
as defined in the paper. However I implemented acompression_ratio
property that is computed dynamically.@dame-cell, after implementing it, I found it was not necessary after all to implement a new cache class. Could you please review and comment this implementation ? I will also share it on the official repository to get a feedback from the authors.