-
-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Paramscheduler emahandler #2326
Changes from 2 commits
0ea2c56
de4cbd4
e1f9565
aae8f97
6046ab8
6194733
4a11472
f37866f
c42f016
930b5b7
9ff67b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import numbers | ||
import re | ||
import warnings | ||
from bisect import bisect_right | ||
from typing import Any, List, Sequence, Tuple, Union | ||
|
||
|
@@ -13,6 +15,9 @@ class StateParamScheduler(BaseParamScheduler): | |
param_name: name of parameter to update. | ||
save_history: whether to log the parameter values to | ||
`engine.state.param_history`, (default=False). | ||
create_new: in case `param_name` already exists in `engine.state`, whether to authorize `StateParamScheduler` | ||
create a new parameter based on `StateParamScheduler` class name. | ||
By default, this option is `False` meaning that `StateParamScheduler` will override `param_name` values. | ||
|
||
Note: | ||
Parameter scheduler works independently of the internal state of the attached engine. | ||
|
@@ -23,10 +28,9 @@ class StateParamScheduler(BaseParamScheduler): | |
|
||
""" | ||
|
||
def __init__( | ||
self, param_name: str, save_history: bool = False, | ||
): | ||
def __init__(self, param_name: str, save_history: bool = False, create_new: bool = False): | ||
super(StateParamScheduler, self).__init__(param_name, save_history) | ||
self.create_new = create_new | ||
|
||
def attach( | ||
self, | ||
|
@@ -43,16 +47,35 @@ def attach( | |
|
||
""" | ||
if hasattr(engine.state, self.param_name): | ||
raise ValueError( | ||
f"Attribute: '{self.param_name}' is already defined in the Engine.state." | ||
f"This may be a conflict between multiple StateParameterScheduler handlers." | ||
f"Please choose another name." | ||
) | ||
if not self.create_new: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not quite understand this if/else clauses, especially "else" one.
What do you think ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes , I've misunderstood the case 3, thought we were handling the creation of state parameter in that case ... So yes Indeed, it's better and simpler ! Thanks |
||
warnings.warn( | ||
f"Attribute: '{self.param_name}' is already defined in the Engine.state. " | ||
f"{type(self).__name__} will override its values." | ||
) | ||
else: | ||
pattern = r"^(" + re.escape(self.param_name) + "_" + re.escape(type(self).__name__) + "_)([0-9]+)$" | ||
matched_params = [] | ||
for engine_params in vars(engine.state).keys(): | ||
match = re.match(pattern, engine_params) | ||
if match: | ||
matched_params.append(int(match.groups()[1])) | ||
if matched_params: | ||
new_param = self.param_name + "_" + type(self).__name__ + "_" + str(sorted(matched_params)[-1] + 1) | ||
else: | ||
new_param = self.param_name + "_" + type(self).__name__ + "_0" | ||
|
||
warnings.warn( | ||
f"Attribute: '{self.param_name}' is already defined in the Engine.state. " | ||
f"{type(self).__name__} will create a new parameter {new_param} in Engine.state." | ||
) | ||
self.param_name = new_param | ||
|
||
setattr(engine.state, self.param_name, None) | ||
|
||
if self.save_history: | ||
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore | ||
setattr(engine.state, "param_history", {}) | ||
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined] | ||
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined] | ||
|
||
engine.add_event_handler(event, self) | ||
|
||
|
@@ -147,8 +170,8 @@ def __call__(self, event_index): | |
|
||
""" | ||
|
||
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False): | ||
super(LambdaStateScheduler, self).__init__(param_name, save_history) | ||
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False): | ||
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new) | ||
|
||
if not callable(lambda_obj): | ||
raise ValueError("Expected lambda_obj to be callable.") | ||
|
@@ -199,9 +222,13 @@ class PiecewiseLinearStateScheduler(StateParamScheduler): | |
""" | ||
|
||
def __init__( | ||
self, milestones_values: List[Tuple[int, float]], param_name: str, save_history: bool = False, | ||
self, | ||
milestones_values: List[Tuple[int, float]], | ||
param_name: str, | ||
save_history: bool = False, | ||
create_new: bool = False, | ||
): | ||
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history) | ||
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history, create_new) | ||
|
||
if not isinstance(milestones_values, Sequence): | ||
raise TypeError( | ||
|
@@ -289,9 +316,9 @@ class ExpStateScheduler(StateParamScheduler): | |
""" | ||
|
||
def __init__( | ||
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, | ||
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, create_new: bool = False, | ||
): | ||
super(ExpStateScheduler, self).__init__(param_name, save_history) | ||
super(ExpStateScheduler, self).__init__(param_name, save_history, create_new) | ||
self.initial_value = initial_value | ||
self.gamma = gamma | ||
self._state_attrs += ["initial_value", "gamma"] | ||
|
@@ -337,9 +364,15 @@ class StepStateScheduler(StateParamScheduler): | |
""" | ||
|
||
def __init__( | ||
self, initial_value: float, gamma: float, step_size: int, param_name: str, save_history: bool = False, | ||
self, | ||
initial_value: float, | ||
gamma: float, | ||
step_size: int, | ||
param_name: str, | ||
save_history: bool = False, | ||
create_new: bool = False, | ||
): | ||
super(StepStateScheduler, self).__init__(param_name, save_history) | ||
super(StepStateScheduler, self).__init__(param_name, save_history, create_new) | ||
self.initial_value = initial_value | ||
self.gamma = gamma | ||
self.step_size = step_size | ||
|
@@ -386,9 +419,15 @@ class MultiStepStateScheduler(StateParamScheduler): | |
""" | ||
|
||
def __init__( | ||
self, initial_value: float, gamma: float, milestones: List[int], param_name: str, save_history: bool = False, | ||
self, | ||
initial_value: float, | ||
gamma: float, | ||
milestones: List[int], | ||
param_name: str, | ||
save_history: bool = False, | ||
create_new: bool = False, | ||
): | ||
super(MultiStepStateScheduler, self).__init__(param_name, save_history) | ||
super(MultiStepStateScheduler, self).__init__(param_name, save_history, create_new) | ||
self.initial_value = initial_value | ||
self.gamma = gamma | ||
self.milestones = milestones | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fco-dv please use double-ticks ```` instead of single-tick when addresses code. Here and above.