Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue when removing handlers on filtered events #2690

Merged
merged 6 commits into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any) -> None:
self._event_handlers[event_name] = new_event_handlers

def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
"""Decorator shortcut for add_event_handler.
"""Decorator shortcut for :meth:`~ignite.engine.engine.Engine.add_event_handler`.

Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.events.Events`
Expand Down
8 changes: 8 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ def remove(self) -> None:
if handler is None or engine is None:
return

if hasattr(handler, "_parent"):
handler = handler._parent() # type: ignore[attr-defined]
if handler is None:
raise RuntimeError(
"Internal error! Please fill an issue on https://github.com/pytorch/ignite/issues "
"if encounter this error. Thank you!"
)

if isinstance(self.event_name, EventsList):
for e in self.event_name:
if engine.has_event_handler(handler, e):
Expand Down
35 changes: 9 additions & 26 deletions tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def bar(e):
engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
assert engine.has_event_handler(bar)
assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED)

engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3))


def test_remove_event_handler_on_callable_events():
Expand Down Expand Up @@ -258,15 +257,9 @@ def _test(event_name, event_attr, every, true_num_calls):

engine = Engine(lambda e, b: b)

counter = [
0,
]
counter_every = [
0,
]
num_calls = [
0,
]
counter = [0]
counter_every = [0]
num_calls = [0]

@engine.on(event_name(every=every))
def assert_every(engine):
Expand Down Expand Up @@ -309,12 +302,8 @@ def _test(event_name, event_attr):
engine = Engine(lambda e, b: b)

once = 2
counter = [
0,
]
num_calls = [
0,
]
counter = [0]
num_calls = [0]

@engine.on(event_name(once=once))
def assert_once(engine):
Expand Down Expand Up @@ -350,9 +339,7 @@ def _test(event_name, event_attr, true_num_calls):

engine = Engine(lambda e, b: b)

num_calls = [
0,
]
num_calls = [0]

@engine.on(event_name(event_filter=custom_event_filter))
def assert_on_special_event(engine):
Expand Down Expand Up @@ -381,9 +368,7 @@ def custom_event_filter(engine, event):

# Check bad behaviour
engine = Engine(lambda e, b: b)
counter = [
0,
]
counter = [0]

# Modify events
Events.ITERATION_STARTED(event_filter=custom_event_filter)
Expand Down Expand Up @@ -433,9 +418,7 @@ def update_fn(engine, batch):
engine = Engine(update_fn)
engine.register_events(*CustomEvents, event_to_attr=event_to_attr)

num_calls = [
0,
]
num_calls = [0]

@engine.on(event_name(event_filter=custom_event_filter))
def assert_on_special_event(engine):
Expand Down
79 changes: 48 additions & 31 deletions tests/ignite/engine/test_event_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,73 +129,90 @@ def test_adding_multiple_event_handlers():
handler.assert_called_once_with(engine)


def test_event_removable_handle():
@pytest.mark.parametrize(
"event1, event2",
[
(Events.STARTED, Events.COMPLETED),
(Events.EPOCH_STARTED, Events.EPOCH_COMPLETED),
(Events.ITERATION_STARTED, Events.ITERATION_COMPLETED),
(Events.ITERATION_STARTED(every=2), Events.ITERATION_COMPLETED(every=2)),
],
)
def test_event_removable_handle(event1, event2):

# Removable handle removes event from engine.
engine = DummyEngine()
engine = Engine(lambda e, b: None)
handler = create_autospec(spec=lambda x: None)
assert not hasattr(handler, "_parent")

removable_handle = engine.add_event_handler(Events.STARTED, handler)
assert engine.has_event_handler(handler, Events.STARTED)
removable_handle = engine.add_event_handler(event1, handler)
assert engine.has_event_handler(handler, event1)

engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
handler.assert_any_call(engine)
num_calls = handler.call_count

removable_handle.remove()
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, event1)

# Second engine pass does not fire handle again.
engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
# Assert that handler wasn't call
assert handler.call_count == num_calls

# Removable handle can be used as a context manager
handler = create_autospec(spec=lambda x: None)

with engine.add_event_handler(Events.STARTED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
engine.run(1)
with engine.add_event_handler(event1, handler):
assert engine.has_event_handler(handler, event1)
engine.run([1, 2])

assert not engine.has_event_handler(handler, Events.STARTED)
handler.assert_called_once_with(engine)
assert not engine.has_event_handler(handler, event1)
handler.assert_any_call(engine)
num_calls = handler.call_count

engine.run(1)
handler.assert_called_once_with(engine)
engine.run([1, 2])
# Assert that handler wasn't call
assert handler.call_count == num_calls

# Removeable handle only effects a single event registration
handler = MagicMock(spec_set=True)

with engine.add_event_handler(Events.STARTED, handler):
with engine.add_event_handler(Events.COMPLETED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler, Events.COMPLETED)
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)
with engine.add_event_handler(event1, handler):
with engine.add_event_handler(event2, handler):
assert engine.has_event_handler(handler, event1)
assert engine.has_event_handler(handler, event2)
assert engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event2)
assert not engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event2)

# Removeable handle is re-enter and re-exitable

handler = MagicMock(spec_set=True)

remove = engine.add_event_handler(Events.STARTED, handler)
remove = engine.add_event_handler(event1, handler)

with remove:
with remove:
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event1)
assert not engine.has_event_handler(handler, event1)

# Removeable handle is a weakref, does not keep engine or event alive
def _add_in_closure():
_engine = DummyEngine()
_engine = Engine(lambda e, b: None)

def _handler(_):
pass

_handle = _engine.add_event_handler(Events.STARTED, _handler)
_handle = _engine.add_event_handler(event1, _handler)
assert _handle.engine() is _engine
assert _handle.handler() is _handler

if event1.filter is None:
assert _handle.handler() is _handler
else:
assert _handle.handler()._parent() is _handler

return _handle

Expand Down