From 35276b61b89f6517109855a99efe73d54524ea35 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 12 Feb 2023 18:25:27 -0600 Subject: [PATCH] Ensure we hold strong references to tasks see https://github.com/python/cpython/issues/88831 --- aioesphomeapi/connection.py | 20 ++++++++++++++++++-- aioesphomeapi/reconnect_logic.py | 12 +++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 6b08bc83..5925a461 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -96,6 +96,7 @@ def __init__( ) -> None: self._params = params self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop + self._on_stop_task: Optional[asyncio.Task[None]] = None self._socket: Optional[socket.socket] = None self._frame_helper: Optional[APIFrameHelper] = None self._api_version: Optional[APIVersion] = None @@ -142,6 +143,10 @@ def _cleanup(self) -> None: self._connect_task.cancel() self._connect_task = None + if self._keep_alive_task is not None: + self._keep_alive_task.cancel() + self._keep_alive_task = None + if self._frame_helper is not None: self._frame_helper.close() self._frame_helper = None @@ -151,8 +156,19 @@ def _cleanup(self) -> None: self._socket = None if self.on_stop and self._connect_complete: + + def _remove_on_stop_task(): + """Remove the stop task from the reconnect loop. + + We need to do this because the asyncio does not hold + a strong reference to the task, so it can be garbage + collected unexpectedly. + """ + self._on_stop_task = None + # Ensure on_stop is called only once - asyncio.create_task(self.on_stop()) + self._on_stop_task = asyncio.create_task(self.on_stop()) + self._on_stop_task.add_done_callback(_remove_on_stop_task) self.on_stop = None # Note: we don't explicitly cancel the ping/read task here @@ -318,7 +334,7 @@ async def _keep_alive_loop() -> None: self._report_fatal_error(err) return - asyncio.create_task(_keep_alive_loop()) + self._keep_alive_task = asyncio.create_task(_keep_alive_loop()) async def connect(self, *, login: bool) -> None: if self._connection_state != ConnectionState.INITIALIZED: diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index f6500da4..4d14588b 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -200,7 +200,17 @@ async def stop(self) -> None: await self._stop_zc_listen() def stop_callback(self) -> None: - asyncio.create_task(self.stop()) + def _remove_stop_task() -> None: + """Remove the stop task from the reconnect loop. + + We need to do this because the asyncio does not hold + a strong reference to the task, so it can be garbage + collected unexpectedly. + """ + self._stop_task = None + + self._stop_task = asyncio.create_task(self.stop()) + self._stop_task.add_done_callback(_remove_stop_task) async def _start_zc_listen(self) -> None: """Listen for mDNS records.