Skip to content

Commit

Permalink
Initial support for ZeRO optimizer state sharding (#1259)
Browse files Browse the repository at this point in the history
Summary:
FairseqOSS will work with any optimizer and dtype.

TODO(future PR):
* support reduce instead of all_reduce
* support gradient sharding
* support parameter sharding

Pull Request resolved: fairinternal/fairseq-py#1259

Test Plan:
Verified that checkpoint save and restore work.

Verified that grad_norm, loss, and ppl are identical with and without
sharding enable.

Before:

$ fairseq-train --task language_modeling   data-bin/wikitext-103   --save-dir checkpoints/transformer_wikitext-103   --arch transformer_lm --share-decoder-input-output-embed   --dropout 0.1   --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0   --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07   --tokens-per-sample 512 --sample-break-mode none   --max-tokens 2048 --update-freq 16   --max-update 50000  --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50
...
2020-08-27 22:24:51 | INFO | train_inner | epoch 001:     49 / 394 loss=18.84, ppl=469411, wps=269226, ups=1.03, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68
2020-08-27 22:24:52 | INFO | train_inner | epoch 001:     50 / 394 loss=18.787, ppl=452312, wps=256992, ups=0.98, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69
2020-08-27 22:24:53 | INFO | train_inner | epoch 001:     51 / 394 loss=18.74, ppl=437735, wps=259178, ups=0.99, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70
2020-08-27 22:24:54 | INFO | train_inner | epoch 001:     52 / 394 loss=18.683, ppl=420727, wps=257710, ups=0.98, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71
2020-08-27 22:24:55 | INFO | train_inner | epoch 001:     53 / 394 loss=18.623, ppl=403794, wps=269279, ups=1.03, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72
2020-08-27 22:24:56 | INFO | train_inner | epoch 001:     54 / 394 loss=18.574, ppl=390255, wps=264616, ups=1.01, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73
2020-08-27 22:24:56 | INFO | fairseq_cli.train | begin save checkpoint
2020-08-27 22:24:56 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
2020-08-27 22:24:56 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 264825 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73
2020-08-27 22:24:56 | INFO | fairseq_cli.train | done training in 72.2 seconds

After:

$ fairseq-train --task language_modeling   data-bin/wikitext-103   --save-dir checkpoints/transformer_wikitext-103   --arch transformer_lm --share-decoder-input-output-embed   --dropout 0.1   --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0   --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07   --tokens-per-sample 512 --sample-break-mode none   --max-tokens 2048 --update-freq 16   --max-update 50000  --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --zero-sharding os
...
2020-08-27 22:22:55 | INFO | train_inner | epoch 001:     49 / 394 loss=18.84, ppl=469411, wps=267663, ups=1.02, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68
2020-08-27 22:22:56 | INFO | train_inner | epoch 001:     50 / 394 loss=18.787, ppl=452312, wps=252797, ups=0.96, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69
2020-08-27 22:22:57 | INFO | train_inner | epoch 001:     51 / 394 loss=18.74, ppl=437735, wps=267692, ups=1.02, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70
2020-08-27 22:22:58 | INFO | train_inner | epoch 001:     52 / 394 loss=18.683, ppl=420727, wps=267507, ups=1.02, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71
2020-08-27 22:22:59 | INFO | train_inner | epoch 001:     53 / 394 loss=18.623, ppl=403794, wps=254410, ups=0.97, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72
2020-08-27 22:23:00 | INFO | train_inner | epoch 001:     54 / 394 loss=18.574, ppl=390255, wps=268234, ups=1.02, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73
2020-08-27 22:23:00 | INFO | fairseq_cli.train | begin save checkpoint
2020-08-27 22:23:00 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
2020-08-27 22:23:00 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 263570 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73
2020-08-27 22:23:00 | INFO | fairseq_cli.train | done training in 72.3 seconds

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Reviewed By: myleott

Differential Revision: D23432082

Pulled By: msbaines

fbshipit-source-id: 6a020b25e36a3d9283582b7d89a6a53038e5b181
  • Loading branch information
mandeeplearning authored and facebook-github-bot committed Sep 2, 2020
1 parent 251c869 commit 5d7ed6a
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 14 deletions.
7 changes: 6 additions & 1 deletion fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
best_function = max if args.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)

if args.no_save or not trainer.is_data_parallel_master:
if args.no_save:
return

