Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Error with OLoRA init when using bnb #2011

Merged
merged 8 commits into from
Sep 3, 2024

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Aug 16, 2024

Resolves #1999

  • dtype check for bnb quantized weights can be misleading when using bnb_4bit_quant_storage
  • updated base weights must be re-quantized if original weight is bnb quantized

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member Author

@tokenizer-decode Would be great if you could review.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review August 21, 2024 08:11
@tokenizer-decode
Copy link
Contributor

Been very busy lately. Is this urgent? I'll take a look at it this week.

@BenjaminBossan
Copy link
Member Author

Been very busy lately. Is this urgent? I'll take a look at it this week.

That's more than enough, thanks.

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Aug 23, 2024

Interesting that this happens for specific models. I got the point we need to create 4bit and 8bit objects ourselves. I guess bnb is not a peft requirement and that's why you are conditionally importing it? Implementation looks good. But beside the actual problem I think olora_init is becoming very complex and hard to maintain at this point. How about doing something like this:

def transform_if_necessary(weights: torch.nn.Parameter) -> torch.nn.Parameter:
    if weights.__class__.__name__ in ["Params4bit", "Int8Params"]:
        return weights.__class__(dequantize_module_weight(weights), quant_type=weights.quant_type).to(weights.device)
    return weights

And we would do:

    def olora_init(self, adapter_name):
        base_layer = self.get_base_layer()
        orig_weight = base_layer.weight
        dtype = orig_weight.dtype

        if dtype in [torch.float32, torch.float16, torch.bfloat16, "Actual_Possible_BNB_Types"]:
            weight_tensor = orig_weight
        else:
            raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")

        scale_factor = self.scaling[adapter_name]
        r = self.r[adapter_name]
        weight_tensor = weight_tensor.to(torch.float32)
        Q, R = torch.linalg.qr(weight_tensor.data)

        Qr, Rr = Q[:, :r], R[:r]
        self.lora_A[adapter_name].weight.data = Rr.contiguous()
        self.lora_B[adapter_name].weight.data = Qr.contiguous()

        weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
        base_layer.weight = transform_if_necessary(orig_weight)

Haven't tested this. It may need adjustments. I might be being too pedantic here. Just a recommendation. Otherwise it looks good.

@BenjaminBossan
Copy link
Member Author

I guess bnb is not a peft requirement and that's why you are conditionally importing it?

Exactly.

I think olora_init is becoming very complex and hard to maintain at this point.

This is unfortunately the nature of the beast. As usage grows, more edge cases are discovered that need to be taken care of. I don't see a big advantage in moving the layer creation into a separate function, it could actually make the code harder to understand (I'd do it if we foresee it being used elsewhere too). But I see a point in using weight.__class__ to avoid the bnb import. WDYT?

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Aug 23, 2024

it could actually make the code harder to understand

Fair enough. But I think there is advantage in eliminating the repetiton with single a weight.__class__ call

But I see a point in using weight.__class__ to avoid the bnb import. WDYT?

Yeah I saw that after sending the comment. Nice touch. Wouldn't hurt imo.

@BenjaminBossan
Copy link
Member Author

But I think there is advantage in eliminating the repetiton with single a weight.__class__ call

The issue is that the init of the 4bit vs 8bit params class is a bit different, so we cannot make it the same call, or do you mean something else?

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Aug 23, 2024

I mean we make a single call when we do this:
if weights.__class__.__name__ in ["Params4bit", "Int8Params"]: return weights.__class__(dequantize_module_weight(weights),quant_type=weights.quant_type).to(weights.device)

@BenjaminBossan
Copy link
Member Author

Hmm, can we though? The constructors of Params4bit and Int8Params are different, e.g. the former has requires_grad=False and the latter requires_grad=True (not sure why).

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Aug 23, 2024

Didn't notice that. Wouldn't that mean Int8Params will not be updated? Maybe it should be set True for inference. Have you tried both with requires_grad=True?

@BenjaminBossan
Copy link
Member Author

Wouldn't that mean Int8Params will not be updated?

Well, the point of QLoRA is exactly that the quantized base weights are not updated ;-) Not sure if it's even possible tbh but in any case we don't want that for PEFT.

