-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Load 8-bit quantized models for eval after fine-tuning #3606
Conversation
ludwig/trainers/trainer.py
Outdated
if torch.cuda.is_available(): | ||
self.model.model.cuda() | ||
self.model.model.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooof, maybe one callout here might be that the .cuda() call is unique and overriden for the Linear8Bit layers which internally does some stuff for 8BitParameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's not clear to me why we need to move to GPU then back to CPU like this. Comment would be great so I don't need to read the full PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment and removed the move to CPU. It turns out the model was on GPU all along: self.model.device
reports that it is on CPU, but self.model.model.device
and deeper modules in the model all report that that they are on GPU.
only_weights_format_keys = ["weights_format" in k for k in unexpected_keys] | ||
assert ( | ||
unexpected_keys == [] or only_weights_format_keys | ||
), f"Unexpected keys found in state dict: {unexpected_keys}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add something about the only_weights_format_keys
to the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in b8b487d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, am I missing something, I still don't see anything in the assert message about it?
ludwig/trainers/trainer.py
Outdated
# to a RuntimeError in `load_state_dict`. Explicitly call `model.cuda()` to make sure the | ||
# matrices are part of model state. This workaround is necessary because the matrices are | ||
# deleted during the model's forward pass. | ||
if self.device == torch.device("cuda"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check self.device.type == "cuda"
as the device might be cuda:0
, etc.
Errors
After training or fine-tuning, the best model checkpoint is loaded for evaluation. When loading an 8-bit quantized model that was fine-tuned on GPU, the following errors occur:
RuntimeError: Loading a quantized checkpoint into non-quantized Linear8bitLt is not supported. Please call module.cuda() before module.load_state_dict()"
Causes
These issues can be reproduced by running
tests/integration_tests/test_llm.py::test_llm_finetuning_strategies
with 8-bit quantization. Both issues are the result of custom handling inbitsandbytes
. They are caused byload_state_dict
inbitsandbytes
raises the RuntimeError.bitsandbytes
adds a number ofweight_format
entries to the state dict behind the scenes. These are metadata entries that are used inload_state_dict
to reconstruct the quantized parameters. Since theseweight_format
entries are never registered in model state, on load they are returned in theunexpected_keys
list. On load for eval, we assert that no unexpected keys were returned.Workaround
This update puts in a workaround that addresses both issues. For 8-bit quantized models only, at the call to load_state_dict we first move the model to GPU and back to solve 1., then we ensure that the only unexpected keys are
weight_format
keys to handle 2. This should unblock 8-bit quantization for the time being, though we should double-check model quality.