-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix fsdp weight tying #1856
Fix fsdp weight tying #1856
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as far as I can tell, will let abhi or other approve. Also, have you tested with the examples repo (without Vitaliy's recent fix)? I'd like to know two things 1) Does it properly respect the tied weights there? and 2) Does it change the memory/throughput?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment, but it generally looks good.
Also whats the point of file composer/scratch?
What does this PR do?
When initializing FSDP with device='meta' it undoes weight tying. This is a known issue in PyTorch with deferred initialization. Additionally, in order to address this, all weight tied modules have to be in the same FSDP module, as a result we try our best to force the FSDP parameters into the same module.
What issue(s) does this change relate to?
CO-1511
Before submitting
pre-commit
on your change? (see thepre-commit
section of prerequisites)