Skip to content

Commit

Permalink
Incorporate assist satellite entity feedback (#124727)
Browse files Browse the repository at this point in the history
* Incorporate feedback

* Raise value error

* Clean up entity description

* More cleanup

* Move some things around

* Add a basic test

* Whatever

* Update CODEOWNERS

* Add tests

* Test tts response finished

* Fix test

* Wrong place

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
  • Loading branch information
synesthesiam and balloob committed Aug 30, 2024
1 parent 9b086bc commit 791c4d6
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 45 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ build.json @home-assistant/supervisor
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/tests/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from homeassistant.helpers.typing import ConfigType

from .const import DOMAIN
from .entity import AssistSatelliteEntity
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteState

__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
]

_LOGGER = logging.getLogger(__name__)
Expand Down
61 changes: 35 additions & 26 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Assist satellite entity."""

from abc import abstractmethod
from collections.abc import AsyncIterable
import time
from typing import Final
Expand All @@ -15,7 +16,6 @@
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.const import EntityCategory
from homeassistant.core import Context
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
Expand All @@ -26,18 +26,16 @@
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes


class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes binary sensor entities."""


class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""

entity_description = EntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
_attr_has_entity_name = True
_attr_name = None
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
_attr_state: AssistSatelliteState | None = None

_conversation_id: str | None = None
_conversation_id_time: float | None = None
Expand All @@ -58,24 +56,27 @@ async def _async_accept_pipeline_from_satellite(
vad_sensitivity = vad.VadSensitivity.DEFAULT

if pipeline_entity_id:
# Resolve pipeline by name
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
if (pipeline_entity_state is not None) and (
pipeline_entity_state.state != OPTION_PREFERRED
):
if (
pipeline_entity_state := self.hass.states.get(pipeline_entity_id)
) is None:
raise ValueError("Pipeline entity not found")

if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
pipeline_id = pipeline.id
break

if vad_sensitivity_entity_id:
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
if vad_sensitivity_state is not None:
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise ValueError("VAD sensitivity entity not found")

device_id: str | None = None
if self.registry_entry is not None:
device_id = self.registry_entry.device_id
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)

device_id = self.registry_entry.device_id if self.registry_entry else None

# Refresh context if necessary
if (
Expand Down Expand Up @@ -105,7 +106,7 @@ async def _async_accept_pipeline_from_satellite(
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self.on_pipeline_event,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
Expand All @@ -123,24 +124,32 @@ async def _async_accept_pipeline_from_satellite(
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
),
start_stage=start_stage,
end_stage=end_stage,
)

@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""

def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.WAKE_WORD_START:
if event.type is PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type == PipelineEventType.STT_START:
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type == PipelineEventType.INTENT_START:
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type == PipelineEventType.TTS_START:
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type == PipelineEventType.RUN_END:
elif event.type is PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

self.on_pipeline_event(event)

def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
Expand Down
17 changes: 8 additions & 9 deletions homeassistant/components/assist_satellite/strings.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
{
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
"entity_component": {
"_": {
"name": "Assist satellite",
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
}
Expand Down
18 changes: 13 additions & 5 deletions homeassistant/components/voip/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntity
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteState,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
Expand Down Expand Up @@ -78,6 +82,12 @@ def async_add_device(device: VoIPDevice) -> None:
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""

entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_has_entity_name = True
_attr_name = None
_attr_state = AssistSatelliteState.LISTENING_WAKE_WORD

def __init__(
self,
hass: HomeAssistant,
Expand Down Expand Up @@ -108,8 +118,8 @@ async def async_added_to_hass(self) -> None:

async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
if self.voip_device.protocol == self:
self.voip_device.protocol = None
assert self.voip_device.protocol == self
self.voip_device.protocol = None

# -------------------------------------------------------------------------
# VoIP
Expand Down Expand Up @@ -188,8 +198,6 @@ def _clear_audio_queue(self) -> None:

def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
super().on_pipeline_event(event)

if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()
Expand Down
8 changes: 4 additions & 4 deletions homeassistant/components/voip/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
"listening_wake_word": "[%key:component::assist_satellite::entity_component::_::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity_component::_::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity_component::_::state::responding%]",
"processing": "[%key:component::assist_satellite::entity_component::_::state::processing%]"
}
}
},
Expand Down
1 change: 1 addition & 0 deletions tests/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for Assist Satellite."""
104 changes: 104 additions & 0 deletions tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Test helpers for Assist Satellite."""

from unittest.mock import Mock

import pytest

from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component

from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)

TEST_DOMAIN = "test_satellite"


class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""

_attr_name = "Test Entity"

def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []

def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)


@pytest.fixture
def entity() -> MockAssistSatellite:
"""Mock Assist Satellite Entity."""
return MockAssistSatellite()


@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Mock config entry."""
entry = MockConfigEntry(domain=TEST_DOMAIN)
entry.add_to_hass(hass)
return entry


@pytest.fixture
async def init_components(
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})

async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
return True

async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
return True

mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)

mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())

async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([entity])

loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{AS_DOMAIN}", loaded_platform)

with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()

return config_entry
Loading

0 comments on commit 791c4d6

Please sign in to comment.