-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] relax checking root condition #620
[FSDP] relax checking root condition #620
Conversation
cc @min-xu-ai , @SeanNaren |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thanks! Once CI passes I can merge if you don't see the merge button.
if not self._has_params: | ||
assert m._queue_wait_for_post_backward_closure is None | ||
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward | ||
assert m._is_root is None or m._is_root == False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert m._is_root is None or m._is_root == False | |
# We relax the assert for non-root instance. A lightning unit test triggers this otherwise. | |
assert m._is_root is None or m._is_root == False |
looks like CI passes. |
if not self._has_params: | ||
assert m._queue_wait_for_post_backward_closure is None | ||
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward | ||
# We relax the assert for non-root instance. A lightning unit test triggers this otherwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we need to mention lightning here inside of fairscale. eventually this comment will also be unclear what it was relaxed from or why its relaxed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to leave no comment so that it is might be hard to figure out in the future why was this relaxed. Maybe you can suggest something clearer?
@ananthsub makes a good point, apologies for being lazy on this one, here is a pure pytorch test we can use to simulate this, and remove the PL line: import os
import unittest
from unittest import mock
import torch
import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel
import torch.nn.functional as F
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def test_wrapping_module():
"""
This test simulates wrapping the module after training to run inference.
This is required in cases where later in a session, the model is wrapped again in FSDP but
contains nested FSDP wrappers within the module.
"""
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
module = nn.Sequential(
FullyShardedDataParallel(nn.Linear(5, 5)),
)
model = FullyShardedDataParallel(module).to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
output = model(input)
loss = F.mse_loss(input, output)
loss.backward()
model = FullyShardedDataParallel(module).to(device)
second_output = model(input)
assert torch.allclose(output, second_output)
torch.distributed.destroy_process_group() We can add this as a unit test to ensure this behaviour works! |
@SeanNaren, this is lovely, I can add a test file once my bug 617 work is done. @shuyingsunshine21, you can add a new test file as well, but please use other fsdp tests as an example. We can't use Sean's code above as is since we don't want to use hard coded tcp port which may cause test port conflict on the same machine when multiple people are running it. Also, a new test file needs to be added to one of the test list text file under tests dir. It is totally fine to leave it to me if you can wait on it a bit. |
the CI test failure might be related to #624 |
yeah, sorry about that. will be merged within the next hour. |
tests/ci_test_list_3.txt
Outdated
@@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_no_sync.py | |||
tests/nn/data_parallel/test_fsdp_summon_full_params.py | |||
tests/nn/data_parallel/test_fsdp_input.py | |||
tests/nn/data_parallel/test_fsdp_multiple_forward.py | |||
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you put the file in list_1.txt since it is shortest right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was curious about where to put and what is the difference also (so put in similarly place as rest of the fsdp).
no problem |
Please merge with master. I think Ben has fixed it already. My PR is going in soon too. |
…into lightning_fsdp_root_relax
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
from fairscale.nn.data_parallel import TrainingState | ||
from fairscale.utils.testing import dist_init, teardown, torch_version | ||
from fairscale.utils.testing import dist_init, teardown, torch_version, skip_if_no_cuda | ||
from torch.nn import Linear, Module, Sequential |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this moved by black/isort? If not, CI will fail again. Our CI is pretty strict, it will take a bit of time to get used to. But it is really good once get used to. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, after black forgot to do isort.
weird thing why it does not trigger CI |
Yeah, I have seen it today too. Perhaps a CI bug. I end up made and pushed a new commit to trigger it. |
all passed :) |
Nice! Thanks again! |
Before submitting
What does this PR do?
When integrating with Lightning, found out that as model is nested in FSDP wrapper after training, and when we call
trainer.test(model)
, it failed the assertion that the root is not set. In this case, non-root has already been set. We relax this assertion in this PR. (link to discussion: Lightning-AI/pytorch-lightning#6152)PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