Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #813 #1072

Merged
merged 7 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 111 additions & 87 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
import numbers
import tempfile
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import copy
from pathlib import Path
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -449,7 +451,16 @@ def __init__(self, schedulers, durations, save_history=False):

self.schedulers = schedulers
self.durations = durations
super(ConcatScheduler, self).__init__(optimizer=_get_fake_optimizer(), param_name="", save_history=save_history)

self.optimizer = self.schedulers[0].optimizer
if not (all(id(s.optimizer) == id(self.optimizer) for s in self.schedulers)):
raise ValueError("schedulers should be related to same optimizer")

# schedulers should have save_history sync with ParamGroupScheduler
for s in schedulers:
s.save_history = save_history

super(ConcatScheduler, self).__init__(optimizer=self.optimizer, param_name="", save_history=save_history)

self._scheduler_index = 0
self._current_scheduler = None
Expand Down Expand Up @@ -510,7 +521,6 @@ def __call__(self, engine, name=None):
if self._current_duration == 0:
self._scheduler_index += 1
self._setup_scheduler()

self._current_scheduler(engine, name)
self._current_duration -= 1

Expand Down Expand Up @@ -549,22 +559,40 @@ def simulate_values(cls, num_events, schedulers, durations, param_names=None, **
"""
if param_names is not None and not isinstance(param_names, (list, tuple)):
raise ValueError("Argument param_names should be list or tuple of strings")
output = []

# Need to copy all schedulers otherwise unsafe
copy_schedulers = [_replicate_scheduler(s) for s in schedulers]

scheduler = cls(copy_schedulers, durations, save_history=False)
if param_names is None:
param_names = [scheduler.param_name]
for i in range(num_events):
scheduler(engine=None)
values = [i]
for param_name in param_names:
params = [p[param_name] for p in scheduler.optimizer_param_groups]
values = values + params
output.append(values)
return output
# This scheduler uses `ParamScheduler` which
# should be replicated in order to simulate LR values and
# not perturb original scheduler.
with tempfile.TemporaryDirectory() as tmpdirname:
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)}
# all schedulers should be related to the same optimizer
objs["optimizer"] = schedulers[0].optimizer.state_dict()

torch.save(objs, cache_filepath.as_posix())

# do not save_history
for s in schedulers:
s.save_history = False

output = []
scheduler = cls(schedulers=schedulers, save_history=False, durations=durations, **kwargs)
if param_names is None:
param_names = [scheduler.param_name]
for i in range(num_events):
scheduler(engine=None)
values = [i]
for param_name in param_names:
params = [p[param_name] for p in scheduler.optimizer_param_groups]
values = values + params
output.append(values)

objs = torch.load(cache_filepath.as_posix())
for i, s in enumerate(schedulers):
s.load_state_dict(objs["lr_scheduler_{}".format(i)])
s.optimizer.load_state_dict(objs["optimizer"])

return output


class LRScheduler(ParamScheduler):
Expand All @@ -591,7 +619,7 @@ class LRScheduler(ParamScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
"""

def __init__(self, lr_scheduler, save_history=False, **kwds):
def __init__(self, lr_scheduler, save_history=False, **kwargs):

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
Expand Down Expand Up @@ -635,33 +663,36 @@ def simulate_values(cls, num_events, lr_scheduler, **kwargs):
list of pairs: [event_index, value]

"""

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
"but given {}".format(type(lr_scheduler))
)

# This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
# should be replicated in order to simulate LR values and
# not perturb original scheduler.
copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
values = []
scheduler = cls(save_history=False, lr_scheduler=copy_lr_scheduler)
for i in range(num_events):
params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
values.append([i] + params)
scheduler(engine=None)
with tempfile.TemporaryDirectory() as tmpdirname:
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
obj = {
"lr_scheduler": lr_scheduler.state_dict(),
"optimizer": lr_scheduler.optimizer.state_dict(),
}
torch.save(obj, cache_filepath.as_posix())

return values
values = []
scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
for i in range(num_events):
params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
values.append([i] + params)
scheduler(engine=None)

@staticmethod
def _replicate_lr_scheduler(lr_scheduler):
if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError("lr_scheduler should inherit of _LRScheduler, got {}".format(type(lr_scheduler)))
lr_scheduler_cls = lr_scheduler.__class__
dummy_optimizer = _replicate_optimizer(lr_scheduler.optimizer)
for group in dummy_optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
kwargs = lr_scheduler.state_dict()
for k in [_k for _k in kwargs.keys() if "_" == _k[0]] + ["base_lrs", "last_epoch"]:
del kwargs[k]
copy_lr_scheduler = lr_scheduler_cls(optimizer=dummy_optimizer, **kwargs)
copy_lr_scheduler.load_state_dict(lr_scheduler.state_dict())
return copy_lr_scheduler
obj = torch.load(cache_filepath.as_posix())
lr_scheduler.load_state_dict(obj["lr_scheduler"])
lr_scheduler.optimizer.load_state_dict(obj["optimizer"])

return values


def create_lr_scheduler_with_warmup(
Expand Down Expand Up @@ -741,7 +772,7 @@ def create_lr_scheduler_with_warmup(
if init_lr != param_group_warmup_end_value:
milestones_values.append((warmup_duration, init_lr))

lr_scheduler = LRScheduler(lr_scheduler)
lr_scheduler = LRScheduler(lr_scheduler, save_history=save_history)
else:
init_lr = lr_scheduler.get_param()
if init_lr == param_group_warmup_end_value:
Expand Down Expand Up @@ -913,16 +944,29 @@ def __init__(self, schedulers: List[ParamScheduler], names: Optional[List[str]]
self.schedulers = schedulers
self.names = names

optimizer = self.schedulers[0].optimizer
if not (all(id(s.optimizer) == id(optimizer) for s in schedulers)):
self.optimizer = self.schedulers[0].optimizer
if not (all(id(s.optimizer) == id(self.optimizer) for s in schedulers)):
raise ValueError("schedulers should be related to same optimizer")

super(ParamGroupScheduler, self).__init__(optimizer=optimizer, param_name="lr", save_history=save_history)
# schedulers should have save_history sync with ParamGroupScheduler
for s in schedulers:
s.save_history = save_history

super(ParamGroupScheduler, self).__init__(optimizer=self.optimizer, param_name="lr", save_history=save_history)

def __call__(self, engine, name=None):
for scheduler, name in zip(self.schedulers, self.names):
scheduler(engine, name)

@property
def save_history(self):
return self.schedulers[0].save_history

@save_history.setter
def save_history(self, value):
for s in self.schedulers:
s.save_history = value

def get_param(self) -> Union[List[float], float]:
return [scheduler.get_param() for scheduler in self.schedulers]

Expand Down Expand Up @@ -981,35 +1025,31 @@ def simulate_values(cls, num_events, schedulers, **kwargs):
list of pairs: [event_index, value]

"""
copy_lr_schedulers = [_replicate_scheduler(s) for s in schedulers]
values = []
scheduler = cls(schedulers=copy_lr_schedulers)
for i in range(num_events):
scheduler(engine=None)
params = scheduler.get_param()
values.append([i] + params)
return values

# This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
# should be replicated in order to simulate LR values and
# not perturb original scheduler.
with tempfile.TemporaryDirectory() as tmpdirname:
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)}
# all schedulers should be related to the same optimizer
objs["optimizer"] = schedulers[0].optimizer.state_dict()

