Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fsdp weight tying #1856

Merged
merged 10 commits into from
Jan 13, 2023
Empty file added composer/scratch
Empty file.
60 changes: 58 additions & 2 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

"""Helpers for running distributed data parallel training."""

import collections
import logging
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, ContextManager, Dict, Optional, Sequence, Union, cast

Expand All @@ -14,6 +16,7 @@

from composer.core import Precision
from composer.core.state import State
from composer.trainer.meta_safe_apply import meta_safe_apply
from composer.utils import StringEnum, dist, ensure_tuple

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module']
Expand Down Expand Up @@ -247,9 +250,62 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
for name, obj in model.named_children():
if not isinstance(obj, (Metric, MetricCollection)):

# If `obj` contains meta tensors, try to use `obj.param_init_fn` to initialize them
def _param_init_fn(module: torch.nn.Module) -> None:
module.to_empty(device=f'cuda:{torch.cuda.current_device()}')
# A dictionary of all tied parameter pointers to module names
tied_pointers = {}

# Goes through all modules finding which weights have the same pointers
for name, mod in module.named_modules():
for attr in ['weight', 'bias']:
if hasattr(mod, attr):
ptr = id(getattr(mod, attr))
ptr_attr = (ptr, attr)
name_list = tied_pointers.get(ptr_attr, [])
name_list.append(name)
tied_pointers[ptr_attr] = name_list

# Creates a dictionary of module names that should be tied together
tied_mod_names = collections.defaultdict(list)
# Creates a set of modules we should not initialize
should_not_init_params = set()
for ptr_attr_type, mod_names in tied_pointers.items():
_, attr_type = ptr_attr_type
first = next(mod_names.__iter__())
for elem in mod_names:
should_not_init_params.add('.'.join([elem, attr_type]))
tied_mod_names[(first, attr_type)].append(elem)
# Make sure at least one of the tied parameters is initialized
should_not_init_params.remove('.'.join([first, attr_type]))

meta_safe_apply(module,
lambda t: torch.empty_like(t, device=f'cuda:{torch.cuda.current_device()}'),
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
should_not_init_params,
module_name='')

if len(tied_mod_names) > 0:
warnings.warn(('The passed in model appears to have tied weights. In order to '
'support effective weight tying, the tied modules need to be '
'in the same FSDP module. If the weights are not properly tied '
'it can lead to loss spikes. We have tried our best to ensure '
'the tied weights are in the same FSDP module.'))

# Redoes weight tying
for name_attr, tied_names in tied_mod_names.items():
name, attr = name_attr
src_mod = module.get_submodule(name)
# We need to make sure the source and destination
# modules end up in the same FSDP module otherwise
# with sharding weight tying gets violated
src_mod._fsdp_wrap = False # type: ignore
src_params = getattr(src_mod, attr)
for tied_name in tied_names:
dest_mod = module.get_submodule(tied_name)
dest_mod._fsdp_wrap = False # type: ignore
if attr == 'weight':
dest_mod.weight = src_params
elif attr == 'bias':
dest_mod.bias = src_params
bcui19 marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(obj, 'param_init_fn') and isinstance(obj.param_init_fn, Callable):
module.apply(obj.param_init_fn)
elif hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
Expand Down
96 changes: 96 additions & 0 deletions composer/trainer/meta_safe_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

# Source code is compiled from a modified version of:
# https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py
# Link to PyTorch License File: https://github.com/pytorch/pytorch/blob/master/LICENSE
# TODO: This code will need to be removed when PyTorch correctly supports delayed initialization
# with meta tensors.

"""Helper function to safely call .apply for initializing meta tensors in PyTorch."""

from typing import Set

import torch
from torch.nn.parameter import Parameter


def meta_safe_apply(self, fn, ignored_modules: Set, module_name: str):
"""Applies the function recursively to a module's children and the module itself.

This variant allows us to ignore modules to apply the function.
The function is a slightly modified version of the one from PyTorch:
https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py#L637

Args:
self: the module to apply fn to.
fn: the function called to each submodule
ignored_modules: a set of names of modules to not apply fn.
module_name: the current module's name.
"""
for name, module in self.named_children():
module_name_list = [module_name, name]
if module_name == '':
module_name_list = [name]
curr_module_name = concatenate_strings(module_name_list)
meta_safe_apply(module, fn, ignored_modules, curr_module_name)

def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False

for key, param in self._parameters.items():
curr_name = concatenate_strings([module_name, key])
if param is None or curr_name in ignored_modules:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, Parameter)
assert param.is_leaf
out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param

