Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][feature] optimizer state dict save and load #537

Merged
merged 33 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ee088bb
consolidate works
sshleifer Mar 19, 2021
ad7df24
cat
sshleifer Mar 19, 2021
ed7526a
Unpad before cat
sshleifer Mar 19, 2021
ed75c59
update params list
sshleifer Mar 19, 2021
44158f7
simple case passing
sshleifer Mar 19, 2021
f82f3b6
found other bug
sshleifer Mar 19, 2021
1022e1e
Broken tests for other optimizers
sshleifer Mar 19, 2021
75119c2
boom boom
sshleifer Mar 20, 2021
89947a4
Merge branch 'master' into fsdp-gather-optimizer
sshleifer Mar 20, 2021
8dcf0a8
remove oss changes
sshleifer Mar 20, 2021
2caf928
passing besides mypy
sshleifer Mar 20, 2021
0b888fd
Smaller delta
sshleifer Mar 20, 2021
a2aacd0
Nesting works
sshleifer Mar 21, 2021
0fc045d
passing, lint attempt
sshleifer Mar 21, 2021
d859734
merge master
sshleifer Mar 21, 2021
3635277
update test list
sshleifer Mar 21, 2021
dbb426f
mypy
sshleifer Mar 22, 2021
f537632
Simpler consolidate_optim_state_dict
sshleifer Mar 22, 2021
a04b406
slightly cleaner
sshleifer Mar 22, 2021
e5e91df
Simplified signature, helper fn for unflattening
sshleifer Mar 22, 2021
ea9d4b5
add todo
sshleifer Mar 22, 2021
47e7cba
Give CI more time to show me a traceback
sshleifer Mar 23, 2021
6cebcec
Fix broadcast_object regression
sshleifer Mar 23, 2021
93c0857
Move most dictionary manipulation to fsdp_optim_utils.py
sshleifer Mar 23, 2021
c93d1db
passing
sshleifer Mar 23, 2021
13b0537
style
sshleifer Mar 23, 2021
9d3dfb7
passing
sshleifer Mar 23, 2021
9f619b2
stateless fix
sshleifer Mar 23, 2021
a4778b7
Min comments
sshleifer Mar 23, 2021
c77a9f7
Min comments
sshleifer Mar 24, 2021
aeefe69
Apply suggestions from code review
sshleifer Mar 24, 2021
d645337
Min comments
sshleifer Mar 24, 2021
75bdd3f
also test param groups
sshleifer Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Dict, Generator, List, Tuple

import torch


# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
constant_state = {}
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

# assumes sd sorted
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
constant_state[buffer_name] = p

for consolidated_pid, state in new_state.items():
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(constant_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}

for pg_id, _ in enumerate(sd["param_groups"]):
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# TODO: this list could be huge. Can we avoid materializing?
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd


def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = (
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved "
f"there were {n_local_params_in_opt}"
)
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg


# All functions below here help saving the list of optimizer states, one from each rank
# build_unflat_state_dict is the interface used by FSDP
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for k, v in combined_state[param_id].items():

if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
constant_state[k] = v[0]
else:
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}")
return constant_state


def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state


def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}
local_ids = [id for id in sorted(combined_state.keys())]

# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step"
# we check that these are identical across workers and then take the first
constant_state = [_extract_constant_state(combined_state, id) for id in combined_state]

# local corresponds to flattened, global corresponds to unflattened
num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_unflat_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
next_global_id += 1
if not combined_state:
return {}, global_to_local_id

# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(constant_state[0]) for i in range(sum(num_unflat_params))}
if constant_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id

local_to_global: Dict[int, List] = {i: [] for i in local_ids}
for g, l in global_to_local_id.items():
local_to_global[l].append(g)
# loop over parameters in state.
# Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views
# get_views returns multiple tensors, each of which is a new parameter with a new "global" id.
for local_id in local_ids:
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in constant_state[local_id]:
continue
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view

return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring?

world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
del world_optim_states

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": [{"params": list(range(num_params))}],
}
109 changes: 101 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_

from . import fsdp_optim_utils as ou

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

Expand Down Expand Up @@ -88,8 +90,8 @@ class FullyShardedDataParallel(nn.Module):
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the doc here!

with enable_wrap(**fsdp_params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
Expand Down Expand Up @@ -185,6 +187,8 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb

self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
Expand Down Expand Up @@ -412,6 +416,7 @@ def _shard_parameters_(self) -> None:
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.numel_padded_per_param = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
Expand All @@ -423,16 +428,19 @@ def _shard_parameters_(self) -> None:
p._orig_size = p.data.size()

if not p._is_sharded:
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True

# Replace p.data with the relevant shard.
orig_data = p.data
p.data = self._get_shard(p.data)
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)

sshleifer marked this conversation as resolved.
Show resolved Hide resolved
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
Expand All @@ -445,7 +453,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
return shard, num_to_pad

def extra_repr(self) -> str:
return (
Expand Down Expand Up @@ -684,7 +692,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
Expand Down Expand Up @@ -1346,6 +1354,91 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be called only on the root FSDP instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, more specifically it should be called on the instance that was the argument to optimizer(model.parameters(). Are there other cases?


Args:

recipient_rank (int): on which rank to materialize the full state dict.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
None is a special value, which means that all ranks should have the state
min-xu-ai marked this conversation as resolved.
Show resolved Hide resolved

Returns:
all_states (list[dict]) the optimizer state from each rank


.. warning: This needs to be called on all replicas"""
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added TODO in the caller

if rank == self.rank:
sd = optim.state_dict()
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))

return all_states

def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

Returns:
a dict with two entries
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group

"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
return new_state_dict

@property
def _fsdp_instances(self) -> List[nn.Module]:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard"""
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this assume all inner FSDP instances also have flatten == True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will assert

full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))

# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict


@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
Expand Down
6 changes: 3 additions & 3 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
# register the views as plain attributes
self._unflatten_params_as_views()

def _get_param_views(self, flat_param: Tensor) -> Generator:
def get_param_views(self, flat_param: Tensor) -> Generator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is becoming an public method, can you please:

  1. add docstring with proper doc
  2. assert flat_param is valid before using it?

return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))

def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param

ps = self._get_param_views(flat_param)
ps = self.get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
Expand All @@ -144,7 +144,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:

def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views(self.flat_param)
ps = self.get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
Expand Down
Loading