From 6301da889dcf6a888c6dabbd1ce4b1488bd8dd12 Mon Sep 17 00:00:00 2001 From: Brandon Date: Thu, 15 Dec 2022 18:10:26 +0000 Subject: [PATCH 1/9] Redoing weight tying with FSDP --- composer/trainer/dist_strategy.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index e683335de8..35120869ec 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -3,6 +3,7 @@ """Helpers for running distributed data parallel training.""" +import collections import logging from contextlib import contextmanager, nullcontext from typing import Any, Callable, ContextManager, Dict, Optional, Sequence, Union, cast @@ -247,9 +248,34 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch for name, obj in model.named_children(): if not isinstance(obj, (Metric, MetricCollection)): - # If `obj` contains meta tensors, try to use `obj.param_init_fn` to initialize them def _param_init_fn(module: torch.nn.Module) -> None: + # A dictionary of all tied parameters + tied_pointers = {} + + # Goes through all modules finding which weights have the same pointers + for n, p in module.named_modules(): + if hasattr(p, 'weight'): + ptr = id(p.weight) + tied_pointers[ptr] = tied_pointers.get(ptr, set()) | set([n]) + + # Creates a dictionary of param names should be tied together + tied_params = collections.defaultdict(lambda: []) + for _, s in tied_pointers.items(): + if len(s) == 1: + continue + first = next(s.__iter__()) + for elem in s: + tied_params[first].append(elem) + module.to_empty(device=f'cuda:{torch.cuda.current_device()}') + + # Redoes weight tying + for n, tied_names in tied_params.items(): + params = module.get_submodule(n).weight + for tied_name in tied_names: + dest_module = module.get_submodule(tied_name) + dest_module.weight = params + if hasattr(obj, 'param_init_fn') and isinstance(obj.param_init_fn, Callable): module.apply(obj.param_init_fn) elif hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable): From 4cf67ed7b69ccc92cd6d9fead6afcbf6700275f7 Mon Sep 17 00:00:00 2001 From: brandon Date: Mon, 19 Dec 2022 19:38:51 +0000 Subject: [PATCH 2/9] Adding in custom safe apply for modules --- composer/trainer/dist_strategy.py | 55 +++++++++++++----- composer/trainer/meta_safe_apply.py | 89 +++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 15 deletions(-) create mode 100644 composer/trainer/meta_safe_apply.py diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 35120869ec..a2b4c83155 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -15,6 +15,7 @@ from composer.core import Precision from composer.core.state import State +from composer.trainer.meta_safe_apply import meta_safe_apply from composer.utils import StringEnum, dist, ensure_tuple __all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module'] @@ -249,32 +250,56 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch if not isinstance(obj, (Metric, MetricCollection)): def _param_init_fn(module: torch.nn.Module) -> None: - # A dictionary of all tied parameters + # A dictionary of all tied parameter pointers to module names tied_pointers = {} # Goes through all modules finding which weights have the same pointers - for n, p in module.named_modules(): - if hasattr(p, 'weight'): - ptr = id(p.weight) - tied_pointers[ptr] = tied_pointers.get(ptr, set()) | set([n]) - - # Creates a dictionary of param names should be tied together - tied_params = collections.defaultdict(lambda: []) - for _, s in tied_pointers.items(): + for name, mod in module.named_modules(): + for attr in ['weight', 'bias']: + if hasattr(mod, attr): + ptr = id(getattr(mod, attr)) + ptr_attr = (ptr, attr) + tied_pointers[ptr_attr] = tied_pointers.get(ptr_attr, set()) | set([name]) + + # Creates a dictionary of module names should be tied together + tied_mod_names = collections.defaultdict(lambda: []) + # Creates a set of modules we should initialize + should_init_params = set() + should_not_init_params = set() + for ptr_attr_type, s in tied_pointers.items(): + _, attr_type = ptr_attr_type if len(s) == 1: + should_init_params.add(next(s.__iter__())) continue first = next(s.__iter__()) + should_init_params.add(first) for elem in s: - tied_params[first].append(elem) + should_not_init_params.add('.'.join([elem, attr_type])) + tied_mod_names[(first, attr_type)].append(elem) + # Make sure at least one of the tied parameters is initialized + should_not_init_params.remove('.'.join([first, attr_type])) - module.to_empty(device=f'cuda:{torch.cuda.current_device()}') + meta_safe_apply(module, + lambda t: torch.empty_like(t, device=f'cuda:{torch.cuda.current_device()}'), + should_not_init_params, + module_name='') # Redoes weight tying - for n, tied_names in tied_params.items(): - params = module.get_submodule(n).weight + for name_attr, tied_names in tied_mod_names.items(): + name, attr = name_attr + src_mod = module.get_submodule(name) + # We need to make sure the source and destination + # modules end up in the same FSDP block otherwise + # with sharding weight tying gets violated + src_mod._fsdp_wrap = False # type: ignore + src_params = getattr(src_mod, attr) for tied_name in tied_names: - dest_module = module.get_submodule(tied_name) - dest_module.weight = params + dest_mod = module.get_submodule(tied_name) + dest_mod._fsdp_wrap = False # type: ignore + if attr == 'weight': + dest_mod.weight = src_params + elif attr == 'bias': + dest_mod.bias = src_params if hasattr(obj, 'param_init_fn') and isinstance(obj.param_init_fn, Callable): module.apply(obj.param_init_fn) diff --git a/composer/trainer/meta_safe_apply.py b/composer/trainer/meta_safe_apply.py new file mode 100644 index 0000000000..67af4978dc --- /dev/null +++ b/composer/trainer/meta_safe_apply.py @@ -0,0 +1,89 @@ +# Source code is compiled from a modified version of: +# https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py +# This code will need to be removed when PyTorch correctly supports delayed initialization +# with meta tensors. + +from typing import Set + +import torch +from torch.nn.parameter import Parameter + + +def meta_safe_apply(self, fn, ignored_modules: Set, module_name: str): + """Applies the function recursively to a module's children and the module itself. + This variant allows us to ignore modules to apply the function. + The function is a slightly modified version of the one from PyTorch: + https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py#L637 + + Args: + self: the module to apply fn to. + fn: the function called to each submodule + ignored_modules: a set of names of modules to not apply fn. + module_name: the current module's name. + """ + for name, module in self.named_children(): + module_name_list = [module_name, name] + if module_name == '': + module_name_list = [name] + curr_module_name = concatenate_strings(module_name_list) + meta_safe_apply(module, fn, ignored_modules, curr_module_name) + + def compute_should_use_set_data(tensor, tensor_applied): + if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + for key, param in self._parameters.items(): + curr_name = concatenate_strings([module_name, key]) + if param is None or curr_name in ignored_modules: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + should_use_set_data = compute_should_use_set_data(param, param_applied) + if should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, Parameter) + assert param.is_leaf + out_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param.grad is not None: + with torch.no_grad(): + grad_applied = fn(param.grad) + should_use_set_data = compute_should_use_set_data(param.grad, grad_applied) + if should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param.grad.is_leaf + out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + +def concatenate_strings(str_list, delim='.'): + """Concatenates a list of strings together with a delimiter in between the strings in the list. + + Args: + str_list: a list of string to join. + delim: the delimiter to separate all strings + """ + return delim.join(str_list) From 755d9e983a249e65850cbe086592ade1733840b9 Mon Sep 17 00:00:00 2001 From: brandon Date: Mon, 19 Dec 2022 19:53:07 +0000 Subject: [PATCH 3/9] Adding in a warning on FSDP modules with weight tying --- composer/trainer/dist_strategy.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index a2b4c83155..2453c81394 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -5,6 +5,7 @@ import collections import logging +import warnings from contextlib import contextmanager, nullcontext from typing import Any, Callable, ContextManager, Dict, Optional, Sequence, Union, cast @@ -284,12 +285,19 @@ def _param_init_fn(module: torch.nn.Module) -> None: should_not_init_params, module_name='') + if len(tied_mod_names) > 0: + warnings.warn(('The passed in model appears to have tied weights. In order to ' + 'support effective weight tying, the tied modules need to be ' + 'in the same FSDP module. If the weights are not properly tied ' + 'it can lead to loss spikes. We have tried our best to ensure ' + 'the tied weights are in the same FSDP module.')) + # Redoes weight tying for name_attr, tied_names in tied_mod_names.items(): name, attr = name_attr src_mod = module.get_submodule(name) # We need to make sure the source and destination - # modules end up in the same FSDP block otherwise + # modules end up in the same FSDP module otherwise # with sharding weight tying gets violated src_mod._fsdp_wrap = False # type: ignore src_params = getattr(src_mod, attr) From e8bb34f78e3ce092b42fb98e7c65922abfb7664e Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 5 Jan 2023 18:10:41 +0000 Subject: [PATCH 4/9] adding tests for fsdp weight tying and initialization --- tests/common/__init__.py | 4 ++- tests/common/datasets.py | 2 +- tests/common/models.py | 68 ++++++++++++++++++++++++++++++++++++++ tests/trainer/test_fsdp.py | 42 +++++++++++++++++++++++ 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/trainer/test_fsdp.py diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 93b05149a9..e5a92f6a9b 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -8,7 +8,7 @@ from tests.common.datasets import RandomClassificationDataset, RandomImageDataset, RandomSegmentationDataset from tests.common.events import EventCounterCallback from tests.common.markers import device, world_size -from tests.common.models import ConvModel, SimpleConvModel, SimpleModel +from tests.common.models import ConvModel, EmbeddedWeightTiedModel, SimpleConvModel, SimpleModel, SimpleWeightTiedModel from tests.common.state import assert_state_equivalent @@ -25,6 +25,8 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]: 'ConvModel', 'SimpleConvModel', 'SimpleModel', + 'EmbeddedWeightTiedModel', + 'SimpleWeightTiedModel', 'EventCounterCallback', 'deep_compare', 'device', diff --git a/tests/common/datasets.py b/tests/common/datasets.py index ca452185ea..a736892e97 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -14,7 +14,7 @@ class RandomClassificationDataset(Dataset): """Classification dataset drawn from a normal distribution. Args: - shape (Sequence[int]): shape of features (default: (5, 1, 1)) + shape (Sequence[int]): shape of features (default: (1, 1, 1)) size (int): number of samples (default: 100) num_classes (int): number of classes (default: 2) """ diff --git a/tests/common/models.py b/tests/common/models.py index 0d468a9cab..f6321f8993 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -47,6 +47,74 @@ def __init__(self, num_features: int = 1, num_classes: int = 2) -> None: self.fc2 = fc2 +class SimpleMLP(torch.nn.Module): + + def __init__(self, num_features: int, device: str): + super().__init__() + self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + + self.net = torch.nn.Sequential(self.fc1, torch.nn.ReLU(), self.fc2) + + def forward(self, x): + return self.net(x) + + +class SimpleWeightTiedModel(ComposerClassifier): + """Small classification model with tied weights. + Typically this model will be used to test weight tying w/ FSDP + + Args: + num_features (int): number of input features (default: 1) + tie_weights (bool): whether or not to tie weights (default: True) + device (str): the device to initialize the model (default: 'cpu') + """ + + def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: + self.num_features = num_features + + mlp = SimpleMLP(num_features, device) + + net = torch.nn.Sequential( + mlp, + torch.nn.Softmax(dim=-1), + ) + + super().__init__(module=net) + + self.mlp = mlp + self.net = net + + self.mlp.fc1.weight = self.mlp.fc2.weight + + +class EmbeddedWeightTiedModel(ComposerClassifier): + """A small classification model that consists of two + Typically this model will be used to test weight tying w/ FSDP. + + Args: + num_features (int): number of input features (default: 1) + device (str): the device to initialize the model (default: 'cpu') + """ + + def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: + net1 = SimpleMLP(num_features, device) + net2 = SimpleMLP(num_features, device) + + net = torch.nn.Sequential( + net1, + net2, + torch.nn.Softmax(dim=-1), + ) + + super().__init__(module=net) + + self.net1 = net1 + self.net2 = net2 + + self.net1.fc1.weight = self.net2.fc1.weight + + class SimpleConvModel(ComposerClassifier): """Small convolutional classifier. diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py new file mode 100644 index 0000000000..af2c437ad5 --- /dev/null +++ b/tests/trainer/test_fsdp.py @@ -0,0 +1,42 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch.utils.data import DataLoader + +from composer.models import ComposerClassifier +from composer.trainer.trainer import Trainer +from composer.utils import dist +from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleWeightTiedModel + + +@pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel]) +@pytest.mark.parametrize('device', ['cpu', 'meta']) +@pytest.mark.filterwarnings('ignore::UserWarning') +def test_fsdp_device_initialization(model: ComposerClassifier, device: str): + """test FSDP device initialization for a simple model with weight tying and a model where two modules + from separate submodules have weight tying applied. This test also covers both 'cpu' and + 'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized + + """ + num_classes = 10 + model = model(num_features=num_classes, device=device) + dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + fsdp_config={}, + max_duration='3ba', + ) + + trainer.fit() + if isinstance(model, SimpleWeightTiedModel): + assert (torch.equal(model.mlp.fc1.weight, model.mlp.fc2.weight)) + + if isinstance(model, EmbeddedWeightTiedModel): + assert (torch.equal(model.net1.fc1.weight, model.net2.fc1.weight)) From 6211de6a6aa318d4278348a896bd5b83db75980b Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 5 Jan 2023 18:23:46 +0000 Subject: [PATCH 5/9] Removing extra code --- composer/trainer/dist_strategy.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 2453c81394..418d62f523 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -264,16 +264,11 @@ def _param_init_fn(module: torch.nn.Module) -> None: # Creates a dictionary of module names should be tied together tied_mod_names = collections.defaultdict(lambda: []) - # Creates a set of modules we should initialize - should_init_params = set() + # Creates a set of modules we should not initialize should_not_init_params = set() for ptr_attr_type, s in tied_pointers.items(): _, attr_type = ptr_attr_type - if len(s) == 1: - should_init_params.add(next(s.__iter__())) - continue first = next(s.__iter__()) - should_init_params.add(first) for elem in s: should_not_init_params.add('.'.join([elem, attr_type])) tied_mod_names[(first, attr_type)].append(elem) From abbc7f3bfb9814198bed9e7947661e3d5c7e475d Mon Sep 17 00:00:00 2001 From: Brandon Date: Fri, 6 Jan 2023 20:54:13 +0000 Subject: [PATCH 6/9] Resolving comments, cleaning up a bit of code --- composer/scratch | 0 composer/trainer/dist_strategy.py | 14 ++++++++------ composer/trainer/meta_safe_apply.py | 9 ++++++++- tests/common/models.py | 3 ++- 4 files changed, 18 insertions(+), 8 deletions(-) create mode 100644 composer/scratch diff --git a/composer/scratch b/composer/scratch new file mode 100644 index 0000000000..e69de29bb2 diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 418d62f523..1d806700e4 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -260,16 +260,18 @@ def _param_init_fn(module: torch.nn.Module) -> None: if hasattr(mod, attr): ptr = id(getattr(mod, attr)) ptr_attr = (ptr, attr) - tied_pointers[ptr_attr] = tied_pointers.get(ptr_attr, set()) | set([name]) + name_list = tied_pointers.get(ptr_attr, []) + name_list.append(name) + tied_pointers[ptr_attr] = name_list - # Creates a dictionary of module names should be tied together - tied_mod_names = collections.defaultdict(lambda: []) + # Creates a dictionary of module names that should be tied together + tied_mod_names = collections.defaultdict(list) # Creates a set of modules we should not initialize should_not_init_params = set() - for ptr_attr_type, s in tied_pointers.items(): + for ptr_attr_type, mod_names in tied_pointers.items(): _, attr_type = ptr_attr_type - first = next(s.__iter__()) - for elem in s: + first = next(mod_names.__iter__()) + for elem in mod_names: should_not_init_params.add('.'.join([elem, attr_type])) tied_mod_names[(first, attr_type)].append(elem) # Make sure at least one of the tied parameters is initialized diff --git a/composer/trainer/meta_safe_apply.py b/composer/trainer/meta_safe_apply.py index 67af4978dc..f5a813a082 100644 --- a/composer/trainer/meta_safe_apply.py +++ b/composer/trainer/meta_safe_apply.py @@ -1,8 +1,14 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + # Source code is compiled from a modified version of: # https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py -# This code will need to be removed when PyTorch correctly supports delayed initialization +# Link to PyTorch License File: https://github.com/pytorch/pytorch/blob/master/LICENSE +# TODO: This code will need to be removed when PyTorch correctly supports delayed initialization # with meta tensors. +"""Helper function to safely call .apply for initializing meta tensors in PyTorch.""" + from typing import Set import torch @@ -11,6 +17,7 @@ def meta_safe_apply(self, fn, ignored_modules: Set, module_name: str): """Applies the function recursively to a module's children and the module itself. + This variant allows us to ignore modules to apply the function. The function is a slightly modified version of the one from PyTorch: https://github.com/pytorch/pytorch/blob/v1.13.0/torch/nn/modules/module.py#L637 diff --git a/tests/common/models.py b/tests/common/models.py index f6321f8993..008aad7db8 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -89,7 +89,8 @@ def __init__(self, num_features: int = 1, device: str = 'cpu') -> None: class EmbeddedWeightTiedModel(ComposerClassifier): - """A small classification model that consists of two + """A small classification model that consists of two simple MLPs, + and we tie weights across the simple MLPs. Typically this model will be used to test weight tying w/ FSDP. Args: From 11d1a81acf15a201245575fae72fd4111d955628 Mon Sep 17 00:00:00 2001 From: Brandon Date: Fri, 6 Jan 2023 22:33:50 +0000 Subject: [PATCH 7/9] Adding in qualifications for fsdp meta tensor tests --- tests/trainer/test_fsdp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index af2c437ad5..6c8ad66ceb 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -3,6 +3,7 @@ import pytest import torch +from packaging import version from torch.utils.data import DataLoader from composer.models import ComposerClassifier @@ -14,6 +15,9 @@ @pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel]) @pytest.mark.parametrize('device', ['cpu', 'meta']) @pytest.mark.filterwarnings('ignore::UserWarning') +@pytest.mark.gpu +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), + reason='requires PyTorch 1.13 or higher') def test_fsdp_device_initialization(model: ComposerClassifier, device: str): """test FSDP device initialization for a simple model with weight tying and a model where two modules from separate submodules have weight tying applied. This test also covers both 'cpu' and From c31ca4ff16eb61cc9d3aa2640ce0332ebbe50ba0 Mon Sep 17 00:00:00 2001 From: brandon Date: Fri, 13 Jan 2023 19:19:23 +0000 Subject: [PATCH 8/9] Removing extraneous file --- composer/scratch | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 composer/scratch diff --git a/composer/scratch b/composer/scratch deleted file mode 100644 index e69de29bb2..0000000000 From c24c1b573c7ba07e3ace601f319de7d049158daf Mon Sep 17 00:00:00 2001 From: brandon Date: Fri, 13 Jan 2023 19:28:56 +0000 Subject: [PATCH 9/9] Cleaning up python attributes --- composer/trainer/dist_strategy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 1d806700e4..12bf873656 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -269,6 +269,9 @@ def _param_init_fn(module: torch.nn.Module) -> None: # Creates a set of modules we should not initialize should_not_init_params = set() for ptr_attr_type, mod_names in tied_pointers.items(): + # No modules for this pointer are tied + if len(mod_names) == 1: + continue _, attr_type = ptr_attr_type first = next(mod_names.__iter__()) for elem in mod_names: @@ -301,10 +304,7 @@ def _param_init_fn(module: torch.nn.Module) -> None: for tied_name in tied_names: dest_mod = module.get_submodule(tied_name) dest_mod._fsdp_wrap = False # type: ignore - if attr == 'weight': - dest_mod.weight = src_params - elif attr == 'bias': - dest_mod.bias = src_params + setattr(dest_mod, attr, src_params) if hasattr(obj, 'param_init_fn') and isinstance(obj.param_init_fn, Callable): module.apply(obj.param_init_fn)