Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial support for ZeRO optimizer state sharding (#1259)
Summary: FairseqOSS will work with any optimizer and dtype. TODO(future PR): * support reduce instead of all_reduce * support gradient sharding * support parameter sharding Pull Request resolved: fairinternal/fairseq-py#1259 Test Plan: Verified that checkpoint save and restore work. Verified that grad_norm, loss, and ppl are identical with and without sharding enable. Before: $ fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 ... 2020-08-27 22:24:51 | INFO | train_inner | epoch 001: 49 / 394 loss=18.84, ppl=469411, wps=269226, ups=1.03, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68 2020-08-27 22:24:52 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=256992, ups=0.98, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69 2020-08-27 22:24:53 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=259178, ups=0.99, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70 2020-08-27 22:24:54 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=257710, ups=0.98, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71 2020-08-27 22:24:55 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403794, wps=269279, ups=1.03, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72 2020-08-27 22:24:56 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=264616, ups=1.01, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73 2020-08-27 22:24:56 | INFO | fairseq_cli.train | begin save checkpoint 2020-08-27 22:24:56 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) 2020-08-27 22:24:56 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 264825 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73 2020-08-27 22:24:56 | INFO | fairseq_cli.train | done training in 72.2 seconds After: $ fairseq-train --task language_modeling data-bin/wikitext-103 --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm --share-decoder-input-output-embed --dropout 0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 --tokens-per-sample 512 --sample-break-mode none --max-tokens 2048 --update-freq 16 --max-update 50000 --memory-efficient-fp16 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --zero-sharding os ... 2020-08-27 22:22:55 | INFO | train_inner | epoch 001: 49 / 394 loss=18.84, ppl=469411, wps=267663, ups=1.02, wpb=262144, bsz=512, num_updates=45, lr=5.72388e-06, gnorm=5.769, loss_scale=8, train_wall=1, wall=68 2020-08-27 22:22:56 | INFO | train_inner | epoch 001: 50 / 394 loss=18.787, ppl=452312, wps=252797, ups=0.96, wpb=262144, bsz=512, num_updates=46, lr=5.84885e-06, gnorm=5.512, loss_scale=8, train_wall=1, wall=69 2020-08-27 22:22:57 | INFO | train_inner | epoch 001: 51 / 394 loss=18.74, ppl=437735, wps=267692, ups=1.02, wpb=262144, bsz=512, num_updates=47, lr=5.97383e-06, gnorm=5.298, loss_scale=8, train_wall=1, wall=70 2020-08-27 22:22:58 | INFO | train_inner | epoch 001: 52 / 394 loss=18.683, ppl=420727, wps=267507, ups=1.02, wpb=262144, bsz=512, num_updates=48, lr=6.0988e-06, gnorm=5.094, loss_scale=8, train_wall=1, wall=71 2020-08-27 22:22:59 | INFO | train_inner | epoch 001: 53 / 394 loss=18.623, ppl=403794, wps=254410, ups=0.97, wpb=262144, bsz=512, num_updates=49, lr=6.22378e-06, gnorm=4.893, loss_scale=8, train_wall=1, wall=72 2020-08-27 22:23:00 | INFO | train_inner | epoch 001: 54 / 394 loss=18.574, ppl=390255, wps=268234, ups=1.02, wpb=262144, bsz=512, num_updates=50, lr=6.34875e-06, gnorm=4.684, loss_scale=8, train_wall=1, wall=73 2020-08-27 22:23:00 | INFO | fairseq_cli.train | begin save checkpoint 2020-08-27 22:23:00 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) 2020-08-27 22:23:00 | INFO | train | epoch 001 | loss 19.736 | ppl 873122 | wps 263570 | ups 1.01 | wpb 262144 | bsz 512 | num_updates 50 | lr 6.34875e-06 | gnorm 8.898 | loss_scale 8 | train_wall 66 | wall 73 2020-08-27 22:23:00 | INFO | fairseq_cli.train | done training in 72.3 seconds # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D23432082 Pulled By: msbaines fbshipit-source-id: 6a020b25e36a3d9283582b7d89a6a53038e5b181
- Loading branch information