torch.save(objs, cache_filepath.as_posix())

def _replicate_scheduler(scheduler, save_history=False):
if isinstance(scheduler, LRScheduler):
return LRScheduler(LRScheduler._replicate_lr_scheduler(scheduler.lr_scheduler), save_history=save_history)
elif isinstance(scheduler, ConcatScheduler):
copy_schedulers = [_replicate_scheduler(s, save_history=save_history) for s in scheduler.schedulers]
return ConcatScheduler(copy_schedulers, durations=scheduler.durations, save_history=save_history)
elif isinstance(scheduler, ParamGroupScheduler):
copy_optimizer = _replicate_optimizer(scheduler.optimizer)
copy_schedulers = [_replicate_scheduler(s, save_history=save_history) for s in scheduler.schedulers]
for s in copy_schedulers:
s.optimizer = copy_optimizer
return ParamGroupScheduler(schedulers=copy_schedulers, names=scheduler.names, save_history=save_history)
elif isinstance(scheduler, ParamScheduler):
new_scheduler = copy(scheduler)
new_scheduler.optimizer = _replicate_optimizer(new_scheduler.optimizer)
new_scheduler.save_history = save_history
return new_scheduler
else:
raise TypeError("Unknown scheduler type {}".format(type(scheduler)))
values = []
scheduler = cls(schedulers=schedulers, **kwargs)
for i in range(num_events):
params = scheduler.get_param()
values.append([i] + params)
scheduler(engine=None)

objs = torch.load(cache_filepath.as_posix())
for i, s in enumerate(schedulers):
s.load_state_dict(objs["lr_scheduler_{}".format(i)])
s.optimizer.load_state_dict(objs["optimizer"])

return values


def _get_fake_optimizer(optimizer_cls=None, **kwargs):
Expand All @@ -1018,19 +1058,3 @@ def _get_fake_optimizer(optimizer_cls=None, **kwargs):
optimizer_cls = torch.optim.SGD
kwargs["lr"] = 0.01
return optimizer_cls([t], **kwargs)


def _replicate_optimizer(optimizer):
cls = optimizer.__class__
defaults = copy(optimizer.defaults)
if not isinstance(defaults["lr"], numbers.Real):
defaults["lr"] = 0.01
param_groups = optimizer.param_groups
param_groups = [
# do no copy params
{k: v for k, v in pg.items() if k != "params"}
for pg in param_groups
]
for pg in param_groups:
pg["params"] = torch.zeros([1], requires_grad=True)
return cls(param_groups, **defaults)
Loading