Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetching arguments for FSDP #2710

Merged
merged 16 commits into from
Nov 14, 2023
20 changes: 20 additions & 0 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
fsdp_config.setdefault('flatten_parameters', True)
fsdp_config.setdefault('forward_prefetch', False)
fsdp_config.setdefault('forward_prefetch_limit', 1)
fsdp_config.setdefault('ignored_modules', None)
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
Expand Down Expand Up @@ -508,6 +510,24 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para
**kwargs,
)

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _forward_prefetch_limit '
'which is unexpected and will result in `forward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'):
fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _backward_prefetch_limit '
'which is unexpected and will result in `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
else:
warnings.warn('FSDP does not have attribute _exec_order_data which is unexpected and will '
'result in `forward_prefetch_limit` and `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')

# Activation Checkpointing
if activation_checkpointing or activation_cpu_offload:
if not activation_checkpointing_reentrant:
Expand Down
42 changes: 37 additions & 5 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
from composer.models import ComposerClassifier
from composer.trainer.trainer import Trainer
from composer.utils import dist
from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
world_size)


@pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.parametrize('device', ['cpu', 'meta'])
@pytest.mark.parametrize('reentrant', [True, False])
@pytest.mark.filterwarnings('ignore::UserWarning')
@world_size(2)
@pytest.mark.gpu
@pytest.mark.filterwarnings('ignore:The passed in model appears to have tied weights.*:UserWarning')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, device: str, reentrant: bool):
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, reentrant: bool, world_size: int,
device: str):
"""test FSDP device initialization for a simple model with weight tying and a model where two modules
from separate submodules have weight tying applied. This test also covers both 'cpu' and
'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized
Expand Down Expand Up @@ -62,15 +65,16 @@ def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision:
@pytest.mark.parametrize('model', [SimpleModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', device: str = 'meta'):
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', world_size: int):
"""
This test is intended to test FSDP for meta initialization when there are attributes
that are `None` and ensure we don't raise nasty UserWarnings.
"""
num_classes = 2
model = model(num_features=1, num_classes=num_classes, device=device, bias=False)
model = model(num_features=1, num_classes=num_classes, device='meta', bias=False)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -85,3 +89,31 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio
},
max_duration='3ba',
)


@pytest.mark.parametrize('forward_prefetch_limit', [1, 2])
@pytest.mark.parametrize('backward_prefetch_limit', [1, 2])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limit: int, world_size: int):
model = SimpleModel()
model.fc1._fsdp_wrap = True
model.fc2._fsdp_wrap = True
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
fsdp_config={
'forward_prefetch_limit': forward_prefetch_limit,
'backward_prefetch_limit': backward_prefetch_limit,
},
max_duration='3ba',
)

trainer.fit()
Loading