-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More convenient way to initialize LoftQ #1543
More convenient way to initialize LoftQ #1543
Conversation
Related to huggingface#1532 At the moment, using LoftQ is quite cumbersome, as shown in this example: https://github.com/huggingface/peft/tree/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/examples/loftq_finetuning Essentially, users have to: 1. Load the non-quantized model with LoftQ (which can be quite huge) 2. Modify the PEFT config 3. Save the adapter 4. Unwrap the base model with custom functions 5. Save the base model with modified weights (i.e. a whole copy of the base model) 6. Load the base model from step 5 with bnb quantization 7. Load the adapter from step 3 Yes, there is a helper script to do this, but this still has the advantage that we need to load the non-quantized model and that we have to create a completely new model checkpoint with the modified weights. This PR aims to make this process more convenient by adding a single function replace_lora_weights_loftq. This function takes the bnb-quantized LoRA model as input. Then it goes through each module with LoRA weights, lazily loads the corresponding non-quantized weights one at a time using safetensors, computes the quantization error, and replaces the LoRA weights with LoftQ-initialized LoRA weights. This is much more convenient because we only require very little extra memory thanks to lazy loading, and we don't have to keep an extra copy of the weights. While working on this, I still found that LoftQ initialization often did not seem to help a lot, as mentioned in huggingface#1532. I measured this by creating (1) logits with the base model, (2) with the quantized+LoRA model, and (3) with the quantized+LoRA+LoftQ model. The expectation is that (1) should be closer to (3) than to (2). This was often not the case. I therefore added the possibility to run a check each time that we replace a LoRA weight with the LoftQ weights. If this check returns True, we proceed to the next weight, otherwise we discard the change. That way, we only make the replacement with LoftQ weights if we see a real improvement. Of course, this is only a form of greedy optimization, but it seems to work in practice. And since it's optional, users can choose not to use it. This PR is not yet finished since I ran into an issue with matching the key names from safetensors not matching. Furthermore, for now this doesn't support 8bit quantization and the num_iter arguments of LoftQ, which I'm not sure is really working. However, I guess the replace_lora_weights_loftq function could be called multiple times in a row.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool function! 👏
I think we should either remove the section above or provide some clarification for when you should use each method, other than the replace_lora_weights_loftq
is easier, in which case most users would probably just pick that way. 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @BenjaminBossan !
The API looks really great ! Thinking about it we should probably not worry about supporting .bin
files for this API, but I left them as an open question, wdyt?
@@ -15,10 +15,16 @@ | |||
# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py | |||
# Reference paper: https://arxiv.org/abs/2310.08659 | |||
|
|||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I guess with the type annotations used in this file, it wouldn't be necessary, but I also don't think it hurts.
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Banjamin, thank you For making it easier to use loftQ, however, I have left a couple comments.
src/peft/utils/loftq_utils.py
Outdated
prefix = "base_model.model." | ||
any_match = False | ||
|
||
with safe_open(model_path, framework="pt", device="cpu") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this apply for model that have state dict sharded across multiple files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, https://huggingface.co/meta-llama/Llama-2-70b-hf
|
||
|
||
@torch.no_grad() | ||
def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed internally, since this approach intends not to adapt the quantized weights (to avoid having an extra copy of all these weights), only single step is implemented.
Makes more sense than from the automatic merge
Better results, bigger margins.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reviewer comments should be addressed.
@pacman100 loading sharded models should now work. It's maybe not the most elegant solution, but this is what I gleaned from combing through transformers code. I tested it on gemma-2b
, which has 2 shards, and it worked.
Moreover, I adjusted the test to use all-linear
, now the margin is much larger (>1.5 for MAE and > 2.5 for MSE).
Please check again.
@@ -15,10 +15,16 @@ | |||
# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py | |||
# Reference paper: https://arxiv.org/abs/2310.08659 | |||
|
|||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I guess with the type annotations used in this file, it wouldn't be necessary, but I also don't think it hurts.
|
||
|
||
@torch.no_grad() | ||
def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed internally, since this approach intends not to adapt the quantized weights (to avoid having an extra copy of all these weights), only single step is implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @BenjaminBossan for iterating to support larger sharded models and using all-linear
inline with reducing quantization error for all layers targetted by bitsandbytes
. Much better results in tests and notebook example! 🔥🚀✨
Related to #1532
At the moment, using LoftQ is quite cumbersome, as shown in this example:
https://github.com/huggingface/peft/tree/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/examples/loftq_finetuning
Essentially, users have to:
Yes, there is a helper script to do this, but this still has the advantage that we need to load the non-quantized model and that we have to create a completely new model checkpoint with the modified weights.
This PR aims to make this process more convenient by adding a single function
replace_lora_weights_loftq
. This function takes the bnb-quantized LoRA model as input. Then it goes through each module with LoRA weights, lazily loads the corresponding non-quantized weights one at a time using safetensors, computes the quantization error, and replaces the LoRA weights with LoftQ-initialized LoRA weights.This is much more convenient because we only require very little extra memory thanks to lazy loading, and we don't have to keep an extra copy of the whole model weights.
While working on this, I still found that LoftQ initialization often did not seem to help a lot, as mentioned in #1532. I measured this by creating (1) logits with the base model, (2) logits with the quantized+LoRA model, and (3) logits with the quantized+LoRA+LoftQ model. The expectation is that (1) should be closer to (3) than to (2). This was often not the case.
I therefore added the possibility to run a check each time that we replace a LoRA weight with the LoftQ weights. If this check returns True, we keep the change and proceed to the next weight, otherwise we discard the change before proceeding. That way, we only make the replacement with LoftQ weights if we see a real improvement. Of course, this is only a form of greedy optimization, but it seems to work in practice. And since it's optional, users can choose not to use it.
This PR is not yet finished since I ran into an issue with matching the key names from safetensors not matching.
Furthermore, for now this doesn't support 8bit quantization and the
num_iter
arguments of LoftQ, which I'm not sure is really working. However, I guess thereplace_lora_weights_loftq
function could be called multiple times in a row.ping @yxli2123