From ee088bb6619d911d2a1f02647432329b57b22a79 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Mar 2021 20:47:53 -0400 Subject: [PATCH 01/31] consolidate works --- .../fully_sharded_data_parallel.py | 120 +++++++++++++++++- fairscale/optim/oss.py | 4 +- tests/nn/data_parallel/test_fsdp.py | 38 +++++- tests/optim/test_oss.py | 3 + 4 files changed, 161 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 4bfaee988..6712247a2 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -177,6 +177,7 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb + self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") @@ -1267,6 +1268,124 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None print(msg) raise ValueError(msg) + # Optim State dict interfaces + def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: + """Update the consolidated state_dict list, one per rank. + + Arguments: + recipient_rank (int): on which rank to materialize the full state dict. + -1 is a special value, which means that all ranks should have the state + + .. warning: This needs to be called on all replicas""" + + # Sync lr and other attributes in case its been updated + from fairscale.optim import OSS + from fairscale.optim.utils import recursive_copy_to_device, broadcast_object + _default_device = torch.device('cuda') + + # OSS._sync_param_groups(self.param_groups, optim.param_groups) + + # Pull the sharded state from all the other replicas + # Store all the states in order, rank by rank + #print("Pulling the sharded optimizer state from all replicas") + + self._all_optimizer_states = [] + should_collect_state = self.rank == recipient_rank or recipient_rank == -1 + should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 + print(f'rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}') + + for rank in range(self.world_size): + if rank == self.rank: + if should_collect_state: + print(f"{rank} Saving self state") + self._all_optimizer_states.append( + recursive_copy_to_device(optim.state_dict(), non_blocking=True, device=torch.device("cpu")) + ) + + # Sync with other replicas + state_to_share = ( + optim.state_dict() + if should_send_state + else torch.tensor([0], dtype=torch.uint8, device=_default_device) + ) + broadcast_object( + state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device, + ) + else: + # Fetch the optim state from the other replicas + replica_state = broadcast_object( + torch.tensor([0], dtype=torch.uint8, device=_default_device), + src_rank=rank, + group=self.process_group, + dist_device=_default_device, + ) + + if should_collect_state: + self._all_optimizer_states.append( + recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) + ) + + print(f"State from rank {rank} received: {self._all_optimizer_states[-1]}") + + def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: + """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the + sharded properties are not exposed. + + + Arguments: + all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on + all ranks + + Returns: + a dict with two entries + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + + * param_groups - a dict containing all parameter groups + + .. warning: + Returning the global state is limited to the replica which was responsible for the consolidation, + if `all_ranks` was not set to `True`. In that case, the state may also not be up to date, + depending on when `consolidate_state_dict` was last called. + """ + + if not all_ranks and len(self._all_optimizer_states) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. \ + Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state" + ) + + if all_ranks: + # Consolidate the state on every rank + self.consolidate_state_dict(recipient_rank=-1) + + # Unify the shard states and the state that pytorch would expect, given the model. + + state_dict = self._all_optimizer_states[0] + + #for c in state + from collections import defaultdict + #state_dict_state = defaultdict(defaultdict() + if self.world_size == 1: return state_dict + + # - go through the per-shard states + + for rank, s in enumerate(self._all_optimizer_states[1:]): + # -- match the local indexing and the global partition, update the corresponding saved state globally + for local_pg in s["param_groups"]: + + for local_param_index in local_pg["params"]: + # Update the state, if any + if local_param_index in s["state"].keys(): + #global_id = self.param_to_index[local_index_to_param_id[local_param_index]] + # This next line is way wrong! + state_dict["state"] = s["state"][local_param_index] + + # Make sure that the parameters are sorted in the state, as expected for a pytorch dict + state_dict["state"] = dict(sorted(state_dict["state"].items())) + + return state_dict + @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: @@ -1313,7 +1432,6 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: assert data.storage().size() == 0 data.storage().resize_(size.numel()) - def _post_state_dict_hook( module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any ) -> "OrderedDict[str, torch.Tensor]": diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index fbb6ab2e4..ed0a6178f 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -318,7 +318,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: self._all_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 - should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 + should_send_state = (self.rank != recipient_rank) or recipient_rank == -1 for rank in range(self.world_size): if rank == self.rank: @@ -351,7 +351,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) ) - logging.debug("State from rank %s received", rank) + print(f"State from rank {rank} received: ") def local_state_dict(self) -> dict: """ .. deprecated:: 0.1.5 diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 9813d1142..4d4508e67 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -44,7 +44,7 @@ def setUp(self): raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") @staticmethod - def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): + def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None, save_optim=False): model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests @@ -254,6 +254,42 @@ def test_cpu_offload_and_cuda_grads_breaks(self): ) spawn_and_init(test_fn) + def test_consolidate_optimizer(self): + config = {"mixed_precision": True} + test_fn = functools.partial( + self._test_consolidated_optimizer, config, + ) + spawn_and_init(test_fn) + + @classmethod + def _test_consolidated_optimizer(self, config, rank, group): + """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" + # Establish reference behavior. + fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) + fsdp_optim = torch.optim.SGD(fsdp.parameters(), lr=0.01, momentum=0.9) + fsdp_optim.zero_grad() + + src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) + output = fsdp(src_ids, tgt_ids) + loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to('cuda') + fsdp.module.run_backward(loss) + fsdp_optim.step() + fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) + + if rank == 0: + assert fsdp._all_optimizer_states + torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') + fsdp_state_dict = fsdp.optim_state_dict() + unwrapped_model = TransformerWithSharedParams(group).cuda() + optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) + output = unwrapped_model(src_ids, tgt_ids) + loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) + unwrapped_model.run_backward(loss) + optim_unwrapped.step() + #assert objects_are_equal(fsdp_state_dict, optim_unwrapped.state_dict(), raise_exception=True) + #optim_unwrapped.load_state_dict(fsdp_state_dict) + + def test_delayed_optim_step(self): # We use a model with a long CUDA delay right before the optimizer step. # This tests our streams logic, and that we don't start the FP32 -> FP16 diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index 8d8e141c1..b11f07a92 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -441,6 +441,9 @@ def closure(): # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(recipient_rank=reference_rank) + # if rank == reference_rank: + # #types = [type(x) for x in optimizer._all_states] + # #assert all(isinstance(s, dict) for s in optimizer._all_states), types # Fetch the state on the reference rank # - check that it has the correct size From ad7df24a443de2c8ff68c3590a383587b1da3f8c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Mar 2021 22:35:48 -0400 Subject: [PATCH 02/31] cat --- .../fully_sharded_data_parallel.py | 46 ++++++++++++++----- fairscale/nn/misc/flatten_params_wrapper.py | 9 ++-- tests/nn/data_parallel/test_fsdp.py | 10 ++++ 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 6712247a2..0db05ec70 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -7,6 +7,7 @@ import copy from enum import Enum, auto import functools +from collections import defaultdict from math import inf from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union @@ -340,6 +341,7 @@ def _shard_parameters_(self) -> None: allocate less memory for optimizer state, avoiding redundancy across data parallel workers. """ + self.num_padded = [] for p in self.params: assert not hasattr(p, "_is_sharded") assert p.is_floating_point() @@ -351,15 +353,17 @@ def _shard_parameters_(self) -> None: p._orig_size = p.data.size() if not p._is_sharded: + self.num_padded.append(0) continue p._is_sharded = True # Replace p.data with the relevant shard. orig_data = p.data - p.data = self._get_shard(p.data) + p.data, num_padded = self._get_shard(p.data) + self.num_padded.append(num_padded) free_storage_(orig_data) - def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: + def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: """Return the local shard of a given full tensor.""" # Shard using torch.chunk to match all-gather/reduce-scatter. chunks = list(torch.flatten(tensor).chunk(self.world_size)) @@ -373,7 +377,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: shard = chunks[self.rank].clone() if num_to_pad > 0: shard = F.pad(shard, [0, num_to_pad]) - return shard + return shard, num_to_pad def extra_repr(self) -> str: return ( @@ -609,7 +613,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge if not volatile: # Copy any changes made to the full params back into # the corresponding local shards. - local_shard = self._get_shard(full_tensor) + local_shard, _ = self._get_shard(full_tensor) p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) if safe_to_free: free_storage_(full_tensor) @@ -1294,6 +1298,8 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 print(f'rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}') + + for rank in range(self.world_size): if rank == self.rank: if should_collect_state: @@ -1361,30 +1367,48 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: # Unify the shard states and the state that pytorch would expect, given the model. - state_dict = self._all_optimizer_states[0] + sd0 = self._all_optimizer_states[0] #for c in state from collections import defaultdict #state_dict_state = defaultdict(defaultdict() - if self.world_size == 1: return state_dict + if self.world_size == 1: return sd0 # - go through the per-shard states + assert len(sd0['param_groups']) == 1, 'not yet supported' + #new = sd0.copy() + for pg0 in sd0['param_groups']: + for param_id in pg0["params"]: + sd0['state'][param_id] = {k: [v] for k,v in sd0['state'][param_id].items()} # so we can append for rank, s in enumerate(self._all_optimizer_states[1:]): # -- match the local indexing and the global partition, update the corresponding saved state globally for local_pg in s["param_groups"]: - for local_param_index in local_pg["params"]: # Update the state, if any if local_param_index in s["state"].keys(): #global_id = self.param_to_index[local_index_to_param_id[local_param_index]] - # This next line is way wrong! - state_dict["state"] = s["state"][local_param_index] + for k in s['state'][local_param_index]: + new_entry = s['state'][local_param_index][k] + sd0['state'][local_param_index][k].append(new_entry) + else: + raise KeyError(f'lost {local_param_index} from rank {rank}') + + # Concatenate everything + for pg0 in sd0['param_groups']: + for param_id in pg0["params"]: + # This attempts to undo the work of shard_parameters. + # It might be assuming self.flatten_parameters=True + #stuff = sd0['state'][param_id] + #assert isinstance(sd0['state'][param_id], list), f'{param_id}, {stuff}' + #print(f'{stuff[0]}') + for k,v in sd0['state'][param_id].items(): + sd0['state'][param_id][k] = torch.cat(v) # Make sure that the parameters are sorted in the state, as expected for a pytorch dict - state_dict["state"] = dict(sorted(state_dict["state"].items())) + sd0["state"] = dict(sorted(sd0["state"].items())) - return state_dict + return sd0 @torch.no_grad() diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 733000f06..bf4cf40dd 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -122,15 +122,16 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: # register the views as plain attributes self._unflatten_params_as_views() - def _get_param_views(self, flat_param: Tensor) -> Generator: - return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) + def get_param_views(self, flat_param: Tensor) -> Generator: + splat = flat_param.split(self._param_numels) + return (t.view(s) for (t, s) in zip(splat, self._param_shapes)) def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: assert self.is_flattened or flat_param is not None self.is_flattened = False flat_param = flat_param if flat_param is not None else self.flat_param - ps = self._get_param_views(flat_param) + ps = self.get_param_views(flat_param) for (m, n), p in zip(self._param_infos, ps): if hasattr(m, n): delattr(m, n) @@ -144,7 +145,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: def _unflatten_params_as_views(self) -> None: assert self.is_flattened - ps = self._get_param_views(self.flat_param) + ps = self.get_param_views(self.flat_param) for (m, n), p in zip(self._param_infos, ps): setattr(m, n, p) # This will set as plain attr for (m, n, shared_m, shared_n) in self._shared_param_infos: diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 4d4508e67..81f01efe4 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -280,6 +280,7 @@ def _test_consolidated_optimizer(self, config, rank, group): assert fsdp._all_optimizer_states torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') fsdp_state_dict = fsdp.optim_state_dict() + torch.save(fsdp_state_dict, f'fsdp_consolidated_{fsdp.world_size}.pt') unwrapped_model = TransformerWithSharedParams(group).cuda() optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) output = unwrapped_model(src_ids, tgt_ids) @@ -424,6 +425,15 @@ def _test_param_change_after_init(self, rank, group, config): assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" + def test_named_params_ordering(self): + """Test assumption of consolidate_optimizer_state_dict""" + group = DummyProcessGroup(0, 1) + model = TransformerWithSharedParams(group) + named_pars = [p for n,p in model.named_parameters()] + for i, p in enumerate(model.parameters()): + assert p.shape == named_pars[i].shape + + class TestSerialization(DistributedTest): @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) def test_pickle(self, mixed_precision, cpu_offload): From ed7526aba2ec998daba28b42ca73c62a6a91acbf Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Mar 2021 23:16:30 -0400 Subject: [PATCH 03/31] Unpad before cat --- .../fully_sharded_data_parallel.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 0db05ec70..c2b196ad9 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1298,19 +1298,19 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 print(f'rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}') - - for rank in range(self.world_size): if rank == self.rank: + sd = optim.state_dict() + sd['num_padded'] = self.num_padded # Communicate between ranks if should_collect_state: - print(f"{rank} Saving self state") + print(f"{rank} Saving self state keys {list(sd.keys())}") self._all_optimizer_states.append( - recursive_copy_to_device(optim.state_dict(), non_blocking=True, device=torch.device("cpu")) + recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) ) # Sync with other replicas state_to_share = ( - optim.state_dict() + sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device) ) @@ -1354,6 +1354,8 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: if `all_ranks` was not set to `True`. In that case, the state may also not be up to date, depending on when `consolidate_state_dict` was last called. """ + if not self.flatten_parameters: + raise NotImplementedError('optim state dict requires flatten_parameters=True') if not all_ranks and len(self._all_optimizer_states) == 0: raise RuntimeError( @@ -1368,10 +1370,9 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: # Unify the shard states and the state that pytorch would expect, given the model. sd0 = self._all_optimizer_states[0] - - #for c in state - from collections import defaultdict - #state_dict_state = defaultdict(defaultdict() + assert 'num_padded' in sd0 + all_num_padded = [s.pop('num_padded')[0] for s in self._all_optimizer_states] + assert all_num_padded[0] == 0, f'this code assumes rank 0 param not padded {all_num_padded[0]}' if self.world_size == 1: return sd0 # - go through the per-shard states @@ -1379,19 +1380,20 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: #new = sd0.copy() for pg0 in sd0['param_groups']: for param_id in pg0["params"]: + # BUG if rank 0's param is padded sd0['state'][param_id] = {k: [v] for k,v in sd0['state'][param_id].items()} # so we can append for rank, s in enumerate(self._all_optimizer_states[1:]): - # -- match the local indexing and the global partition, update the corresponding saved state globally for local_pg in s["param_groups"]: for local_param_index in local_pg["params"]: # Update the state, if any if local_param_index in s["state"].keys(): - #global_id = self.param_to_index[local_index_to_param_id[local_param_index]] for k in s['state'][local_param_index]: new_entry = s['state'][local_param_index][k] sd0['state'][local_param_index][k].append(new_entry) else: + + # OSS does not raise in this case, maybe we shouldn't either raise KeyError(f'lost {local_param_index} from rank {rank}') # Concatenate everything @@ -1399,12 +1401,12 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: for param_id in pg0["params"]: # This attempts to undo the work of shard_parameters. # It might be assuming self.flatten_parameters=True - #stuff = sd0['state'][param_id] - #assert isinstance(sd0['state'][param_id], list), f'{param_id}, {stuff}' - #print(f'{stuff[0]}') for k,v in sd0['state'][param_id].items(): - sd0['state'][param_id][k] = torch.cat(v) - + def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v + v_unpad = [maybe_unpad(t, np) for t,np in zip(v, all_num_padded)] + flat_buffer = torch.cat(v_unpad) + flat_buffer = list(self.module.get_param_views(flat_buffer)) + sd0['state'][param_id][k] = flat_buffer # Make sure that the parameters are sorted in the state, as expected for a pytorch dict sd0["state"] = dict(sorted(sd0["state"].items())) From ed75c593eb27c8cc42f0a0b51800057a41f3ba95 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Mar 2021 23:57:20 -0400 Subject: [PATCH 04/31] update params list --- .../fully_sharded_data_parallel.py | 41 ++++++++++++++++--- tests/nn/data_parallel/test_fsdp.py | 37 +++++++++++------ 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index c2b196ad9..1b7533456 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -7,7 +7,6 @@ import copy from enum import Enum, auto import functools -from collections import defaultdict from math import inf from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union @@ -364,7 +363,7 @@ def _shard_parameters_(self) -> None: free_storage_(orig_data) def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: - """Return the local shard of a given full tensor.""" + """Return the local shard of a full tensor.""" # Shard using torch.chunk to match all-gather/reduce-scatter. chunks = list(torch.flatten(tensor).chunk(self.world_size)) while len(chunks) < self.world_size: @@ -1397,21 +1396,50 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: raise KeyError(f'lost {local_param_index} from rank {rank}') # Concatenate everything - for pg0 in sd0['param_groups']: + for pg_id, pg0 in enumerate(sd0['param_groups']): + n_params = 1 for param_id in pg0["params"]: + assert param_id == 0 + # This attempts to undo the work of shard_parameters. # It might be assuming self.flatten_parameters=True + for k,v in sd0['state'][param_id].items(): def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v v_unpad = [maybe_unpad(t, np) for t,np in zip(v, all_num_padded)] flat_buffer = torch.cat(v_unpad) - flat_buffer = list(self.module.get_param_views(flat_buffer)) - sd0['state'][param_id][k] = flat_buffer + flat_buffer = self.module.get_param_views(flat_buffer) + for i, entry in enumerate(flat_buffer): + if i not in sd0['state']: + sd0['state'][i] = {} + sd0['state'][i][k] = entry + n_params += i + sd0['param_groups'][pg_id]['params'] = list(range(n_params)) + # Make sure that the parameters are sorted in the state, as expected for a pytorch dict sd0["state"] = dict(sorted(sd0["state"].items())) - return sd0 + def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: + sd = full_optim_state_dict + if self.flatten_parameters: + sd = self.flatten_optim_state_dict(sd) + for id, s in sd['state'].items(): + for k,v in s: + tensor, _ = self._get_shard(v) + sd['state'][id][k] = tensor + + return sd + + + def flatten_optim_state_dict(self, sd) -> Dict: + return sd + + + + + + @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: @@ -1458,6 +1486,7 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: assert data.storage().size() == 0 data.storage().resize_(size.numel()) + def _post_state_dict_hook( module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any ) -> "OrderedDict[str, torch.Tensor]": diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 81f01efe4..35447b0b3 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -276,19 +276,30 @@ def _test_consolidated_optimizer(self, config, rank, group): fsdp_optim.step() fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) - if rank == 0: - assert fsdp._all_optimizer_states - torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') - fsdp_state_dict = fsdp.optim_state_dict() - torch.save(fsdp_state_dict, f'fsdp_consolidated_{fsdp.world_size}.pt') - unwrapped_model = TransformerWithSharedParams(group).cuda() - optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) - output = unwrapped_model(src_ids, tgt_ids) - loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) - unwrapped_model.run_backward(loss) - optim_unwrapped.step() - #assert objects_are_equal(fsdp_state_dict, optim_unwrapped.state_dict(), raise_exception=True) - #optim_unwrapped.load_state_dict(fsdp_state_dict) + if rank > 0 or fsdp.world_size == 1: + return + + unwrapped_model = TransformerWithSharedParams(group).cuda() + n_pars = len(list(unwrapped_model.parameters())) + assert fsdp._all_optimizer_states + torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') + sd = fsdp.optim_state_dict() + torch.save(sd, f'fsdp_consolidated_{fsdp.world_size}.pt') + + st = sd['state'] + assert len(sd['state']) == n_pars, f'{len(st)} != {n_pars}' + + assert torch.is_tensor(sd['state'][21]['momentum_buffer']) + def assert_equal(a,b): + assert a == b, f'{a} != {b}' + + assert_equal(len(sd['param_groups'][0]['params']), len(sd['state'])) + + optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) + output = unwrapped_model(src_ids, tgt_ids) + loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) + unwrapped_model.run_backward(loss) + optim_unwrapped.step() def test_delayed_optim_step(self): From 44158f785c2e4d05a2b62c6793b5d549707939a5 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 19 Mar 2021 00:35:36 -0400 Subject: [PATCH 05/31] simple case passing --- .../fully_sharded_data_parallel.py | 20 ++++++++++++++++++- tests/nn/data_parallel/test_fsdp.py | 4 ++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 1b7533456..4f490901f 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1424,8 +1424,10 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: sd = full_optim_state_dict if self.flatten_parameters: sd = self.flatten_optim_state_dict(sd) + assert len(sd['state']) == 1 + assert len(sd['param_groups'][0]['params']) == 1 for id, s in sd['state'].items(): - for k,v in s: + for k,v in s.items(): tensor, _ = self._get_shard(v) sd['state'][id][k] = tensor @@ -1433,6 +1435,22 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: def flatten_optim_state_dict(self, sd) -> Dict: + from collections import defaultdict + + flat_params = defaultdict(list) + + for _, buffers in sd['state'].items(): + for k, p in buffers.items(): + flat_params[k].append(p.reshape(-1)) + state = {0: {}} + for k,v in flat_params.items(): + state[0][k] = torch.cat(v) + + assert state[0][k].dim() == 1, state[0][k].dim() + sd['state'] = state + for pg_id, _ in enumerate(sd['param_groups']): + sd['param_groups'][pg_id]['params'] = list(range(1)) + return sd diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 35447b0b3..a2c1718bd 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -295,6 +295,10 @@ def assert_equal(a,b): assert_equal(len(sd['param_groups'][0]['params']), len(sd['state'])) + shard_sd = fsdp.get_shard_from_optim_state_dict(sd) + from fairscale.optim.utils import recursive_copy_to_device + assert objects_are_equal(shard_sd, recursive_copy_to_device(fsdp_optim.state_dict(), non_blocking=False, device='cpu')) + optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) output = unwrapped_model(src_ids, tgt_ids) loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) From f82f3b6594e20ffaad04dee145729c50ce8c8a83 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 19 Mar 2021 01:25:44 -0400 Subject: [PATCH 06/31] found other bug --- tests/nn/data_parallel/test_fsdp.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index a2c1718bd..b4de8e814 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -48,7 +48,7 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests - optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + optim = torch.optim.SGD(_make_optimizer_groups(model), lr=lr, momentum=0.9) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -210,6 +210,11 @@ def _reduce_scatter(output, input_list, **kwargs): CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] +def _make_optimizer_groups(model): + return [{'params': model.module.output_proj.parameters()}, + {'params': model.module.transformer.parameters(), 'lr': 1e-3}] + + def rename_test(testcase_func, param_num, param): return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) @@ -245,6 +250,7 @@ def test_cpu_offload_and_cpu_grads(self): ) spawn_and_init(test_fn) + def test_cpu_offload_and_cuda_grads_breaks(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} @@ -254,7 +260,8 @@ def test_cpu_offload_and_cuda_grads_breaks(self): ) spawn_and_init(test_fn) - def test_consolidate_optimizer(self): + @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + def test_consolidate_optimizer(self, optim_fn): config = {"mixed_precision": True} test_fn = functools.partial( self._test_consolidated_optimizer, config, @@ -262,11 +269,17 @@ def test_consolidate_optimizer(self): spawn_and_init(test_fn) @classmethod - def _test_consolidated_optimizer(self, config, rank, group): + def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, optim_fn=torch.optim.SGD): """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) - fsdp_optim = torch.optim.SGD(fsdp.parameters(), lr=0.01, momentum=0.9) + + + + if lr_groups: + fsdp_optim = optim_fn(_make_optimizer_groups(model), lr=0.01, momentum=0.9) + else: + fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01, momentum=0.9) fsdp_optim.zero_grad() src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) From 1022e1e056a7c7b9a8c35aa61ec6e1b0cdf4caee Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 19 Mar 2021 01:41:23 -0400 Subject: [PATCH 07/31] Broken tests for other optimizers --- tests/nn/data_parallel/test_fsdp.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index b4de8e814..86f8af45b 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -18,6 +18,7 @@ from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper +from fairscale.optim import AdaScale from fairscale.utils.testing import ( DeviceAndTypeCheckModule, DummyProcessGroup, @@ -48,7 +49,7 @@ def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests - optim = torch.optim.SGD(_make_optimizer_groups(model), lr=lr, momentum=0.9) + optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -210,11 +211,6 @@ def _reduce_scatter(output, input_list, **kwargs): CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] -def _make_optimizer_groups(model): - return [{'params': model.module.output_proj.parameters()}, - {'params': model.module.transformer.parameters(), 'lr': 1e-3}] - - def rename_test(testcase_func, param_num, param): return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),) @@ -260,11 +256,13 @@ def test_cpu_offload_and_cuda_grads_breaks(self): ) spawn_and_init(test_fn) - @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) + @parameterized.expand( + [[functools.partial(torch.optim.SGD, momentum=0.9)], [torch.optim.SGD], [torch.optim.Adam], [AdaScale]], + name_func=rename_test) def test_consolidate_optimizer(self, optim_fn): config = {"mixed_precision": True} test_fn = functools.partial( - self._test_consolidated_optimizer, config, + self._test_consolidated_optimizer, config, optim_fn=optim_fn ) spawn_and_init(test_fn) @@ -273,13 +271,7 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) - - - - if lr_groups: - fsdp_optim = optim_fn(_make_optimizer_groups(model), lr=0.01, momentum=0.9) - else: - fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01, momentum=0.9) + fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) fsdp_optim.zero_grad() src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) From 75119c25a4a651ab4a24773acdf01fcd1e3a53c2 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 19 Mar 2021 20:27:07 -0400 Subject: [PATCH 08/31] boom boom --- .../fully_sharded_data_parallel.py | 69 ++++++++++++++----- tests/nn/data_parallel/test_fsdp.py | 48 ++++++++----- 2 files changed, 79 insertions(+), 38 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 4f490901f..bfb15893c 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -86,8 +86,8 @@ class FullyShardedDataParallel(nn.Module): import torch from fairscale.nn.auto_wrap import enable_wrap, auto_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP - fsdp_params = dict(mixed_precision=True, flatten_parameters=True) - with enable_wrap(wrapper_cls=FSDP, **fsdp_params): + fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) + with enable_wrap(**fsdp_params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) assert isinstance(self.l1, FSDP) @@ -1332,7 +1332,7 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: print(f"State from rank {rank} received: {self._all_optimizer_states[-1]}") - def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: + def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. @@ -1366,22 +1366,25 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: # Consolidate the state on every rank self.consolidate_state_dict(recipient_rank=-1) - # Unify the shard states and the state that pytorch would expect, given the model. - + # Unify the shard states by concatenating tensors and otherwise assuming rank zero is correct. sd0 = self._all_optimizer_states[0] assert 'num_padded' in sd0 all_num_padded = [s.pop('num_padded')[0] for s in self._all_optimizer_states] assert all_num_padded[0] == 0, f'this code assumes rank 0 param not padded {all_num_padded[0]}' - if self.world_size == 1: return sd0 # - go through the per-shard states assert len(sd0['param_groups']) == 1, 'not yet supported' - #new = sd0.copy() + + if len(sd0['state']) == 0: + # This is a stateless optimizer, like vanilla SGD. + sd0['param_groups'][0]['params'] = [0] + return sd0 + + for pg0 in sd0['param_groups']: for param_id in pg0["params"]: - # BUG if rank 0's param is padded sd0['state'][param_id] = {k: [v] for k,v in sd0['state'][param_id].items()} # so we can append - + other_states = self._all_optimizer_states[1:] if self.world_size > 1 else [] for rank, s in enumerate(self._all_optimizer_states[1:]): for local_pg in s["param_groups"]: for local_param_index in local_pg["params"]: @@ -1397,14 +1400,19 @@ def optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: # Concatenate everything for pg_id, pg0 in enumerate(sd0['param_groups']): - n_params = 1 + n_params = 0 for param_id in pg0["params"]: assert param_id == 0 # This attempts to undo the work of shard_parameters. # It might be assuming self.flatten_parameters=True + constant_state = self.extract_constant_state(sd0, param_id) for k,v in sd0['state'][param_id].items(): + assert isinstance(v, list), f'expected list, got {v}' + if k in constant_state: + continue + def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v v_unpad = [maybe_unpad(t, np) for t,np in zip(v, all_num_padded)] flat_buffer = torch.cat(v_unpad) @@ -1413,23 +1421,42 @@ def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v if i not in sd0['state']: sd0['state'][i] = {} sd0['state'][i][k] = entry - n_params += i - sd0['param_groups'][pg_id]['params'] = list(range(n_params)) + sd0['state'][i].update(constant_state.copy()) + n_params = max(i, n_params) + sd0['param_groups'][pg_id]['params'] = list(range(n_params+1)) # Make sure that the parameters are sorted in the state, as expected for a pytorch dict sd0["state"] = dict(sorted(sd0["state"].items())) return sd0 + def extract_constant_state(self, sd0, param_id): + constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. + for k, v in sd0['state'][param_id].items(): + if torch.is_tensor(v[0]): + continue + elif len(set(v)) == 1: + constant_state[k] = v[0] + else: + raise ValueError(f'Dont know how to expand optimizer param {k} with value {v}') + return constant_state + def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: sd = full_optim_state_dict + + if self.flatten_parameters: sd = self.flatten_optim_state_dict(sd) assert len(sd['state']) == 1 assert len(sd['param_groups'][0]['params']) == 1 + + # get the portion of dict associated with the shard for id, s in sd['state'].items(): for k,v in s.items(): - tensor, _ = self._get_shard(v) - sd['state'][id][k] = tensor + if torch.is_tensor(v): + v_shard, _ = self._get_shard(v) + else: + v_shard = v # dont partition entries that are not tensors + sd['state'][id][k] = v_shard return sd @@ -1438,14 +1465,18 @@ def flatten_optim_state_dict(self, sd) -> Dict: from collections import defaultdict flat_params = defaultdict(list) - + constant_state = {}# self.extract_constant_state(sd, 0) for _, buffers in sd['state'].items(): for k, p in buffers.items(): - flat_params[k].append(p.reshape(-1)) - state = {0: {}} - for k,v in flat_params.items(): - state[0][k] = torch.cat(v) + if torch.is_tensor(p): + flat_params[k].append(p.reshape(-1)) + else: + assert isinstance(p, int) + constant_state[k] = p # THIS COULD BE WAY WRONG. What if step is different for different params... At least check. + state = {0: constant_state} + for k, v in flat_params.items(): + state[0][k] = torch.cat(v) assert state[0][k].dim() == 1, state[0][k].dim() sd['state'] = state for pg_id, _ in enumerate(sd['param_groups']): diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 86f8af45b..b45993762 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -28,11 +28,16 @@ spawn_for_all_world_sizes, torch_version, ) +from fairscale.optim.utils import recursive_copy_to_device +from fairseq.optim.cpu_adam import CPUAdam # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # All helper functions called by spawn must be either @classmethod, @staticmethod +def assert_equal(a, b): + assert a == b, f'{a} != {b}' + class DistributedTest(unittest.TestCase): def setUp(self): if torch_version() < (1, 6, 0): @@ -257,7 +262,8 @@ def test_cpu_offload_and_cuda_grads_breaks(self): spawn_and_init(test_fn) @parameterized.expand( - [[functools.partial(torch.optim.SGD, momentum=0.9)], [torch.optim.SGD], [torch.optim.Adam], [AdaScale]], + [[functools.partial(torch.optim.SGD, momentum=0.9)], + [torch.optim.SGD], [torch.optim.Adam], [CPUAdam]], name_func=rename_test) def test_consolidate_optimizer(self, optim_fn): config = {"mixed_precision": True} @@ -271,8 +277,17 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) - fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) + unwrapped_model = TransformerWithSharedParams(group).cuda() + try: + fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) + optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) + except TypeError: # AdaScale + fsdp_optim = optim_fn(fsdp.parameters()) + optim_unwrapped = optim_fn(unwrapped_model.parameters()) + + fsdp_optim.zero_grad() + optim_unwrapped.zero_grad() src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) output = fsdp(src_ids, tgt_ids) @@ -284,31 +299,26 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op if rank > 0 or fsdp.world_size == 1: return - unwrapped_model = TransformerWithSharedParams(group).cuda() + + + output = unwrapped_model(src_ids, tgt_ids) + loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) + unwrapped_model.run_backward(loss) + optim_unwrapped.step() + unwrapped_sd = optim_unwrapped.state_dict() + n_pars = len(list(unwrapped_model.parameters())) - assert fsdp._all_optimizer_states + assert len(fsdp._all_optimizer_states) == fsdp.world_size torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') - sd = fsdp.optim_state_dict() + sd = fsdp.gather_full_optim_state_dict() torch.save(sd, f'fsdp_consolidated_{fsdp.world_size}.pt') - st = sd['state'] - assert len(sd['state']) == n_pars, f'{len(st)} != {n_pars}' - - assert torch.is_tensor(sd['state'][21]['momentum_buffer']) - def assert_equal(a,b): - assert a == b, f'{a} != {b}' - - assert_equal(len(sd['param_groups'][0]['params']), len(sd['state'])) + assert_equal(len(sd['state']), len(unwrapped_sd['state'])) + assert_equal(len(sd['param_groups'][0]['params']), len(unwrapped_sd['param_groups'][0]['params'])) shard_sd = fsdp.get_shard_from_optim_state_dict(sd) - from fairscale.optim.utils import recursive_copy_to_device assert objects_are_equal(shard_sd, recursive_copy_to_device(fsdp_optim.state_dict(), non_blocking=False, device='cpu')) - optim_unwrapped = torch.optim.SGD(unwrapped_model.parameters(), lr=0.01, momentum=0.9) - output = unwrapped_model(src_ids, tgt_ids) - loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) - unwrapped_model.run_backward(loss) - optim_unwrapped.step() def test_delayed_optim_step(self): From 8dcf0a8db6944216aa7a7693f87b899a576c1d23 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 19 Mar 2021 20:41:00 -0400 Subject: [PATCH 09/31] remove oss changes --- .../fully_sharded_data_parallel.py | 99 +++++++++---------- fairscale/optim/oss.py | 4 +- tests/nn/data_parallel/test_fsdp.py | 44 ++++----- tests/optim/test_oss.py | 3 - 4 files changed, 68 insertions(+), 82 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index e722880d7..8c6cff09e 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1291,24 +1291,25 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: # Sync lr and other attributes in case its been updated from fairscale.optim import OSS - from fairscale.optim.utils import recursive_copy_to_device, broadcast_object - _default_device = torch.device('cuda') + from fairscale.optim.utils import broadcast_object, recursive_copy_to_device + + _default_device = torch.device("cuda") # OSS._sync_param_groups(self.param_groups, optim.param_groups) # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank - #print("Pulling the sharded optimizer state from all replicas") + # print("Pulling the sharded optimizer state from all replicas") self._all_optimizer_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 - print(f'rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}') + print(f"rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}") for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() - sd['num_padded'] = self.num_padded # Communicate between ranks + sd["num_padded"] = self.num_padded # Communicate between ranks if should_collect_state: print(f"{rank} Saving self state keys {list(sd.keys())}") self._all_optimizer_states.append( @@ -1317,9 +1318,7 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: # Sync with other replicas state_to_share = ( - sd - if should_send_state - else torch.tensor([0], dtype=torch.uint8, device=_default_device) + sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device) ) broadcast_object( state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device, @@ -1362,7 +1361,7 @@ def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any depending on when `consolidate_state_dict` was last called. """ if not self.flatten_parameters: - raise NotImplementedError('optim state dict requires flatten_parameters=True') + raise NotImplementedError("optim state dict requires flatten_parameters=True") if not all_ranks and len(self._all_optimizer_states) == 0: raise RuntimeError( @@ -1376,38 +1375,37 @@ def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any # Unify the shard states by concatenating tensors and otherwise assuming rank zero is correct. sd0 = self._all_optimizer_states[0] - assert 'num_padded' in sd0 - all_num_padded = [s.pop('num_padded')[0] for s in self._all_optimizer_states] - assert all_num_padded[0] == 0, f'this code assumes rank 0 param not padded {all_num_padded[0]}' + assert "num_padded" in sd0 + all_num_padded = [s.pop("num_padded")[0] for s in self._all_optimizer_states] + assert all_num_padded[0] == 0, f"this code assumes rank 0 param not padded {all_num_padded[0]}" # - go through the per-shard states - assert len(sd0['param_groups']) == 1, 'not yet supported' + assert len(sd0["param_groups"]) == 1, "not yet supported" - if len(sd0['state']) == 0: + if len(sd0["state"]) == 0: # This is a stateless optimizer, like vanilla SGD. - sd0['param_groups'][0]['params'] = [0] + sd0["param_groups"][0]["params"] = [0] return sd0 - - for pg0 in sd0['param_groups']: + for pg0 in sd0["param_groups"]: for param_id in pg0["params"]: - sd0['state'][param_id] = {k: [v] for k,v in sd0['state'][param_id].items()} # so we can append + sd0["state"][param_id] = {k: [v] for k, v in sd0["state"][param_id].items()} # so we can append other_states = self._all_optimizer_states[1:] if self.world_size > 1 else [] for rank, s in enumerate(self._all_optimizer_states[1:]): for local_pg in s["param_groups"]: for local_param_index in local_pg["params"]: # Update the state, if any if local_param_index in s["state"].keys(): - for k in s['state'][local_param_index]: - new_entry = s['state'][local_param_index][k] - sd0['state'][local_param_index][k].append(new_entry) + for k in s["state"][local_param_index]: + new_entry = s["state"][local_param_index][k] + sd0["state"][local_param_index][k].append(new_entry) else: # OSS does not raise in this case, maybe we shouldn't either - raise KeyError(f'lost {local_param_index} from rank {rank}') + raise KeyError(f"lost {local_param_index} from rank {rank}") # Concatenate everything - for pg_id, pg0 in enumerate(sd0['param_groups']): + for pg_id, pg0 in enumerate(sd0["param_groups"]): n_params = 0 for param_id in pg0["params"]: assert param_id == 0 @@ -1416,22 +1414,24 @@ def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any # It might be assuming self.flatten_parameters=True constant_state = self.extract_constant_state(sd0, param_id) - for k,v in sd0['state'][param_id].items(): - assert isinstance(v, list), f'expected list, got {v}' + for k, v in sd0["state"][param_id].items(): + assert isinstance(v, list), f"expected list, got {v}" if k in constant_state: continue - def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v - v_unpad = [maybe_unpad(t, np) for t,np in zip(v, all_num_padded)] + def maybe_unpad(v, num_pad): + return v[:-num_pad] if num_pad > 0 else v + + v_unpad = [maybe_unpad(t, np) for t, np in zip(v, all_num_padded)] flat_buffer = torch.cat(v_unpad) flat_buffer = self.module.get_param_views(flat_buffer) for i, entry in enumerate(flat_buffer): - if i not in sd0['state']: - sd0['state'][i] = {} - sd0['state'][i][k] = entry - sd0['state'][i].update(constant_state.copy()) + if i not in sd0["state"]: + sd0["state"][i] = {} + sd0["state"][i][k] = entry + sd0["state"][i].update(constant_state.copy()) n_params = max(i, n_params) - sd0['param_groups'][pg_id]['params'] = list(range(n_params+1)) + sd0["param_groups"][pg_id]["params"] = list(range(n_params + 1)) # Make sure that the parameters are sorted in the state, as expected for a pytorch dict sd0["state"] = dict(sorted(sd0["state"].items())) @@ -1439,65 +1439,60 @@ def maybe_unpad(v, num_pad): return v[:-num_pad] if num_pad > 0 else v def extract_constant_state(self, sd0, param_id): constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. - for k, v in sd0['state'][param_id].items(): + for k, v in sd0["state"][param_id].items(): if torch.is_tensor(v[0]): continue elif len(set(v)) == 1: constant_state[k] = v[0] else: - raise ValueError(f'Dont know how to expand optimizer param {k} with value {v}') + raise ValueError(f"Dont know how to expand optimizer param {k} with value {v}") return constant_state def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: sd = full_optim_state_dict - if self.flatten_parameters: sd = self.flatten_optim_state_dict(sd) - assert len(sd['state']) == 1 - assert len(sd['param_groups'][0]['params']) == 1 + assert len(sd["state"]) == 1 + assert len(sd["param_groups"][0]["params"]) == 1 # get the portion of dict associated with the shard - for id, s in sd['state'].items(): - for k,v in s.items(): + for id, s in sd["state"].items(): + for k, v in s.items(): if torch.is_tensor(v): v_shard, _ = self._get_shard(v) else: v_shard = v # dont partition entries that are not tensors - sd['state'][id][k] = v_shard + sd["state"][id][k] = v_shard return sd - def flatten_optim_state_dict(self, sd) -> Dict: from collections import defaultdict flat_params = defaultdict(list) - constant_state = {}# self.extract_constant_state(sd, 0) - for _, buffers in sd['state'].items(): + constant_state = {} # self.extract_constant_state(sd, 0) + for _, buffers in sd["state"].items(): for k, p in buffers.items(): if torch.is_tensor(p): flat_params[k].append(p.reshape(-1)) else: assert isinstance(p, int) - constant_state[k] = p # THIS COULD BE WAY WRONG. What if step is different for different params... At least check. + constant_state[ + k + ] = p # THIS COULD BE WAY WRONG. What if step is different for different params... At least check. state = {0: constant_state} for k, v in flat_params.items(): state[0][k] = torch.cat(v) assert state[0][k].dim() == 1, state[0][k].dim() - sd['state'] = state - for pg_id, _ in enumerate(sd['param_groups']): - sd['param_groups'][pg_id]['params'] = list(range(1)) + sd["state"] = state + for pg_id, _ in enumerate(sd["param_groups"]): + sd["param_groups"][pg_id]["params"] = list(range(1)) return sd - - - - - @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 6dd81eaf0..9912660b6 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -326,7 +326,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: self._all_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 - should_send_state = (self.rank != recipient_rank) or recipient_rank == -1 + should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 for rank in range(self.world_size): if rank == self.rank: @@ -359,7 +359,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) ) - print(f"State from rank {rank} received: ") + logging.debug("State from rank %s received", rank) def local_state_dict(self) -> dict: """ .. deprecated:: 0.1.5 diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index b45993762..1dcc6fbb5 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -12,6 +12,7 @@ import unittest from unittest import mock +from fairseq.optim.cpu_adam import CPUAdam from parameterized import parameterized import torch from torch import nn @@ -19,6 +20,7 @@ from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.optim import AdaScale +from fairscale.optim.utils import recursive_copy_to_device from fairscale.utils.testing import ( DeviceAndTypeCheckModule, DummyProcessGroup, @@ -28,15 +30,14 @@ spawn_for_all_world_sizes, torch_version, ) -from fairscale.optim.utils import recursive_copy_to_device -from fairseq.optim.cpu_adam import CPUAdam # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # All helper functions called by spawn must be either @classmethod, @staticmethod def assert_equal(a, b): - assert a == b, f'{a} != {b}' + assert a == b, f"{a} != {b}" + class DistributedTest(unittest.TestCase): def setUp(self): @@ -251,7 +252,6 @@ def test_cpu_offload_and_cpu_grads(self): ) spawn_and_init(test_fn) - def test_cpu_offload_and_cuda_grads_breaks(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} @@ -262,18 +262,16 @@ def test_cpu_offload_and_cuda_grads_breaks(self): spawn_and_init(test_fn) @parameterized.expand( - [[functools.partial(torch.optim.SGD, momentum=0.9)], - [torch.optim.SGD], [torch.optim.Adam], [CPUAdam]], - name_func=rename_test) + [[functools.partial(torch.optim.SGD, momentum=0.9)], [torch.optim.SGD], [torch.optim.Adam], [CPUAdam]], + name_func=rename_test, + ) def test_consolidate_optimizer(self, optim_fn): config = {"mixed_precision": True} - test_fn = functools.partial( - self._test_consolidated_optimizer, config, optim_fn=optim_fn - ) + test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=optim_fn) spawn_and_init(test_fn) @classmethod - def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, optim_fn=torch.optim.SGD): + def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, optim_fn=torch.optim.SGD): """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) @@ -281,17 +279,16 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op try: fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) - except TypeError: # AdaScale + except TypeError: # AdaScale fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) - fsdp_optim.zero_grad() optim_unwrapped.zero_grad() src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) output = fsdp(src_ids, tgt_ids) - loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to('cuda') + loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) @@ -299,8 +296,6 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op if rank > 0 or fsdp.world_size == 1: return - - output = unwrapped_model(src_ids, tgt_ids) loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) unwrapped_model.run_backward(loss) @@ -309,17 +304,17 @@ def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, op n_pars = len(list(unwrapped_model.parameters())) assert len(fsdp._all_optimizer_states) == fsdp.world_size - torch.save(fsdp._all_optimizer_states, f'all_optim_states_world_size_{fsdp.world_size}.pt') + torch.save(fsdp._all_optimizer_states, f"all_optim_states_world_size_{fsdp.world_size}.pt") sd = fsdp.gather_full_optim_state_dict() - torch.save(sd, f'fsdp_consolidated_{fsdp.world_size}.pt') + torch.save(sd, f"fsdp_consolidated_{fsdp.world_size}.pt") - assert_equal(len(sd['state']), len(unwrapped_sd['state'])) - assert_equal(len(sd['param_groups'][0]['params']), len(unwrapped_sd['param_groups'][0]['params'])) + assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) + assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) shard_sd = fsdp.get_shard_from_optim_state_dict(sd) - assert objects_are_equal(shard_sd, recursive_copy_to_device(fsdp_optim.state_dict(), non_blocking=False, device='cpu')) - - + assert objects_are_equal( + shard_sd, recursive_copy_to_device(fsdp_optim.state_dict(), non_blocking=False, device="cpu") + ) def test_delayed_optim_step(self): # We use a model with a long CUDA delay right before the optimizer step. @@ -454,12 +449,11 @@ def _test_param_change_after_init(self, rank, group, config): assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" - def test_named_params_ordering(self): """Test assumption of consolidate_optimizer_state_dict""" group = DummyProcessGroup(0, 1) model = TransformerWithSharedParams(group) - named_pars = [p for n,p in model.named_parameters()] + named_pars = [p for n, p in model.named_parameters()] for i, p in enumerate(model.parameters()): assert p.shape == named_pars[i].shape diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index 089e7bdb1..87ff6056f 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -442,9 +442,6 @@ def closure(): # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(recipient_rank=reference_rank) - # if rank == reference_rank: - # #types = [type(x) for x in optimizer._all_states] - # #assert all(isinstance(s, dict) for s in optimizer._all_states), types # Fetch the state on the reference rank # - check that it has the correct size From 2caf9287f2f0a164b317fcda3ed480ff9f1fa918 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 20 Mar 2021 14:42:55 -0400 Subject: [PATCH 10/31] passing besides mypy --- .../fully_sharded_data_parallel.py | 119 ++++++++---------- tests/nn/data_parallel/test_fsdp.py | 58 --------- .../test_fsdp_optimizer_utils.py | 72 +++++++++++ 3 files changed, 122 insertions(+), 127 deletions(-) create mode 100644 tests/nn/data_parallel/test_fsdp_optimizer_utils.py diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 8c6cff09e..787c617fc 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict import contextlib import copy from enum import Enum, auto @@ -21,7 +22,7 @@ from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap -from fairscale.optim.utils import calc_grad_norm +from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer @@ -179,7 +180,6 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb - self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") @@ -1280,7 +1280,9 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None raise ValueError(msg) # Optim State dict interfaces - def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: + def consolidate_optim_state_dict( + self, optim: torch.optim.Optimizer, recipient_rank: int = 0 + ) -> List[Dict[str, Any]]: """Update the consolidated state_dict list, one per rank. Arguments: @@ -1290,29 +1292,24 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: .. warning: This needs to be called on all replicas""" # Sync lr and other attributes in case its been updated - from fairscale.optim import OSS - from fairscale.optim.utils import broadcast_object, recursive_copy_to_device _default_device = torch.device("cuda") - + # NOTE(SS): we do not support param groups yet, as these seem to break FSDP # OSS._sync_param_groups(self.param_groups, optim.param_groups) # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank - # print("Pulling the sharded optimizer state from all replicas") - - self._all_optimizer_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 print(f"rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}") - + _all_optimizer_states: List[Dict[str, Any]] = [] for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() sd["num_padded"] = self.num_padded # Communicate between ranks if should_collect_state: print(f"{rank} Saving self state keys {list(sd.keys())}") - self._all_optimizer_states.append( + _all_optimizer_states.append( recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) ) @@ -1333,20 +1330,20 @@ def consolidate_optim_state_dict(self, optim, recipient_rank: int = 0) -> None: ) if should_collect_state: - self._all_optimizer_states.append( + _all_optimizer_states.append( recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) ) - print(f"State from rank {rank} received: {self._all_optimizer_states[-1]}") + print(f"State from rank {rank} received: {_all_optimizer_states[-1]}") + return _all_optimizer_states - def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any]: + def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> Dict[str, Any]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Arguments: all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on - all ranks Returns: a dict with two entries @@ -1363,35 +1360,22 @@ def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any if not self.flatten_parameters: raise NotImplementedError("optim state dict requires flatten_parameters=True") - if not all_ranks and len(self._all_optimizer_states) == 0: - raise RuntimeError( - "Optimizer state has not been consolidated on this rank. \ - Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state" - ) - - if all_ranks: - # Consolidate the state on every rank - self.consolidate_state_dict(recipient_rank=-1) + _all_optimizer_states: List[Dict] = self.consolidate_optim_state_dict(optim, recipient_rank=recipient_rank) - # Unify the shard states by concatenating tensors and otherwise assuming rank zero is correct. - sd0 = self._all_optimizer_states[0] + # Unify the shard states by concatenating tensors and unflattening params + sd0 = _all_optimizer_states[0] assert "num_padded" in sd0 - all_num_padded = [s.pop("num_padded")[0] for s in self._all_optimizer_states] + all_num_padded = [s.pop("num_padded")[0] for s in _all_optimizer_states] assert all_num_padded[0] == 0, f"this code assumes rank 0 param not padded {all_num_padded[0]}" - - # - go through the per-shard states assert len(sd0["param_groups"]) == 1, "not yet supported" - if len(sd0["state"]) == 0: - # This is a stateless optimizer, like vanilla SGD. - sd0["param_groups"][0]["params"] = [0] - return sd0 - for pg0 in sd0["param_groups"]: for param_id in pg0["params"]: - sd0["state"][param_id] = {k: [v] for k, v in sd0["state"][param_id].items()} # so we can append - other_states = self._all_optimizer_states[1:] if self.world_size > 1 else [] - for rank, s in enumerate(self._all_optimizer_states[1:]): + if param_id in sd0["state"]: + sd0["state"][param_id] = {k: [v] for k, v in sd0["state"][param_id].items()} # so we can append + + other_states = _all_optimizer_states[1:] if self.world_size > 1 else [] + for rank, s in enumerate(other_states): for local_pg in s["param_groups"]: for local_param_index in local_pg["params"]: # Update the state, if any @@ -1399,45 +1383,43 @@ def gather_full_optim_state_dict(self, all_ranks: bool = False) -> Dict[str, Any for k in s["state"][local_param_index]: new_entry = s["state"][local_param_index][k] sd0["state"][local_param_index][k].append(new_entry) - else: - - # OSS does not raise in this case, maybe we shouldn't either - raise KeyError(f"lost {local_param_index} from rank {rank}") # Concatenate everything for pg_id, pg0 in enumerate(sd0["param_groups"]): n_params = 0 for param_id in pg0["params"]: + if param_id not in sd0["state"]: + continue assert param_id == 0 - # This attempts to undo the work of shard_parameters. - # It might be assuming self.flatten_parameters=True - constant_state = self.extract_constant_state(sd0, param_id) - + # It assumes self.flatten_parameters=True + constant_state = self._extract_constant_state(sd0, param_id) for k, v in sd0["state"][param_id].items(): assert isinstance(v, list), f"expected list, got {v}" if k in constant_state: continue - def maybe_unpad(v, num_pad): - return v[:-num_pad] if num_pad > 0 else v - - v_unpad = [maybe_unpad(t, np) for t, np in zip(v, all_num_padded)] + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, all_num_padded)] flat_buffer = torch.cat(v_unpad) flat_buffer = self.module.get_param_views(flat_buffer) for i, entry in enumerate(flat_buffer): if i not in sd0["state"]: sd0["state"][i] = {} sd0["state"][i][k] = entry - sd0["state"][i].update(constant_state.copy()) + sd0["state"][i].update(constant_state) n_params = max(i, n_params) - sd0["param_groups"][pg_id]["params"] = list(range(n_params + 1)) + + if n_params > 0: + sd0["param_groups"][pg_id]["params"] = list(range(n_params + 1)) + else: + sd0["param_groups"][pg_id]["params"] = list(range(len(self.module._param_infos))) # Make sure that the parameters are sorted in the state, as expected for a pytorch dict sd0["state"] = dict(sorted(sd0["state"].items())) return sd0 - def extract_constant_state(self, sd0, param_id): + @staticmethod + def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. for k, v in sd0["state"][param_id].items(): if torch.is_tensor(v[0]): @@ -1445,31 +1427,28 @@ def extract_constant_state(self, sd0, param_id): elif len(set(v)) == 1: constant_state[k] = v[0] else: - raise ValueError(f"Dont know how to expand optimizer param {k} with value {v}") + raise TypeError(f"Dont know how to expand optimizer param {k} with value {v}") return constant_state - def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict: - sd = full_optim_state_dict - + def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict[str, Any]: if self.flatten_parameters: - sd = self.flatten_optim_state_dict(sd) - assert len(sd["state"]) == 1 - assert len(sd["param_groups"][0]["params"]) == 1 + full_optim_state_dict = self.flatten_optim_state_dict(full_optim_state_dict) + assert len(full_optim_state_dict["state"]) <= 1 + assert len(full_optim_state_dict["param_groups"][0]["params"]) == 1 # get the portion of dict associated with the shard - for id, s in sd["state"].items(): + for id, s in full_optim_state_dict["state"].items(): for k, v in s.items(): if torch.is_tensor(v): v_shard, _ = self._get_shard(v) else: v_shard = v # dont partition entries that are not tensors - sd["state"][id][k] = v_shard - - return sd + full_optim_state_dict["state"][id][k] = v_shard - def flatten_optim_state_dict(self, sd) -> Dict: - from collections import defaultdict + return full_optim_state_dict + @staticmethod + def flatten_optim_state_dict(sd: Dict) -> Dict: flat_params = defaultdict(list) constant_state = {} # self.extract_constant_state(sd, 0) for _, buffers in sd["state"].items(): @@ -1478,15 +1457,17 @@ def flatten_optim_state_dict(self, sd) -> Dict: flat_params[k].append(p.reshape(-1)) else: assert isinstance(p, int) - constant_state[ - k - ] = p # THIS COULD BE WAY WRONG. What if step is different for different params... At least check. + constant_state[k] = p + # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check state = {0: constant_state} for k, v in flat_params.items(): state[0][k] = torch.cat(v) assert state[0][k].dim() == 1, state[0][k].dim() - sd["state"] = state + + # Do not put empty state for stateless optimizers + sd["state"] = state if state != {0: {}} else {} + for pg_id, _ in enumerate(sd["param_groups"]): sd["param_groups"][pg_id]["params"] = list(range(1)) diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 1dcc6fbb5..564a062d5 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -12,15 +12,12 @@ import unittest from unittest import mock -from fairseq.optim.cpu_adam import CPUAdam from parameterized import parameterized import torch from torch import nn from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper -from fairscale.optim import AdaScale -from fairscale.optim.utils import recursive_copy_to_device from fairscale.utils.testing import ( DeviceAndTypeCheckModule, DummyProcessGroup, @@ -261,61 +258,6 @@ def test_cpu_offload_and_cuda_grads_breaks(self): ) spawn_and_init(test_fn) - @parameterized.expand( - [[functools.partial(torch.optim.SGD, momentum=0.9)], [torch.optim.SGD], [torch.optim.Adam], [CPUAdam]], - name_func=rename_test, - ) - def test_consolidate_optimizer(self, optim_fn): - config = {"mixed_precision": True} - test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=optim_fn) - spawn_and_init(test_fn) - - @classmethod - def _test_consolidated_optimizer(self, config, rank, group, lr_groups=False, optim_fn=torch.optim.SGD): - """FSDP.optim_state_dict() should return something very similar to optimizer.state_dict()""" - # Establish reference behavior. - fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) - unwrapped_model = TransformerWithSharedParams(group).cuda() - try: - fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) - optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) - except TypeError: # AdaScale - fsdp_optim = optim_fn(fsdp.parameters()) - optim_unwrapped = optim_fn(unwrapped_model.parameters()) - - fsdp_optim.zero_grad() - optim_unwrapped.zero_grad() - - src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) - output = fsdp(src_ids, tgt_ids) - loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to("cuda") - fsdp.module.run_backward(loss) - fsdp_optim.step() - fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) - - if rank > 0 or fsdp.world_size == 1: - return - - output = unwrapped_model(src_ids, tgt_ids) - loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) - unwrapped_model.run_backward(loss) - optim_unwrapped.step() - unwrapped_sd = optim_unwrapped.state_dict() - - n_pars = len(list(unwrapped_model.parameters())) - assert len(fsdp._all_optimizer_states) == fsdp.world_size - torch.save(fsdp._all_optimizer_states, f"all_optim_states_world_size_{fsdp.world_size}.pt") - sd = fsdp.gather_full_optim_state_dict() - torch.save(sd, f"fsdp_consolidated_{fsdp.world_size}.pt") - - assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) - assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) - - shard_sd = fsdp.get_shard_from_optim_state_dict(sd) - assert objects_are_equal( - shard_sd, recursive_copy_to_device(fsdp_optim.state_dict(), non_blocking=False, device="cpu") - ) - def test_delayed_optim_step(self): # We use a model with a long CUDA delay right before the optimizer step. # This tests our streams logic, and that we don't start the FP32 -> FP16 diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py new file mode 100644 index 000000000..ee5722eef --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -0,0 +1,72 @@ +import functools + +from parameterized import parameterized +import torch + +from fairscale.optim.utils import recursive_copy_to_device +from fairscale.utils.testing import objects_are_equal + +from .test_fsdp import DistributedTest, TransformerWithSharedParams, assert_equal, rename_test, spawn_and_init + + +class TestOptimizerUtils(DistributedTest): + @parameterized.expand( + [ + [functools.partial(torch.optim.SGD, momentum=0.9)], + [torch.optim.SGD], + [torch.optim.Adam], + [torch.optim.Adadelta], + ], + name_func=rename_test, + ) + def test_consolidate_optimizer(self, optim_fn): + config = {"mixed_precision": True} + test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=optim_fn) + spawn_and_init(test_fn) + + @classmethod + def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD): + """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" + # Establish reference behavior. + fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) + unwrapped_model = TransformerWithSharedParams(group).cuda() + try: + fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) + optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) + except TypeError: # AdaScale + fsdp_optim = optim_fn(fsdp.parameters()) + optim_unwrapped = optim_fn(unwrapped_model.parameters()) + + fsdp_optim.zero_grad() + optim_unwrapped.zero_grad() + + src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) + output = fsdp(src_ids, tgt_ids) + loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to("cuda") + fsdp.module.run_backward(loss) + fsdp_optim.step() + # fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) + + output = unwrapped_model(src_ids, tgt_ids) + loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) + unwrapped_model.run_backward(loss) + optim_unwrapped.step() + unwrapped_sd = optim_unwrapped.state_dict() + + n_pars = len(list(unwrapped_model.parameters())) + + # torch.save(fsdp._all_optimizer_states, f"all_optim_states_world_size_{fsdp.world_size}.pt") + sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=-1) + # assert_equal(len(fsdp._all_optimizer_states), fsdp.world_size) + torch.save(sd, f"fsdp_consolidated_{fsdp.world_size}.pt") + + assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) + assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) + + shard_sd = fsdp.get_shard_from_optim_state_dict(sd) + + original_shard_sd = fsdp_optim.state_dict() + assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) + assert objects_are_equal( + shard_sd, recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") + ) From 0b888fd9f80e6f8a315df38fc01428ef24a02a7a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 20 Mar 2021 14:44:35 -0400 Subject: [PATCH 11/31] Smaller delta --- tests/nn/data_parallel/test_fsdp.py | 14 +------------- .../nn/data_parallel/test_fsdp_optimizer_utils.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 564a062d5..9813d1142 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -32,10 +32,6 @@ # All helper functions called by spawn must be either @classmethod, @staticmethod -def assert_equal(a, b): - assert a == b, f"{a} != {b}" - - class DistributedTest(unittest.TestCase): def setUp(self): if torch_version() < (1, 6, 0): @@ -48,7 +44,7 @@ def setUp(self): raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") @staticmethod - def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None, save_optim=False): + def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests @@ -391,14 +387,6 @@ def _test_param_change_after_init(self, rank, group, config): assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" - def test_named_params_ordering(self): - """Test assumption of consolidate_optimizer_state_dict""" - group = DummyProcessGroup(0, 1) - model = TransformerWithSharedParams(group) - named_pars = [p for n, p in model.named_parameters()] - for i, p in enumerate(model.parameters()): - assert p.shape == named_pars[i].shape - class TestSerialization(DistributedTest): @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index ee5722eef..81feff031 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -6,7 +6,11 @@ from fairscale.optim.utils import recursive_copy_to_device from fairscale.utils.testing import objects_are_equal -from .test_fsdp import DistributedTest, TransformerWithSharedParams, assert_equal, rename_test, spawn_and_init +from .test_fsdp import DistributedTest, DummyProcessGroup, TransformerWithSharedParams, rename_test, spawn_and_init + + +def assert_equal(a, b): + assert a == b, f"{a} != {b}" class TestOptimizerUtils(DistributedTest): @@ -62,7 +66,6 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) - shard_sd = fsdp.get_shard_from_optim_state_dict(sd) original_shard_sd = fsdp_optim.state_dict() @@ -70,3 +73,11 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim assert objects_are_equal( shard_sd, recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") ) + + def test_named_params_ordering(self): + """Test assumption of consolidate_optimizer_state_dict""" + group = DummyProcessGroup(0, 1) + model = TransformerWithSharedParams(group) + named_pars = [p for n, p in model.named_parameters()] + for i, p in enumerate(model.parameters()): + assert p.shape == named_pars[i].shape From a2aacd064bf3509ba51e26f4a1884a55c050c67b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 21 Mar 2021 15:31:03 -0400 Subject: [PATCH 12/31] Nesting works --- .../fully_sharded_data_parallel.py | 204 +++++++++++------- .../test_fsdp_optimizer_utils.py | 83 +++++-- 2 files changed, 188 insertions(+), 99 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 787c617fc..de01d9807 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -237,7 +237,7 @@ def __init__( self._return_full_state_dict = True @property - def module(self) -> nn.Module: + def module(self) -> Union[FlattenParamsWrapper, nn.Module]: return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance @torch.no_grad() @@ -1279,10 +1279,18 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None traceback.print_stack() raise ValueError(msg) - # Optim State dict interfaces - def consolidate_optim_state_dict( - self, optim: torch.optim.Optimizer, recipient_rank: int = 0 - ) -> List[Dict[str, Any]]: + # Optim State dict functions + def get_num_padded_from_children(self) -> List[List[int]]: + np = [] + for name, module in self.named_modules(): + assert name not in np + is_self = name == "" + if not isinstance(module, FullyShardedDataParallel): + continue + np.append(module.num_padded) + return np + + def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> None: """Update the consolidated state_dict list, one per rank. Arguments: @@ -1290,25 +1298,19 @@ def consolidate_optim_state_dict( -1 is a special value, which means that all ranks should have the state .. warning: This needs to be called on all replicas""" - - # Sync lr and other attributes in case its been updated - _default_device = torch.device("cuda") # NOTE(SS): we do not support param groups yet, as these seem to break FSDP - # OSS._sync_param_groups(self.param_groups, optim.param_groups) # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank should_collect_state = self.rank == recipient_rank or recipient_rank == -1 should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 - print(f"rank: {self.rank}, should_collect: {should_collect_state}, should_send {should_send_state}") _all_optimizer_states: List[Dict[str, Any]] = [] for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() - sd["num_padded"] = self.num_padded # Communicate between ranks + sd["num_padded"] = self.get_num_padded_from_children() # Communicate between ranks if should_collect_state: - print(f"{rank} Saving self state keys {list(sd.keys())}") _all_optimizer_states.append( recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) ) @@ -1334,10 +1336,9 @@ def consolidate_optim_state_dict( recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) ) - print(f"State from rank {rank} received: {_all_optimizer_states[-1]}") - return _all_optimizer_states + self._all_optimizer_states = _all_optimizer_states - def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> Dict[str, Any]: + def gather_full_optim_state_dict(self) -> Dict[str, Any]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. @@ -1359,69 +1360,98 @@ def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_r """ if not self.flatten_parameters: raise NotImplementedError("optim state dict requires flatten_parameters=True") - - _all_optimizer_states: List[Dict] = self.consolidate_optim_state_dict(optim, recipient_rank=recipient_rank) + if len(self._all_optimizer_states) == 0: + raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict") # Unify the shard states by concatenating tensors and unflattening params - sd0 = _all_optimizer_states[0] - assert "num_padded" in sd0 - all_num_padded = [s.pop("num_padded")[0] for s in _all_optimizer_states] - assert all_num_padded[0] == 0, f"this code assumes rank 0 param not padded {all_num_padded[0]}" - assert len(sd0["param_groups"]) == 1, "not yet supported" + combined_state = self._all_optimizer_states[0] + assert "num_padded" in combined_state + world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states] + assert len(combined_state["param_groups"]) == 1, "not yet supported" + + for param_id in combined_state["state"]: + combined_state["state"][param_id] = { + k: [v] for k, v in combined_state["state"][param_id].items() + } # so we can append + other_states = self._all_optimizer_states[1:] if self.world_size > 1 else [] + constant_state = [self._extract_constant_state(combined_state, pid) for pid in combined_state["state"]] + for rank, s in enumerate(other_states): + for param_id in s["state"]: + # Update the state, if any + for k in s["state"][param_id]: + new_entry = s["state"][param_id][k] + combined_state["state"][param_id][k].append(new_entry) + + # cleanup all_optimizer_states_list + self._all_optimizer_states = [] + + new_state_dict = {"state": {}, "param_groups": copy.deepcopy(combined_state["param_groups"])} + instance_list: List[FullyShardedDataParallel] = [ + m for m in self.modules() if isinstance(m, FullyShardedDataParallel) + ] + numels_flatten = [sum(m._param_numels) for m in instance_list] - for pg0 in sd0["param_groups"]: - for param_id in pg0["params"]: - if param_id in sd0["state"]: - sd0["state"][param_id] = {k: [v] for k, v in sd0["state"][param_id].items()} # so we can append + # loop over parameters in state. + # If they are tensors, unpad each rank properly, concatenate it, and then + # call _get_param_views. This returns multiple tensors + # each of which is a new parameter with a new "global" id. - other_states = _all_optimizer_states[1:] if self.world_size > 1 else [] - for rank, s in enumerate(other_states): - for local_pg in s["param_groups"]: - for local_param_index in local_pg["params"]: - # Update the state, if any - if local_param_index in s["state"].keys(): - for k in s["state"][local_param_index]: - new_entry = s["state"][local_param_index][k] - sd0["state"][local_param_index][k].append(new_entry) - - # Concatenate everything - for pg_id, pg0 in enumerate(sd0["param_groups"]): - n_params = 0 + local_to_global_param_id: Dict[ + int, List[int] + ] = {} # local ids are after in the current state, global_ids will be in returned state. + + global_param_id = 0 # gets incremented + for pg_id, pg0 in enumerate(combined_state["param_groups"]): for param_id in pg0["params"]: - if param_id not in sd0["state"]: + local_to_global_param_id[param_id] = [] + if param_id not in combined_state["state"]: continue - assert param_id == 0 - # This attempts to undo the work of shard_parameters. - # It assumes self.flatten_parameters=True - constant_state = self._extract_constant_state(sd0, param_id) - for k, v in sd0["state"][param_id].items(): - assert isinstance(v, list), f"expected list, got {v}" - if k in constant_state: + # undo the work of shard_parameters + for k, v in combined_state["state"][param_id].items(): + if k in constant_state[param_id]: continue - - v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, all_num_padded)] + assert isinstance(v, list), f"expected list, got {k}:{v} for {param_id} at rank {self.rank}" + assert all(len(s[param_id]) == 1 for s in world_pad_info) + pad_info = [s[param_id][0] for s in world_pad_info] + assert len(pad_info) == self.world_size == len(v), f"{len(pad_info), self.world_size, len(v)}" + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] flat_buffer = torch.cat(v_unpad) - flat_buffer = self.module.get_param_views(flat_buffer) - for i, entry in enumerate(flat_buffer): - if i not in sd0["state"]: - sd0["state"][i] = {} - sd0["state"][i][k] = entry - sd0["state"][i].update(constant_state) - n_params = max(i, n_params) - - if n_params > 0: - sd0["param_groups"][pg_id]["params"] = list(range(n_params + 1)) + assert ( + numels_flatten[param_id] == flat_buffer.shape[0] == flat_buffer.numel() + ), f"{numels_flatten[param_id]},{flat_buffer.shape[0]}, {flat_buffer.numel()}" + param_views: Generator[torch.Tensor] = instance_list[param_id].get_param_views(flat_buffer) + for i, param_view in enumerate(param_views): + if i == len(local_to_global_param_id[param_id]): + # We have not seen this global param before, and make a new ID + local_to_global_param_id[param_id].append(global_param_id) + global_param_id += 1 + cur_pid = local_to_global_param_id[param_id][i] + if cur_pid not in new_state_dict["state"]: + new_state_dict["state"][cur_pid] = copy.deepcopy(constant_state[param_id]) + assert k not in new_state_dict["state"][cur_pid] + new_state_dict["state"][cur_pid][k] = param_view + + if global_param_id == 0: # stateless optimizer + num_params = sum([len(m._param_numels) for m in instance_list]) + new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) else: - sd0["param_groups"][pg_id]["params"] = list(range(len(self.module._param_infos))) + new_state_dict["param_groups"][pg_id]["params"] = list(range(global_param_id)) + + global_to_local_id = {} + for old_pid, global_param_id in local_to_global_param_id.items(): + for nid in global_param_id: + global_to_local_id[nid] = old_pid + new_state_dict["param_id_map"] = global_to_local_id # Make sure that the parameters are sorted in the state, as expected for a pytorch dict - sd0["state"] = dict(sorted(sd0["state"].items())) - return sd0 + new_state_dict["state"] = dict(sorted(new_state_dict["state"].items())) + return new_state_dict @staticmethod def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. for k, v in sd0["state"][param_id].items(): + if torch.is_tensor(v[0]): continue elif len(set(v)) == 1: @@ -1431,10 +1461,12 @@ def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: return constant_state def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict[str, Any]: + self.validate_nesting_unchanged(full_optim_state_dict) + stateless = len(full_optim_state_dict["state"]) == 0 + mlist = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] if self.flatten_parameters: full_optim_state_dict = self.flatten_optim_state_dict(full_optim_state_dict) - assert len(full_optim_state_dict["state"]) <= 1 - assert len(full_optim_state_dict["param_groups"][0]["params"]) == 1 + assert stateless or len(full_optim_state_dict["state"]) == len(mlist) # get the portion of dict associated with the shard for id, s in full_optim_state_dict["state"].items(): @@ -1447,31 +1479,47 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict[str, An return full_optim_state_dict + def validate_nesting_unchanged(self, sd): + # This should be removed, could be wasteful + if not sd["state"]: + return + n_parameters = len(list(self.parameters())) + n_params_in_optimizer = len(set(sd["param_id_map"].values())) + essay = f"Including itself, this model has {n_parameters} nested instances. When the optimizer state was saved, however, there were only {n_params_in_optimizer}" + assert n_parameters == n_params_in_optimizer, essay + @staticmethod def flatten_optim_state_dict(sd: Dict) -> Dict: - flat_params = defaultdict(list) + param_id_map = sd["param_id_map"] + npars_final = len(set(param_id_map.values())) + if sd["state"]: + new_state = {consolidated_pid: defaultdict(list) for consolidated_pid in range(npars_final)} + else: + new_state = {} constant_state = {} # self.extract_constant_state(sd, 0) - for _, buffers in sd["state"].items(): - for k, p in buffers.items(): + + # assumes sd sorted + for expanded_pid, buffers in sd["state"].items(): + consolidated_pid = param_id_map[expanded_pid] + for buffer_name, p in buffers.items(): if torch.is_tensor(p): - flat_params[k].append(p.reshape(-1)) + new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) else: assert isinstance(p, int) - constant_state[k] = p + constant_state[buffer_name] = p # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check - - state = {0: constant_state} - for k, v in flat_params.items(): - state[0][k] = torch.cat(v) - assert state[0][k].dim() == 1, state[0][k].dim() + new_state = {k: dict(v) for k, v in new_state.items()} + for consolidated_pid, state in new_state.items(): + for buffer_name, tensors in state.items(): + new_state[consolidated_pid][buffer_name] = torch.cat(tensors) + new_state[consolidated_pid].update(constant_state) + new_sd = {"state": new_state, "param_groups": sd["param_groups"]} # Do not put empty state for stateless optimizers - sd["state"] = state if state != {0: {}} else {} - for pg_id, _ in enumerate(sd["param_groups"]): - sd["param_groups"][pg_id]["params"] = list(range(1)) + new_sd["param_groups"][pg_id]["params"] = list(range(npars_final)) - return sd + return new_sd @torch.no_grad() diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 81feff031..69b32f625 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -3,10 +3,25 @@ from parameterized import parameterized import torch +from fairscale.nn import FullyShardedDataParallel from fairscale.optim.utils import recursive_copy_to_device from fairscale.utils.testing import objects_are_equal -from .test_fsdp import DistributedTest, DummyProcessGroup, TransformerWithSharedParams, rename_test, spawn_and_init +from .test_fsdp import ( + DistributedTest, + DummyProcessGroup, + NestedWrappedModule, + TransformerWithSharedParams, + rename_test, + spawn_and_init, +) + + +def first_tensor_shape(dct): + for k, v in dct.items(): + if torch.is_tensor(v): + return v.numel() + raise ValueError("found no tensors") def assert_equal(a, b): @@ -16,24 +31,32 @@ def assert_equal(a, b): class TestOptimizerUtils(DistributedTest): @parameterized.expand( [ - [functools.partial(torch.optim.SGD, momentum=0.9)], - [torch.optim.SGD], - [torch.optim.Adam], - [torch.optim.Adadelta], + [functools.partial(torch.optim.SGD, momentum=0.9), False], + [torch.optim.SGD, False], + [torch.optim.Adam, False], + [torch.optim.Adadelta, True], ], name_func=rename_test, ) - def test_consolidate_optimizer(self, optim_fn): - config = {"mixed_precision": True} - test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=optim_fn) + def test_consolidate_optimizer(self, optim_fn, transformer): + config = {"mixed_precision": True, "flatten_parameters": True} + test_fn = functools.partial( + self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer + ) spawn_and_init(test_fn) @classmethod - def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD): + def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. - fsdp = self.get_wrapped_model(group, cuda_first=False, config=config) - unwrapped_model = TransformerWithSharedParams(group).cuda() + + if transformer: + fsdp = self.get_wrapped_model(group, config=config).cuda() + unwrapped_model = TransformerWithSharedParams(group).cuda() + else: + fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() + unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() + try: fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) @@ -44,15 +67,14 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim fsdp_optim.zero_grad() optim_unwrapped.zero_grad() - src_ids, tgt_ids = fsdp.module.get_input(torch.device("cuda")) - output = fsdp(src_ids, tgt_ids) - loss = fsdp.module.get_loss((src_ids, tgt_ids), output).to("cuda") + x = fsdp.module.get_input(torch.device("cuda")) + output = fsdp(*x) + loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() - # fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) - output = unwrapped_model(src_ids, tgt_ids) - loss = unwrapped_model.get_loss((src_ids, tgt_ids), output) + output = unwrapped_model(*x) + loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() @@ -60,19 +82,38 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim n_pars = len(list(unwrapped_model.parameters())) # torch.save(fsdp._all_optimizer_states, f"all_optim_states_world_size_{fsdp.world_size}.pt") - sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=-1) + fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) + # first_key = unwrapped_sd['state'][0].keys() + + if rank > 0: + return + + sd = fsdp.gather_full_optim_state_dict() + # optim_par = sum(v['square_avg'].numel() for k, v in sd.items()) # assert_equal(len(fsdp._all_optimizer_states), fsdp.world_size) torch.save(sd, f"fsdp_consolidated_{fsdp.world_size}.pt") - assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) + assert_equal( + sum([first_tensor_shape(v) for k, v in sd["state"].items()]), + sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]), + ) + shard_sd = fsdp.get_shard_from_optim_state_dict(sd) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) - assert objects_are_equal( - shard_sd, recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") + assert_equal(shard_sd.keys(), original_shard_sd.keys()) + torch.save(shard_sd, f"new_shard_{fsdp.world_size}.pt") + original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") + + assert_equal( + sum([first_tensor_shape(v) for k, v in shard_sd["state"].items()]), + sum([first_tensor_shape(v) for k, v in original_shard_sd["state"].items()]), ) + if shard_sd["state"]: + assert objects_are_equal(shard_sd["state"][0], original_shard_sd["state"][0]) + assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) def test_named_params_ordering(self): """Test assumption of consolidate_optimizer_state_dict""" From 0fc045d64e070b86d59a1a4fd3c725903602ea6b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 21 Mar 2021 17:36:45 -0400 Subject: [PATCH 13/31] passing, lint attempt --- .../fully_sharded_data_parallel.py | 133 +++++++++--------- .../test_fsdp_optimizer_utils.py | 8 +- 2 files changed, 67 insertions(+), 74 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index de01d9807..4dbc65d1a 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from collections import defaultdict import contextlib import copy from enum import Enum, auto @@ -180,6 +179,8 @@ def __init__( self.buffer_dtype = buffer_dtype or self.compute_dtype self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb + self.num_padded: List[int] = [] + self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state if self.fp32_reduce_scatter and not self.mixed_precision: raise ValueError("fp32_reduce_scatter requires mixed_precision=True") @@ -237,7 +238,7 @@ def __init__( self._return_full_state_dict = True @property - def module(self) -> Union[FlattenParamsWrapper, nn.Module]: + def module(self) -> nn.Module: return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance @torch.no_grad() @@ -1280,15 +1281,6 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None raise ValueError(msg) # Optim State dict functions - def get_num_padded_from_children(self) -> List[List[int]]: - np = [] - for name, module in self.named_modules(): - assert name not in np - is_self = name == "" - if not isinstance(module, FullyShardedDataParallel): - continue - np.append(module.num_padded) - return np def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> None: """Update the consolidated state_dict list, one per rank. @@ -1299,7 +1291,7 @@ def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_r .. warning: This needs to be called on all replicas""" _default_device = torch.device("cuda") - # NOTE(SS): we do not support param groups yet, as these seem to break FSDP + # NOTE(SS): we do not support param groups yet, as they seem to break FSDP # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank @@ -1309,7 +1301,7 @@ def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_r for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() - sd["num_padded"] = self.get_num_padded_from_children() # Communicate between ranks + sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)] if should_collect_state: _all_optimizer_states.append( recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) @@ -1364,50 +1356,37 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict") # Unify the shard states by concatenating tensors and unflattening params - combined_state = self._all_optimizer_states[0] - assert "num_padded" in combined_state world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states] - assert len(combined_state["param_groups"]) == 1, "not yet supported" - - for param_id in combined_state["state"]: - combined_state["state"][param_id] = { - k: [v] for k, v in combined_state["state"][param_id].items() - } # so we can append - other_states = self._all_optimizer_states[1:] if self.world_size > 1 else [] - constant_state = [self._extract_constant_state(combined_state, pid) for pid in combined_state["state"]] - for rank, s in enumerate(other_states): - for param_id in s["state"]: - # Update the state, if any - for k in s["state"][param_id]: - new_entry = s["state"][param_id][k] - combined_state["state"][param_id][k].append(new_entry) + + # constant_state refers to entries in sd[state][param_id] that are not tensors + + param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"]) + combined_state = self._combine_tensor_optim_state([x["state"] for x in self._all_optimizer_states]) + constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state] # cleanup all_optimizer_states_list self._all_optimizer_states = [] - new_state_dict = {"state": {}, "param_groups": copy.deepcopy(combined_state["param_groups"])} - instance_list: List[FullyShardedDataParallel] = [ - m for m in self.modules() if isinstance(m, FullyShardedDataParallel) - ] - numels_flatten = [sum(m._param_numels) for m in instance_list] + new_state_dict = {"state": {}, "param_groups": param_groups} + instance_list = self._fsdp_instances + numels_per_instance = [sum(m._param_numels) for m in instance_list] # noqa # loop over parameters in state. # If they are tensors, unpad each rank properly, concatenate it, and then # call _get_param_views. This returns multiple tensors # each of which is a new parameter with a new "global" id. - local_to_global_param_id: Dict[ - int, List[int] - ] = {} # local ids are after in the current state, global_ids will be in returned state. + local_to_global_param_id: Dict[int, List[int]] = {} + # local ids are after in the current state, global_ids will be in returned state. global_param_id = 0 # gets incremented - for pg_id, pg0 in enumerate(combined_state["param_groups"]): + for pg_id, pg0 in enumerate(param_groups): for param_id in pg0["params"]: local_to_global_param_id[param_id] = [] - if param_id not in combined_state["state"]: + if param_id not in combined_state: continue # undo the work of shard_parameters - for k, v in combined_state["state"][param_id].items(): + for k, v in combined_state[param_id].items(): if k in constant_state[param_id]: continue assert isinstance(v, list), f"expected list, got {k}:{v} for {param_id} at rank {self.rank}" @@ -1417,9 +1396,9 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] flat_buffer = torch.cat(v_unpad) assert ( - numels_flatten[param_id] == flat_buffer.shape[0] == flat_buffer.numel() - ), f"{numels_flatten[param_id]},{flat_buffer.shape[0]}, {flat_buffer.numel()}" - param_views: Generator[torch.Tensor] = instance_list[param_id].get_param_views(flat_buffer) + numels_per_instance[param_id] == flat_buffer.shape[0] == flat_buffer.numel() + ), f"{numels_per_instance[param_id]},{flat_buffer.shape[0]}, {flat_buffer.numel()}" + param_views: Generator = instance_list[param_id].get_param_views(flat_buffer) for i, param_view in enumerate(param_views): if i == len(local_to_global_param_id[param_id]): # We have not seen this global param before, and make a new ID @@ -1432,14 +1411,14 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: new_state_dict["state"][cur_pid][k] = param_view if global_param_id == 0: # stateless optimizer - num_params = sum([len(m._param_numels) for m in instance_list]) + num_params = sum([len(m._param_numels) for m in instance_list]) # noqa new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) else: new_state_dict["param_groups"][pg_id]["params"] = list(range(global_param_id)) global_to_local_id = {} - for old_pid, global_param_id in local_to_global_param_id.items(): - for nid in global_param_id: + for old_pid, global_ids in local_to_global_param_id.items(): + for nid in global_ids: global_to_local_id[nid] = old_pid new_state_dict["param_id_map"] = global_to_local_id @@ -1447,10 +1426,23 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: new_state_dict["state"] = dict(sorted(new_state_dict["state"].items())) return new_state_dict + def _combine_tensor_optim_state(self, states: List[Dict]) -> Dict[int, Dict]: + combined_state = states[0] + for param_id in combined_state: + combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} + if self.world_size == 1: + return combined_state + + for rank, s in enumerate(states[1:]): + for param_id, param_state in s.items(): + for k, tensor in param_state.items(): + combined_state[param_id][k].append(tensor) + return combined_state + @staticmethod def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. - for k, v in sd0["state"][param_id].items(): + for k, v in sd0[param_id].items(): if torch.is_tensor(v[0]): continue @@ -1460,12 +1452,25 @@ def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: raise TypeError(f"Dont know how to expand optimizer param {k} with value {v}") return constant_state - def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict[str, Any]: - self.validate_nesting_unchanged(full_optim_state_dict) + @property + def _fsdp_instances(self) -> List[nn.Module]: + """Returns all fsdp modules including self.""" + assert self._is_root + return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] + + def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: + """Get the portion of the optimizer state dict associated with the shard""" + # Assert nesting is the same as it was at save time + n_instances = len(self._fsdp_instances) + n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) + msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}" + stateless = len(full_optim_state_dict["state"]) == 0 + assert stateless or (n_instances == n_local_params_in_opt), msg + stateless = len(full_optim_state_dict["state"]) == 0 - mlist = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] + mlist = self._fsdp_instances if self.flatten_parameters: - full_optim_state_dict = self.flatten_optim_state_dict(full_optim_state_dict) + full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict) assert stateless or len(full_optim_state_dict["state"]) == len(mlist) # get the portion of dict associated with the shard @@ -1479,45 +1484,37 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict) -> Dict[str, An return full_optim_state_dict - def validate_nesting_unchanged(self, sd): - # This should be removed, could be wasteful - if not sd["state"]: - return - n_parameters = len(list(self.parameters())) - n_params_in_optimizer = len(set(sd["param_id_map"].values())) - essay = f"Including itself, this model has {n_parameters} nested instances. When the optimizer state was saved, however, there were only {n_params_in_optimizer}" - assert n_parameters == n_params_in_optimizer, essay - @staticmethod - def flatten_optim_state_dict(sd: Dict) -> Dict: + def _flatten_optim_state_dict(sd: Dict) -> Dict: param_id_map = sd["param_id_map"] - npars_final = len(set(param_id_map.values())) + num_local_params = len(set(param_id_map.values())) if sd["state"]: - new_state = {consolidated_pid: defaultdict(list) for consolidated_pid in range(npars_final)} + new_state: Dict = {local_id: {} for local_id in range(num_local_params)} else: new_state = {} - constant_state = {} # self.extract_constant_state(sd, 0) + constant_state = {} # assumes sd sorted for expanded_pid, buffers in sd["state"].items(): consolidated_pid = param_id_map[expanded_pid] for buffer_name, p in buffers.items(): if torch.is_tensor(p): + if buffer_name not in new_state[consolidated_pid]: + new_state[consolidated_pid][buffer_name] = [] new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) else: - assert isinstance(p, int) + assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" constant_state[buffer_name] = p # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check - new_state = {k: dict(v) for k, v in new_state.items()} + for consolidated_pid, state in new_state.items(): for buffer_name, tensors in state.items(): new_state[consolidated_pid][buffer_name] = torch.cat(tensors) new_state[consolidated_pid].update(constant_state) new_sd = {"state": new_state, "param_groups": sd["param_groups"]} - # Do not put empty state for stateless optimizers for pg_id, _ in enumerate(sd["param_groups"]): - new_sd["param_groups"][pg_id]["params"] = list(range(npars_final)) + new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) return new_sd diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 69b32f625..57e8cf406 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -2,6 +2,7 @@ from parameterized import parameterized import torch +from torch.optim import SGD, Adadelta, Adam # noqa from fairscale.nn import FullyShardedDataParallel from fairscale.optim.utils import recursive_copy_to_device @@ -30,12 +31,7 @@ def assert_equal(a, b): class TestOptimizerUtils(DistributedTest): @parameterized.expand( - [ - [functools.partial(torch.optim.SGD, momentum=0.9), False], - [torch.optim.SGD, False], - [torch.optim.Adam, False], - [torch.optim.Adadelta, True], - ], + [[functools.partial(SGD, momentum=0.9), False], [SGD, False], [Adam, False], [Adadelta, True]], name_func=rename_test, ) def test_consolidate_optimizer(self, optim_fn, transformer): From 363527796f3a9dc60192ba1ac59b32a1b34abb04 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 21 Mar 2021 17:39:37 -0400 Subject: [PATCH 14/31] update test list --- tests/ci_test_list_3.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ci_test_list_3.txt b/tests/ci_test_list_3.txt index 8e6cd773e..3010b6b30 100644 --- a/tests/ci_test_list_3.txt +++ b/tests/ci_test_list_3.txt @@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_regnet.py +tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/pipe/skip/test_gpipe.py From dbb426f7f755749b9937364dbaa610c67f9fae8d Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Mar 2021 10:58:24 -0400 Subject: [PATCH 15/31] mypy --- .../fully_sharded_data_parallel.py | 106 +++++++++--------- fairscale/nn/misc/flatten_params_wrapper.py | 3 +- .../test_fsdp_optimizer_utils.py | 2 +- 3 files changed, 54 insertions(+), 57 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 5520d3b4a..d7c15742f 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1404,23 +1404,14 @@ def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_r def gather_full_optim_state_dict(self) -> Dict[str, Any]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the - sharded properties are not exposed. - - - Arguments: - all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on + sharded properties are not exposed. Multiple parameter groups are not yet supported. Returns: a dict with two entries * state - a dict holding current optimization state. Its content differs between optimizer classes. - * param_groups - a dict containing all parameter groups - .. warning: - Returning the global state is limited to the replica which was responsible for the consolidation, - if `all_ranks` was not set to `True`. In that case, the state may also not be up to date, - depending on when `consolidate_state_dict` was last called. """ if not self.flatten_parameters: raise NotImplementedError("optim state dict requires flatten_parameters=True") @@ -1430,10 +1421,16 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: # Unify the shard states by concatenating tensors and unflattening params world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states] - # constant_state refers to entries in sd[state][param_id] that are not tensors - param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"]) - combined_state = self._combine_tensor_optim_state([x["state"] for x in self._all_optimizer_states]) + + # combined_state refers to tensor values in sd[state][param_id]. + # Here we just aggregate them into a list inside the dictionary from a list of dictionaries. + combined_state = self._combine_tensor_optim_state( + [x["state"] for x in self._all_optimizer_states], self.world_size + ) + + # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" + # we check that these are identical across workers and then take the first constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state] # cleanup all_optimizer_states_list @@ -1441,68 +1438,70 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: new_state_dict = {"state": {}, "param_groups": param_groups} instance_list = self._fsdp_instances - numels_per_instance = [sum(m._param_numels) for m in instance_list] # noqa + numels_per_instance = [sum(m._param_numels) for m in instance_list] # type: ignore # loop over parameters in state. - # If they are tensors, unpad each rank properly, concatenate it, and then - # call _get_param_views. This returns multiple tensors - # each of which is a new parameter with a new "global" id. + # Tensor state will be padded, concatenated, and then restored to their original + # shape with FlattenParamsWrapper.get_views + # get_views multiple tensors, each of which is a new parameter with a new "global" id. local_to_global_param_id: Dict[int, List[int]] = {} - # local ids are after in the current state, global_ids will be in returned state. + # local ids are in the current state, global_ids will be in returned state. - global_param_id = 0 # gets incremented - for pg_id, pg0 in enumerate(param_groups): - for param_id in pg0["params"]: - local_to_global_param_id[param_id] = [] - if param_id not in combined_state: + next_global_param_id = 0 # gets incremented + for pg_id, param_group in enumerate(param_groups): + for local_id in param_group["params"]: + local_to_global_param_id[local_id] = [] + if local_id not in combined_state: continue # undo the work of shard_parameters - for k, v in combined_state[param_id].items(): - if k in constant_state[param_id]: + for k, v in combined_state[local_id].items(): + if k in constant_state[local_id]: continue - assert isinstance(v, list), f"expected list, got {k}:{v} for {param_id} at rank {self.rank}" - assert all(len(s[param_id]) == 1 for s in world_pad_info) - pad_info = [s[param_id][0] for s in world_pad_info] + assert isinstance(v, list), f"expected list, got {k}:{v} for {local_id} at rank {self.rank}" + assert all(len(s[local_id]) == 1 for s in world_pad_info) # because of flatten_parameters + pad_info = [s[local_id][0] for s in world_pad_info] assert len(pad_info) == self.world_size == len(v), f"{len(pad_info), self.world_size, len(v)}" + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] flat_buffer = torch.cat(v_unpad) assert ( - numels_per_instance[param_id] == flat_buffer.shape[0] == flat_buffer.numel() - ), f"{numels_per_instance[param_id]},{flat_buffer.shape[0]}, {flat_buffer.numel()}" - param_views: Generator = instance_list[param_id].get_param_views(flat_buffer) + numels_per_instance[local_id] == flat_buffer.shape[0] == flat_buffer.numel() + ), f"{numels_per_instance[local_id]} {flat_buffer.shape[0]}, {flat_buffer.numel()}" + param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) for i, param_view in enumerate(param_views): - if i == len(local_to_global_param_id[param_id]): + if i == len(local_to_global_param_id[local_id]): # We have not seen this global param before, and make a new ID - local_to_global_param_id[param_id].append(global_param_id) - global_param_id += 1 - cur_pid = local_to_global_param_id[param_id][i] - if cur_pid not in new_state_dict["state"]: - new_state_dict["state"][cur_pid] = copy.deepcopy(constant_state[param_id]) - assert k not in new_state_dict["state"][cur_pid] - new_state_dict["state"][cur_pid][k] = param_view - - if global_param_id == 0: # stateless optimizer - num_params = sum([len(m._param_numels) for m in instance_list]) # noqa + local_to_global_param_id[local_id].append(next_global_param_id) + next_global_param_id += 1 + global_id = local_to_global_param_id[local_id][i] + if global_id not in new_state_dict["state"]: + new_state_dict["state"][global_id] = copy.deepcopy(constant_state[local_id]) + assert k not in new_state_dict["state"][global_id], f"already added {k} to new[{global_id}]" + new_state_dict["state"][global_id][k] = param_view + + if next_global_param_id == 0: # stateless optimizer + num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) else: - new_state_dict["param_groups"][pg_id]["params"] = list(range(global_param_id)) + new_state_dict["param_groups"][pg_id]["params"] = list(range(next_global_param_id)) global_to_local_id = {} for old_pid, global_ids in local_to_global_param_id.items(): - for nid in global_ids: - global_to_local_id[nid] = old_pid + for new_id in global_ids: + global_to_local_id[new_id] = old_pid new_state_dict["param_id_map"] = global_to_local_id # Make sure that the parameters are sorted in the state, as expected for a pytorch dict new_state_dict["state"] = dict(sorted(new_state_dict["state"].items())) return new_state_dict - def _combine_tensor_optim_state(self, states: List[Dict]) -> Dict[int, Dict]: + @staticmethod + def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]: combined_state = states[0] for param_id in combined_state: combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} - if self.world_size == 1: + if world_size == 1: return combined_state for rank, s in enumerate(states[1:]): @@ -1512,22 +1511,21 @@ def _combine_tensor_optim_state(self, states: List[Dict]) -> Dict[int, Dict]: return combined_state @staticmethod - def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: + def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. - for k, v in sd0[param_id].items(): + for k, v in combined_state[param_id].items(): if torch.is_tensor(v[0]): continue elif len(set(v)) == 1: constant_state[k] = v[0] else: - raise TypeError(f"Dont know how to expand optimizer param {k} with value {v}") + raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") return constant_state @property def _fsdp_instances(self) -> List[nn.Module]: - """Returns all fsdp modules including self.""" - assert self._is_root + """Returns all fsdp modules in self.modules() including self.""" return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: @@ -1540,10 +1538,10 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) assert stateless or (n_instances == n_local_params_in_opt), msg stateless = len(full_optim_state_dict["state"]) == 0 - mlist = self._fsdp_instances + instance_list = self._fsdp_instances if self.flatten_parameters: full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict) - assert stateless or len(full_optim_state_dict["state"]) == len(mlist) + assert stateless or len(full_optim_state_dict["state"]) == len(instance_list) # get the portion of dict associated with the shard for id, s in full_optim_state_dict["state"].items(): diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index bf4cf40dd..9db2f6ab6 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -123,8 +123,7 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: self._unflatten_params_as_views() def get_param_views(self, flat_param: Tensor) -> Generator: - splat = flat_param.split(self._param_numels) - return (t.view(s) for (t, s) in zip(splat, self._param_shapes)) + return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: assert self.is_flattened or flat_param is not None diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 57e8cf406..535dc39ef 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -2,7 +2,7 @@ from parameterized import parameterized import torch -from torch.optim import SGD, Adadelta, Adam # noqa +from torch.optim import SGD, Adadelta, Adam # type: ignore from fairscale.nn import FullyShardedDataParallel from fairscale.optim.utils import recursive_copy_to_device From f537632318e96f79423a1e065472f1d7b2d7d4bb Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Mar 2021 13:48:36 -0400 Subject: [PATCH 16/31] Simpler consolidate_optim_state_dict --- .../fully_sharded_data_parallel.py | 55 ++++++------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index d7c15742f..da314ecb4 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -21,7 +21,7 @@ from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap -from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device +from fairscale.optim.utils import calc_grad_norm, recursive_copy_to_device from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer @@ -436,6 +436,7 @@ def _shard_parameters_(self) -> None: p.data, num_padded = self._get_shard(p.data) self.num_padded.append(num_padded) free_storage_(orig_data) + assert len(self.num_padded) == len(self.params) def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: """Return the local shard of a full tensor.""" @@ -1354,53 +1355,33 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None # Optim State dict functions - def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> None: - """Update the consolidated state_dict list, one per rank. + def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None) -> None: + """Update the consolidated state_dict list, one per rank. The result is at self._all_optimizer_states - Arguments: + Args: recipient_rank (int): on which rank to materialize the full state dict. - -1 is a special value, which means that all ranks should have the state + None is a special value, which means that all ranks should have the state .. warning: This needs to be called on all replicas""" - _default_device = torch.device("cuda") + self._lazy_init() # NOTE(SS): we do not support param groups yet, as they seem to break FSDP - # Pull the sharded state from all the other replicas # Store all the states in order, rank by rank - should_collect_state = self.rank == recipient_rank or recipient_rank == -1 - should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 - _all_optimizer_states: List[Dict[str, Any]] = [] + should_collect_state = recipient_rank is None or (self.rank == recipient_rank) + all_states: List[Dict[str, Any]] = [] for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() - sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)] - if should_collect_state: - _all_optimizer_states.append( - recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) - ) - - # Sync with other replicas - state_to_share = ( - sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device) - ) - broadcast_object( - state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device, - ) + sd["num_padded"] = [m.num_padded for m in self._fsdp_instances] else: - # Fetch the optim state from the other replicas - replica_state = broadcast_object( - torch.tensor([0], dtype=torch.uint8, device=_default_device), - src_rank=rank, - group=self.process_group, - dist_device=_default_device, - ) - - if should_collect_state: - _all_optimizer_states.append( - recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) - ) - - self._all_optimizer_states = _all_optimizer_states + sd = None # type: ignore + obj_lst = [sd] + torch.distributed.broadcast_object_list(obj_lst, src=rank, group=self.process_group) + if should_collect_state: + assert isinstance(obj_lst[0], dict), f"{rank}, {self.rank} {all_states}" + all_states.append(recursive_copy_to_device(obj_lst[0], non_blocking=False, device=torch.device("cpu"))) + + self._all_optimizer_states = all_states def gather_full_optim_state_dict(self) -> Dict[str, Any]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the From a04b406d3a214e8af24759dc045e115fdf3bc494 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Mar 2021 15:24:29 -0400 Subject: [PATCH 17/31] slightly cleaner --- .../fully_sharded_data_parallel.py | 81 +++++++++---------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index da314ecb4..bf5950981 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1400,9 +1400,13 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict") # Unify the shard states by concatenating tensors and unflattening params - world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states] + world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in self._all_optimizer_states] + instance_list: List[nn.Module] = self._fsdp_instances + assert all(len(s) == len(instance_list) for s in world_pad_info) + assert all(len(s[0]) == 1 for s in world_pad_info) param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"]) + assert len(param_groups) == 1 # combined_state refers to tensor values in sd[state][param_id]. # Here we just aggregate them into a list inside the dictionary from a list of dictionaries. @@ -1418,7 +1422,7 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: self._all_optimizer_states = [] new_state_dict = {"state": {}, "param_groups": param_groups} - instance_list = self._fsdp_instances + numels_per_instance = [sum(m._param_numels) for m in instance_list] # type: ignore # loop over parameters in state. @@ -1426,51 +1430,38 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: # shape with FlattenParamsWrapper.get_views # get_views multiple tensors, each of which is a new parameter with a new "global" id. - local_to_global_param_id: Dict[int, List[int]] = {} # local ids are in the current state, global_ids will be in returned state. - - next_global_param_id = 0 # gets incremented - for pg_id, param_group in enumerate(param_groups): - for local_id in param_group["params"]: - local_to_global_param_id[local_id] = [] - if local_id not in combined_state: + local_to_global_param_id: Dict[int, List[int]] = {} + next_global_id = 0 # gets incremented + for local_id in combined_state: + local_to_global_param_id[local_id] = [] + # undo the work of shard_parameters + for k, v in combined_state[local_id].items(): + if k in constant_state[local_id]: continue - # undo the work of shard_parameters - for k, v in combined_state[local_id].items(): - if k in constant_state[local_id]: - continue - assert isinstance(v, list), f"expected list, got {k}:{v} for {local_id} at rank {self.rank}" - assert all(len(s[local_id]) == 1 for s in world_pad_info) # because of flatten_parameters - pad_info = [s[local_id][0] for s in world_pad_info] - assert len(pad_info) == self.world_size == len(v), f"{len(pad_info), self.world_size, len(v)}" - - v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] - flat_buffer = torch.cat(v_unpad) - assert ( - numels_per_instance[local_id] == flat_buffer.shape[0] == flat_buffer.numel() - ), f"{numels_per_instance[local_id]} {flat_buffer.shape[0]}, {flat_buffer.numel()}" - param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) - for i, param_view in enumerate(param_views): - if i == len(local_to_global_param_id[local_id]): - # We have not seen this global param before, and make a new ID - local_to_global_param_id[local_id].append(next_global_param_id) - next_global_param_id += 1 - global_id = local_to_global_param_id[local_id][i] - if global_id not in new_state_dict["state"]: - new_state_dict["state"][global_id] = copy.deepcopy(constant_state[local_id]) - assert k not in new_state_dict["state"][global_id], f"already added {k} to new[{global_id}]" - new_state_dict["state"][global_id][k] = param_view - - if next_global_param_id == 0: # stateless optimizer - num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore - new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) - else: - new_state_dict["param_groups"][pg_id]["params"] = list(range(next_global_param_id)) - - global_to_local_id = {} - for old_pid, global_ids in local_to_global_param_id.items(): - for new_id in global_ids: - global_to_local_id[new_id] = old_pid + assert isinstance(v, list), f"got {k}: {v} for {local_id} at rank {self.rank}" + pad_info = [s[local_id][0] for s in world_pad_info] + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] + flat_buffer = torch.cat(v_unpad) + assert numels_per_instance[local_id] == flat_buffer.numel() + param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore + for i, param_view in enumerate(param_views): + if i == len(local_to_global_param_id[local_id]): # make a new ID + local_to_global_param_id[local_id].append(next_global_id) + next_global_id += 1 + global_id = local_to_global_param_id[local_id][i] + if global_id not in new_state_dict["state"]: + new_state_dict["state"][global_id] = copy.deepcopy(constant_state[local_id]) + + assert k not in new_state_dict["state"][global_id], f"already added {k} to new[{global_id}]" + new_state_dict["state"][global_id][k] = param_view + + num_params = next_global_id or sum([len(m._param_numels) for m in instance_list]) # type: ignore + new_state_dict["param_groups"][0]["params"] = list(range(num_params)) + + global_to_local_id = { + new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids + } new_state_dict["param_id_map"] = global_to_local_id # Make sure that the parameters are sorted in the state, as expected for a pytorch dict From e5e91dfc75168b2288d85ec1df33c8701fdcf830 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Mar 2021 16:32:24 -0400 Subject: [PATCH 18/31] Simplified signature, helper fn for unflattening --- .../fully_sharded_data_parallel.py | 86 ++++++++++--------- .../test_fsdp_optimizer_utils.py | 12 +-- 2 files changed, 47 insertions(+), 51 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index bf5950981..0ebad54e2 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -187,7 +187,6 @@ def __init__( self.bucket_cap_mb = bucket_cap_mb self.num_padded: List[int] = [] - self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state self.compute_device = compute_device if self.fp32_reduce_scatter and not self.mixed_precision: @@ -1354,9 +1353,10 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None raise ValueError(msg) # Optim State dict functions - - def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None) -> None: - """Update the consolidated state_dict list, one per rank. The result is at self._all_optimizer_states + def _consolidate_optim_state_dict( + self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None + ) -> List[Dict]: + """Update the consolidated state_dict list, one per rank. Args: recipient_rank (int): on which rank to materialize the full state dict. @@ -1381,9 +1381,11 @@ def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_r assert isinstance(obj_lst[0], dict), f"{rank}, {self.rank} {all_states}" all_states.append(recursive_copy_to_device(obj_lst[0], non_blocking=False, device=torch.device("cpu"))) - self._all_optimizer_states = all_states + return all_states - def gather_full_optim_state_dict(self) -> Dict[str, Any]: + def gather_full_optim_state_dict( + self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0 + ) -> Optional[Dict[str, Any]]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Multiple parameter groups are not yet supported. @@ -1396,77 +1398,81 @@ def gather_full_optim_state_dict(self) -> Dict[str, Any]: """ if not self.flatten_parameters: raise NotImplementedError("optim state dict requires flatten_parameters=True") - if len(self._all_optimizer_states) == 0: - raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict") + world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank) + if self.rank != recipient_rank and recipient_rank is not None: + return None # Unify the shard states by concatenating tensors and unflattening params - world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in self._all_optimizer_states] + world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] instance_list: List[nn.Module] = self._fsdp_instances assert all(len(s) == len(instance_list) for s in world_pad_info) assert all(len(s[0]) == 1 for s in world_pad_info) - param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"]) + param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) assert len(param_groups) == 1 # combined_state refers to tensor values in sd[state][param_id]. - # Here we just aggregate them into a list inside the dictionary from a list of dictionaries. - combined_state = self._combine_tensor_optim_state( - [x["state"] for x in self._all_optimizer_states], self.world_size - ) - - # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" - # we check that these are identical across workers and then take the first - constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state] - + # Here we just aggregate them into a dictionary of lists (from a list of dictionaries) + combined_state = self._combine_tensor_optim_state([x["state"] for x in world_optim_states], self.world_size) # cleanup all_optimizer_states_list - self._all_optimizer_states = [] + del world_optim_states new_state_dict = {"state": {}, "param_groups": param_groups} - numels_per_instance = [sum(m._param_numels) for m in instance_list] # type: ignore + # local ids are in the current state, global_ids will be in returned state. + unflat_state, global_to_local_id = self._unflatten_optim_state(combined_state, instance_list, world_pad_info) + + num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore + new_state_dict["param_groups"][0]["params"] = list(range(num_params)) + + new_state_dict["param_id_map"] = global_to_local_id + # Make sure that the parameters are sorted in the state, as expected for a pytorch dict + new_state_dict["state"] = dict(sorted(unflat_state.items())) + return new_state_dict + + @staticmethod + def _unflatten_optim_state( + combined_state: Dict[int, Dict], instance_list: List[nn.Module], world_pad_info: List[List[List[int]]], + ) -> Tuple[Dict[int, Dict], Dict[int, int]]: + local_to_global_param_id: Dict[int, List[int]] = {} + next_global_id = 0 # gets incremented + unflat_state = {} + pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} + + # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" + # we check that these are identical across workers and then take the first + constant_state = [FullyShardedDataParallel._extract_constant_state(combined_state, id) for id in combined_state] # loop over parameters in state. # Tensor state will be padded, concatenated, and then restored to their original # shape with FlattenParamsWrapper.get_views # get_views multiple tensors, each of which is a new parameter with a new "global" id. - - # local ids are in the current state, global_ids will be in returned state. - local_to_global_param_id: Dict[int, List[int]] = {} - next_global_id = 0 # gets incremented for local_id in combined_state: local_to_global_param_id[local_id] = [] # undo the work of shard_parameters for k, v in combined_state[local_id].items(): if k in constant_state[local_id]: continue - assert isinstance(v, list), f"got {k}: {v} for {local_id} at rank {self.rank}" - pad_info = [s[local_id][0] for s in world_pad_info] - v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] + assert isinstance(v, list), f"got {k}: {v} for {local_id}" + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] flat_buffer = torch.cat(v_unpad) - assert numels_per_instance[local_id] == flat_buffer.numel() param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore for i, param_view in enumerate(param_views): if i == len(local_to_global_param_id[local_id]): # make a new ID local_to_global_param_id[local_id].append(next_global_id) next_global_id += 1 global_id = local_to_global_param_id[local_id][i] - if global_id not in new_state_dict["state"]: - new_state_dict["state"][global_id] = copy.deepcopy(constant_state[local_id]) + if global_id not in unflat_state: + unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) - assert k not in new_state_dict["state"][global_id], f"already added {k} to new[{global_id}]" - new_state_dict["state"][global_id][k] = param_view - - num_params = next_global_id or sum([len(m._param_numels) for m in instance_list]) # type: ignore - new_state_dict["param_groups"][0]["params"] = list(range(num_params)) + assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" + unflat_state[global_id][k] = param_view global_to_local_id = { new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids } - new_state_dict["param_id_map"] = global_to_local_id - # Make sure that the parameters are sorted in the state, as expected for a pytorch dict - new_state_dict["state"] = dict(sorted(new_state_dict["state"].items())) - return new_state_dict + return unflat_state, global_to_local_id @staticmethod def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]: diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 535dc39ef..dcf782524 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -75,19 +75,9 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() - n_pars = len(list(unwrapped_model.parameters())) - - # torch.save(fsdp._all_optimizer_states, f"all_optim_states_world_size_{fsdp.world_size}.pt") - fsdp.consolidate_optim_state_dict(fsdp_optim, recipient_rank=0) # first_key = unwrapped_sd['state'][0].keys() + sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=None) - if rank > 0: - return - - sd = fsdp.gather_full_optim_state_dict() - # optim_par = sum(v['square_avg'].numel() for k, v in sd.items()) - # assert_equal(len(fsdp._all_optimizer_states), fsdp.world_size) - torch.save(sd, f"fsdp_consolidated_{fsdp.world_size}.pt") assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( From ea9d4b5ccac0b01e5d266c3b8ce3c41e989a3aa9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Mar 2021 17:11:46 -0400 Subject: [PATCH 19/31] add todo --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 0ebad54e2..0af28fd72 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1562,6 +1562,7 @@ def _flatten_optim_state_dict(sd: Dict) -> Dict: new_sd = {"state": new_state, "param_groups": sd["param_groups"]} for pg_id, _ in enumerate(sd["param_groups"]): + # TODO: this list could be huge. Can we avoid materializing? new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) return new_sd From 47e7cba35263470aa4fea7889012991734d33d06 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 08:38:18 -0400 Subject: [PATCH 20/31] Give CI more time to show me a traceback --- .circleci/config.yml | 6 +++--- tests/nn/data_parallel/test_fsdp_optimizer_utils.py | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 181ca99cd..d1d54a2a5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -167,7 +167,7 @@ run_coverage: &run_coverage - run: name: Run Unit Tests With Coverage command: | - pytest --junitxml=test-results/junit.xml --verbose --timeout 60 --cov-report=xml --cov=./ + pytest --junitxml=test-results/junit.xml --verbose --timeout 120 --cov-report=xml --cov=./ #Uploading test coverage for Python code bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python @@ -233,7 +233,7 @@ run_unittests: &run_unittests name: Run all unit tests. # We run all and not stopping on failure on CPU since docker time is cheaper. command: | - pytest --junitxml=test-results/junit.xml --verbose --timeout 60 + pytest --junitxml=test-results/junit.xml --verbose --timeout 120 commands: @@ -249,7 +249,7 @@ commands: name: Run Unit Tests command: | if [ ! -f <> ]; then exit 1; fi - pytest --junitxml=test-results/junit.xml --verbose --timeout 60 `cat <>` + pytest --junitxml=test-results/junit.xml --verbose --timeout 120 `cat <>` # ------------------------------------------------------------------------------------- # Jobs to run diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index dcf782524..e3f3d84e2 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -1,4 +1,5 @@ import functools +import unittest from parameterized import parameterized import torch @@ -31,7 +32,7 @@ def assert_equal(a, b): class TestOptimizerUtils(DistributedTest): @parameterized.expand( - [[functools.partial(SGD, momentum=0.9), False], [SGD, False], [Adam, False], [Adadelta, True]], + [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False],], name_func=rename_test, ) def test_consolidate_optimizer(self, optim_fn, transformer): @@ -39,7 +40,8 @@ def test_consolidate_optimizer(self, optim_fn, transformer): test_fn = functools.partial( self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer ) - spawn_and_init(test_fn) + + spawn_and_init(test_fn, world_sizes=[1, min(torch.cuda.device_count(), 4)]) @classmethod def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): @@ -76,7 +78,9 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim unwrapped_sd = optim_unwrapped.state_dict() # first_key = unwrapped_sd['state'][0].keys() - sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=None) + sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) + if fsdp.rank > 0: + return assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) From 6cebcecd7cae974f369aed4813ab194b4ed80a71 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 09:48:55 -0400 Subject: [PATCH 21/31] Fix broadcast_object regression --- .circleci/config.yml | 6 +++--- .../nn/data_parallel/fully_sharded_data_parallel.py | 12 ++++++------ tests/nn/data_parallel/test_fsdp_optimizer_utils.py | 13 ++++++++++--- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index d1d54a2a5..181ca99cd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -167,7 +167,7 @@ run_coverage: &run_coverage - run: name: Run Unit Tests With Coverage command: | - pytest --junitxml=test-results/junit.xml --verbose --timeout 120 --cov-report=xml --cov=./ + pytest --junitxml=test-results/junit.xml --verbose --timeout 60 --cov-report=xml --cov=./ #Uploading test coverage for Python code bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python @@ -233,7 +233,7 @@ run_unittests: &run_unittests name: Run all unit tests. # We run all and not stopping on failure on CPU since docker time is cheaper. command: | - pytest --junitxml=test-results/junit.xml --verbose --timeout 120 + pytest --junitxml=test-results/junit.xml --verbose --timeout 60 commands: @@ -249,7 +249,7 @@ commands: name: Run Unit Tests command: | if [ ! -f <> ]; then exit 1; fi - pytest --junitxml=test-results/junit.xml --verbose --timeout 120 `cat <>` + pytest --junitxml=test-results/junit.xml --verbose --timeout 60 `cat <>` # ------------------------------------------------------------------------------------- # Jobs to run diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 0af28fd72..e98430fdd 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -21,7 +21,7 @@ from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap -from fairscale.optim.utils import calc_grad_norm, recursive_copy_to_device +from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer @@ -1369,17 +1369,17 @@ def _consolidate_optim_state_dict( # Store all the states in order, rank by rank should_collect_state = recipient_rank is None or (self.rank == recipient_rank) all_states: List[Dict[str, Any]] = [] + dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() sd["num_padded"] = [m.num_padded for m in self._fsdp_instances] else: - sd = None # type: ignore - obj_lst = [sd] - torch.distributed.broadcast_object_list(obj_lst, src=rank, group=self.process_group) + sd = dummy_tensor # type: ignore + sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore if should_collect_state: - assert isinstance(obj_lst[0], dict), f"{rank}, {self.rank} {all_states}" - all_states.append(recursive_copy_to_device(obj_lst[0], non_blocking=False, device=torch.device("cpu"))) + assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict" + all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu"))) return all_states diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index e3f3d84e2..7e74acb7f 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -1,5 +1,4 @@ import functools -import unittest from parameterized import parameterized import torch @@ -30,9 +29,12 @@ def assert_equal(a, b): assert a == b, f"{a} != {b}" +from time import time + + class TestOptimizerUtils(DistributedTest): @parameterized.expand( - [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False],], + [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]], name_func=rename_test, ) def test_consolidate_optimizer(self, optim_fn, transformer): @@ -41,7 +43,7 @@ def test_consolidate_optimizer(self, optim_fn, transformer): self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer ) - spawn_and_init(test_fn, world_sizes=[1, min(torch.cuda.device_count(), 4)]) + spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)]) @classmethod def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): @@ -78,7 +80,12 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim unwrapped_sd = optim_unwrapped.state_dict() # first_key = unwrapped_sd['state'][0].keys() + tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) + duration = time() - tstart + # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise + assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" + if fsdp.rank > 0: return From 93c0857bab5f9b14cbf478f1e82fd1a919f2e540 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 10:24:43 -0400 Subject: [PATCH 22/31] Move most dictionary manipulation to fsdp_optim_utils.py --- .../nn/data_parallel/fsdp_optim_utils.py | 151 ++++++++++++++++++ .../fully_sharded_data_parallel.py | 151 +----------------- 2 files changed, 158 insertions(+), 144 deletions(-) create mode 100644 fairscale/nn/data_parallel/fsdp_optim_utils.py diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py new file mode 100644 index 000000000..d4185b8a6 --- /dev/null +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""These files are used by fsdp to help consolidate and shard optimizer states.""" +import copy +from typing import Dict, Generator, List, Tuple + +import torch + + +# This function helps shard an +def flatten_optim_state_dict(sd: Dict) -> Dict: + """Called by FSDP.get_shard_from_optim_state_dict""" + param_id_map = sd["param_id_map"] + num_local_params = len(set(param_id_map.values())) + if sd["state"]: + new_state: Dict = {local_id: {} for local_id in range(num_local_params)} + else: + new_state = {} + constant_state = {} + + # assumes sd sorted + for expanded_pid, buffers in sd["state"].items(): + consolidated_pid = param_id_map[expanded_pid] + for buffer_name, p in buffers.items(): + if torch.is_tensor(p): + if buffer_name not in new_state[consolidated_pid]: + new_state[consolidated_pid][buffer_name] = [] + new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) + else: + assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" + constant_state[buffer_name] = p + # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check + + for consolidated_pid, state in new_state.items(): + for buffer_name, tensors in state.items(): + new_state[consolidated_pid][buffer_name] = torch.cat(tensors) + new_state[consolidated_pid].update(constant_state) + new_sd = {"state": new_state, "param_groups": sd["param_groups"]} + + for pg_id, _ in enumerate(sd["param_groups"]): + # TODO: this list could be huge. Can we avoid materializing? + new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) + + return new_sd + + +# All functions help saving the list of optimizer states, one from each rank +# build_unflat_state_dict is the interface used by FSDP +def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: + constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. + for k, v in combined_state[param_id].items(): + + if torch.is_tensor(v[0]): + continue + elif len(set(v)) == 1: + constant_state[k] = v[0] + else: + raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") + return constant_state + + +def _combine_tensor_optim_state(states: List[Dict]) -> Dict[int, Dict]: + combined_state = states[0] + for param_id in combined_state: + combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} + if len(states) == 1: + return combined_state + + for rank, s in enumerate(states[1:]): + for param_id, param_state in s.items(): + for k, tensor in param_state.items(): + combined_state[param_id][k].append(tensor) + return combined_state + + +def _unflatten_optim_state( + combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]], +) -> Tuple[Dict[int, Dict], Dict[int, int]]: + local_to_global_param_id: Dict[int, List[int]] = {} + next_global_id = 0 # gets incremented + unflat_state = {} + pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} + + # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" + # we check that these are identical across workers and then take the first + constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] + + # loop over parameters in state. + # Tensor state will be padded, concatenated, and then restored to their original + # shape with FlattenParamsWrapper.get_views + # get_views multiple tensors, each of which is a new parameter with a new "global" id. + for local_id in combined_state: + local_to_global_param_id[local_id] = [] + # undo the work of shard_parameters + for k, v in combined_state[local_id].items(): + if k in constant_state[local_id]: + continue + assert isinstance(v, list), f"got {k}: {v} for {local_id}" + v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] + flat_buffer = torch.cat(v_unpad) + param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore + for i, param_view in enumerate(param_views): + if i == len(local_to_global_param_id[local_id]): # make a new ID + local_to_global_param_id[local_id].append(next_global_id) + next_global_id += 1 + global_id = local_to_global_param_id[local_id][i] + if global_id not in unflat_state: + unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) + + assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" + unflat_state[global_id][k] = param_view + + global_to_local_id = { + new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids + } + + return unflat_state, global_to_local_id + + +def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: + world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] + assert all(len(s) == len(instance_list) for s in world_pad_info) + assert all(len(s[0]) == 1 for s in world_pad_info) + param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) + assert len(param_groups) == 1 + # combined_state refers to tensor values in sd[state][param_id]. + # Here we just aggregate them into a dictionary of lists (from a list of dictionaries) + combined_state = _combine_tensor_optim_state([x["state"] for x in world_optim_states]) + # cleanup all_optimizer_states_list + del world_optim_states + new_state_dict = {"state": {}, "param_groups": param_groups} + # local ids are in the current state, global_ids will be in returned state. + unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) + num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore + new_state_dict["param_groups"][0]["params"] = list(range(num_params)) + new_state_dict["param_id_map"] = global_to_local_id + # Make sure that the parameters are sorted in the state, as expected for a pytorch dict + new_state_dict["state"] = dict(sorted(unflat_state.items())) + return new_state_dict + + +def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: + n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) + msg = ( + f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " + f"there were {n_local_params_in_opt}" + ) + stateless = len(full_optim_state_dict["state"]) == 0 + assert stateless or (n_instances == n_local_params_in_opt), msg diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index e98430fdd..f591ae064 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -19,6 +19,7 @@ from torch.nn import Parameter import torch.nn.functional as F +import fairscale.nn.data_parallel.fsdp_optim_utils as ou from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device @@ -1352,7 +1353,6 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None traceback.print_stack() raise ValueError(msg) - # Optim State dict functions def _consolidate_optim_state_dict( self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None ) -> List[Dict]: @@ -1401,106 +1401,10 @@ def gather_full_optim_state_dict( world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank) if self.rank != recipient_rank and recipient_rank is not None: return None - # Unify the shard states by concatenating tensors and unflattening params - world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] - instance_list: List[nn.Module] = self._fsdp_instances - assert all(len(s) == len(instance_list) for s in world_pad_info) - assert all(len(s[0]) == 1 for s in world_pad_info) - - param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) - assert len(param_groups) == 1 - - # combined_state refers to tensor values in sd[state][param_id]. - # Here we just aggregate them into a dictionary of lists (from a list of dictionaries) - combined_state = self._combine_tensor_optim_state([x["state"] for x in world_optim_states], self.world_size) - # cleanup all_optimizer_states_list - del world_optim_states - - new_state_dict = {"state": {}, "param_groups": param_groups} - - # local ids are in the current state, global_ids will be in returned state. - unflat_state, global_to_local_id = self._unflatten_optim_state(combined_state, instance_list, world_pad_info) - - num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore - new_state_dict["param_groups"][0]["params"] = list(range(num_params)) - - new_state_dict["param_id_map"] = global_to_local_id - # Make sure that the parameters are sorted in the state, as expected for a pytorch dict - new_state_dict["state"] = dict(sorted(unflat_state.items())) + new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) return new_state_dict - @staticmethod - def _unflatten_optim_state( - combined_state: Dict[int, Dict], instance_list: List[nn.Module], world_pad_info: List[List[List[int]]], - ) -> Tuple[Dict[int, Dict], Dict[int, int]]: - local_to_global_param_id: Dict[int, List[int]] = {} - next_global_id = 0 # gets incremented - unflat_state = {} - pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} - - # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" - # we check that these are identical across workers and then take the first - constant_state = [FullyShardedDataParallel._extract_constant_state(combined_state, id) for id in combined_state] - - # loop over parameters in state. - # Tensor state will be padded, concatenated, and then restored to their original - # shape with FlattenParamsWrapper.get_views - # get_views multiple tensors, each of which is a new parameter with a new "global" id. - for local_id in combined_state: - local_to_global_param_id[local_id] = [] - # undo the work of shard_parameters - for k, v in combined_state[local_id].items(): - if k in constant_state[local_id]: - continue - assert isinstance(v, list), f"got {k}: {v} for {local_id}" - v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] - flat_buffer = torch.cat(v_unpad) - param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore - for i, param_view in enumerate(param_views): - if i == len(local_to_global_param_id[local_id]): # make a new ID - local_to_global_param_id[local_id].append(next_global_id) - next_global_id += 1 - global_id = local_to_global_param_id[local_id][i] - if global_id not in unflat_state: - unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) - - assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" - unflat_state[global_id][k] = param_view - - global_to_local_id = { - new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids - } - - return unflat_state, global_to_local_id - - @staticmethod - def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]: - combined_state = states[0] - for param_id in combined_state: - combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} - if world_size == 1: - return combined_state - - for rank, s in enumerate(states[1:]): - for param_id, param_state in s.items(): - for k, tensor in param_state.items(): - combined_state[param_id][k].append(tensor) - return combined_state - - @staticmethod - def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: - constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. - for k, v in combined_state[param_id].items(): - - if torch.is_tensor(v[0]): - continue - elif len(set(v)) == 1: - constant_state[k] = v[0] - else: - raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") - return constant_state - @property def _fsdp_instances(self) -> List[nn.Module]: """Returns all fsdp modules in self.modules() including self.""" @@ -1509,64 +1413,23 @@ def _fsdp_instances(self) -> List[nn.Module]: def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: """Get the portion of the optimizer state dict associated with the shard""" # Assert nesting is the same as it was at save time - n_instances = len(self._fsdp_instances) - n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) - msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}" - stateless = len(full_optim_state_dict["state"]) == 0 - assert stateless or (n_instances == n_local_params_in_opt), msg - - stateless = len(full_optim_state_dict["state"]) == 0 instance_list = self._fsdp_instances + ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) if self.flatten_parameters: - full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict) - assert stateless or len(full_optim_state_dict["state"]) == len(instance_list) + full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) + assert len(full_optim_state_dict["state"]) in (0, len(instance_list)) - # get the portion of dict associated with the shard + # get the portion of dict associated with the shard, in place for id, s in full_optim_state_dict["state"].items(): for k, v in s.items(): if torch.is_tensor(v): v_shard, _ = self._get_shard(v) else: - v_shard = v # dont partition entries that are not tensors + v_shard = v # dont shard entries that are not tensors full_optim_state_dict["state"][id][k] = v_shard return full_optim_state_dict - @staticmethod - def _flatten_optim_state_dict(sd: Dict) -> Dict: - param_id_map = sd["param_id_map"] - num_local_params = len(set(param_id_map.values())) - if sd["state"]: - new_state: Dict = {local_id: {} for local_id in range(num_local_params)} - else: - new_state = {} - constant_state = {} - - # assumes sd sorted - for expanded_pid, buffers in sd["state"].items(): - consolidated_pid = param_id_map[expanded_pid] - for buffer_name, p in buffers.items(): - if torch.is_tensor(p): - if buffer_name not in new_state[consolidated_pid]: - new_state[consolidated_pid][buffer_name] = [] - new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) - else: - assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" - constant_state[buffer_name] = p - # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check - - for consolidated_pid, state in new_state.items(): - for buffer_name, tensors in state.items(): - new_state[consolidated_pid][buffer_name] = torch.cat(tensors) - new_state[consolidated_pid].update(constant_state) - new_sd = {"state": new_state, "param_groups": sd["param_groups"]} - - for pg_id, _ in enumerate(sd["param_groups"]): - # TODO: this list could be huge. Can we avoid materializing? - new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) - - return new_sd - @torch.no_grad() def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: From c93d1db313891a45a652acb3428538442a6bcd65 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 15:14:02 -0400 Subject: [PATCH 23/31] passing --- .../nn/data_parallel/fsdp_optim_utils.py | 58 ++++++++++--------- .../test_fsdp_optimizer_utils.py | 2 +- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index d4185b8a6..d76bd4dfe 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -2,14 +2,14 @@ # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -"""These files are used by fsdp to help consolidate and shard optimizer states.""" +"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" import copy from typing import Dict, Generator, List, Tuple import torch -# This function helps shard an +# This function helps shard def flatten_optim_state_dict(sd: Dict) -> Dict: """Called by FSDP.get_shard_from_optim_state_dict""" param_id_map = sd["param_id_map"] @@ -29,9 +29,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: new_state[consolidated_pid][buffer_name] = [] new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) else: - assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" constant_state[buffer_name] = p - # TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check for consolidated_pid, state in new_state.items(): for buffer_name, tensors in state.items(): @@ -45,8 +43,18 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: return new_sd +def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: + n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) + msg = ( + f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " + f"there were {n_local_params_in_opt}" + ) + stateless = len(full_optim_state_dict["state"]) == 0 + assert stateless or (n_instances == n_local_params_in_opt), msg + -# All functions help saving the list of optimizer states, one from each rank + +# All functions below here help saving the list of optimizer states, one from each rank # build_unflat_state_dict is the interface used by FSDP def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. @@ -61,7 +69,7 @@ def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id return constant_state -def _combine_tensor_optim_state(states: List[Dict]) -> Dict[int, Dict]: +def _combine_state(states: List[Dict]) -> Dict[int, Dict]: combined_state = states[0] for param_id in combined_state: combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} @@ -87,6 +95,17 @@ def _unflatten_optim_state( # we check that these are identical across workers and then take the first constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] + # If the constant state is the same as the combined state, copy it N times, no unflattening needed. + if constant_state[0].keys() == combined_state[0].keys(): + num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore + unflat_state = {i: constant_state[0] for i in range(sum(num_unflat_params))} + global_to_local_id = {} + for local_id, num_unflat in enumerate(num_unflat_params): + for _ in range(num_unflat): + global_to_local_id[next_global_id] = local_id + next_global_id += 1 + return unflat_state, global_to_local_id + # loop over parameters in state. # Tensor state will be padded, concatenated, and then restored to their original # shape with FlattenParamsWrapper.get_views @@ -125,27 +144,14 @@ def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_st assert all(len(s[0]) == 1 for s in world_pad_info) param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) assert len(param_groups) == 1 - # combined_state refers to tensor values in sd[state][param_id]. - # Here we just aggregate them into a dictionary of lists (from a list of dictionaries) - combined_state = _combine_tensor_optim_state([x["state"] for x in world_optim_states]) - # cleanup all_optimizer_states_list + + # Aggregate from a list of dictionaries to a dictionary of lists + combined_state = _combine_state([x["state"] for x in world_optim_states]) del world_optim_states - new_state_dict = {"state": {}, "param_groups": param_groups} + # local ids are in the current state, global_ids will be in returned state. unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore - new_state_dict["param_groups"][0]["params"] = list(range(num_params)) - new_state_dict["param_id_map"] = global_to_local_id - # Make sure that the parameters are sorted in the state, as expected for a pytorch dict - new_state_dict["state"] = dict(sorted(unflat_state.items())) - return new_state_dict - - -def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: - n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) - msg = ( - f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " - f"there were {n_local_params_in_opt}" - ) - stateless = len(full_optim_state_dict["state"]) == 0 - assert stateless or (n_instances == n_local_params_in_opt), msg + return {"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted + "param_id_map": global_to_local_id, + "param_groups": [{'params': list(range(num_params))}]} diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 7e74acb7f..6864955f4 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -22,7 +22,7 @@ def first_tensor_shape(dct): for k, v in dct.items(): if torch.is_tensor(v): return v.numel() - raise ValueError("found no tensors") + return 0 def assert_equal(a, b): From 13b053755a618d39a1c205a527bfb3bdde2e2480 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 15:14:54 -0400 Subject: [PATCH 24/31] style --- fairscale/nn/data_parallel/fsdp_optim_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index d76bd4dfe..578c5b719 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -43,6 +43,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: return new_sd + def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) msg = ( @@ -53,7 +54,6 @@ def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: assert stateless or (n_instances == n_local_params_in_opt), msg - # All functions below here help saving the list of optimizer states, one from each rank # build_unflat_state_dict is the interface used by FSDP def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: @@ -152,6 +152,8 @@ def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_st # local ids are in the current state, global_ids will be in returned state. unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore - return {"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted - "param_id_map": global_to_local_id, - "param_groups": [{'params': list(range(num_params))}]} + return { + "state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted + "param_id_map": global_to_local_id, + "param_groups": [{"params": list(range(num_params))}], + } From 9d3dfb77dabd2d77bcffb63c702f1384fd9547ad Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 15:44:08 -0400 Subject: [PATCH 25/31] passing --- .../nn/data_parallel/fsdp_optim_utils.py | 47 ++++++++----------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 578c5b719..18734136e 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -86,32 +86,36 @@ def _combine_state(states: List[Dict]) -> Dict[int, Dict]: def _unflatten_optim_state( combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]], ) -> Tuple[Dict[int, Dict], Dict[int, int]]: - local_to_global_param_id: Dict[int, List[int]] = {} + # local ids are the keys in the current state (combined_state), (usually fewer) + # global ids will be the keys in the unflattened state next_global_id = 0 # gets incremented - unflat_state = {} pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} + local_ids = [id for id in sorted(combined_state.keys())] # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" # we check that these are identical across workers and then take the first constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] + # local corresponds to flattened, global corresponds to unflattened + num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore + global_to_local_id = {} + for local_id, num_unflat in enumerate(num_unflat_params): + for _ in range(num_unflat): + global_to_local_id[next_global_id] = local_id + next_global_id += 1 + # If the constant state is the same as the combined state, copy it N times, no unflattening needed. + unflat_state = {i: copy.deepcopy(constant_state[0]) for i in range(sum(num_unflat_params))} if constant_state[0].keys() == combined_state[0].keys(): - num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore - unflat_state = {i: constant_state[0] for i in range(sum(num_unflat_params))} - global_to_local_id = {} - for local_id, num_unflat in enumerate(num_unflat_params): - for _ in range(num_unflat): - global_to_local_id[next_global_id] = local_id - next_global_id += 1 return unflat_state, global_to_local_id + local_to_global: Dict[int, List] = {i: [] for i in local_ids} + for g, l in global_to_local_id.items(): + local_to_global[l].append(g) # loop over parameters in state. - # Tensor state will be padded, concatenated, and then restored to their original - # shape with FlattenParamsWrapper.get_views - # get_views multiple tensors, each of which is a new parameter with a new "global" id. - for local_id in combined_state: - local_to_global_param_id[local_id] = [] + # Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views + # get_views returns multiple tensors, each of which is a new parameter with a new "global" id. + for local_id in local_ids: # undo the work of shard_parameters for k, v in combined_state[local_id].items(): if k in constant_state[local_id]: @@ -120,21 +124,10 @@ def _unflatten_optim_state( v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] flat_buffer = torch.cat(v_unpad) param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore - for i, param_view in enumerate(param_views): - if i == len(local_to_global_param_id[local_id]): # make a new ID - local_to_global_param_id[local_id].append(next_global_id) - next_global_id += 1 - global_id = local_to_global_param_id[local_id][i] - if global_id not in unflat_state: - unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) - - assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" + for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): + assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" unflat_state[global_id][k] = param_view - global_to_local_id = { - new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids - } - return unflat_state, global_to_local_id From 9f619b2461a67ba140ad4bfa73f03c1497d8055f Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 16:15:47 -0400 Subject: [PATCH 26/31] stateless fix --- fairscale/nn/data_parallel/fsdp_optim_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 18734136e..77f2e375e 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -103,6 +103,8 @@ def _unflatten_optim_state( for _ in range(num_unflat): global_to_local_id[next_global_id] = local_id next_global_id += 1 + if not combined_state: + return {}, global_to_local_id # If the constant state is the same as the combined state, copy it N times, no unflattening needed. unflat_state = {i: copy.deepcopy(constant_state[0]) for i in range(sum(num_unflat_params))} From a4778b7141510165d526b3d0f868ab73feb6bca9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Mar 2021 17:22:36 -0400 Subject: [PATCH 27/31] Min comments --- .../nn/data_parallel/fsdp_optim_utils.py | 4 ++-- .../fully_sharded_data_parallel.py | 22 +++++++++++-------- .../test_fsdp_optimizer_utils.py | 8 ++++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 77f2e375e..afb124cfe 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -9,9 +9,9 @@ import torch -# This function helps shard +# This function helps shard a full optimizer state dict def flatten_optim_state_dict(sd: Dict) -> Dict: - """Called by FSDP.get_shard_from_optim_state_dict""" + """Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)""" param_id_map = sd["param_id_map"] num_local_params = len(set(param_id_map.values())) if sd["state"]: diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index f591ae064..de6a23143 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -187,7 +187,7 @@ def __init__( self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.bucket_cap_mb = bucket_cap_mb - self.num_padded: List[int] = [] + self.numel_padded_per_param: List[int] = [] self.compute_device = compute_device if self.fp32_reduce_scatter and not self.mixed_precision: @@ -415,7 +415,7 @@ def _shard_parameters_(self) -> None: allocate less memory for optimizer state, avoiding redundancy across data parallel workers. """ - self.num_padded = [] + self.numel_padded_per_param = [] for p in self.params: assert not hasattr(p, "_is_sharded") assert p.is_floating_point() @@ -427,16 +427,16 @@ def _shard_parameters_(self) -> None: p._orig_size = p.data.size() if not p._is_sharded: - self.num_padded.append(0) + self.numel_padded_per_param.append(0) continue p._is_sharded = True # Replace p.data with the relevant shard. orig_data = p.data p.data, num_padded = self._get_shard(p.data) - self.num_padded.append(num_padded) + self.numel_padded_per_param.append(num_padded) free_storage_(orig_data) - assert len(self.num_padded) == len(self.params) + assert len(self.numel_padded_per_param) == len(self.params) def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: """Return the local shard of a full tensor.""" @@ -1359,9 +1359,14 @@ def _consolidate_optim_state_dict( """Update the consolidated state_dict list, one per rank. Args: + recipient_rank (int): on which rank to materialize the full state dict. None is a special value, which means that all ranks should have the state + Returns: + all_states (list[dict]) the optimizer state from each rank + + .. warning: This needs to be called on all replicas""" self._lazy_init() # NOTE(SS): we do not support param groups yet, as they seem to break FSDP @@ -1373,7 +1378,7 @@ def _consolidate_optim_state_dict( for rank in range(self.world_size): if rank == self.rank: sd = optim.state_dict() - sd["num_padded"] = [m.num_padded for m in self._fsdp_instances] + sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances] else: sd = dummy_tensor # type: ignore sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore @@ -1391,9 +1396,8 @@ def gather_full_optim_state_dict( Returns: a dict with two entries - * state - a dict holding current optimization state. Its content - differs between optimizer classes. - * param_groups - a dict containing all parameter groups + * state - a dict holding gathered optimization state, 1 entry per unflat parameter + * param_groups - a dict containing the 1 parameter group """ if not self.flatten_parameters: diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 6864955f4..2fdfd8964 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -1,4 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. import functools +from time import time from parameterized import parameterized import torch @@ -29,9 +34,6 @@ def assert_equal(a, b): assert a == b, f"{a} != {b}" -from time import time - - class TestOptimizerUtils(DistributedTest): @parameterized.expand( [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]], From c77a9f7f29be2c5f4e18fe69c06d0fb8964602dc Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Mar 2021 10:33:54 -0400 Subject: [PATCH 28/31] Min comments --- .../fully_sharded_data_parallel.py | 7 ++++++- .../data_parallel/test_fsdp_optimizer_utils.py | 18 +++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index de6a23143..32b12152d 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -19,7 +19,6 @@ from torch.nn import Parameter import torch.nn.functional as F -import fairscale.nn.data_parallel.fsdp_optim_utils as ou from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device @@ -28,6 +27,8 @@ from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ +from . import fsdp_optim_utils as ou + if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -1407,6 +1408,7 @@ def gather_full_optim_state_dict( return None # Unify the shard states by concatenating tensors and unflattening params new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) + # TODO: check if this code supports nested instances with different world size return new_state_dict @property @@ -1418,6 +1420,9 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) """Get the portion of the optimizer state dict associated with the shard""" # Assert nesting is the same as it was at save time instance_list = self._fsdp_instances + assert all( + x.world_size == self.world_size for x in instance_list + ), "all nested instances must have same world size" ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) if self.flatten_parameters: full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index 2fdfd8964..fafab8813 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -23,7 +23,7 @@ ) -def first_tensor_shape(dct): +def first_tensor_numel(dct): for k, v in dct.items(): if torch.is_tensor(v): return v.numel() @@ -62,7 +62,7 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim try: fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) - except TypeError: # AdaScale + except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) @@ -81,7 +81,6 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() - # first_key = unwrapped_sd['state'][0].keys() tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart @@ -94,8 +93,8 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( - sum([first_tensor_shape(v) for k, v in sd["state"].items()]), - sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]), ) shard_sd = fsdp.get_shard_from_optim_state_dict(sd) @@ -103,15 +102,12 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) - torch.save(shard_sd, f"new_shard_{fsdp.world_size}.pt") original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") assert_equal( - sum([first_tensor_shape(v) for k, v in shard_sd["state"].items()]), - sum([first_tensor_shape(v) for k, v in original_shard_sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), + sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]), ) - if shard_sd["state"]: - assert objects_are_equal(shard_sd["state"][0], original_shard_sd["state"][0]) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) def test_named_params_ordering(self): @@ -120,4 +116,4 @@ def test_named_params_ordering(self): model = TransformerWithSharedParams(group) named_pars = [p for n, p in model.named_parameters()] for i, p in enumerate(model.parameters()): - assert p.shape == named_pars[i].shape + assert objects_are_equal(p, named_pars[i]) From aeefe693336de1d8ff62d98d1e13f91019614d37 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Mar 2021 17:53:18 -0400 Subject: [PATCH 29/31] Apply suggestions from code review Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> --- .../nn/data_parallel/fsdp_optim_utils.py | 6 +++-- .../fully_sharded_data_parallel.py | 22 ++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index afb124cfe..34c4e71eb 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -20,7 +20,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: new_state = {} constant_state = {} - # assumes sd sorted + # Populate `new_state["state"]`. (Assuming sd is sorted) for expanded_pid, buffers in sd["state"].items(): consolidated_pid = param_id_map[expanded_pid] for buffer_name, p in buffers.items(): @@ -31,12 +31,14 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: else: constant_state[buffer_name] = p + # Now combine all tensors in each buffer using torch.cat(). for consolidated_pid, state in new_state.items(): for buffer_name, tensors in state.items(): new_state[consolidated_pid][buffer_name] = torch.cat(tensors) new_state[consolidated_pid].update(constant_state) new_sd = {"state": new_state, "param_groups": sd["param_groups"]} + # add pointers from the `params` dict. for pg_id, _ in enumerate(sd["param_groups"]): # TODO: this list could be huge. Can we avoid materializing? new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) @@ -57,7 +59,7 @@ def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: # All functions below here help saving the list of optimizer states, one from each rank # build_unflat_state_dict is the interface used by FSDP def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: - constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. + constant_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it. for k, v in combined_state[param_id].items(): if torch.is_tensor(v[0]): diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 32b12152d..70b868548 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1361,6 +1361,8 @@ def _consolidate_optim_state_dict( Args: + optim (Optimizer): an optimizer instance for this FSDP rank. Its state is + used in the consolidation. However, its state is not modified. recipient_rank (int): on which rank to materialize the full state dict. None is a special value, which means that all ranks should have the state @@ -1394,6 +1396,14 @@ def gather_full_optim_state_dict( ) -> Optional[Dict[str, Any]]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Multiple parameter groups are not yet supported. + + This should be called only on the root FSDP instance. + + Different world_size groups in nested FSDP instances is not supported. + Args: + optim (Optimizer): an optimizer instance for this FSDP rank. Its state is + used in the consolidation. However, its state is not modified. + recipient_rank (int): on which rank to materialize the full state dict. Returns: a dict with two entries @@ -1417,7 +1427,17 @@ def _fsdp_instances(self) -> List[nn.Module]: return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: - """Get the portion of the optimizer state dict associated with the shard""" + """Get the portion of the optimizer state dict associated with the shard + + This can be used to get the right sharded optimizer state to be loaded + into the sharded optimizer for this FSDP rank. + + Args: + full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint. + + Returns: + (dict): a shard of the optimizer state. + """ # Assert nesting is the same as it was at save time instance_list = self._fsdp_instances assert all( From d645337209163218562356c94082fb5c1fc32324 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Mar 2021 18:10:17 -0400 Subject: [PATCH 30/31] Min comments --- .../nn/data_parallel/fsdp_optim_utils.py | 28 +++++++++---------- .../fully_sharded_data_parallel.py | 10 +++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 34c4e71eb..4cda519f9 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -18,7 +18,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: new_state: Dict = {local_id: {} for local_id in range(num_local_params)} else: new_state = {} - constant_state = {} + non_tensor_state = {} # Populate `new_state["state"]`. (Assuming sd is sorted) for expanded_pid, buffers in sd["state"].items(): @@ -29,13 +29,13 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: new_state[consolidated_pid][buffer_name] = [] new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) else: - constant_state[buffer_name] = p + non_tensor_state[buffer_name] = p # Now combine all tensors in each buffer using torch.cat(). for consolidated_pid, state in new_state.items(): for buffer_name, tensors in state.items(): new_state[consolidated_pid][buffer_name] = torch.cat(tensors) - new_state[consolidated_pid].update(constant_state) + new_state[consolidated_pid].update(non_tensor_state) new_sd = {"state": new_state, "param_groups": sd["param_groups"]} # add pointers from the `params` dict. @@ -58,17 +58,16 @@ def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: # All functions below here help saving the list of optimizer states, one from each rank # build_unflat_state_dict is the interface used by FSDP -def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: - constant_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it. +def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: + non_tensor_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it. for k, v in combined_state[param_id].items(): - if torch.is_tensor(v[0]): continue elif len(set(v)) == 1: - constant_state[k] = v[0] + non_tensor_state[k] = v[0] else: - raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") - return constant_state + raise TypeError(f"Dont know how to consolidate optimizer param {k} with values {v}") + return non_tensor_state def _combine_state(states: List[Dict]) -> Dict[int, Dict]: @@ -94,9 +93,9 @@ def _unflatten_optim_state( pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} local_ids = [id for id in sorted(combined_state.keys())] - # constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" + # non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step". # we check that these are identical across workers and then take the first - constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] + non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state] # local corresponds to flattened, global corresponds to unflattened num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore @@ -109,8 +108,8 @@ def _unflatten_optim_state( return {}, global_to_local_id # If the constant state is the same as the combined state, copy it N times, no unflattening needed. - unflat_state = {i: copy.deepcopy(constant_state[0]) for i in range(sum(num_unflat_params))} - if constant_state[0].keys() == combined_state[0].keys(): + unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))} + if non_tensor_state[0].keys() == combined_state[0].keys(): return unflat_state, global_to_local_id local_to_global: Dict[int, List] = {i: [] for i in local_ids} @@ -122,7 +121,7 @@ def _unflatten_optim_state( for local_id in local_ids: # undo the work of shard_parameters for k, v in combined_state[local_id].items(): - if k in constant_state[local_id]: + if k in non_tensor_state[local_id]: continue assert isinstance(v, list), f"got {k}: {v} for {local_id}" v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] @@ -136,6 +135,7 @@ def _unflatten_optim_state( def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: + """Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank.""" world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] assert all(len(s) == len(instance_list) for s in world_pad_info) assert all(len(s[0]) == 1 for s in world_pad_info) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 70b868548..16862cd9c 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1396,9 +1396,9 @@ def gather_full_optim_state_dict( ) -> Optional[Dict[str, Any]]: """Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Multiple parameter groups are not yet supported. - + This should be called only on the root FSDP instance. - + Different world_size groups in nested FSDP instances is not supported. Args: optim (Optimizer): an optimizer instance for this FSDP rank. Its state is @@ -1428,13 +1428,13 @@ def _fsdp_instances(self) -> List[nn.Module]: def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: """Get the portion of the optimizer state dict associated with the shard - + This can be used to get the right sharded optimizer state to be loaded into the sharded optimizer for this FSDP rank. - + Args: full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint. - + Returns: (dict): a shard of the optimizer state. """ From 75bdd3fb74ca8d5816438fb66bf4698cbc15d218 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Mar 2021 20:29:44 -0400 Subject: [PATCH 31/31] also test param groups --- fairscale/nn/data_parallel/fsdp_optim_utils.py | 3 ++- tests/nn/data_parallel/test_fsdp_optimizer_utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fairscale/nn/data_parallel/fsdp_optim_utils.py b/fairscale/nn/data_parallel/fsdp_optim_utils.py index 4cda519f9..841d7a8af 100644 --- a/fairscale/nn/data_parallel/fsdp_optim_utils.py +++ b/fairscale/nn/data_parallel/fsdp_optim_utils.py @@ -149,8 +149,9 @@ def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_st # local ids are in the current state, global_ids will be in returned state. unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore + param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential return { "state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted "param_id_map": global_to_local_id, - "param_groups": [{"params": list(range(num_params))}], + "param_groups": param_groups, } diff --git a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py index fafab8813..806fc7868 100644 --- a/tests/nn/data_parallel/test_fsdp_optimizer_utils.py +++ b/tests/nn/data_parallel/test_fsdp_optimizer_utils.py @@ -108,7 +108,7 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]), ) - assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) + assert objects_are_equal(shard_sd, original_shard_sd) def test_named_params_ordering(self): """Test assumption of consolidate_optimizer_state_dict"""