Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimLayerKVPress #28

Merged
merged 12 commits into from
Dec 11, 2024
Merged

Add SimLayerKVPress #28

merged 12 commits into from
Dec 11, 2024

Conversation

SimJeg
Copy link
Collaborator

@SimJeg SimJeg commented Dec 9, 2024

Add SimLayerKVPress (paper, official repository) following issue #19 and PR #22.

SimLayerKV uses a layer-wise approach to compression:
- layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept)
- other layers use the full KV cache

To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens
over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy.

As for the wrapper for layer-wise compression ratio, this press only works using flash attention (more investigation to be done on why). A notable difference from other presses is that the input of SimLayerKVPress is not a compression ratio but the lazy_threshold as defined in the paper. However I implemented a compression_ratio property that is computed dynamically.

@dame-cell, after implementing it, I found it was not necessary after all to implement a new cache class. Could you please review and comment this implementation ? I will also share it on the official repository to get a feedback from the authors.

@dame-cell
Copy link

dame-cell commented Dec 9, 2024

Add SimLayerKVPress (paper, official repository) following issue #19 and PR #22.

SimLayerKV uses a layer-wise approach to compression:
- layers identified as lazy use the Streaming LLM approach (only initial and recent KV pairs are kept)
- other layers use the full KV cache

To identify lazy layers, the last attention weights are used. If the sum of attention weights of the last tokens
over the initial and recent tokens is above the lazy_threshold, the layer is considered lazy.

As for the wrapper for layer-wise compression ratio, this press only works using flash attention (more investigation to be done on why). A notable difference from other presses is that the input of SimLayerKVPress is not a compression ratio but the lazy_threshold as defined in the paper. However I implemented a compression_ratio property that is computed dynamically.

@dame-cell, after implementing it, I found it was not necessary after all to implement a new cache class. Could you please review and comment this implementation ? I will also share it on the official repository to get a feedback from the authors.

Thanks for adding this i was banging my head just trying to implement this

Right now Im actually busy I'll try reviewing tommorow if that's ok 😔

@SimJeg
Copy link
Collaborator Author

SimJeg commented Dec 9, 2024

Thanks for adding this i was banging my head just trying to implement this

Right now Im actually busy I'll try reviewing tommorow if that's ok 😔

@dame-cell no pb, I also ask for a review on the official repository (see sail-sg/SimLayerKV#4)

@maxjeblick maxjeblick self-assigned this Dec 9, 2024
@SimJeg SimJeg changed the title add simlayerkvpress Add SimLayerKVPress Dec 9, 2024
Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR!
In general, the PR is in a good shape. I left some comments, they should be fast to fix.
Regrading the press itself, it looks good to me. I haven't studied the original work in detail, so it may make sense to also wait if the authors give some feedback.

@SimJeg SimJeg mentioned this pull request Dec 10, 2024
Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code + implementation LGTM, thanks for adding the press!

I haven't checked in detail if the press is equivalent to the original one, if there's no feedback in the next day(s), I will have a look at this, as well.

@dame-cell
Copy link

dame-cell commented Dec 10, 2024

@SimJeg I have added a comment please check it out and tell me if this will work or not ?
other than that i have tested it myself seems to be working pretty good 💯

@SimJeg
Copy link
Collaborator Author

SimJeg commented Dec 10, 2024

@dame-cell I don't see your comment, can you provide a link to it ?

@dame-cell
Copy link

@SimJeg forgive me here is the comment
in the original implementation they included different threshold for different models

        if 'llama3' in out_path:
            threshold = 0.9
        elif 'llama2' in out_path:
            threshold = 0.65
        elif 'mistral' in out_path:
            threshold = 0.8
        elif 'qwen' in out_path:
            threshold = 0.85
       

Adding something similar could make this even more versatile and model-aware. Just a thought—curious to hear your perspective! 😊

maybe something like this

def get_lazy_threshold(model_name: str) -> float:
    if 'llama3' in model_name:
        return 0.9
    elif 'llama2' in model_name:
        return 0.65
    elif 'mistral' in model_name:
        return 0.8
    elif 'qwen' in model_name:
        return 0.85
    else:
        return 0.7  # Default threshold

# Example  
module.config.get("model_name", "")
lazy_threshold = get_lazy_threshold(model_name)

@SimJeg
Copy link
Collaborator Author

SimJeg commented Dec 10, 2024

I prefered to let the user specify the lazy_threshold argument to make it model agnostic but I will update the docstring to provide to the user help to set this value.

@SimJeg SimJeg linked an issue Dec 10, 2024 that may be closed by this pull request
Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!
I left two small comments.

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reapproving

@SimJeg SimJeg merged commit e36615c into main Dec 11, 2024
2 checks passed
@SimJeg SimJeg deleted the simon/simlayerkv branch December 11, 2024 14:08
@jadeCurl
Copy link

jadeCurl commented Dec 23, 2024

Hi,

Thanks for incorporating our SimLayerKV!

We are currently working on version 2 of our project, with a major update being its integration with flash attention to enhance efficiency. For your reference, here is the source code:

attn_out, lse = flash_attn(q, k, v, causal=True, return_lse=True)
# identification
# w_last = 32, w_sink=4, w_recent=1020
q_last = q[:, -w_last:].permute(0, 2, 1, 3)
k_comb = torch.cat([k[:, 0:w_sink], k[:, -w_recent:]], dim=1).permute(0, 2, 3, 1)
log_lazy_weight = torch.matmul(q_last, k_comb).logsumexp(dim=-1) - lse[:,:,-w_recent:]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add SIMLAYERKV
4 participants