Skip to content

Commit

Permalink
Fixed issue when removing handlers on filtered events
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Sep 2, 2022
1 parent c12ad57 commit 7027359
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 58 deletions.
2 changes: 1 addition & 1 deletion ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any) -> None:
self._event_handlers[event_name] = new_event_handlers

def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
"""Decorator shortcut for add_event_handler.
"""Decorator shortcut for :meth:`~ignite.engine.engine.Engine.add_event_handler`.
Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.events.Events`
Expand Down
8 changes: 8 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ def remove(self) -> None:
if handler is None or engine is None:
return

if hasattr(handler, "_parent"):
handler = handler._parent() # type: ignore[attr-defined]
if handler is None:
raise RuntimeError(
"Internal error! Please fill an issue on https://github.com/pytorch/ignite/issues "
"if encounter this error. Thank you!"
)

if isinstance(self.event_name, EventsList):
for e in self.event_name:
if engine.has_event_handler(handler, e):
Expand Down
35 changes: 9 additions & 26 deletions tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def bar(e):
engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
assert engine.has_event_handler(bar)
assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED)

engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3))


def test_remove_event_handler_on_callable_events():
Expand Down Expand Up @@ -258,15 +257,9 @@ def _test(event_name, event_attr, every, true_num_calls):

engine = Engine(lambda e, b: b)

counter = [
0,
]
counter_every = [
0,
]
num_calls = [
0,
]
counter = [0]
counter_every = [0]
num_calls = [0]

@engine.on(event_name(every=every))
def assert_every(engine):
Expand Down Expand Up @@ -309,12 +302,8 @@ def _test(event_name, event_attr):
engine = Engine(lambda e, b: b)

once = 2
counter = [
0,
]
num_calls = [
0,
]
counter = [0]
num_calls = [0]

@engine.on(event_name(once=once))
def assert_once(engine):
Expand Down Expand Up @@ -350,9 +339,7 @@ def _test(event_name, event_attr, true_num_calls):

engine = Engine(lambda e, b: b)

num_calls = [
0,
]
num_calls = [0]

@engine.on(event_name(event_filter=custom_event_filter))
def assert_on_special_event(engine):
Expand Down Expand Up @@ -381,9 +368,7 @@ def custom_event_filter(engine, event):

# Check bad behaviour
engine = Engine(lambda e, b: b)
counter = [
0,
]
counter = [0]

# Modify events
Events.ITERATION_STARTED(event_filter=custom_event_filter)
Expand Down Expand Up @@ -433,9 +418,7 @@ def update_fn(engine, batch):
engine = Engine(update_fn)
engine.register_events(*CustomEvents, event_to_attr=event_to_attr)

num_calls = [
0,
]
num_calls = [0]

@engine.on(event_name(event_filter=custom_event_filter))
def assert_on_special_event(engine):
Expand Down
75 changes: 44 additions & 31 deletions tests/ignite/engine/test_event_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,73 +129,86 @@ def test_adding_multiple_event_handlers():
handler.assert_called_once_with(engine)


def test_event_removable_handle():
@pytest.mark.parametrize(
"event1, event2",
[
(Events.STARTED, Events.COMPLETED),
(Events.EPOCH_STARTED, Events.EPOCH_COMPLETED),
(Events.ITERATION_STARTED, Events.ITERATION_COMPLETED),
(Events.ITERATION_STARTED(every=2), Events.ITERATION_COMPLETED(every=2)),
],
)
def test_event_removable_handle(event1, event2):

# Removable handle removes event from engine.
engine = DummyEngine()
engine = Engine(lambda e, b: None)
handler = create_autospec(spec=lambda x: None)
assert not hasattr(handler, "_parent")

removable_handle = engine.add_event_handler(Events.STARTED, handler)
assert engine.has_event_handler(handler, Events.STARTED)
removable_handle = engine.add_event_handler(event1, handler)
assert engine.has_event_handler(handler, event1)

engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
handler.assert_any_call(engine)

removable_handle.remove()
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, event1)

# Second engine pass does not fire handle again.
engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
handler.assert_any_call(engine)

# Removable handle can be used as a context manager
handler = create_autospec(spec=lambda x: None)

with engine.add_event_handler(Events.STARTED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
engine.run(1)
with engine.add_event_handler(event1, handler):
assert engine.has_event_handler(handler, event1)
engine.run([1, 2])

assert not engine.has_event_handler(handler, Events.STARTED)
handler.assert_called_once_with(engine)
assert not engine.has_event_handler(handler, event1)
handler.assert_any_call(engine)

engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
handler.assert_any_call(engine)

# Removeable handle only effects a single event registration
handler = MagicMock(spec_set=True)

with engine.add_event_handler(Events.STARTED, handler):
with engine.add_event_handler(Events.COMPLETED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler, Events.COMPLETED)
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)
with engine.add_event_handler(event1, handler):
with engine.add_event_handler(event2, handler):
assert engine.has_event_handler(handler, event1)
assert engine.has_event_handler(handler, event2)
assert engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event2)
assert not engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event2)

# Removeable handle is re-enter and re-exitable

handler = MagicMock(spec_set=True)

remove = engine.add_event_handler(Events.STARTED, handler)
remove = engine.add_event_handler(event1, handler)

with remove:
with remove:
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event1)

# Removeable handle is a weakref, does not keep engine or event alive
def _add_in_closure():
_engine = DummyEngine()
_engine = Engine(lambda e, b: None)

def _handler(_):
pass

_handle = _engine.add_event_handler(Events.STARTED, _handler)
_handle = _engine.add_event_handler(event1, _handler)
assert _handle.engine() is _engine
assert _handle.handler() is _handler

if event1.filter is None:
assert _handle.handler() is _handler
else:
assert _handle.handler()._parent() is _handler

return _handle

Expand Down

0 comments on commit 7027359

Please sign in to comment.