Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paramscheduler emahandler #2326

Merged
merged 11 commits into from
Nov 21, 2021
77 changes: 58 additions & 19 deletions ignite/handlers/state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numbers
import re
import warnings
from bisect import bisect_right
from typing import Any, List, Sequence, Tuple, Union

Expand All @@ -13,6 +15,9 @@ class StateParamScheduler(BaseParamScheduler):
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
create_new: in case `param_name` already exists in `engine.state`, whether to authorize `StateParamScheduler`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fco-dv please use double-ticks ```` instead of single-tick when addresses code. Here and above.

Suggested change
create_new: in case `param_name` already exists in `engine.state`, whether to authorize `StateParamScheduler`
create_new: in case ``param_name`` already exists in ``engine.state``, whether to authorize ``StateParamScheduler``

create a new parameter based on `StateParamScheduler` class name.
By default, this option is `False` meaning that `StateParamScheduler` will override `param_name` values.

Note:
Parameter scheduler works independently of the internal state of the attached engine.
Expand All @@ -23,10 +28,9 @@ class StateParamScheduler(BaseParamScheduler):

"""

def __init__(
self, param_name: str, save_history: bool = False,
):
def __init__(self, param_name: str, save_history: bool = False, create_new: bool = False):
super(StateParamScheduler, self).__init__(param_name, save_history)
self.create_new = create_new

def attach(
self,
Expand All @@ -43,16 +47,35 @@ def attach(

"""
if hasattr(engine.state, self.param_name):
raise ValueError(
f"Attribute: '{self.param_name}' is already defined in the Engine.state."
f"This may be a conflict between multiple StateParameterScheduler handlers."
f"Please choose another name."
)
if not self.create_new:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not quite understand this if/else clauses, especially "else" one.
I was thinking about the following logic:

  1. param_name NOT in engine.state AND create_new is True => simply create new attribute
  2. param_name NOT in engine.state AND create_new is False => warn that we will create new attribute, but to remove this warning, create_new should be True
  3. param_name in engine.state AND create_new is True => raise error as we can not create new attribute as it already exists in the state.
  4. param_name in engine.state AND create_new is False => silently override existing attribute

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes , I've misunderstood the case 3, thought we were handling the creation of state parameter in that case ... So yes Indeed, it's better and simpler ! Thanks

warnings.warn(
f"Attribute: '{self.param_name}' is already defined in the Engine.state. "
f"{type(self).__name__} will override its values."
)
else:
pattern = r"^(" + re.escape(self.param_name) + "_" + re.escape(type(self).__name__) + "_)([0-9]+)$"
matched_params = []
for engine_params in vars(engine.state).keys():
match = re.match(pattern, engine_params)
if match:
matched_params.append(int(match.groups()[1]))
if matched_params:
new_param = self.param_name + "_" + type(self).__name__ + "_" + str(sorted(matched_params)[-1] + 1)
else:
new_param = self.param_name + "_" + type(self).__name__ + "_0"

warnings.warn(
f"Attribute: '{self.param_name}' is already defined in the Engine.state. "
f"{type(self).__name__} will create a new parameter {new_param} in Engine.state."
)
self.param_name = new_param

setattr(engine.state, self.param_name, None)

if self.save_history:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined]
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined]

engine.add_event_handler(event, self)

Expand Down Expand Up @@ -147,8 +170,8 @@ def __call__(self, event_index):

"""

def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history)
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)

if not callable(lambda_obj):
raise ValueError("Expected lambda_obj to be callable.")
Expand Down Expand Up @@ -199,9 +222,13 @@ class PiecewiseLinearStateScheduler(StateParamScheduler):
"""

def __init__(
self, milestones_values: List[Tuple[int, float]], param_name: str, save_history: bool = False,
self,
milestones_values: List[Tuple[int, float]],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history)
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history, create_new)

if not isinstance(milestones_values, Sequence):
raise TypeError(
Expand Down Expand Up @@ -289,9 +316,9 @@ class ExpStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False,
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, create_new: bool = False,
):
super(ExpStateScheduler, self).__init__(param_name, save_history)
super(ExpStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self._state_attrs += ["initial_value", "gamma"]
Expand Down Expand Up @@ -337,9 +364,15 @@ class StepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, step_size: int, param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
step_size: int,
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(StepStateScheduler, self).__init__(param_name, save_history)
super(StepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.step_size = step_size
Expand Down Expand Up @@ -386,9 +419,15 @@ class MultiStepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, milestones: List[int], param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
milestones: List[int],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(MultiStepStateScheduler, self).__init__(param_name, save_history)
super(MultiStepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.milestones = milestones
Expand Down
49 changes: 49 additions & 0 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def _test(scheduler_cls, scheduler_kwargs):
_test(scheduler_cls, scheduler_kwargs)


@pytest.mark.skip
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
def test_state_param_asserts(scheduler_cls, scheduler_kwargs):
import re
Expand Down Expand Up @@ -414,3 +415,51 @@ def __call__(self, event_index):
param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

engine.run([0] * 8, max_epochs=10)


@pytest.mark.parametrize("create_new", [True, False])
def test_param_scheduler_with_ema_handler(create_new):
import torch.nn as nn

from ignite.handlers import EMAHandler

model = nn.Linear(2, 1)
trainer = Engine(lambda e, b: model(b))
data = torch.rand(100, 2)

param_name = "ema_decay"
save_history = True
ema_handler = EMAHandler(model)
ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)

ema_decay_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10 * len(data), 0.999)],
save_history=save_history,
create_new=create_new,
)

if create_new:
with pytest.warns(
UserWarning,
match=r"Attribute: 'ema_decay' is already defined in the Engine.state. "
r"PiecewiseLinearStateScheduler will create a new parameter ema_decay_PiecewiseLinearStateScheduler_0 in"
r" Engine.state.",
):
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
else:
with pytest.warns(
UserWarning,
match=r"Attribute: 'ema_decay' is already defined in the Engine.state. "
r"PiecewiseLinearStateScheduler will override its values.",
):
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)

trainer.run(data, max_epochs=20)
if create_new:
assert "ema_decay_PiecewiseLinearStateScheduler_0" in vars(trainer.state).keys()
assert "ema_decay_PiecewiseLinearStateScheduler_0" in trainer.state.param_history
assert "ema_decay" in vars(trainer.state).keys()
else:
assert "ema_decay" in vars(trainer.state).keys()
assert "ema_decay" in trainer.state.param_history