@tokenizer-decode
Copy link
Contributor

tokenizer-decode commented Aug 23, 2024

But Int8 set to be True by default. Weird. I don't know much about QLoRA tbh. Anyway this still would not invalidate our approach. You would instead do:
weights.__class__(dequantize_module_weight(weights),quant_type=weights.quant_type, requires_grad=False).to(weights.device)

@BenjaminBossan
Copy link
Member Author

You would instead do:
weights.__class__(dequantize_module_weight(weights),quant_type=weights.quant_type, requires_grad=False).to(weights.device)

This still does not quite work, as 8bit params don't have the quant_type attribute. I simplified the code to use __class__ now, but I don't think more than that is possible. Please take a look.

@tokenizer-decode
Copy link
Contributor

Okay at least we don't import bnb. Lgtm.

@BenjaminBossan BenjaminBossan requested a review from SunMarc August 26, 2024 10:10
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @BenjaminBossan ! I left a small question to better understand what's happening !

Comment on lines 194 to 204
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(weight_tensor, quant_type=orig_weight.quant_type).to(
orig_weight.device
)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(weight_tensor, requires_grad=False).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we quantizing the weights this time ?

Copy link
Member Author

@BenjaminBossan BenjaminBossan Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally for bnb weights, the tensors are flattened, e.g. shape [64, 1]. But after dequantizing, the weight_tensor that we assign here is not flat anymore, e.g. shape [16, 16]. My reasoning was that we should get back a "correct" tensor, so better to re-initialize it.

I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor and curiously, this seems to work too and the test passes, even if the shape is now wrong. This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)? Not sure, any suggestion?

Copy link
Member

@SunMarc SunMarc Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried what happens when I remove this and just do base_layer.weight.data = weight_tensor

Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this. cc @matthewdouglas

This makes me wonder if bnb somehow handles this automatically and we should not re-initialize (which could cause its own problems)?

I think that's fine as long as you pass the relevant kwargs that you can get from orig_weight. However, make sure to not pass bnb_quantized arg for the 4-bit case. Then, with to(orig_weight.device), it should quantize the weights properly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for the additional info.

However, make sure to not pass bnb_quantized arg for the 4-bit case. Then, with to(orig_weight.device), it should quantize the weights properly.

To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device) too?

Wow that's really strange indeed. I tried to check the code in bnb and it doesn't look like they handle this.

Okay, then it's probably better to get Matthew's opinion before merging this.

Copy link
Member

@matthewdouglas matthewdouglas Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as the shapes go, both 4bit and 8bit have some mechanisms in place to track the original shapes, but it's different for each. The Linear8bitLt has a state.SB and for 4bit that information is part of quant_state. The main expectation is that it is all stored in a contiguous row-major format.

That said, it's not really clear to me that dequantize_module_weight() is doing all that it would need to do. Maybe it would pass the test here but I would think the updated weights would not be quantized properly afterwards, so re-initializing it is probably the way to go.

To clarify, is the present code in alignment with what you suggest or do I need to call to(orig_weight.device) too?

You'd want to have .to(orig_weight.device) in addition to the other kwargs as @SunMarc mentioned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the inits to take into account all arguments. Unfortunately, this may get out of date if bnb is updated, but I think there is no method such as bnb.create_param_like(tensor) or such to offload this work to bnb.

It would be great if you could do a final pass over the change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just pass orig_weight.__dict_ as the kwargs ? This is what how we did it in transformers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I wonder if that's really more robust. If a new attribute is added that is not an __init__ argument, this would fail, right?

class Foo:
    def __init__(self, x):
        self.x = x
        self.y = 123

foo = Foo("hi")
Foo(**foo.__dict__)  # TypeError: Foo.__init__() got an unexpected keyword argument 'y'

So no matter what, this code may break if there is some change to the __init__ code in bnb.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, that right :/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I merged as is then. Code is going to eventually break one way or the other :D

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for iterating !

@BenjaminBossan BenjaminBossan merged commit 37b9c5c into huggingface:main Sep 3, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-olora-bnb branch September 3, 2024 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Lora initialisation with olora and pissa not working with quantisation.
5 participants