Skip to content

[NeurIPS 2023] The PyTorch Implementation of Scheduled (Stable) Weight Decay.

License

Notifications You must be signed in to change notification settings

zeke-xie/stable-weight-decay-regularization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Scheduled(Stable)-Weight-Decay-Regularization

The PyTorch Implementation of Scheduled (Stable) Weight Decay.

The algorithms were first proposed in our arxiv paper.

A formal version with major revision and theoretical mechanism "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective" is accepted at NeurIPS 2023.

Why Scheduled (Stable) Weight Decay?

We proposed the Scheduled (Stable) Weight Decay (SWD) method to mitigate overlooked large-gradient-norm pitfalls of weight decay in modern deep learning libraries.

  • SWD can penalize the large gradient norms at the final phase of training.

  • SWD usually makes significant improvements over both L2 regularization and decoupled weight decay.

  • Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can usually outperform complex Adam variants, which have more hyperparameters.

The environment is as bellow:

Python 3.7.3

PyTorch >= 1.4.0

Usage

You may use it as a standard PyTorch optimizer.

import swd_optim

optimizer = swd_optim.AdamS(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

Test performance

Dataset Model AdamS SGD M Adam AMSGrad AdamW AdaBound Padam Yogi RAdam
CIFAR-10 ResNet18 4.910.04 5.010.03 6.530.03 6.160.18 5.080.07 5.650.08 5.120.04 5.870.12 6.010.10
VGG16 6.090.11 6.420.02 7.310.25 7.140.14 6.480.13 6.760.12 6.150.06 6.900.22 6.560.04
CIFAR-100 DenseNet121 20.520.26 19.810.33 25.110.15 24.430.09 21.550.14 22.690.15 21.100.23 22.150.36 22.270.22
GoogLeNet 21.050.18 21.210.29 26.120.33 25.530.17 21.290.17 23.180.31 21.820.17 24.240.16 22.230.15

Citing

If you use Scheduled (Stable) Weight Decay in your work, please cite "On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective".

@inproceedings{xie2023onwd,
    title={On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective},
    author={Xie, Zeke and Xu, Zhiqiang and Zhang, Jingzhao and Sato, Issei and Sugiyama, Masashi},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023}
}

About

[NeurIPS 2023] The PyTorch Implementation of Scheduled (Stable) Weight Decay.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published