Skip to content

Commit

Permalink
Interpret VoiceAssistantCommand(start=False) as abort (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam authored Sep 12, 2024
1 parent a2a0bbf commit 0765a87
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
8 changes: 4 additions & 4 deletions aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ def subscribe_voice_assistant(
[str, int, VoiceAssistantAudioSettingsModel, str | None],
Coroutine[Any, Any, int | None],
],
handle_stop: Callable[[], Coroutine[Any, Any, None]],
handle_stop: Callable[[bool], Coroutine[Any, Any, None]],
handle_audio: (
Callable[
[bytes],
Expand All @@ -1302,7 +1302,7 @@ def subscribe_voice_assistant(
handle_start: called when the devices requests a server to send audio data to.
This callback is asynchronous and returns the port number the server is started on.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed.
handle_stop: called when the device has stopped sending audio data and the pipeline should be closed or aborted.
handle_audio: called when a chunk of audio is sent from the device.
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None:
# We hold a reference to the start_task in unsub function
# so we don't need to add it to the background tasks.
else:
self._create_background_task(handle_stop())
self._create_background_task(handle_stop(True))

remove_callbacks = []
flags = 0
Expand All @@ -1353,7 +1353,7 @@ def _on_voice_assistant_request(msg: VoiceAssistantRequest) -> None:
def _on_voice_assistant_audio(msg: VoiceAssistantAudio) -> None:
audio = VoiceAssistantAudioData.from_pb(msg)
if audio.end:
self._create_background_task(handle_stop())
self._create_background_task(handle_stop(False))
else:
self._create_background_task(handle_audio(audio.data))

Expand Down
51 changes: 37 additions & 14 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,6 +2224,7 @@ async def test_subscribe_voice_assistant(
send = patch_send(client)
starts = []
stops = []
aborts = []

async def handle_start(
conversation_id: str,
Expand All @@ -2234,8 +2235,11 @@ async def handle_start(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 42

async def handle_stop() -> None:
stops.append(True)
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)

unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
Expand Down Expand Up @@ -2269,7 +2273,8 @@ async def handle_stop() -> None:
"okay nabu",
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=42))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
Expand All @@ -2278,7 +2283,8 @@ async def handle_stop() -> None:
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
assert not stops
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
Expand All @@ -2301,6 +2307,7 @@ async def test_subscribe_voice_assistant_failure(
send = patch_send(client)
starts = []
stops = []
aborts = []

async def handle_start(
conversation_id: str,
Expand All @@ -2312,8 +2319,11 @@ async def handle_start(
# Return None to indicate failure
return None

async def handle_stop() -> None:
stops.append(True)
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)

unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
Expand Down Expand Up @@ -2346,7 +2356,8 @@ async def handle_stop() -> None:
None,
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(error=True))
send.reset_mock()
response: message.Message = VoiceAssistantRequest(
Expand All @@ -2355,7 +2366,8 @@ async def handle_stop() -> None:
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True]
assert not stops
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
Expand All @@ -2378,6 +2390,7 @@ async def test_subscribe_voice_assistant_cancels_long_running_handle_start(
send = patch_send(client)
starts = []
stops = []
aborts = []

async def handle_start(
conversation_id: str,
Expand All @@ -2391,8 +2404,11 @@ async def handle_start(
starts.append("never")
return None

async def handle_stop() -> None:
stops.append(True)
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)

unsub = client.subscribe_voice_assistant(
handle_start=handle_start, handle_stop=handle_stop
Expand All @@ -2416,6 +2432,7 @@ async def handle_stop() -> None:
unsub()
await asyncio.sleep(0)
assert not stops
assert not aborts
assert starts == [
(
"theone",
Expand All @@ -2441,6 +2458,7 @@ async def test_subscribe_voice_assistant_api_audio(
send = patch_send(client)
starts = []
stops = []
aborts = []
data_received = 0

async def handle_start(
Expand All @@ -2452,8 +2470,11 @@ async def handle_start(
starts.append((conversation_id, flags, audio_settings, wake_word_phrase))
return 0

async def handle_stop() -> None:
stops.append(True)
async def handle_stop(abort: bool) -> None:
if abort:
aborts.append(True)
else:
stops.append(True)

async def handle_audio(data: bytes) -> None:
nonlocal data_received
Expand Down Expand Up @@ -2493,7 +2514,8 @@ async def handle_audio(data: bytes) -> None:
"okay nabu",
)
]
assert stops == []
assert not stops
assert not aborts
send.assert_called_once_with(VoiceAssistantResponse(port=0))
send.reset_mock()

Expand Down Expand Up @@ -2523,7 +2545,8 @@ async def handle_audio(data: bytes) -> None:
)
mock_data_received(protocol, generate_plaintext_packet(response))
await asyncio.sleep(0)
assert stops == [True, True]
assert stops == [True]
assert aborts == [True]
send.reset_mock()
unsub()
send.assert_called_once_with(SubscribeVoiceAssistantRequest(subscribe=False))
Expand Down

0 comments on commit 0765a87

Please sign in to comment.