This library contains PyTorch implementations of the warmup schedules described in On the adequacy of untuned warmup for adaptive optimization.
Make sure you have Python 3.7+ and PyTorch 1.1+ or 2.x. Then, run the following command in the project directory:
python -m pip install .
or install the latest version from the Python Package Index:
pip install -U pytorch_warmup
- CIFAR10 - A sample script to train a ResNet20 model on the CIFAR10 dataset using an optimization algorithm with a warmup.
- EMNIST - A sample script to train a CNN model on the EMNIST dataset using the Adam algorithm with a warmup.
- Plots - A script to plot effective warmup periods as a function of β₂, and warmup schedules over time.
The Documentation provides more detailed information on this library, unseen below.
The scheduled learning rate is dampened by the multiplication of the warmup factor:
When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used
together with Adam
or its variant (AdamW
, NAdam
, etc.) as follows:
import torch
import pytorch_warmup as warmup
optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
# This sample code uses the AdamW optimizer.
num_steps = len(dataloader) * num_epochs
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
# The LR schedule initialization resets the initial LR of the optimizer.
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
# The warmup schedule initialization dampens the initial LR of the optimizer.
for epoch in range(1,num_epochs+1):
for batch in dataloader:
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
lr_scheduler.step()
Warning
Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule.
If you want to use the learning rate schedule chaining, which is supported for PyTorch 1.4 or above, you may simply write a code of learning rate schedulers as a suite of the with
statement:
lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for batch in dataloader:
...
optimizer.step()
with warmup_scheduler.dampening():
lr_scheduler1.step()
lr_scheduler2.step()
If you want to start the learning rate schedule after the end of the linear warmup, delay it by the warmup period:
warmup_period = 2000
num_steps = len(dataloader) * num_epochs - warmup_period
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period)
for epoch in range(1,num_epochs+1):
for batch in dataloader:
...
optimizer.step()
with warmup_scheduler.dampening():
if warmup_scheduler.last_step + 1 >= warmup_period:
lr_scheduler.step()
When the learning rate schedule uses the epoch number, the warmup schedule can be used as follows:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs//3], gamma=0.1)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
if i < len(dataloader)-1:
with warmup_scheduler.dampening():
pass
with warmup_scheduler.dampening():
lr_scheduler.step()
This code can be rewritten more compactly:
for epoch in range(1,num_epochs+1):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
if i + 1 == len(dataloader):
lr_scheduler.step()
When you use CosineAnnealingWarmRestarts
, the warmup schedule can be used as follows:
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
warmup_period = 2000
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period)
iters = len(dataloader)
warmup_epochs = ... # for example, (warmup_period + iters - 1) // iters
for epoch in range(epochs+warmup_epochs):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
if epoch >= warmup_epochs:
lr_scheduler.step(epoch-warmup_epochs + i / iters)
The warmup factor w(t)
depends on the warmup period, which must manually be specified, for LinearWarmup
and ExponentialWarmup
.
w(t) = min(1, t / warmup_period)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000)
w(t) = 1 - exp(-t / warmup_period)
warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000)
The warmup period is determined by a function of Adam's beta2
parameter for UntunedLinearWarmup
and UntunedExponentialWarmup
.
warmup_period = 2 / (1 - beta2)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
warmup_period = 1 / (1 - beta2)
warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer)
The warmup factor depends on Adam's beta2
parameter for RAdamWarmup
. For details please refer to the
Documentation or
"On the Variance of the Adaptive Learning Rate and Beyond."
warmup_scheduler = warmup.RAdamWarmup(optimizer)
The Apex library provides an Adam optimizer tuned for CUDA devices, FusedAdam. The FusedAdam optimizer can be used together with any one of the warmup schedules above. For example:
optimizer = apex.optimizers.FusedAdam(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
MIT License
© 2019-2024 Takenori Yamamoto