-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Transpose weight matrix based on fan_in_fan_out condition in PiSSA initialization (#2103) #2104
FIX: Transpose weight matrix based on fan_in_fan_out condition in PiSSA initialization (#2103) #2104
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR. First of all, my apologies for not responding earlier. The notification somehow slipped my attention and I just wasn't aware of this PR. In the future, feel free to ping my after a couple of days when there is no response.
The fix looks good, thanks for that. Let's add some tests to ensure that this bug doesn't happen again. For this, could you please add the following tests to the existing PiSSA tests:
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_4bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_8bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)
For this to work, we need to make some changes to these lines though:
peft/tests/test_gpu_examples.py
Line 1721 in 5e91b54
if isinstance(module, torch.nn.Linear) and "lm_head" not in name: |
peft/tests/test_gpu_examples.py
Line 1730 in 5e91b54
if isinstance(module, torch.nn.Linear) and "lm_head" not in name: |
There, we need to change isinstance(module, torch.nn.Linear)
to isinstance(module, (torch.nn.Linear, Conv1D))
, where Conv1D
is imported from transformers.pytorch_utils
.
2a513f6
to
4d77af8
Compare
…SA initialization (huggingface#2103) This update addresses an issue where the weight matrix was converted to float32 without considering the need for transposition. The weight matrix is now transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training. To ensure this fix is robust, tests have been updated to include parameterized cases for different devices and bit configurations. Additionally, the isinstance checks have been modified to include Conv1D layers, ensuring all relevant layers are processed correctly.
4d77af8
to
1bf7d7a
Compare
@BenjaminBossan Thank you for your feedback and suggestions, I've updated this PR. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this fix, LGTM.
…ce#2104) Transpose weight matrix based on fan_in_fan_out condition in PiSSA initialization. Co-authored-by: Yang Su <suyang360@gmail.com>
Previously, the weight matrix was converted to float32 without considering the need for transposition. This update ensures that the weight matrix is transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.