diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0ff5b04b30b0a..6d0dc2dd4073f 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -38,6 +38,7 @@ from lightning_utilities.core.overrides import is_overridden from torch import Tensor from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler import lightning.fabric @@ -212,6 +213,7 @@ def setup( self, module: nn.Module, *optimizers: Optimizer, + scheduler: Optional[_LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy @@ -220,6 +222,7 @@ def setup( Args: module: A :class:`torch.nn.Module` to set up *optimizers: The optimizer(s) to set up (no optimizers is also possible) + scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the @@ -242,8 +245,8 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers if optimizers: - module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] - module, list(optimizers) + module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] + module, list(optimizers), scheduler ) else: module = self._strategy.setup_module(module) @@ -272,7 +275,7 @@ def setup( if optimizers: # join both types in a tuple for API convenience - return (module, *optimizers) + return (module, *optimizers, scheduler) return module def setup_module( diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 93a17f10c8998..e0d9d0091657f 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -25,6 +25,7 @@ from lightning_utilities.core.imports import RequirementCache from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from typing_extensions import override from lightning.fabric.accelerators import Accelerator, CUDAAccelerator @@ -311,9 +312,9 @@ def model(self) -> "DeepSpeedEngine": @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple["DeepSpeedEngine", List[Optimizer]]: - """Set up a model and multiple optimizers together. + self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]: + """Set up a model and multiple optimizers together along with an optional learning rate scheduler. Currently, only a single optimizer is supported. @@ -328,9 +329,9 @@ def setup_module_and_optimizers( f" Got {len(optimizers)} optimizers instead." ) - self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) + self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler) self._set_deepspeed_activation_checkpointing() - return self._deepspeed_engine, [optimizer] + return self._deepspeed_engine, [optimizer], scheduler @override def setup_module(self, module: Module) -> "DeepSpeedEngine": @@ -339,7 +340,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine": For training, see :meth:`setup_module_and_optimizers`. """ - self._deepspeed_engine, _ = self._initialize_engine(module) + self._deepspeed_engine, _, _ = self._initialize_engine(module) return self._deepspeed_engine @override @@ -592,10 +593,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: ) def _initialize_engine( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - ) -> Tuple["DeepSpeedEngine", Optimizer]: + self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None + ) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. @@ -604,15 +603,16 @@ def _initialize_engine( import deepspeed model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( + deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), config=self.config, model=model, model_parameters=model_parameters, optimizer=optimizer, + lr_scheduler=scheduler, dist_init_required=False, ) - return deepspeed_engine, deepspeed_optimizer + return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler @override def setup_environment(self) -> None: diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index e7fdd29f6287f..6efc372db627b 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -39,6 +39,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from typing_extensions import TypeGuard, override from lightning.fabric.accelerators import Accelerator @@ -267,8 +268,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -280,7 +281,7 @@ def setup_module_and_optimizers( " call `setup_optimizer`." ) module = self.setup_module(module) - return module, optimizers + return module, optimizers, scheduler @override def setup_module(self, module: Module) -> Module: diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..96e856e68e5af 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -20,6 +20,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from lightning.fabric.accelerators import Accelerator @@ -144,8 +145,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]: """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -154,7 +155,7 @@ def setup_module_and_optimizers( """ module = self.setup_module(module) optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] - return module, optimizers + return module, optimizers, scheduler def setup_module(self, module: Module) -> Module: """Performs setup for the model, e.g., by wrapping it by another class.""" diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 6da693bafb1c8..b2236aedab43f 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -21,6 +21,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from typing_extensions import override @@ -196,8 +197,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]: """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" raise NotImplementedError( f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."