From 49ae6d67e82f58333b7355f1ebb5560408a62c45 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 12 Sep 2024 16:28:00 -0500 Subject: [PATCH] Interpret VoiceAssistantCommand(start=False) as abort --- aioesphomeapi/client.py | 8 +++---- tests/test_client.py | 51 ++++++++++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 08161200..1a895b21 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -1281,7 +1281,7 @@ def subscribe_voice_assistant( [str, int, VoiceAssistantAudioSettingsModel, str | None], Coroutine[Any, Any, int | None], ], - handle_stop: Callable[[], Coroutine[Any, Any, None]], + handle_stop: Callable[[bool], Coroutine[Any, Any, None]], handle_audio: ( Callable[ [bytes], @@ -1302,7 +1302,7 @@ def subscribe_voice_assistant( handle_start: called when the devices requests a server to send audio data to. This callback is asynchronous and returns the port number the server is started on. - handle_stop: called when the device has stopped sending audio data and the pipeline should be closed. + handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted. handle_audio: called when a chunk of audio is sent from the device. @@ -1343,7 +1343,7 @@ def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None: # We hold a reference to the start_task in unsub function # so we don't need to add it to the background tasks. else: - self._create_background_task(handle_stop()) + self._create_background_task(handle_stop(True)) remove_callbacks = [] flags = 0 @@ -1353,7 +1353,7 @@ def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None: def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None: audio = VoiceAssistantAudioData.from_pb(msg) if audio.end: - self._create_background_task(handle_stop()) + self._create_background_task(handle_stop(False)) else: self._create_background_task(handle_audio(audio.data)) diff --git a/tests/test_client.py b/tests/test_client.py index 2cb8bb47..0f7d405f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant( send = patch_send(client) starts = [] stops = [] + aborts = [] async def handle_start( conversation_id: str, @@ -2234,8 +2235,11 @@ async def handle_start( starts.append((conversation_id, flags, audio_settings, wake_word_phrase)) return 42 - async def handle_stop() -> None: - stops.append(True) + async def handle_stop(abort: bool) -> None: + if abort: + aborts.append(True) + else: + stops.append(True) unsub = client.subscribe_voice_assistant( handle_start=handle_start, handle_stop=handle_stop @@ -2269,7 +2273,8 @@ async def handle_stop() -> None: "okay nabu", ) ] - assert stops == [] + assert not stops + assert not aborts send.assert_called_once_with(VoiceAssistantResponse(port=42)) send.reset_mock() response: message.Message = VoiceAssistantRequest( @@ -2278,7 +2283,8 @@ async def handle_stop() -> None: ) mock_data_received(protocol, generate_plaintext_packet(response)) await asyncio.sleep(0) - assert stops == [True] + assert not stops + assert aborts == [True] send.reset_mock() unsub() send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) @@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure( send = patch_send(client) starts = [] stops = [] + aborts = [] async def handle_start( conversation_id: str, @@ -2312,8 +2319,11 @@ async def handle_start( # Return None to indicate failure return None - async def handle_stop() -> None: - stops.append(True) + async def handle_stop(abort: bool) -> None: + if abort: + aborts.append(True) + else: + stops.append(True) unsub = client.subscribe_voice_assistant( handle_start=handle_start, handle_stop=handle_stop @@ -2346,7 +2356,8 @@ async def handle_stop() -> None: None, ) ] - assert stops == [] + assert not stops + assert not aborts send.assert_called_once_with(VoiceAssistantResponse(error=True)) send.reset_mock() response: message.Message = VoiceAssistantRequest( @@ -2355,7 +2366,8 @@ async def handle_stop() -> None: ) mock_data_received(protocol, generate_plaintext_packet(response)) await asyncio.sleep(0) - assert stops == [True] + assert not stops + assert aborts == [True] send.reset_mock() unsub() send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False)) @@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start( send = patch_send(client) starts = [] stops = [] + aborts = [] async def handle_start( conversation_id: str, @@ -2391,8 +2404,11 @@ async def handle_start( starts.append("never") return None - async def handle_stop() -> None: - stops.append(True) + async def handle_stop(abort: bool) -> None: + if abort: + aborts.append(True) + else: + stops.append(True) unsub = client.subscribe_voice_assistant( handle_start=handle_start, handle_stop=handle_stop @@ -2416,6 +2432,7 @@ async def handle_stop() -> None: unsub() await asyncio.sleep(0) assert not stops + assert not aborts assert starts == [ ( "theone", @@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio( send = patch_send(client) starts = [] stops = [] + aborts = [] data_received = 0 async def handle_start( @@ -2452,8 +2470,11 @@ async def handle_start( starts.append((conversation_id, flags, audio_settings, wake_word_phrase)) return 0 - async def handle_stop() -> None: - stops.append(True) + async def handle_stop(abort: bool) -> None: + if abort: + aborts.append(True) + else: + stops.append(True) async def handle_audio(data: bytes) -> None: nonlocal data_received @@ -2493,7 +2514,8 @@ async def handle_audio(data: bytes) -> None: "okay nabu", ) ] - assert stops == [] + assert not stops + assert not aborts send.assert_called_once_with(VoiceAssistantResponse(port=0)) send.reset_mock() @@ -2523,7 +2545,8 @@ async def handle_audio(data: bytes) -> None: ) mock_data_received(protocol, generate_plaintext_packet(response)) await asyncio.sleep(0) - assert stops == [True, True] + assert stops == [True] + assert aborts == [True] send.reset_mock() unsub() send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))