Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX PiSSA & OLoRA with rank/alpha pattern, rslora #1930

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jul 16, 2024

See #1929 (comment)

At the moment, when using PiSSA or OLoRA with weight conversion to restore the original base weights, there is an error when either of rank_pattern, alpha_pattern, or rslora is being used. This PR fixes this.

The issue is that we need to double the rank of the LoRA adapter. Right now, this is done by simply doubling r and alpha. But if rank_pattern and alpha_pattern are being used, those need to be doubled too.

Furthermore, when using rslora, the scaling is again different, namely alpha / sqrt(r). This also needs to be adjusted.

Unfortunately, when using rslora together with rank_pattern and alpha_pattern, this gets way more complicated. Since we don't store the scaling in the state_dict, we would have to resolve all the patterns here to determine the correct scaling, i.e. reimplement the whole matching and init logic. This is a lot of work for a very edgy edge case.

Therefore, I opted to raise an error instead. This is not super nice, as the error is only raised when trying to save the model, i.e. a lot of time may already have been spent to train the model. But we cannot know this earlier, so not much can be done.

Overall, this fix is ugly because it further couples unrelated code. For instance, if we add new init methods that affect the scaling, we need to remember to change the saving logic accordingly. If anyone has a better idea, LMK.

Update

I noticed that the test test_lora_pissa_conversion_same_output_after_loading_with_quantization was failing locally. This is because the argument convert_mutated_to_lora was renamed to path_initial_model_for_weight_conversion in #1828. This oversight has been addressed. However, this also means that the tests did not run in CI (nightly CI with GPU). The reason was that the PiSSA and OLoRA tests were missing the right marker, so I added them as well.

See huggingface#1929 (comment)

At the moment, when using PiSSA or OLoRA with weight conversion to
restore the original base weights, there is an error when either of
rank_pattern, alpha_pattern, or rslora is being used. This PR fixes
this.

The issue is that we need to double the rank of the LoRA adapter. Right
now, this is done by simply doubling r and alpha. But if rank_pattern
and alpha_pattern are being used, those need to be doubled too.

Furthermore, when using rslora, the scaling is again different, namely
alpha / sqrt(r). This also needs to be adjusted.

Unfortunately, when using rslora with rank_pattern and alpha_pattern,
this gets way more complicated. Since we don't store the scaling in the
state_dict, we would have to resolve all the patterns here to determine
the correct scaling, i.e. reimplement the whole matching and init logic.
This is a lot of work for a very edgy edge case.

Therefore, I opted to raise an error instead. This is not super nice, as
the error is only raised when trying to save the model, i.e. a lot of
time may already have been spent to train the model. But we cannot know
this earlier, so not much can be done.

Overall, this fix is ugly because it further couples unrelated code. For
instance, if we add new init methods that affect the scaling, we need to
remember to change the saving logic accordingly. If anyone has a better
idea, LMK.
@BenjaminBossan
Copy link
Member Author

Tagging @fxmeng and @tokenizer-decode as this affects the code they contributed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Jul 16, 2024

Oh, so we raise the error after all the training is done. Can you tell me why we are checking if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern): after all the training is done and before trying the save the model? Why are we not doing this when initializing PiSSA or OLoRA? Aren't peft_config.rank_pattern and peft_config.alpha_pattern options determined before doing anything? Or maybe I did not understand you correctly.

@BenjaminBossan
Copy link
Member Author

Yes, valid question. The issue is that this is only a problem if users pass path_initial_model_for_weight_conversion to save_pretrained, so we can only really check this once save_pretrained is called.

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Jul 16, 2024

Hmm. Maybe at the initialization if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern) is True we can warn the user saying you cannot use path_initial_model_for_weight_conversion? There is a chance he may stop there. I don't know. This is also not a good fix but at least gives user a chance. I would be pissed after spending bunch of many to an EC2 instance then getting this error.

@BenjaminBossan
Copy link
Member Author

@tokenizer-decode Good suggestion about adding a warning. I did that in the last commit, though unfortunately many users don't read warnings. Still, there is not much else that can be done here.

@tokenizer-decode
Copy link
Contributor

Yeah you are right not much else to do here. But I agree that this is a very edgy edge case so we should be ok.

@BenjaminBossan BenjaminBossan requested a review from sayakpaul July 17, 2024 09:52
@BenjaminBossan
Copy link
Member Author

@sayakpaul LMK if you don't want to review this, I'll look for someone else.

@BenjaminBossan
Copy link
Member Author

@tokenizer-decode While working on this, I also discovered a small error in a test caused by your PR, which I have fixed (see update section of PR description).

@tokenizer-decode
Copy link
Contributor

Oh I totally missed that. Thanks

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Seems like a very deep edge case. I have left some comments to better understand some stuff. Once replied to, I will review one more time.

src/peft/peft_model.py Show resolved Hide resolved
src/peft/tuners/lora/config.py Show resolved Hide resolved
output_pissa = peft_model(data)[0]

# sanity check
tol = 1e-06
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) a stricter check for negations is to relax the tolerance a bit and see if that still holds IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. At the moment, we use the same tolerance for assert not torch.allclose(...) and assert torch.allclose(...) throughout the whole PEFT code base. Therefore, I wouldn't change this now only for these specific tests. I think this could be a nice use case for a code-LLM to go through all tests and rewrite them to use higher tolerance for negative checks :)

Comment on lines +369 to +370
assert model_loaded.peft_config["default"].r == 8
assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between r being 8 and rank_pattern being 32 for linear here?

target_modules=["linear"], r=8, rank_pattern={"linear": 32}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this comment earlier, so addressing it now.

For this simple test, we only have a small model with a single linear layer. Therefore, if we set rank_pattern={"linear": 32}, we basically ignore r=8 because rank_pattern takes precedence. If we had a real model, we could, however, set all ranks to 8 this way and pick individual layers with a different rank by using rank_pattern.

Does that answer the question?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it does. Maybe just add a comment explaining this in the test? Because many developers refer to tests to understand different usage patterns and they may get confused here as well.

Comment on lines 346 to 352
# use rank_pattern here
config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8, rank_pattern={"linear": 32})
peft_model = get_peft_model(deepcopy(model), config)
# save the initial model
peft_model.peft_config["default"].init_lora_weights = True
peft_model.save_pretrained(tmp_path / "init-model")
peft_model.peft_config["default"].init_lora_weights = "pissa"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me unpack this bit of code.

We are using pissa for init_lora_weights already in the LoraConfig. So, I don't understand what would differ between peft_model.peft_config["default"].init_lora_weights = True and peft_model.peft_config["default"].init_lora_weights = "pissa".

Copy link
Contributor

@tokenizer-decode tokenizer-decode Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was due to a bug in original PiSSA implementation. I fixed it when I was implementing olora but I didn’t change tests. What would differ is init_lora_weights should be True in the config file. See here

model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
)

def test_lora_pissa_conversion_same_output_after_loading_with_alpha_pattern(self, data, tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments apply for this too.

@BenjaminBossan BenjaminBossan requested a review from sayakpaul July 18, 2024 10:29
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor changes requested. I think this is the best we can have right now.

@BenjaminBossan
Copy link
Member Author

@sayakpaul I added clarifying comments to the tests in my latest commit.

@sayakpaul sayakpaul merged commit e02b938 into huggingface:main Jul 19, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-pissa-olora-rank_pattern-alpha_pattern-rslora branch July 22, 2024 09:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants