Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning rate scheduling support for DeepSpeedStrategy #20320

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
9 changes: 6 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from lightning_utilities.core.overrides import is_overridden
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

import lightning.fabric
Expand Down Expand Up @@ -212,6 +213,7 @@ def setup(
self,
module: nn.Module,
*optimizers: Optimizer,
scheduler: Optional[_LRScheduler] = None,
move_to_device: bool = True,
_reapply_compile: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
Expand All @@ -220,6 +222,7 @@ def setup(
Args:
module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
Expand All @@ -242,8 +245,8 @@ def setup(

# Let accelerator/plugin wrap and connect the models and optimizers
if optimizers:
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers)
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers), scheduler
)
else:
module = self._strategy.setup_module(module)
Expand Down Expand Up @@ -272,7 +275,7 @@ def setup(

if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers)
return (module, *optimizers, scheduler)
return module

def setup_module(
Expand Down
24 changes: 12 additions & 12 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing_extensions import override

from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
Expand Down Expand Up @@ -311,9 +312,9 @@ def model(self) -> "DeepSpeedEngine":

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple["DeepSpeedEngine", List[Optimizer]]:
"""Set up a model and multiple optimizers together.
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]:
"""Set up a model and multiple optimizers together along with an optional learning rate scheduler.

Currently, only a single optimizer is supported.

Expand All @@ -328,9 +329,9 @@ def setup_module_and_optimizers(
f" Got {len(optimizers)} optimizers instead."
)

self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
self._set_deepspeed_activation_checkpointing()
return self._deepspeed_engine, [optimizer]
return self._deepspeed_engine, [optimizer], scheduler

@override
def setup_module(self, module: Module) -> "DeepSpeedEngine":
Expand All @@ -339,7 +340,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
For training, see :meth:`setup_module_and_optimizers`.

"""
self._deepspeed_engine, _ = self._initialize_engine(module)
self._deepspeed_engine, _, _ = self._initialize_engine(module)
return self._deepspeed_engine

@override
Expand Down Expand Up @@ -592,10 +593,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
)

def _initialize_engine(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
) -> Tuple["DeepSpeedEngine", Optimizer]:
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None
) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

This calls :func:`deepspeed.initialize` internally.
Expand All @@ -604,15 +603,16 @@ def _initialize_engine(
import deepspeed

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
config=self.config,
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=scheduler,
dist_init_required=False,
)
return deepspeed_engine, deepspeed_optimizer
return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler

@override
def setup_environment(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing_extensions import TypeGuard, override

from lightning.fabric.accelerators import Accelerator
Expand Down Expand Up @@ -267,8 +268,8 @@ def setup_environment(self) -> None:

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]:
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
Expand All @@ -280,7 +281,7 @@ def setup_module_and_optimizers(
" call `setup_optimizer`."
)
module = self.setup_module(module)
return module, optimizers
return module, optimizers, scheduler

@override
def setup_module(self, module: Module) -> Module:
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
Expand Down Expand Up @@ -144,8 +145,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag
return stack

def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]:
"""Set up a model and multiple optimizers together.

The returned objects are expected to be in the same order they were passed in. The default implementation will
Expand All @@ -154,7 +155,7 @@ def setup_module_and_optimizers(
"""
module = self.setup_module(module)
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
return module, optimizers
return module, optimizers, scheduler

def setup_module(self, module: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from typing_extensions import override

Expand Down Expand Up @@ -196,8 +197,8 @@ def setup_environment(self) -> None:

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> Tuple[Module, List[Optimizer], Optional[_LRScheduler]]:
"""Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
raise NotImplementedError(
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
Expand Down
Loading