Skip to content

Commit

Permalink
Revert "Issue #1247 (#1252)"
Browse files Browse the repository at this point in the history
This reverts commit b829473.
  • Loading branch information
vfdev-5 committed Feb 17, 2023
1 parent 3493328 commit dc7e668
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ignite.contrib.handlers.clearml_logger import ClearMLLogger
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
from ignite.contrib.handlers.neptune_logger import NeptuneLogger
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger
Expand Down
125 changes: 125 additions & 0 deletions ignite/contrib/handlers/custom_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import warnings

from ignite.engine import EventEnum, Events, State


class CustomPeriodicEvent:
"""DEPRECATED. Use filtered events instead.
Handler to define a custom periodic events as a number of elapsed iterations/epochs
for an engine.
When custom periodic event is created and attached to an engine, the following events are fired:
1) K iterations is specified:
- `Events.ITERATIONS_<K>_STARTED`
- `Events.ITERATIONS_<K>_COMPLETED`
1) K epochs is specified:
- `Events.EPOCHS_<K>_STARTED`
- `Events.EPOCHS_<K>_COMPLETED`
Examples:
.. code-block:: python
from ignite.engine import Engine, Events
from ignite.contrib.handlers import CustomPeriodicEvent
# Let's define an event every 1000 iterations
cpe1 = CustomPeriodicEvent(n_iterations=1000)
cpe1.attach(trainer)
# Let's define an event every 10 epochs
cpe2 = CustomPeriodicEvent(n_epochs=10)
cpe2.attach(trainer)
@trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED)
def on_every_1000_iterations(engine):
# run a computation after 1000 iterations
# ...
print(engine.state.iterations_1000)
@trainer.on(cpe2.Events.EPOCHS_10_STARTED)
def on_every_10_epochs(engine):
# run a computation every 10 epochs
# ...
print(engine.state.epochs_10)
Args:
n_iterations (int, optional): number iterations of the custom periodic event
n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one,
either n_iterations or n_epochs should defined.
"""

def __init__(self, n_iterations=None, n_epochs=None):

warnings.warn(
"CustomPeriodicEvent is deprecated since 0.4.0 and will be removed in 0.5.0. Use filtered events instead.",
DeprecationWarning,
)

if n_iterations is not None:
if not isinstance(n_iterations, int):
raise TypeError("Argument n_iterations should be an integer")
if n_iterations < 1:
raise ValueError("Argument n_iterations should be positive")

if n_epochs is not None:
if not isinstance(n_epochs, int):
raise TypeError("Argument n_epochs should be an integer")
if n_epochs < 1:
raise ValueError("Argument n_epochs should be positive")

if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs):
raise ValueError("Either n_iterations or n_epochs should be defined")

if n_iterations:
prefix = "iterations"
self.state_attr = "iteration"
self.period = n_iterations

if n_epochs:
prefix = "epochs"
self.state_attr = "epoch"
self.period = n_epochs

self.custom_state_attr = "{}_{}".format(prefix, self.period)
event_name = "{}_{}".format(prefix.upper(), self.period)
setattr(
self,
"Events",
EventEnum("Events", " ".join(["{}_STARTED".format(event_name), "{}_COMPLETED".format(event_name)])),
)

# Update State.event_to_attr
for e in self.Events:
State.event_to_attr[e] = self.custom_state_attr

# Create aliases
self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name))
self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name))

def _on_started(self, engine):
setattr(engine.state, self.custom_state_attr, 0)

def _on_periodic_event_started(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 1:
setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1)
engine.fire_event(self._periodic_event_started)

def _on_periodic_event_completed(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 0:
engine.fire_event(self._periodic_event_completed)

def attach(self, engine):
engine.register_events(*self.Events)

engine.add_event_handler(Events.STARTED, self._on_started)
engine.add_event_handler(
getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
)
engine.add_event_handler(
getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
)
29 changes: 29 additions & 0 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
from typing import Any, Union
from unittest.mock import call, MagicMock

import pytest
import torch

from ignite.contrib.handlers import CustomPeriodicEvent
from ignite.contrib.handlers.base_logger import (
BaseLogger,
BaseOptimizerParamsHandler,
Expand Down Expand Up @@ -259,6 +261,33 @@ def update_fn(engine, batch):
mock_log_handler.assert_called_with(trainer, logger, event)
assert mock_log_handler.call_count == n_calls

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_iterations = 10
cpe1 = CustomPeriodicEvent(n_iterations=n_iterations)
n = len(data) * n_epochs / n_iterations
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe1.Events.ITERATIONS_10_STARTED, ns, cpe1)
_test(cpe1.Events.ITERATIONS_10_COMPLETED, nf, cpe1)

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_iterations = 15
cpe2 = CustomPeriodicEvent(n_iterations=n_iterations)
n = len(data) * n_epochs / n_iterations
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe2.Events.ITERATIONS_15_STARTED, ns, cpe2)
_test(cpe2.Events.ITERATIONS_15_COMPLETED, nf, cpe2)

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_custom_epochs = 2
cpe3 = CustomPeriodicEvent(n_epochs=n_custom_epochs)
n = n_epochs / n_custom_epochs
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe3.Events.EPOCHS_2_STARTED, ns, cpe3)
_test(cpe3.Events.EPOCHS_2_COMPLETED, nf, cpe3)


