Skip to content

Commit

Permalink
Minor refactor assist satellite (#124912)
Browse files Browse the repository at this point in the history
* Extract entity ID for pipeline to property

* Add super calls

* Mock tts
  • Loading branch information
balloob authored and synesthesiam committed Aug 30, 2024
1 parent 791c4d6 commit eedee04
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 38 deletions.
81 changes: 51 additions & 30 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.core import Context
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
Expand All @@ -36,46 +36,32 @@ class AssistSatelliteEntity(entity.Entity):
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = None
_attr_pipeline_entity_id: str | None = None
_attr_vad_sensitivity_entity_id: str | None = None

_conversation_id: str | None = None
_conversation_id_time: float | None = None

_run_has_tts: bool = False

async def _async_accept_pipeline_from_satellite(
@property
def pipeline_entity_id(self) -> str | None:
"""Entity ID of the pipeline to use for the next conversation."""
return self._attr_pipeline_entity_id

@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id

async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
pipeline_entity_id: str | None = None,
vad_sensitivity_entity_id: str | None = None,
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
pipeline_id: str | None = None
vad_sensitivity = vad.VadSensitivity.DEFAULT

if pipeline_entity_id:
if (
pipeline_entity_state := self.hass.states.get(pipeline_entity_id)
) is None:
raise ValueError("Pipeline entity not found")

if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
pipeline_id = pipeline.id
break

if vad_sensitivity_entity_id:
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise ValueError("VAD sensitivity entity not found")

vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)

device_id = self.registry_entry.device_id if self.registry_entry else None

# Refresh context if necessary
Expand Down Expand Up @@ -116,13 +102,13 @@ async def _async_accept_pipeline_from_satellite(
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=pipeline_id,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
Expand All @@ -132,6 +118,7 @@ async def _async_accept_pipeline_from_satellite(
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""

@callback
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type is PipelineEventType.WAKE_WORD_START:
Expand All @@ -150,11 +137,45 @@ def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:

self.on_pipeline_event(event)

@callback
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
self.async_write_ha_state()

@callback
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

@callback
def _resolve_pipeline(self) -> str | None:
"""Resolve pipeline from select entity to id."""
if not (pipeline_entity_id := self.pipeline_entity_id):
return None

if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
raise RuntimeError("Pipeline entity not found")

if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
return pipeline.id

return None

@callback
def _resolve_vad_sensitivity(self) -> float:
"""Resolve VAD sensitivity from select entity to enum."""
vad_sensitivity = vad.VadSensitivity.DEFAULT

if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise RuntimeError("VAD sensitivity entity not found")

vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)

return vad.VadSensitivity.to_seconds(vad_sensitivity)
20 changes: 13 additions & 7 deletions homeassistant/components/voip/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,24 @@ def __init__(
self._tones = tones
self._processing_tone_done = asyncio.Event()

@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
return self.voip_device.get_pipeline_entity_id(self.hass)

@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
return self.voip_device.get_vad_sensitivity_entity_id(self.hass)

async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
self.voip_device.protocol = self

async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
assert self.voip_device.protocol == self
self.voip_device.protocol = None

Expand Down Expand Up @@ -155,16 +167,10 @@ async def _run_pipeline(
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
await self.async_accept_pipeline_from_satellite(
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
self.hass
),
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
self.hass
),
)

if self._pipeline_had_error:
Expand Down
2 changes: 1 addition & 1 deletion tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def test_entity_state(
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
) as mock_start_pipeline:
await entity._async_accept_pipeline_from_satellite(audio_stream)
await entity.async_accept_pipeline_from_satellite(audio_stream)

assert mock_start_pipeline.called
kwargs = mock_start_pipeline.call_args[1]
Expand Down
3 changes: 3 additions & 0 deletions tests/components/voip/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from homeassistant.setup import async_setup_component

from tests.common import MockConfigEntry
from tests.components.tts.conftest import (
mock_tts_cache_dir_fixture_autouse, # noqa: F401
)


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit eedee04

Please sign in to comment.