if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param.grad.is_leaf
out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)

for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

return self


def concatenate_strings(str_list, delim='.'):
"""Concatenates a list of strings together with a delimiter in between the strings in the list.

Args:
str_list: a list of string to join.
delim: the delimiter to separate all strings
"""
return delim.join(str_list)
4 changes: 3 additions & 1 deletion tests/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tests.common.datasets import RandomClassificationDataset, RandomImageDataset, RandomSegmentationDataset
from tests.common.events import EventCounterCallback
from tests.common.markers import device, world_size
from tests.common.models import ConvModel, SimpleConvModel, SimpleModel
from tests.common.models import ConvModel, EmbeddedWeightTiedModel, SimpleConvModel, SimpleModel, SimpleWeightTiedModel
from tests.common.state import assert_state_equivalent


Expand All @@ -25,6 +25,8 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]:
'ConvModel',
'SimpleConvModel',
'SimpleModel',
'EmbeddedWeightTiedModel',
'SimpleWeightTiedModel',
'EventCounterCallback',
'deep_compare',
'device',
Expand Down
2 changes: 1 addition & 1 deletion tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RandomClassificationDataset(Dataset):
"""Classification dataset drawn from a normal distribution.

Args:
shape (Sequence[int]): shape of features (default: (5, 1, 1))
shape (Sequence[int]): shape of features (default: (1, 1, 1))
size (int): number of samples (default: 100)
num_classes (int): number of classes (default: 2)
"""
Expand Down
69 changes: 69 additions & 0 deletions tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,75 @@ def __init__(self, num_features: int = 1, num_classes: int = 2) -> None:
self.fc2 = fc2


class SimpleMLP(torch.nn.Module):

def __init__(self, num_features: int, device: str):
super().__init__()
self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False)

self.net = torch.nn.Sequential(self.fc1, torch.nn.ReLU(), self.fc2)

def forward(self, x):
return self.net(x)


class SimpleWeightTiedModel(ComposerClassifier):
"""Small classification model with tied weights.
Typically this model will be used to test weight tying w/ FSDP

Args:
num_features (int): number of input features (default: 1)
tie_weights (bool): whether or not to tie weights (default: True)
device (str): the device to initialize the model (default: 'cpu')
"""

def __init__(self, num_features: int = 1, device: str = 'cpu') -> None:
self.num_features = num_features

mlp = SimpleMLP(num_features, device)

net = torch.nn.Sequential(
mlp,
torch.nn.Softmax(dim=-1),
)

super().__init__(module=net)

self.mlp = mlp
self.net = net

self.mlp.fc1.weight = self.mlp.fc2.weight


class EmbeddedWeightTiedModel(ComposerClassifier):
"""A small classification model that consists of two simple MLPs,
and we tie weights across the simple MLPs.
Typically this model will be used to test weight tying w/ FSDP.

Args:
num_features (int): number of input features (default: 1)
device (str): the device to initialize the model (default: 'cpu')
"""

def __init__(self, num_features: int = 1, device: str = 'cpu') -> None:
net1 = SimpleMLP(num_features, device)
net2 = SimpleMLP(num_features, device)

net = torch.nn.Sequential(
net1,
net2,
torch.nn.Softmax(dim=-1),
)

super().__init__(module=net)

self.net1 = net1
self.net2 = net2

self.net1.fc1.weight = self.net2.fc1.weight


class SimpleConvModel(ComposerClassifier):
"""Small convolutional classifier.

Expand Down
46 changes: 46 additions & 0 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from packaging import version
from torch.utils.data import DataLoader

from composer.models import ComposerClassifier
from composer.trainer.trainer import Trainer
from composer.utils import dist
from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleWeightTiedModel


@pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel])
@pytest.mark.parametrize('device', ['cpu', 'meta'])
@pytest.mark.filterwarnings('ignore::UserWarning')
@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_device_initialization(model: ComposerClassifier, device: str):
"""test FSDP device initialization for a simple model with weight tying and a model where two modules
from separate submodules have weight tying applied. This test also covers both 'cpu' and
'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized

"""
num_classes = 10
model = model(num_features=num_classes, device=device)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
fsdp_config={},
max_duration='3ba',
)

trainer.fit()
if isinstance(model, SimpleWeightTiedModel):
assert (torch.equal(model.mlp.fc1.weight, model.mlp.fc2.weight))

if isinstance(model, EmbeddedWeightTiedModel):
assert (torch.equal(model.net1.fc1.weight, model.net2.fc1.weight))