trainer.consolidate_optimizer()

if not trainer.is_data_parallel_master:
return

def is_better(a, b):
Expand Down
2 changes: 2 additions & 0 deletions fairseq/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF # noqa
from fairseq.optim.shard import shard_


__all__ = [
'FairseqOptimizer',
'FP16Optimizer',
'MemoryEfficientFP16Optimizer',
'shard_',
]


Expand Down
9 changes: 9 additions & 0 deletions fairseq/optim/fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def optimizer(self):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
return self._optimizer

@optimizer.setter
def optimizer(self, optimizer):
"""Reset optimizer instance."""
if not hasattr(self, '_optimizer'):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
self._optimizer = optimizer

@property
def optimizer_config(self):
"""
Expand Down
35 changes: 22 additions & 13 deletions fairseq/optim/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def build_optimizer(cls, args, params):
def optimizer(self):
return self.fp32_optimizer.optimizer

@optimizer.setter
def optimizer(self, optimizer):
self.fp32_optimizer.optimizer = optimizer

@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
Expand Down Expand Up @@ -279,19 +283,20 @@ def load_state_dict(self, state_dict, optimizer_overrides=None):
# params are FP16 while the optimizer state is FP32 and we don't want
# to cast. A workaround is to manually copy back the original state
# after the optimizer has been loaded.
groups = self.optimizer.param_groups
saved_groups = state_dict['param_groups']
id_map = {
old_id: p
for old_id, p in zip(
chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups))
)
}
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
self.optimizer.state[param] = v
if not getattr(self.optimizer, 'disable_mem_eff_fp16_loading_hack', False):
groups = self.optimizer.param_groups
saved_groups = state_dict['param_groups']
id_map = {
old_id: p
for old_id, p in zip(
chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups))
)
}
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
self.optimizer.state[param] = v

def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Expand Down Expand Up @@ -412,6 +417,10 @@ def build_optimizer(cls, args, params):
def optimizer(self):
return self.wrapped_optimizer.optimizer

@optimizer.setter
def optimizer(self, optimizer):
self.wrapped_optimizer.optimizer = optimizer

@property
def optimizer_config(self):
return self.wrapped_optimizer.optimizer_config
Expand Down
33 changes: 33 additions & 0 deletions fairseq/optim/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


try:
from fairscale.optim import OSS
_has_fairscale = True
except ImportError:
_has_fairscale = False


def shard_(args, optimizer):
if not _has_fairscale:
raise ImportError(
'\n\nPlease install the fairscale package:'
'\n\n pip install fairscale'
)

class FairseqOSS(OSS):
@property
def disable_mem_eff_fp16_loading_hack(self):
return True

def __getattr__(self, name):
if name.startswith("supports") and hasattr(self.optim, name):
return getattr(self.optim, name)
raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name))

torch_optimizer = optimizer.optimizer
optim_cls = type(torch_optimizer)
optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, **optimizer.optimizer_config)
4 changes: 4 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def add_distributed_training_args(parser, default_world_size=None):
help='number of GPUs in each node. An allreduce operation across GPUs in '
'a node is very fast. Hence, we do allreduce across GPUs in a node, '
'and gossip across different nodes')
# Add argument for ZeRO sharding of OptimizerState(os), gradients(g) and parameters(p)
group.add_argument('--zero-sharding', default='none', type=str,
choices=['none', 'os'],
help='ZeRO sharding')
# fmt: on
return group

Expand Down
17 changes: 17 additions & 0 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,28 @@ def _build_optimizer(self):
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)

if self.args.zero_sharding == 'os':
if (self.args.fp16
and not self.args.memory_efficient_fp16
and not self.args.memory_efficient_bf16
) and not self.args.fp16_no_flatten_grads:
raise ValueError(
"ZeRO is incomptabile with fp16 and flattened grads. "
"Please use --fp16-no-flatten-grads"
)
else:
optim.shard_(self.args, self._optimizer)

# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._lr_scheduler.step_update(0)

def consolidate_optimizer(self):
"""For OSS, we need to consolidate the state dict."""
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
self.optimizer.optimizer.consolidate_state_dict()

def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if self.is_data_parallel_master: # only save one checkpoint
Expand Down

0 comments on commit 5d7ed6a

Please sign in to comment.