@pytest.mark.parametrize(
"event, n_calls",
Expand Down
133 changes: 133 additions & 0 deletions tests/ignite/contrib/handlers/test_custom_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import math

import pytest

from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
from ignite.engine import Engine


def test_bad_input():

with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"):
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
CustomPeriodicEvent(n_iterations="a")
with pytest.raises(ValueError, match="Argument n_iterations should be positive"):
CustomPeriodicEvent(n_iterations=0)
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
CustomPeriodicEvent(n_iterations=10.0)
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
CustomPeriodicEvent(n_epochs="a")
with pytest.raises(ValueError, match="Argument n_epochs should be positive"):
CustomPeriodicEvent(n_epochs=0)
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
CustomPeriodicEvent(n_epochs=10.0)
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
CustomPeriodicEvent()
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
CustomPeriodicEvent(n_iterations=1, n_epochs=2)


def test_new_events():
def update(*args, **kwargs):
pass

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
engine = Engine(update)
cpe = CustomPeriodicEvent(n_iterations=5)
cpe.attach(engine)

assert hasattr(cpe, "Events")
assert hasattr(cpe.Events, "ITERATIONS_5_STARTED")
assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED")

assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED")
assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED")

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_epochs=5)
cpe.attach(engine)

assert hasattr(cpe, "Events")
assert hasattr(cpe.Events, "EPOCHS_5_STARTED")
assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED")

assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED")
assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")


def test_integration_iterations():
def _test(n_iterations, max_epochs, n_iters_per_epoch):
def update(*args, **kwargs):
pass

engine = Engine(update)
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_iterations=n_iterations)
cpe.attach(engine)
data = list(range(n_iters_per_epoch))

custom_period = [0]
n_calls_iter_started = [0]
n_calls_iter_completed = [0]

event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations))

@engine.on(event_started)
def on_my_event_started(engine):
assert (engine.state.iteration - 1) % n_iterations == 0
custom_period[0] += 1
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
assert custom_iter == custom_period[0]
n_calls_iter_started[0] += 1

event_completed = getattr(cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations))

@engine.on(event_completed)
def on_my_event_ended(engine):
assert engine.state.iteration % n_iterations == 0
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
assert custom_iter == custom_period[0]
n_calls_iter_completed[0] += 1

engine.run(data, max_epochs=max_epochs)

n = len(data) * max_epochs / n_iterations
nf = math.floor(n)
assert custom_period[0] == n_calls_iter_started[0]
assert n_calls_iter_started[0] == nf + 1 if nf < n else nf
assert n_calls_iter_completed[0] == nf

_test(3, 5, 16)
_test(4, 5, 16)
_test(5, 5, 16)
_test(300, 50, 1000)


def test_integration_epochs():
def update(*args, **kwargs):
pass

engine = Engine(update)

n_epochs = 3
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_epochs=n_epochs)
cpe.attach(engine)
data = list(range(16))

custom_period = [1]

@engine.on(cpe.Events.EPOCHS_3_STARTED)
def on_my_epoch_started(engine):
assert (engine.state.epoch - 1) % n_epochs == 0
assert engine.state.epochs_3 == custom_period[0]

@engine.on(cpe.Events.EPOCHS_3_COMPLETED)
def on_my_epoch_ended(engine):
assert engine.state.epoch % n_epochs == 0
assert engine.state.epochs_3 == custom_period[0]
custom_period[0] += 1

engine.run(data, max_epochs=10)

assert custom_period[0] == 4
13 changes: 12 additions & 1 deletion tests/ignite/contrib/handlers/test_tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from packaging.version import Version

from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers import CustomPeriodicEvent, ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan
from ignite.metrics import RunningAverage
Expand Down Expand Up @@ -475,6 +475,17 @@ def test_pbar_wrong_events_order():
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))


def test_pbar_on_custom_events(capsys):

engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_iterations=15)

with pytest.raises(ValueError, match=r"not in allowed events for this engine"):
pbar.attach(engine, event_name=cpe.Events.ITERATIONS_15_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)


def test_pbar_with_nan_input():
def update(engine, batch):
x = batch
Expand Down

0 comments on commit dc7e668

Please sign in to comment.