From 7027359088f33e1cde6ccb507b84c074b56284ef Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 2 Sep 2022 23:13:40 +0200 Subject: [PATCH] Fixed issue when removing handlers on filtered events Fixes #2684 --- ignite/engine/engine.py | 2 +- ignite/engine/events.py | 8 +++ tests/ignite/engine/test_custom_events.py | 35 +++------- tests/ignite/engine/test_event_handlers.py | 75 +++++++++++++--------- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index dc4ab550fc4..ce70d6f168c 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -373,7 +373,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any) -> None: self._event_handlers[event_name] = new_event_handlers def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable: - """Decorator shortcut for add_event_handler. + """Decorator shortcut for :meth:`~ignite.engine.engine.Engine.add_event_handler`. Args: event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.events.Events` diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 0229bdba1d9..62f808f5e87 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -474,6 +474,14 @@ def remove(self) -> None: if handler is None or engine is None: return + if hasattr(handler, "_parent"): + handler = handler._parent() # type: ignore[attr-defined] + if handler is None: + raise RuntimeError( + "Internal error! Please fill an issue on https://github.com/pytorch/ignite/issues " + "if encounter this error. Thank you!" + ) + if isinstance(self.event_name, EventsList): for e in self.event_name: if engine.has_event_handler(handler, e): diff --git a/tests/ignite/engine/test_custom_events.py b/tests/ignite/engine/test_custom_events.py index a39978634e6..a0059722da7 100644 --- a/tests/ignite/engine/test_custom_events.py +++ b/tests/ignite/engine/test_custom_events.py @@ -218,8 +218,7 @@ def bar(e): engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED) - - engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) + assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) def test_remove_event_handler_on_callable_events(): @@ -258,15 +257,9 @@ def _test(event_name, event_attr, every, true_num_calls): engine = Engine(lambda e, b: b) - counter = [ - 0, - ] - counter_every = [ - 0, - ] - num_calls = [ - 0, - ] + counter = [0] + counter_every = [0] + num_calls = [0] @engine.on(event_name(every=every)) def assert_every(engine): @@ -309,12 +302,8 @@ def _test(event_name, event_attr): engine = Engine(lambda e, b: b) once = 2 - counter = [ - 0, - ] - num_calls = [ - 0, - ] + counter = [0] + num_calls = [0] @engine.on(event_name(once=once)) def assert_once(engine): @@ -350,9 +339,7 @@ def _test(event_name, event_attr, true_num_calls): engine = Engine(lambda e, b: b) - num_calls = [ - 0, - ] + num_calls = [0] @engine.on(event_name(event_filter=custom_event_filter)) def assert_on_special_event(engine): @@ -381,9 +368,7 @@ def custom_event_filter(engine, event): # Check bad behaviour engine = Engine(lambda e, b: b) - counter = [ - 0, - ] + counter = [0] # Modify events Events.ITERATION_STARTED(event_filter=custom_event_filter) @@ -433,9 +418,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) engine.register_events(*CustomEvents, event_to_attr=event_to_attr) - num_calls = [ - 0, - ] + num_calls = [0] @engine.on(event_name(event_filter=custom_event_filter)) def assert_on_special_event(engine): diff --git a/tests/ignite/engine/test_event_handlers.py b/tests/ignite/engine/test_event_handlers.py index 0c45a958608..053091a8168 100644 --- a/tests/ignite/engine/test_event_handlers.py +++ b/tests/ignite/engine/test_event_handlers.py @@ -129,73 +129,86 @@ def test_adding_multiple_event_handlers(): handler.assert_called_once_with(engine) -def test_event_removable_handle(): +@pytest.mark.parametrize( + "event1, event2", + [ + (Events.STARTED, Events.COMPLETED), + (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED), + (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED), + (Events.ITERATION_STARTED(every=2), Events.ITERATION_COMPLETED(every=2)), + ], +) +def test_event_removable_handle(event1, event2): # Removable handle removes event from engine. - engine = DummyEngine() + engine = Engine(lambda e, b: None) handler = create_autospec(spec=lambda x: None) assert not hasattr(handler, "_parent") - removable_handle = engine.add_event_handler(Events.STARTED, handler) - assert engine.has_event_handler(handler, Events.STARTED) + removable_handle = engine.add_event_handler(event1, handler) + assert engine.has_event_handler(handler, event1) - engine.run(1) - handler.assert_called_once_with(engine) + engine.run([1, 2]) + handler.assert_any_call(engine) removable_handle.remove() - assert not engine.has_event_handler(handler, Events.STARTED) + assert not engine.has_event_handler(handler, event1) # Second engine pass does not fire handle again. - engine.run(1) - handler.assert_called_once_with(engine) + engine.run([1, 2]) + handler.assert_any_call(engine) # Removable handle can be used as a context manager handler = create_autospec(spec=lambda x: None) - with engine.add_event_handler(Events.STARTED, handler): - assert engine.has_event_handler(handler, Events.STARTED) - engine.run(1) + with engine.add_event_handler(event1, handler): + assert engine.has_event_handler(handler, event1) + engine.run([1, 2]) - assert not engine.has_event_handler(handler, Events.STARTED) - handler.assert_called_once_with(engine) + assert not engine.has_event_handler(handler, event1) + handler.assert_any_call(engine) - engine.run(1) - handler.assert_called_once_with(engine) + engine.run([1, 2]) + handler.assert_any_call(engine) # Removeable handle only effects a single event registration handler = MagicMock(spec_set=True) - with engine.add_event_handler(Events.STARTED, handler): - with engine.add_event_handler(Events.COMPLETED, handler): - assert engine.has_event_handler(handler, Events.STARTED) - assert engine.has_event_handler(handler, Events.COMPLETED) - assert engine.has_event_handler(handler, Events.STARTED) - assert not engine.has_event_handler(handler, Events.COMPLETED) - assert not engine.has_event_handler(handler, Events.STARTED) - assert not engine.has_event_handler(handler, Events.COMPLETED) + with engine.add_event_handler(event1, handler): + with engine.add_event_handler(event2, handler): + assert engine.has_event_handler(handler, event1) + assert engine.has_event_handler(handler, event2) + assert engine.has_event_handler(handler, event1) + assert not engine.has_event_handler(handler, event2) + assert not engine.has_event_handler(handler, event1) + assert not engine.has_event_handler(handler, event2) # Removeable handle is re-enter and re-exitable handler = MagicMock(spec_set=True) - remove = engine.add_event_handler(Events.STARTED, handler) + remove = engine.add_event_handler(event1, handler) with remove: with remove: - assert engine.has_event_handler(handler, Events.STARTED) - assert not engine.has_event_handler(handler, Events.STARTED) - assert not engine.has_event_handler(handler, Events.STARTED) + assert engine.has_event_handler(handler, event1) + assert not engine.has_event_handler(handler, event1) + assert not engine.has_event_handler(handler, event1) # Removeable handle is a weakref, does not keep engine or event alive def _add_in_closure(): - _engine = DummyEngine() + _engine = Engine(lambda e, b: None) def _handler(_): pass - _handle = _engine.add_event_handler(Events.STARTED, _handler) + _handle = _engine.add_event_handler(event1, _handler) assert _handle.engine() is _engine - assert _handle.handler() is _handler + + if event1.filter is None: + assert _handle.handler() is _handler + else: + assert _handle.handler()._parent() is _handler return _handle