-
-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed parameter scheduler bug with CosineAnnealingWarmRestarts
#2938
Conversation
CosineAnnealingWarmRestarts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @AlexanderChaptykov
I left few suggestions on how to improve the PR
assert warm_lrs[warm_steps:] == cosine_lrs | ||
else: | ||
assert (np.linspace(warm_start, lr, warm_steps).round(3) == np.array(warm_lrs[:warm_steps]).round(3)).all() | ||
assert warm_lrs[warm_steps - 1 : -1] == cosine_lrs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need this, beacuse of shifting lrs if warmup_end_value == None
Let's make the test as following: @pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
lr = 0.2
steps = 200
warm_steps = 50
warm_start = 0.023
def get_optim():
t1 = torch.zeros([1], requires_grad=True)
return torch.optim.SGD([t1], lr=lr)
def get_cos_shed():
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)
optimizer = get_optim()
scheduler = get_cos_shed()
cosine_lrs = []
for i in range(steps):
cosine_lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()
optimizer = get_optim()
scheduler = create_lr_scheduler_with_warmup(
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
)
warm_lrs = []
for epoch in range(warm_steps + steps):
scheduler(None)
warm_lrs.append(optimizer.param_groups[0]["lr"])
if warmup_end_value is not None:
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[warm_steps:] == cosine_lrs
else:
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[warm_steps - 1:-1] == cosine_lrs |
…e_sched # Conflicts: # tests/ignite/handlers/test_param_scheduler.py
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from ignite.handlers import create_lr_scheduler_with_warmup
def plot(warmup_end_value):
lr = 0.2
warm_steps = 5
steps = 100
warm_start = 0.023
def get_optim():
t1 = torch.zeros([1], requires_grad=True)
return torch.optim.SGD([t1], lr=lr)
def get_cos_shed():
return CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, verbose=False)
optimizer = get_optim()
scheduler = get_cos_shed()
cosine_lrs = []
for i in range(steps):
cosine_lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()
optimizer = get_optim()
scheduler = create_lr_scheduler_with_warmup(
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
)
warm_lrs = []
for epoch in range(warm_steps + steps):
scheduler(None)
warm_lrs.append(optimizer.param_groups[0]["lr"])
if warmup_end_value is not None:
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value != lr")
plt.plot(warm_lrs, "-*")
plt.subplot(122)
plt.title("CosineAnnealingWarmRestarts")
plt.plot(cosine_lrs, "-*")
plt.show()
else:
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value == lr")
plt.plot(warm_lrs, "-*")
plt.subplot(122)
plt.title("CosineAnnealingWarmRestarts")
plt.plot(cosine_lrs, "-*")
plt.show()
plot(None)
plot(.26) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @AlexanderChaptykov for working on this issue!
Fixes #2910
Description:
Check list:
Plotting learning rates: