The PyTorch Implementation of Scheduled (Stable) Weight Decay.
The algorithms were first proposed in our arxiv paper.
A formal version with major revision and theoretical mechanism "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective" is accepted at NeurIPS 2023.
We proposed the Scheduled (Stable) Weight Decay (SWD) method to mitigate overlooked large-gradient-norm pitfalls of weight decay in modern deep learning libraries.
-
SWD can penalize the large gradient norms at the final phase of training.
-
SWD usually makes significant improvements over both L2 regularization and decoupled weight decay.
-
Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can usually outperform complex Adam variants, which have more hyperparameters.
Python 3.7.3
PyTorch >= 1.4.0
You may use it as a standard PyTorch optimizer.
import swd_optim
optimizer = swd_optim.AdamS(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)
Dataset | Model | AdamS | SGD M | Adam | AMSGrad | AdamW | AdaBound | Padam | Yogi | RAdam |
---|---|---|---|---|---|---|---|---|---|---|
CIFAR-10 | ResNet18 | 4.910.04 | 5.010.03 | 6.530.03 | 6.160.18 | 5.080.07 | 5.650.08 | 5.120.04 | 5.870.12 | 6.010.10 |
VGG16 | 6.090.11 | 6.420.02 | 7.310.25 | 7.140.14 | 6.480.13 | 6.760.12 | 6.150.06 | 6.900.22 | 6.560.04 | |
CIFAR-100 | DenseNet121 | 20.520.26 | 19.810.33 | 25.110.15 | 24.430.09 | 21.550.14 | 22.690.15 | 21.100.23 | 22.150.36 | 22.270.22 |
GoogLeNet | 21.050.18 | 21.210.29 | 26.120.33 | 25.530.17 | 21.290.17 | 23.180.31 | 21.820.17 | 24.240.16 | 22.230.15 |
If you use Scheduled (Stable) Weight Decay in your work, please cite "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective".
@inproceedings{xie2023onwd,
title={On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective},
author={Xie, Zeke and Xu, Zhiqiang and Zhang, Jingzhao and Sato, Issei and Sugiyama, Masashi},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}