-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Error with OLoRA init when using bnb #2011
Conversation
[WIP] Resolves huggingface#1999
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@tokenizer-decode Would be great if you could review. |
Been very busy lately. Is this urgent? I'll take a look at it this week. |
That's more than enough, thanks. |
Interesting that this happens for specific models. I got the point we need to create 4bit and 8bit objects ourselves. I guess bnb is not a peft requirement and that's why you are conditionally importing it? Implementation looks good. But beside the actual problem I think
And we would do:
Haven't tested this. It may need adjustments. I might be being too pedantic here. Just a recommendation. Otherwise it looks good. |
Exactly.
This is unfortunately the nature of the beast. As usage grows, more edge cases are discovered that need to be taken care of. I don't see a big advantage in moving the layer creation into a separate function, it could actually make the code harder to understand (I'd do it if we foresee it being used elsewhere too). But I see a point in using |
Fair enough. But I think there is advantage in eliminating the repetiton with single a
Yeah I saw that after sending the comment. Nice touch. Wouldn't hurt imo. |
The issue is that the init of the 4bit vs 8bit params class is a bit different, so we cannot make it the same call, or do you mean something else? |
I mean we make a single call when we do this: |
Hmm, can we though? The constructors of |
Didn't notice that. Wouldn't that mean Int8Params will not be updated? Maybe it should be set True for inference. Have you tried both with |
Well, the point of QLoRA is exactly that the quantized base weights are not updated ;-) Not sure if it's even possible tbh but in any case we don't want that for PEFT. |
But Int8 set to be True by default. Weird. I don't know much about QLoRA tbh. Anyway this still would not invalidate our approach. You would instead do: |
This still does not quite work, as 8bit params don't have the |
Okay at least we don't import bnb. Lgtm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @BenjaminBossan ! I left a small question to better understand what's happening !
src/peft/tuners/lora/layer.py
Outdated
if bnb_param_type == "4bit": | ||
weight_tensor = orig_weight.__class__(weight_tensor, quant_type=orig_weight.quant_type).to( | ||
orig_weight.device | ||
) | ||
base_layer.weight = weight_tensor | ||
elif bnb_param_type == "8bit": | ||
weight_tensor = orig_weight.__class__(weight_tensor, requires_grad=False).to(orig_weight.device) | ||
base_layer.weight = weight_tensor | ||
else: | ||
weight_tensor = weight_tensor.to(dtype) | ||
base_layer.weight.data = weight_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we quantizing the weights this time ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally for bnb weights, the tensors are flattened, e.g. shape [64, 1]. But after dequantizing, the weight_tensor
that we assign here is not flat anymore, e.g. shape [16, 16]. My reasoning was that we should get back a "correct" tensor, so better to re-initialize it.
I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor
and curiously, this seems to work too and the test passes, even if the shape is now wrong. This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)? Not sure, any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor
Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this. cc @matthewdouglas
This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)?
I think that's fine as long as you pass the relevant kwargs that you can get from orig_weight
. However, make sure to not pass bnb_quantized
arg for the 4-bit case. Then, with to(orig_weight.device)
, it should quantize the weights properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for the additional info.
However, make sure to not pass
bnb_quantized
arg for the 4-bit case. Then, withto(orig_weight.device)
, it should quantize the weights properly.
To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device)
too?
Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this.
Okay, then it's probably better to get Matthew's opinion before merging this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as the shapes go, both 4bit and 8bit have some mechanisms in place to track the original shapes, but it's different for each. The Linear8bitLt
has a state.SB
and for 4bit that information is part of quant_state
. The main expectation is that it is all stored in a contiguous row-major format.
That said, it's not really clear to me that dequantize_module_weight()
is doing all that it would need to do. Maybe it would pass the test here but I would think the updated weights would not be quantized properly afterwards, so re-initializing it is probably the way to go.
To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device) too?
You'd want to have .to(orig_weight.device)
in addition to the other kwargs as @SunMarc mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the inits to take into account all arguments. Unfortunately, this may get out of date if bnb is updated, but I think there is no method such as bnb.create_param_like(tensor)
or such to offload this work to bnb.
It would be great if you could do a final pass over the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just pass orig_weight.__dict_
as the kwargs ? This is what how we did it in transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I wonder if that's really more robust. If a new attribute is added that is not an __init__
argument, this would fail, right?
class Foo:
def __init__(self, x):
self.x = x
self.y = 123
foo = Foo("hi")
Foo(**foo.__dict__) # TypeError: Foo.__init__() got an unexpected keyword argument 'y'
So no matter what, this code may break if there is some change to the __init__
code in bnb.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, that right :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I merged as is then. Code is going to eventually break one way or the other :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks for iterating !
Resolves #1999
bnb_4bit_quant_storage