-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][feature] optimizer state dict save and load #537
Changes from 16 commits
ee088bb
ad7df24
ed7526a
ed75c59
44158f7
f82f3b6
1022e1e
75119c2
89947a4
8dcf0a8
2caf928
0b888fd
a2aacd0
0fc045d
d859734
3635277
dbb426f
f537632
a04b406
e5e91df
ea9d4b5
47e7cba
6cebcec
93c0857
c93d1db
13b0537
9d3dfb7
9f619b2
a4778b7
c77a9f7
aeefe69
d645337
75bdd3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
|
||
from fairscale.nn.misc import FlattenParamsWrapper | ||
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap | ||
from fairscale.optim.utils import calc_grad_norm | ||
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device | ||
from fairscale.utils.containers import apply_to_tensors | ||
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group | ||
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer | ||
|
@@ -88,8 +88,8 @@ class FullyShardedDataParallel(nn.Module): | |
import torch | ||
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
fsdp_params = dict(mixed_precision=True, flatten_parameters=True) | ||
with enable_wrap(wrapper_cls=FSDP, **fsdp_params): | ||
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) | ||
with enable_wrap(**fsdp_params): | ||
# Wraps layer in FSDP by default if within context | ||
self.l1 = wrap(torch.nn.Linear(5, 5)) | ||
assert isinstance(self.l1, FSDP) | ||
|
@@ -185,6 +185,9 @@ def __init__( | |
self.buffer_dtype = buffer_dtype or self.compute_dtype | ||
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu | ||
self.bucket_cap_mb = bucket_cap_mb | ||
|
||
self.num_padded: List[int] = [] | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._all_optimizer_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state | ||
self.compute_device = compute_device | ||
|
||
if self.fp32_reduce_scatter and not self.mixed_precision: | ||
|
@@ -412,6 +415,7 @@ def _shard_parameters_(self) -> None: | |
allocate less memory for optimizer state, avoiding redundancy across | ||
data parallel workers. | ||
""" | ||
self.num_padded = [] | ||
for p in self.params: | ||
assert not hasattr(p, "_is_sharded") | ||
assert p.is_floating_point() | ||
|
@@ -423,16 +427,18 @@ def _shard_parameters_(self) -> None: | |
p._orig_size = p.data.size() | ||
|
||
if not p._is_sharded: | ||
self.num_padded.append(0) | ||
continue | ||
p._is_sharded = True | ||
|
||
# Replace p.data with the relevant shard. | ||
orig_data = p.data | ||
p.data = self._get_shard(p.data) | ||
p.data, num_padded = self._get_shard(p.data) | ||
self.num_padded.append(num_padded) | ||
free_storage_(orig_data) | ||
|
||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: | ||
"""Return the local shard of a given full tensor.""" | ||
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: | ||
"""Return the local shard of a full tensor.""" | ||
# Shard using torch.chunk to match all-gather/reduce-scatter. | ||
chunks = list(torch.flatten(tensor).chunk(self.world_size)) | ||
while len(chunks) < self.world_size: | ||
|
@@ -445,7 +451,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: | |
shard = chunks[self.rank].clone() | ||
if num_to_pad > 0: | ||
shard = F.pad(shard, [0, num_to_pad]) | ||
return shard | ||
return shard, num_to_pad | ||
|
||
def extra_repr(self) -> str: | ||
return ( | ||
|
@@ -684,7 +690,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge | |
if not volatile: | ||
# Copy any changes made to the full params back into | ||
# the corresponding local shards. | ||
local_shard = self._get_shard(full_tensor) | ||
local_shard, _ = self._get_shard(full_tensor) | ||
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) | ||
if safe_to_free: | ||
free_storage_(full_tensor) | ||
|
@@ -1346,6 +1352,244 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None | |
traceback.print_stack() | ||
raise ValueError(msg) | ||
|
||
# Optim State dict functions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered moving these to a separate |
||
|
||
def consolidate_optim_state_dict(self, optim: torch.optim.Optimizer, recipient_rank: int = 0) -> None: | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Update the consolidated state_dict list, one per rank. | ||
|
||
Arguments: | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
recipient_rank (int): on which rank to materialize the full state dict. | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
-1 is a special value, which means that all ranks should have the state | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. warning: This needs to be called on all replicas""" | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_default_device = torch.device("cuda") | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP | ||
|
||
# Pull the sharded state from all the other replicas | ||
# Store all the states in order, rank by rank | ||
should_collect_state = self.rank == recipient_rank or recipient_rank == -1 | ||
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 | ||
_all_optimizer_states: List[Dict[str, Any]] = [] | ||
for rank in range(self.world_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added TODO in the caller |
||
if rank == self.rank: | ||
sd = optim.state_dict() | ||
sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)] | ||
if should_collect_state: | ||
_all_optimizer_states.append( | ||
recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# Sync with other replicas | ||
state_to_share = ( | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device) | ||
) | ||
broadcast_object( | ||
state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device, | ||
) | ||
else: | ||
# Fetch the optim state from the other replicas | ||
replica_state = broadcast_object( | ||
torch.tensor([0], dtype=torch.uint8, device=_default_device), | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
src_rank=rank, | ||
group=self.process_group, | ||
dist_device=_default_device, | ||
) | ||
|
||
if should_collect_state: | ||
_all_optimizer_states.append( | ||
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be rearranged to remove some duplication? Something like: for rank in range(self.world_size):
if rank == self.rank:
state = optim.state_dict()
sd["num_padded"] = ...
state = broadcast_object(state, src_rank=rank, ...)
else:
state = broadcast_object(None, src_rank=rank, ...)
if should_collect_state:
_all_optimizer_states.append(recursive_copy_to_device(state, device=torch.device("cpu")) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just copy pasted this func from OSS. I think the reason for the extra append is to save useless communication from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have the simplified implem working with |
||
|
||
self._all_optimizer_states = _all_optimizer_states | ||
|
||
def gather_full_optim_state_dict(self) -> Dict[str, Any]: | ||
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the | ||
sharded properties are not exposed. | ||
|
||
|
||
Arguments: | ||
all_ranks (bool): materialize the state on all ranks. In that case, `.state_dict()` needs to be called on | ||
|
||
Returns: | ||
a dict with two entries | ||
* state - a dict holding current optimization state. Its content | ||
differs between optimizer classes. | ||
|
||
* param_groups - a dict containing all parameter groups | ||
|
||
.. warning: | ||
Returning the global state is limited to the replica which was responsible for the consolidation, | ||
if `all_ranks` was not set to `True`. In that case, the state may also not be up to date, | ||
depending on when `consolidate_state_dict` was last called. | ||
""" | ||
if not self.flatten_parameters: | ||
raise NotImplementedError("optim state dict requires flatten_parameters=True") | ||
if len(self._all_optimizer_states) == 0: | ||
raise ValueError("You must call consolidate_optim_state_dict before gather_full_optim_state_dict") | ||
|
||
# Unify the shard states by concatenating tensors and unflattening params | ||
world_pad_info: List[List[int]] = [s.pop("num_padded") for s in self._all_optimizer_states] | ||
|
||
# constant_state refers to entries in sd[state][param_id] that are not tensors | ||
|
||
param_groups = copy.deepcopy(self._all_optimizer_states[0]["param_groups"]) | ||
combined_state = self._combine_tensor_optim_state([x["state"] for x in self._all_optimizer_states]) | ||
constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state] | ||
|
||
# cleanup all_optimizer_states_list | ||
self._all_optimizer_states = [] | ||
|
||
new_state_dict = {"state": {}, "param_groups": param_groups} | ||
instance_list = self._fsdp_instances | ||
numels_per_instance = [sum(m._param_numels) for m in instance_list] # noqa | ||
|
||
# loop over parameters in state. | ||
# If they are tensors, unpad each rank properly, concatenate it, and then | ||
# call _get_param_views. This returns multiple tensors | ||
# each of which is a new parameter with a new "global" id. | ||
|
||
local_to_global_param_id: Dict[int, List[int]] = {} | ||
# local ids are after in the current state, global_ids will be in returned state. | ||
|
||
global_param_id = 0 # gets incremented | ||
for pg_id, pg0 in enumerate(param_groups): | ||
for param_id in pg0["params"]: | ||
local_to_global_param_id[param_id] = [] | ||
if param_id not in combined_state: | ||
continue | ||
# undo the work of shard_parameters | ||
for k, v in combined_state[param_id].items(): | ||
if k in constant_state[param_id]: | ||
continue | ||
assert isinstance(v, list), f"expected list, got {k}:{v} for {param_id} at rank {self.rank}" | ||
assert all(len(s[param_id]) == 1 for s in world_pad_info) | ||
pad_info = [s[param_id][0] for s in world_pad_info] | ||
assert len(pad_info) == self.world_size == len(v), f"{len(pad_info), self.world_size, len(v)}" | ||
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info)] | ||
flat_buffer = torch.cat(v_unpad) | ||
assert ( | ||
numels_per_instance[param_id] == flat_buffer.shape[0] == flat_buffer.numel() | ||
), f"{numels_per_instance[param_id]},{flat_buffer.shape[0]}, {flat_buffer.numel()}" | ||
param_views: Generator = instance_list[param_id].get_param_views(flat_buffer) | ||
for i, param_view in enumerate(param_views): | ||
if i == len(local_to_global_param_id[param_id]): | ||
# We have not seen this global param before, and make a new ID | ||
local_to_global_param_id[param_id].append(global_param_id) | ||
global_param_id += 1 | ||
cur_pid = local_to_global_param_id[param_id][i] | ||
if cur_pid not in new_state_dict["state"]: | ||
new_state_dict["state"][cur_pid] = copy.deepcopy(constant_state[param_id]) | ||
assert k not in new_state_dict["state"][cur_pid] | ||
new_state_dict["state"][cur_pid][k] = param_view | ||
|
||
if global_param_id == 0: # stateless optimizer | ||
num_params = sum([len(m._param_numels) for m in instance_list]) # noqa | ||
new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this list could be quite large, right? I guess this only affects SGD w/o momentum, but I wonder if there's a more compact way. Let's not worry about it for now, but perhaps put a note or TODO to make it more efficient There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you talking about |
||
else: | ||
new_state_dict["param_groups"][pg_id]["params"] = list(range(global_param_id)) | ||
|
||
global_to_local_id = {} | ||
for old_pid, global_ids in local_to_global_param_id.items(): | ||
for nid in global_ids: | ||
global_to_local_id[nid] = old_pid | ||
|
||
new_state_dict["param_id_map"] = global_to_local_id | ||
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict | ||
new_state_dict["state"] = dict(sorted(new_state_dict["state"].items())) | ||
return new_state_dict | ||
|
||
def _combine_tensor_optim_state(self, states: List[Dict]) -> Dict[int, Dict]: | ||
combined_state = states[0] | ||
for param_id in combined_state: | ||
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} | ||
if self.world_size == 1: | ||
return combined_state | ||
|
||
for rank, s in enumerate(states[1:]): | ||
for param_id, param_state in s.items(): | ||
for k, tensor in param_state.items(): | ||
combined_state[param_id][k].append(tensor) | ||
return combined_state | ||
|
||
@staticmethod | ||
def _extract_constant_state(sd0: Dict, param_id: int) -> Dict[str, Any]: | ||
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. | ||
for k, v in sd0[param_id].items(): | ||
|
||
if torch.is_tensor(v[0]): | ||
continue | ||
elif len(set(v)) == 1: | ||
constant_state[k] = v[0] | ||
else: | ||
raise TypeError(f"Dont know how to expand optimizer param {k} with value {v}") | ||
return constant_state | ||
|
||
@property | ||
def _fsdp_instances(self) -> List[nn.Module]: | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns all fsdp modules including self.""" | ||
assert self._is_root | ||
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] | ||
|
||
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Get the portion of the optimizer state dict associated with the shard""" | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Assert nesting is the same as it was at save time | ||
n_instances = len(self._fsdp_instances) | ||
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) | ||
msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}" | ||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
assert stateless or (n_instances == n_local_params_in_opt), msg | ||
|
||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
mlist = self._fsdp_instances | ||
if self.flatten_parameters: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this assume all inner FSDP instances also have flatten == True? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will assert |
||
full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict) | ||
assert stateless or len(full_optim_state_dict["state"]) == len(mlist) | ||
|
||
# get the portion of dict associated with the shard | ||
for id, s in full_optim_state_dict["state"].items(): | ||
for k, v in s.items(): | ||
if torch.is_tensor(v): | ||
v_shard, _ = self._get_shard(v) | ||
else: | ||
v_shard = v # dont partition entries that are not tensors | ||
full_optim_state_dict["state"][id][k] = v_shard | ||
|
||
return full_optim_state_dict | ||
|
||
@staticmethod | ||
def _flatten_optim_state_dict(sd: Dict) -> Dict: | ||
param_id_map = sd["param_id_map"] | ||
num_local_params = len(set(param_id_map.values())) | ||
if sd["state"]: | ||
new_state: Dict = {local_id: {} for local_id in range(num_local_params)} | ||
else: | ||
new_state = {} | ||
constant_state = {} | ||
|
||
# assumes sd sorted | ||
for expanded_pid, buffers in sd["state"].items(): | ||
consolidated_pid = param_id_map[expanded_pid] | ||
for buffer_name, p in buffers.items(): | ||
if torch.is_tensor(p): | ||
if buffer_name not in new_state[consolidated_pid]: | ||
new_state[consolidated_pid][buffer_name] = [] | ||
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) | ||
else: | ||
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" | ||
constant_state[buffer_name] = p | ||
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check | ||
|
||
for consolidated_pid, state in new_state.items(): | ||
for buffer_name, tensors in state.items(): | ||
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) | ||
new_state[consolidated_pid].update(constant_state) | ||
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
|
||
for pg_id, _ in enumerate(sd["param_groups"]): | ||
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) | ||
|
||
return new_sd | ||
|
||
|
||
@torch.no_grad() | ||
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,15 +122,16 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: | |
# register the views as plain attributes | ||
self._unflatten_params_as_views() | ||
|
||
def _get_param_views(self, flat_param: Tensor) -> Generator: | ||
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) | ||
def get_param_views(self, flat_param: Tensor) -> Generator: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since this is becoming an public method, can you please:
|
||
splat = flat_param.split(self._param_numels) | ||
return (t.view(s) for (t, s) in zip(splat, self._param_shapes)) | ||
|
||
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: | ||
assert self.is_flattened or flat_param is not None | ||
self.is_flattened = False | ||
flat_param = flat_param if flat_param is not None else self.flat_param | ||
|
||
ps = self._get_param_views(flat_param) | ||
ps = self.get_param_views(flat_param) | ||
for (m, n), p in zip(self._param_infos, ps): | ||
if hasattr(m, n): | ||
delattr(m, n) | ||
|
@@ -144,7 +145,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: | |
|
||
def _unflatten_params_as_views(self) -> None: | ||
assert self.is_flattened | ||
ps = self._get_param_views(self.flat_param) | ||
ps = self.get_param_views(self.flat_param) | ||
for (m, n), p in zip(self._param_infos, ps): | ||
setattr(m, n, p) # This will set as plain attr | ||
for (m, n, shared_m, shared_n) in self._shared_param_infos: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the doc here!