Skip to content

Commit

Permalink
Add FSDP strategy (#1553)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi-mosaic authored Sep 29, 2022
1 parent 7db8a99 commit aa19e24
Show file tree
Hide file tree
Showing 10 changed files with 1,189 additions and 170 deletions.
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
entry: pydocstyle
language: python
types: [python]
exclude: '(?:tests|.ci|composer\/algorithms|composer\/datasets|composer\/models)\/.*'
exclude: '(?:tests|.ci|composer\/algorithms|composer\/datasets|composer\/models)\/.*|composer\/trainer\/activation_checkpointing.py'
additional_dependencies:
- "toml"
rev: 6.1.1
Expand Down Expand Up @@ -78,6 +78,8 @@ repos:
- --comment-style
- "#"
types: [python]
exclude: 'composer\/trainer\/activation_checkpointing.py'

- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
hooks:
Expand Down
87 changes: 78 additions & 9 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from __future__ import annotations

import collections.abc
import contextlib
import logging
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Sequence, Union, cast

import torch
import torch.nn.modules.utils
from packaging import version
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torchmetrics import Metric
Expand All @@ -34,6 +37,32 @@
logger = logging.getLogger(__name__)


@contextmanager
def get_fsdp_rank0_cpu_save_context(obj: torch.nn.Module):
if version.parse(torch.__version__) < version.parse('1.12.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.12.0.')
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(obj, StateDictType.FULL_STATE_DICT, full_state_dict_config):
yield


def get_fsdp_sharded_optim_state_dict(full_optim_state_dict: Dict[str, Any], model: torch.nn.Module):
if version.parse(torch.__version__) < version.parse('1.12.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.12.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
return FSDP.scatter_full_optim_state_dict(full_optim_state_dict=full_optim_state_dict, model=model)


def get_fsdp_full_optim_state_dict(model: torch.nn.Module, optim: torch.optim.Optimizer, rank0_only: bool = True):
if version.parse(torch.__version__) < version.parse('1.12.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.12.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
return FSDP.full_optim_state_dict(model=model, optim=optim, rank0_only=rank0_only)


def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
# v0.4.1 removed the leading underscores for the keys in the state_dict
# It also renamed _is_model_ddp_wrapped to is_model_ddp
Expand Down Expand Up @@ -456,6 +485,17 @@ def deepspeed_enabled(self):
"""Indicates if deepspeed is enabled."""
return self.deepspeed_config is not None

@property
def fsdp_enabled(self):
"""Indicates if FSDP is enabled."""
if version.parse(torch.__version__) < version.parse('1.12.0'):
return False
from torch.distributed.fsdp import FullyShardedDataParallel
for module in self.model.modules():
if isinstance(module, FullyShardedDataParallel):
return True
return False

def state_dict(self) -> Dict[str, Any]:
state_dict = {}

Expand All @@ -464,17 +504,29 @@ def state_dict(self) -> Dict[str, Any]:
if attribute_name == 'model':
# Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel
# If it is DDP wrapped, do not save the `module.` prefix, as that is an implmentation detail
model_state = attribute_value.state_dict()
with get_fsdp_rank0_cpu_save_context(
attribute_value) if self.fsdp_enabled else contextlib.nullcontext():
model_state = attribute_value.state_dict()

if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state, 'module.')
serialized_value = model_state
else:
if attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
elif attribute_name == 'optimizers':
if self.fsdp_enabled:
serialized_value = {}
for obj in ensure_tuple(attribute_value):
serialized_value = {
type(obj).__qualname__:
get_fsdp_full_optim_state_dict(model=self.model, optim=obj, rank0_only=True)
}
else:
serialized_value = {
type(obj).__qualname__: obj.state_dict() for obj in ensure_tuple(attribute_value)
}
else:
serialized_value = attribute_value
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
serialized_value = {type(obj).__qualname__: obj.state_dict() for obj in ensure_tuple(attribute_value)}
else:
serialized_value = attribute_value

state_dict[attribute_name] = serialized_value

Expand All @@ -492,12 +544,28 @@ def load_model_state(self, state_dict: Dict[str, Any], strict: bool):
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)

with get_fsdp_rank0_cpu_save_context(self.model) if self.fsdp_enabled else contextlib.nullcontext():
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
if len(missing_keys) > 0:
logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")

def load_optim_state(self, state_dict: Dict[str, Any]):
serialized_value = state_dict['optimizers']
for target in ensure_tuple(self.optimizers):
if type(target).__qualname__ not in serialized_value:
warnings.warn(f'{type(target).__qualname__} is not in the state_dict. Its state will not be restored.',
category=UserWarning)
continue
source = serialized_value[type(target).__qualname__]
if self.fsdp_enabled:
sharded_osd = get_fsdp_sharded_optim_state_dict(full_optim_state_dict=source, model=self.model)
target.load_state_dict(sharded_osd)
else:
target.load_state_dict(source)

def _verify_required_algorithms_enabled(self, state: Dict[str, Any]):
"""Verifies all required algorithms are enabled when loading state.
Expand Down Expand Up @@ -553,9 +621,10 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False):

if attribute_name == 'model':
self.load_model_state(state, strict=strict)
continue
state_field_value = getattr(self, attribute_name)
if attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
elif attribute_name == 'optimizers':
self.load_optim_state(state)
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
state_field_value = getattr(self, attribute_name)
for target in ensure_tuple(state_field_value):
if type(target).__qualname__ not in serialized_value:
warnings.warn(
Expand Down
11 changes: 6 additions & 5 deletions composer/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
"""

from composer.optim.decoupled_weight_decay import DecoupledAdamW, DecoupledSGDW
from composer.optim.scheduler import (ComposerScheduler, ConstantScheduler, CosineAnnealingScheduler,
CosineAnnealingWarmRestartsScheduler, CosineAnnealingWithWarmupScheduler,
ExponentialScheduler, LinearScheduler, LinearWithWarmupScheduler,
MultiStepScheduler, MultiStepWithWarmupScheduler, PolynomialScheduler,
PolynomialWithWarmupScheduler, StepScheduler)
from composer.optim.scheduler import (ComposerScheduler, ConstantScheduler, ConstantWithWarmupScheduler,
CosineAnnealingScheduler, CosineAnnealingWarmRestartsScheduler,
CosineAnnealingWithWarmupScheduler, ExponentialScheduler, LinearScheduler,
LinearWithWarmupScheduler, MultiStepScheduler, MultiStepWithWarmupScheduler,
PolynomialScheduler, PolynomialWithWarmupScheduler, StepScheduler)

__all__ = [
'DecoupledAdamW',
'DecoupledSGDW',
'ComposerScheduler',
'ConstantScheduler',
'ConstantWithWarmupScheduler',
'CosineAnnealingScheduler',
'CosineAnnealingWarmRestartsScheduler',
'CosineAnnealingWithWarmupScheduler',
Expand Down
Loading

0 comments on commit aa19e24

Please sign in to comment.