Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paramscheduler emahandler #2326

Merged
merged 11 commits into from
Nov 21, 2021
65 changes: 45 additions & 20 deletions ignite/handlers/state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
from bisect import bisect_right
from typing import Any, List, Sequence, Tuple, Union

Expand All @@ -11,8 +12,9 @@ class StateParamScheduler(BaseParamScheduler):

Args:
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
save_history: whether to log the parameter values to ``engine.state.param_history``, (default=False).
create_new: whether to create ``param_name`` on ``engine.state`` taking into account whether ``param_name``
attribute already exists or not. Overrides existing attribute by default, (default=False).

Note:
Parameter scheduler works independently of the internal state of the attached engine.
Expand All @@ -23,10 +25,9 @@ class StateParamScheduler(BaseParamScheduler):

"""

def __init__(
self, param_name: str, save_history: bool = False,
):
def __init__(self, param_name: str, save_history: bool = False, create_new: bool = False):
super(StateParamScheduler, self).__init__(param_name, save_history)
self.create_new = create_new

def attach(
self,
Expand All @@ -43,11 +44,19 @@ def attach(

"""
if hasattr(engine.state, self.param_name):
raise ValueError(
f"Attribute: '{self.param_name}' is already defined in the Engine.state."
f"This may be a conflict between multiple StateParameterScheduler handlers."
f"Please choose another name."
)
if self.create_new:
raise ValueError(
f"Attribute '{self.param_name}' already exists in the engine.state. "
f"This may be a conflict between multiple handlers. "
f"Please choose another name."
)
else:
if not self.create_new:
warnings.warn(
f"Attribute '{self.param_name}' is not defined in the engine.state. "
f"{type(self).__name__} will create it. Remove this warning by setting create_new=True."
)
setattr(engine.state, self.param_name, None)

if self.save_history:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
Expand Down Expand Up @@ -147,8 +156,8 @@ def __call__(self, event_index):

"""

def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history)
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)

if not callable(lambda_obj):
raise ValueError("Expected lambda_obj to be callable.")
Expand Down Expand Up @@ -199,9 +208,13 @@ class PiecewiseLinearStateScheduler(StateParamScheduler):
"""

def __init__(
self, milestones_values: List[Tuple[int, float]], param_name: str, save_history: bool = False,
self,
milestones_values: List[Tuple[int, float]],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history)
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history, create_new)

if not isinstance(milestones_values, Sequence):
raise TypeError(
Expand Down Expand Up @@ -289,9 +302,9 @@ class ExpStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False,
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, create_new: bool = False,
):
super(ExpStateScheduler, self).__init__(param_name, save_history)
super(ExpStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self._state_attrs += ["initial_value", "gamma"]
Expand Down Expand Up @@ -337,9 +350,15 @@ class StepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, step_size: int, param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
step_size: int,
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(StepStateScheduler, self).__init__(param_name, save_history)
super(StepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.step_size = step_size
Expand Down Expand Up @@ -386,9 +405,15 @@ class MultiStepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, milestones: List[int], param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
milestones: List[int],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(MultiStepStateScheduler, self).__init__(param_name, save_history)
super(MultiStepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.milestones = milestones
Expand Down
93 changes: 68 additions & 25 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn

from ignite.engine import Engine, Events
from ignite.handlers.state_param_scheduler import (
Expand Down Expand Up @@ -281,31 +283,6 @@ def _test(scheduler_cls, scheduler_kwargs):
_test(scheduler_cls, scheduler_kwargs)


@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
def test_state_param_asserts(scheduler_cls, scheduler_kwargs):
import re

def _test(scheduler_cls, scheduler_kwargs):
scheduler = scheduler_cls(**scheduler_kwargs)
with pytest.raises(
ValueError,
match=r"Attribute: '"
+ re.escape(scheduler_kwargs["param_name"])
+ "' is already defined in the Engine.state.This may be a conflict between multiple StateParameterScheduler"
+ " handlers.Please choose another name.",
):

trainer = Engine(lambda engine, batch: None)
event = Events.EPOCH_COMPLETED
max_epochs = 2
data = [0] * 10
scheduler.attach(trainer, event)
trainer.run(data, max_epochs=max_epochs)
scheduler.attach(trainer, event)

_test(scheduler_cls, scheduler_kwargs)


def test_torch_save_load():

lambda_state_parameter_scheduler = LambdaStateScheduler(
Expand Down Expand Up @@ -441,3 +418,69 @@ def __call__(self, event_index):
param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

engine.run([0] * 8, max_epochs=10)


def test_param_scheduler_attach_exception():
trainer = Engine(lambda e, b: None)
param_name = "state_param"

setattr(trainer.state, param_name, None)

save_history = True
create_new = True

param_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10, 0.999)],
save_history=save_history,
create_new=create_new,
)

with pytest.raises(
ValueError,
match=r"Attribute '" + re.escape(param_name) + "' already exists in the engine.state. "
r"This may be a conflict between multiple handlers. "
r"Please choose another name.",
):
param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)


def test_param_scheduler_attach_warning():
trainer = Engine(lambda e, b: None)
param_name = "state_param"
save_history = True
create_new = False

param_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10, 0.999)],
save_history=save_history,
create_new=create_new,
)

with pytest.warns(
UserWarning,
match=r"Attribute '" + re.escape(param_name) + "' is not defined in the engine.state. "
r"PiecewiseLinearStateScheduler will create it. Remove this warning by setting create_new=True.",
):
param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)


def test_param_scheduler_with_ema_handler():

from ignite.handlers import EMAHandler

model = nn.Linear(2, 1)
trainer = Engine(lambda e, b: model(b))
data = torch.rand(100, 2)

param_name = "ema_decay"

ema_handler = EMAHandler(model)
ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)

ema_decay_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name, milestones_values=[(0, 0.0), (10, 0.999),], save_history=True
)
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
trainer.run(data, max_epochs=20)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved