Skip to content

Commit

Permalink
remove StateParamScheduler tests warnings (#2334)
Browse files Browse the repository at this point in the history
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
fco-dv and vfdev-5 authored Dec 10, 2021
1 parent bc3e06d commit 4597bbc
Showing 1 changed file with 48 additions and 21 deletions.
69 changes: 48 additions & 21 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@
config2 = (30, [(10, 0), (20, 10)], True, expected_hist2)
config3 = (
PiecewiseLinearStateScheduler,
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)]},
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)], "create_new": True},
)
config4 = (
ExpStateScheduler,
{"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99, "create_new": True},
)
config4 = (ExpStateScheduler, {"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99})
config5 = (
MultiStepStateScheduler,
{"param_name": "multistep_scheduled_param", "initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
{
"param_name": "multistep_scheduled_param",
"initial_value": 10,
"gamma": 0.99,
"milestones": [3, 6],
"create_new": True,
},
)


Expand All @@ -39,12 +48,16 @@ def __call__(self, event_index):

config6 = (
LambdaStateScheduler,
{"param_name": "custom_scheduled_param", "lambda_obj": LambdaState(initial_value=10, gamma=0.99)},
{
"param_name": "custom_scheduled_param",
"lambda_obj": LambdaState(initial_value=10, gamma=0.99),
"create_new": True,
},
)

config7 = (
StepStateScheduler,
{"param_name": "step_scheduled_param", "initial_value": 10, "gamma": 0.99, "step_size": 5},
{"param_name": "step_scheduled_param", "initial_value": 10, "gamma": 0.99, "step_size": 5, "create_new": True},
)


Expand All @@ -57,7 +70,10 @@ def test_pwlinear_scheduler_linear_increase_history(
# Testing linear increase
engine = Engine(lambda e, b: None)
pw_linear_step_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, save_history=save_history,
param_name="pwlinear_scheduled_param",
milestones_values=milestones_values,
save_history=save_history,
create_new=True,
)
pw_linear_step_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -78,7 +94,7 @@ def test_pwlinear_scheduler_step_constant(max_epochs, milestones_values):
# Testing step_constant
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -96,7 +112,7 @@ def test_pwlinear_scheduler_linear_increase(max_epochs, milestones_values, expec
# Testing linear increase
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -115,7 +131,7 @@ def test_pwlinear_scheduler_max_value(
# Testing max_value
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="linear_scheduled_param", milestones_values=milestones_values,
param_name="linear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand Down Expand Up @@ -152,7 +168,7 @@ def test_piecewiselinear_asserts():
def test_exponential_scheduler(max_epochs, initial_value, gamma):
engine = Engine(lambda e, b: None)
exp_state_parameter_scheduler = ExpStateScheduler(
param_name="exp_scheduled_param", initial_value=initial_value, gamma=gamma
param_name="exp_scheduled_param", initial_value=initial_value, gamma=gamma, create_new=True
)
exp_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -170,7 +186,11 @@ def test_step_scheduler(
):
engine = Engine(lambda e, b: None)
step_state_parameter_scheduler = StepStateScheduler(
param_name="step_scheduled_param", initial_value=initial_value, gamma=gamma, step_size=step_size
param_name="step_scheduled_param",
initial_value=initial_value,
gamma=gamma,
step_size=step_size,
create_new=True,
)
step_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -193,7 +213,11 @@ def test_multistep_scheduler(
):
engine = Engine(lambda e, b: None)
multi_step_state_parameter_scheduler = MultiStepStateScheduler(
param_name="multistep_scheduled_param", initial_value=initial_value, gamma=gamma, milestones=milestones,
param_name="multistep_scheduled_param",
initial_value=initial_value,
gamma=gamma,
milestones=milestones,
create_new=True,
)
multi_step_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -219,7 +243,7 @@ def __call__(self, event_index):
return self.initial_value * self.gamma ** (event_index % 9)

lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)
lambda_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=2)
Expand All @@ -243,7 +267,7 @@ def __init__(self, initial_value, gamma):

with pytest.raises(ValueError, match=r"Expected lambda_obj to be callable."):
lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)


Expand Down Expand Up @@ -286,7 +310,7 @@ def _test(scheduler_cls, scheduler_kwargs):
def test_torch_save_load():

lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)

torch.save(lambda_state_parameter_scheduler, "dummy_lambda_state_parameter_scheduler.pt")
Expand Down Expand Up @@ -333,6 +357,7 @@ def _test(scheduler_cls, **scheduler_kwargs):
initial_value=10,
gamma=0.99,
milestones=[3, 6],
create_new=True,
)


Expand All @@ -343,7 +368,7 @@ def test_multiple_scheduler_with_save_history():
if "save_history" in config:
del config["save_history"]
_scheduler = scheduler(**config, save_history=True)
_scheduler.attach(engine_multiple_schedulers)
_scheduler.attach(engine_multiple_schedulers,)

engine_multiple_schedulers.run([0] * 8, max_epochs=2)

Expand Down Expand Up @@ -371,7 +396,7 @@ def __init__(self, initial_value, gamma):
def __call__(self, event_index):
return self.initial_value * self.gamma ** (event_index % 9)

param_scheduler = LambdaStateScheduler(param_name="param", lambda_obj=LambdaState(10, 0.99),)
param_scheduler = LambdaStateScheduler(param_name="param", lambda_obj=LambdaState(10, 0.99), create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -382,7 +407,7 @@ def __call__(self, event_index):
engine = Engine(lambda e, b: None)

param_scheduler = PiecewiseLinearStateScheduler(
param_name="param", milestones_values=[(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)]
param_name="param", milestones_values=[(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)], create_new=True
)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)
Expand All @@ -393,7 +418,7 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = ExpStateScheduler(param_name="param", initial_value=10, gamma=0.99)
param_scheduler = ExpStateScheduler(param_name="param", initial_value=10, gamma=0.99, create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -403,7 +428,7 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = StepStateScheduler(param_name="param", initial_value=10, gamma=0.99, step_size=5)
param_scheduler = StepStateScheduler(param_name="param", initial_value=10, gamma=0.99, step_size=5, create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -413,7 +438,9 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = MultiStepStateScheduler(param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6],)
param_scheduler = MultiStepStateScheduler(
param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6], create_new=True
)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand Down

0 comments on commit 4597bbc

Please sign in to comment.