From 9474d75d4925eec3353a06d2682fce805fc8970c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 25 Mar 2021 11:03:14 -0400 Subject: [PATCH] [FSDP][feature] optimizer state dict save and load (#537) Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> --- .../nn/data_parallel/fsdp_optim_utils.py | 157 ++++++++++++++++++ .../fully_sharded_data_parallel.py | 129 +++++++++++++- fairscale/nn/misc/flatten_params_wrapper.py | 6 +- tests/ci_test_list_3.txt | 1 + .../test_fsdp_optimizer_utils.py | 119 +++++++++++++ 5 files changed, 401 insertions(+), 11 deletions(-) create mode 100644 fairscale/nn/data_parallel/fsdp_optim_utils.py create mode 100644 tests/nn/data_parallel/test_fsdp_optimizer_utils.py diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py new file mode 100644 index 000000000..841d7a8af --- /dev/null +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" +import copy +from typing import Dict, Generator, List, Tuple + +import torch + + +# This function helps shard a full optimizer state dict +def flatten_optim_state_dict(sd: Dict) -> Dict: + """Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)""" + param_id_map = sd["param_id_map"] + num_local_params = len(set(param_id_map.values())) + if sd["state"]: + new_state: Dict = {local_id: {} for local_id in range(num_local_params)} + else: + new_state = {} + non_tensor_state = {} + + # Populate `new_state["state"]`. (Assuming sd is sorted) + for expanded_pid, buffers in sd["state"].items(): + consolidated_pid = param_id_map[expanded_pid] + for buffer_name, p in buffers.items(): + if torch.is_tensor(p): + if buffer_name not in new_state[consolidated_pid]: + new_state[consolidated_pid][buffer_name] = [] + new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) + else: + non_tensor_state[buffer_name] = p + + # Now combine all tensors in each buffer using torch.cat(). + for consolidated_pid, state in new_state.items(): + for buffer_name, tensors in state.items(): + new_state[consolidated_pid][buffer_name] = torch.cat(tensors) + new_state[consolidated_pid].update(non_tensor_state) + new_sd = {"state": new_state, "param_groups": sd["param_groups"]} + + # add pointers from the `params` dict. + for pg_id, _ in enumerate(sd["param_groups"]): + # TODO: this list could be huge. Can we avoid materializing? + new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) + + return new_sd + + +def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: + n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) + msg = ( + f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " + f"there were {n_local_params_in_opt}" + ) + stateless = len(full_optim_state_dict["state"]) == 0 + assert stateless or (n_instances == n_local_params_in_opt), msg + + +# All functions below here help saving the list of optimizer states, one from each rank +# build_unflat_state_dict is the interface used by FSDP +def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: + non_tensor_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it. + for k, v in combined_state[param_id].items(): + if torch.is_tensor(v[0]): + continue + elif len(set(v)) == 1: + non_tensor_state[k] = v[0] + else: + raise TypeError(f"Dont know how to consolidate optimizer param {k} with values {v}") + return non_tensor_state + + +def _combine_state(states: List[Dict]) -> Dict[int, Dict]: + combined_state = states[0] + for param_id in combined_state: + combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} + if len(states) == 1: + return combined_state + + for rank, s in enumerate(states[1:]): + for param_id, param_state in s.items(): + for k, tensor in param_state.items(): + combined_state[param_id][k].append(tensor) + return combined_state + + +def _unflatten_optim_state( + combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]], +) -> Tuple[Dict[int, Dict], Dict[int, int]]: + # local ids are the keys in the current state (combined_state), (usually fewer) + # global ids will be the keys in the unflattened state + next_global_id = 0 # gets incremented + pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} + local_ids = [id for id in sorted(combined_state.keys())] + + # non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step". + # we check that these are identical across workers and then take the first + non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state] + + # local corresponds to flattened, global corresponds to unflattened + num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore + global_to_local_id = {} + for local_id, num_unflat in enumerate(num_unflat_params): + for _ in range(num_unflat): + global_to_local_id[next_global_id] = local_id + next_global_id += 1 + if not combined_state: + return {}, global_to_local_id + + # If the constant state is the same as the combined state, copy it N times, no unflattening needed. + unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))} + if non_tensor_state[0].keys() == combined_state[0].keys(): + return unflat_state, global_to_local_id + + local_to_global: Dict[int, List] = {i: [] for i in local_ids} + for g, l in global_to_local_id.items(): + local_to_global[l].append(g) + # loop over parameters in state. + # Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views + # get_views returns multiple tensors, each of which is a new parameter with a new "global" id. + for local_id in local_ids: + # undo the work of shard_parameters + for k, v in combined_state[local_id].items(): + if k in non_tensor_state[local_id]: + continue + assert isinstance(v, list), f"got {k}: {v} for {local_id}" + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] + flat_buffer = torch.cat(v_unpad) + param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore + for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): + assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" + unflat_state[global_id][k] = param_view + + return unflat_state, global_to_local_id + + +def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: + """Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank.""" + world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] + assert all(len(s) == len(instance_list) for s in world_pad_info) + assert all(len(s[0]) == 1 for s in world_pad_info) + param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) + assert len(param_groups) == 1 + + # Aggregate from a list of dictionaries to a dictionary of lists + combined_state = _combine_state([x["state"] for x in world_optim_states]) + del world_optim_states + + # local ids are in the current state, global_ids will be in returned state. + unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) + num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore + param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential + return { + "state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted + "param_id_map": global_to_local_id, + "param_groups": param_groups, + } diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 5e1e5c0de..16862cd9c 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -21,12 +21,14 @@ from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap -from fairscale.optim.utils import calc_grad_norm +from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ +from . import fsdp_optim_utils as ou + if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -88,8 +90,8 @@ class FullyShardedDataParallel(nn.Module): import torch from fairscale.nn.auto_wrap import enable_wrap, auto_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP - fsdp_params = dict(mixed_precision=True, flatten_parameters=True) - with enable_wrap(wrapper_cls=FSDP, **fsdp_params): + fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) + with enable_wrap(**fsdp_params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) assert isinstance(self.l1, FSDP) @@ -185,6 +187,8 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb + + self.numel_padded_per_param: List[int] = [] self.compute_device = compute_device if self.fp32_reduce_scatter and not self.mixed_precision: @@ -412,6 +416,7 @@ def _shard_parameters_(self) -> None: allocate less memory for optimizer state, avoiding redundancy across data parallel workers. """ + self.numel_padded_per_param = [] for p in self.params: assert not hasattr(p, "_is_sharded") assert p.is_floating_point() @@ -423,16 +428,19 @@ def _shard_parameters_(self) -> None: p._orig_size = p.data.size() if not p._is_sharded: + self.numel_padded_per_param.append(0) continue p._is_sharded = True # Replace p.data with the relevant shard. orig_data = p.data - p.data = self._get_shard(p.data) + p.data, num_padded = self._get_shard(p.data) + self.numel_padded_per_param.append(num_padded) free_storage_(orig_data) + assert len(self.numel_padded_per_param) == len(self.params) - def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: - """Return the local shard of a given full tensor.""" + def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + """Return the local shard of a full tensor.""" # Shard using torch.chunk to match all-gather/reduce-scatter. chunks = list(torch.flatten(tensor).chunk(self.world_size)) while len(chunks) < self.world_size: @@ -445,7 +453,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: shard = chunks[self.rank].clone() if num_to_pad > 0: shard = F.pad(shard, [0, num_to_pad]) - return shard + return shard, num_to_pad def extra_repr(self) -> str: return ( @@ -684,7 +692,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge if not volatile: # Copy any changes made to the full params back into # the corresponding local shards. - local_shard = self._get_shard(full_tensor) + local_shard, _ = self._get_shard(full_tensor) p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) if safe_to_free: free_storage_(full_tensor) @@ -1346,6 +1354,111 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None traceback.print_stack() raise ValueError(msg) + def _consolidate_optim_state_dict( + self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None + ) -> List[Dict]: + """Update the consolidated state_dict list, one per rank. + + Args: + + optim (Optimizer): an optimizer instance for this FSDP rank. Its state is + used in the consolidation. However, its state is not modified. + recipient_rank (int): on which rank to materialize the full state dict. + None is a special value, which means that all ranks should have the state + + Returns: + all_states (list[dict]) the optimizer state from each rank + + + .. warning: This needs to be called on all replicas""" + self._lazy_init() + # NOTE(SS): we do not support param groups yet, as they seem to break FSDP + # Pull the sharded state from all the other replicas + # Store all the states in order, rank by rank + should_collect_state = recipient_rank is None or (self.rank == recipient_rank) + all_states: List[Dict[str, Any]] = [] + dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) + for rank in range(self.world_size): + if rank == self.rank: + sd = optim.state_dict() + sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances] + else: + sd = dummy_tensor # type: ignore + sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore + if should_collect_state: + assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict" + all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu"))) + + return all_states + + def gather_full_optim_state_dict( + self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0 + ) -> Optional[Dict[str, Any]]: + """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the + sharded properties are not exposed. Multiple parameter groups are not yet supported. + + This should be called only on the root FSDP instance. + + Different world_size groups in nested FSDP instances is not supported. + Args: + optim (Optimizer): an optimizer instance for this FSDP rank. Its state is + used in the consolidation. However, its state is not modified. + recipient_rank (int): on which rank to materialize the full state dict. + + Returns: + a dict with two entries + * state - a dict holding gathered optimization state, 1 entry per unflat parameter + * param_groups - a dict containing the 1 parameter group + + """ + if not self.flatten_parameters: + raise NotImplementedError("optim state dict requires flatten_parameters=True") + world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank) + if self.rank != recipient_rank and recipient_rank is not None: + return None + # Unify the shard states by concatenating tensors and unflattening params + new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) + # TODO: check if this code supports nested instances with different world size + return new_state_dict + + @property + def _fsdp_instances(self) -> List[nn.Module]: + """Returns all fsdp modules in self.modules() including self.""" + return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] + + def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: + """Get the portion of the optimizer state dict associated with the shard + + This can be used to get the right sharded optimizer state to be loaded + into the sharded optimizer for this FSDP rank. + + Args: + full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint. + + Returns: + (dict): a shard of the optimizer state. + """ + # Assert nesting is the same as it was at save time + instance_list = self._fsdp_instances + assert all( + x.world_size == self.world_size for x in instance_list + ), "all nested instances must have same world size" + ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) + if self.flatten_parameters: + full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) + assert len(full_optim_state_dict["state"]) in (0, len(instance_list)) + + # get the portion of dict associated with the shard, in place + for id, s in full_optim_state_dict["state"].items(): + for k, v in s.items(): + if torch.is_tensor(v): + v_shard, _ = self._get_shard(v) + else: + v_shard = v # dont shard entries that are not tensors + full_optim_state_dict["state"][id][k] = v_shard + + return full_optim_state_dict + @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 733000f06..9db2f6ab6 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -122,7 +122,7 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: # register the views as plain attributes self._unflatten_params_as_views() - def _get_param_views(self, flat_param: Tensor) -> Generator: + def get_param_views(self, flat_param: Tensor) -> Generator: return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: @@ -130,7 +130,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: self.is_flattened = False flat_param = flat_param if flat_param is not None else self.flat_param - ps = self._get_param_views(flat_param) + ps = self.get_param_views(flat_param) for (m, n), p in zip(self._param_infos, ps): if hasattr(m, n): delattr(m, n) @@ -144,7 +144,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: def _unflatten_params_as_views(self) -> None: assert self.is_flattened - ps = self._get_param_views(self.flat_param) + ps = self.get_param_views(self.flat_param) for (m, n), p in zip(self._param_infos, ps): setattr(m, n, p) # This will set as plain attr for (m, n, shared_m, shared_n) in self._shared_param_infos: diff --git a/tests/ci_test_list_3.txt b/tests/ci_test_list_3.txt index 8e6cd773e..3010b6b30 100644 --- a/tests/ci_test_list_3.txt +++ b/tests/ci_test_list_3.txt @@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_regnet.py +tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/pipe/skip/test_gpipe.py diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py new file mode 100644 index 000000000..806fc7868 --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import functools +from time import time + +from parameterized import parameterized +import torch +from torch.optim import SGD, Adadelta, Adam # type: ignore + +from fairscale.nn import FullyShardedDataParallel +from fairscale.optim.utils import recursive_copy_to_device +from fairscale.utils.testing import objects_are_equal + +from .test_fsdp import ( + DistributedTest, + DummyProcessGroup, + NestedWrappedModule, + TransformerWithSharedParams, + rename_test, + spawn_and_init, +) + + +def first_tensor_numel(dct): + for k, v in dct.items(): + if torch.is_tensor(v): + return v.numel() + return 0 + + +def assert_equal(a, b): + assert a == b, f"{a} != {b}" + + +class TestOptimizerUtils(DistributedTest): + @parameterized.expand( + [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]], + name_func=rename_test, + ) + def test_consolidate_optimizer(self, optim_fn, transformer): + config = {"mixed_precision": True, "flatten_parameters": True} + test_fn = functools.partial( + self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer + ) + + spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)]) + + @classmethod + def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): + """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" + # Establish reference behavior. + + if transformer: + fsdp = self.get_wrapped_model(group, config=config).cuda() + unwrapped_model = TransformerWithSharedParams(group).cuda() + else: + fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() + unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() + + try: + fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) + optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) + except TypeError: # Adadelta + fsdp_optim = optim_fn(fsdp.parameters()) + optim_unwrapped = optim_fn(unwrapped_model.parameters()) + + fsdp_optim.zero_grad() + optim_unwrapped.zero_grad() + + x = fsdp.module.get_input(torch.device("cuda")) + output = fsdp(*x) + loss = fsdp.module.get_loss(x, output).to("cuda") + fsdp.module.run_backward(loss) + fsdp_optim.step() + + output = unwrapped_model(*x) + loss = unwrapped_model.get_loss(x, output) + unwrapped_model.run_backward(loss) + optim_unwrapped.step() + unwrapped_sd = optim_unwrapped.state_dict() + + tstart = time() + sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) + duration = time() - tstart + # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise + assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" + + if fsdp.rank > 0: + return + + assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) + assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) + assert_equal( + sum([first_tensor_numel(v) for k, v in sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]), + ) + + shard_sd = fsdp.get_shard_from_optim_state_dict(sd) + + original_shard_sd = fsdp_optim.state_dict() + assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) + assert_equal(shard_sd.keys(), original_shard_sd.keys()) + original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") + + assert_equal( + sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]), + ) + assert objects_are_equal(shard_sd, original_shard_sd) + + def test_named_params_ordering(self): + """Test assumption of consolidate_optimizer_state_dict""" + group = DummyProcessGroup(0, 1) + model = TransformerWithSharedParams(group) + named_pars = [p for n, p in model.named_parameters()] + for i, p in enumerate(model.parameters()): + assert objects_are_equal(p, named_pars[i])