From 06eb3cc28ba07b9f213c78419c070b94e10a4fef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 8 Mar 2024 11:48:16 +0100 Subject: [PATCH] Pass `enabled` down to `_BackwardSyncControl` (#19577) --- src/lightning/fabric/CHANGELOG.md | 3 ++- src/lightning/fabric/fabric.py | 4 ++-- src/lightning/fabric/strategies/ddp.py | 5 ++++- src/lightning/fabric/strategies/fsdp.py | 6 ++++-- src/lightning/fabric/strategies/strategy.py | 2 +- src/lightning/fabric/strategies/xla_fsdp.py | 4 +++- tests/tests_fabric/strategies/test_ddp.py | 8 +++++--- tests/tests_fabric/strategies/test_fsdp.py | 8 +++++--- tests/tests_fabric/strategies/test_xla_fsdp.py | 8 ++++++-- tests/tests_fabric/test_fabric.py | 6 +++--- 10 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 4c1814bf9f11a..94741aee042fc 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -26,7 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493)) -- +- `_BackwardSyncControl` can now control what to do when gradient accumulation is disabled ([#19577](https://github.com/Lightning-AI/lightning/pull/19577)) + ### Deprecated diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 09b1203b7b85b..4b9c14eb06e62 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -672,7 +672,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte "You need to set up the model first before you can call `fabric.no_backward_sync()`:" " `model = fabric.setup(model, ...)`" ) - if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)): + if isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)): return nullcontext() if self._strategy._backward_sync_control is None: rank_zero_warn( @@ -683,7 +683,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte return nullcontext() forward_module, _ = _unwrap_compiled(module._forward_module) - return self._strategy._backward_sync_control.no_backward_sync(forward_module) + return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled) def sharded_model(self) -> ContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 3b1de32a5d98b..0ec5df1a6b0ae 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -224,9 +224,12 @@ def _determine_ddp_device_ids(self) -> Optional[List[int]]: class _DDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" + if not enabled: + return nullcontext() + if not isinstance(module, DistributedDataParallel): raise TypeError( "Blocking backward sync is only possible if the module passed to" diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 5679155380bea..ed89629f720e8 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil -from contextlib import ExitStack +from contextlib import ExitStack, nullcontext from datetime import timedelta from functools import partial from pathlib import Path @@ -768,9 +768,11 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa class _FSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" + if not enabled: + return nullcontext() from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel if not isinstance(module, FullyShardedDataParallel): diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6c29b95d2a481..1c64f97394fa2 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -424,7 +424,7 @@ class _BackwardSyncControl(ABC): """ @abstractmethod - def no_backward_sync(self, module: Module) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 19a9b77238223..1b53292ff1581 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -679,9 +679,11 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` wrapper.""" + if not enabled: + return nullcontext() from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP if not isinstance(module, XLAFSDP): diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 0365f75a7e4e1..beea7eccb69c2 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -61,13 +61,15 @@ def test_ddp_no_backward_sync(): with pytest.raises( TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock()): + ), strategy._backward_sync_control.no_backward_sync(Mock(), True): pass module = MagicMock(spec=DistributedDataParallel) - with strategy._backward_sync_control.no_backward_sync(module): + with strategy._backward_sync_control.no_backward_sync(module, False): + pass + module.no_sync.assert_not_called() + with strategy._backward_sync_control.no_backward_sync(module, True): pass - module.no_sync.assert_called_once() diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index c6d3135f621bb..3f2d02e06be2a 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -150,13 +150,15 @@ def test_fsdp_no_backward_sync(): with pytest.raises( TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock()): + ), strategy._backward_sync_control.no_backward_sync(Mock(), True): pass module = MagicMock(spec=FullyShardedDataParallel) - with strategy._backward_sync_control.no_backward_sync(module): + with strategy._backward_sync_control.no_backward_sync(module, False): + pass + module.no_sync.assert_not_called() + with strategy._backward_sync_control.no_backward_sync(module, True): pass - module.no_sync.assert_called_once() diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 9548794a3e332..bcd2a6e637417 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -50,13 +50,17 @@ def test_xla_fsdp_no_backward_sync(): with pytest.raises( TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(object()): + ), strategy._backward_sync_control.no_backward_sync(object(), True): pass module = MagicMock(spec=XlaFullyShardedDataParallel) - with strategy._backward_sync_control.no_backward_sync(module): + + with strategy._backward_sync_control.no_backward_sync(module, False): pass + module.no_sync.assert_not_called() + with strategy._backward_sync_control.no_backward_sync(module, True): + pass module.no_sync.assert_called_once() diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 9568bda6a79e4..fde9479c73eaf 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -767,11 +767,11 @@ def test_no_backward_sync(): # disabling the context manager makes it a no-op with fabric.no_backward_sync(model, enabled=False): pass - fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called() - # when enabled, the wrapped module gets passed down + fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, False) + fabric._strategy._backward_sync_control.reset_mock() with fabric.no_backward_sync(model): pass - fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module) + fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, True) def test_launch_without_function():