Skip to content

Commit

Permalink
Add announce service to assist_satellite
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Aug 30, 2024
1 parent eedee04 commit b806b26
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 5 deletions.
22 changes: 20 additions & 2 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import logging

import voluptuous as vol

from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
Expand All @@ -10,13 +12,14 @@

from .const import DOMAIN
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteState
from .models import AssistSatelliteEntityFeature, AssistSatelliteState

__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature",
"AssistSatelliteState",
]

_LOGGER = logging.getLogger(__name__)
Expand All @@ -30,6 +33,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await component.async_setup(config)

component.async_register_entity_service(
"announce",
vol.All(
vol.Schema(
{
vol.Optional("text"): str,
vol.Optional("media"): str,
}
),
cv.has_at_least_one_key("text", "media"),
),
"async_internal_annonuce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)

return True


Expand Down
68 changes: 67 additions & 1 deletion homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,28 @@
import time
from typing import Final

from homeassistant.components import stt
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid

from .errors import SatelliteBusyError
from .models import AssistSatelliteState

_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
Expand All @@ -43,6 +49,7 @@ class AssistSatelliteEntity(entity.Entity):
_conversation_id_time: float | None = None

_run_has_tts: bool = False
_is_announcing = False

@property
def pipeline_entity_id(self) -> str | None:
Expand All @@ -54,6 +61,65 @@ def vad_sensitivity_entity_id(self) -> str | None:
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id

async def async_internal_announce(
self,
text: str | None = None,
media_id: str | None = None,
) -> None:
"""Play an announcement on the satellite.
If media_id is not provided, text is synthesized to
audio with the selected pipeline.
Calls _internal_async_announce with media id and expects it to block
until the announcement is completed.
"""
if text is None:
text = ""

if not media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline(pipeline_entity_id)

Check failure on line 82 in homeassistant/components/assist_satellite/entity.py

View workflow job for this annotation

GitHub Actions / Check mypy

Too many arguments for "_resolve_pipeline" of "AssistSatelliteEntity" [call-arg]

Check failure on line 82 in homeassistant/components/assist_satellite/entity.py

View workflow job for this annotation

GitHub Actions / Check mypy

Name "pipeline_entity_id" is not defined [name-defined]

Check failure on line 82 in homeassistant/components/assist_satellite/entity.py

View workflow job for this annotation

GitHub Actions / Check ruff

Ruff (F821)

homeassistant/components/assist_satellite/entity.py:82:50: F821 Undefined name `pipeline_entity_id`
pipeline = async_get_pipeline(self.hass, pipeline_id)

tts_options: dict[str, Any] = {}

Check failure on line 85 in homeassistant/components/assist_satellite/entity.py

View workflow job for this annotation

GitHub Actions / Check mypy

Name "Any" is not defined [name-defined]

Check failure on line 85 in homeassistant/components/assist_satellite/entity.py

View workflow job for this annotation

GitHub Actions / Check ruff

Ruff (F821)

homeassistant/components/assist_satellite/entity.py:85:36: F821 Undefined name `Any`
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice

media_id = tts_generate_media_source_id(
self.hass,
text,
engine=pipeline.tts_engine,
language=pipeline.tts_language,
options=tts_options,
)

if media_source.is_media_source_id(media_id):
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url

# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)

if self._is_announcing:
raise SatelliteBusyError

self._is_announcing = True

try:
# Block until announcement is finished
await self.async_announce(text, media_id)
finally:
self._is_announcing = False

async def async_announce(self, text: str, media_id: str) -> None:
"""Announce media on the satellite."""
raise NotImplementedError

async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
Expand Down
11 changes: 11 additions & 0 deletions homeassistant/components/assist_satellite/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Errors for assist satellite."""

from homeassistant.exceptions import HomeAssistantError


class AssistSatelliteError(HomeAssistantError):
"""Base class for assist satellite errors."""


class SatelliteBusyError(AssistSatelliteError):
"""Satellite is busy and cannot handle the request."""
5 changes: 5 additions & 0 deletions homeassistant/components/assist_satellite/icons.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"services": {
"announce": "mdi:bullhorn"
}
}
2 changes: 1 addition & 1 deletion homeassistant/components/assist_satellite/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}
9 changes: 8 additions & 1 deletion homeassistant/components/assist_satellite/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for assist satellite."""

from enum import StrEnum
from enum import IntFlag, StrEnum


class AssistSatelliteState(StrEnum):
Expand All @@ -17,3 +17,10 @@ class AssistSatelliteState(StrEnum):

RESPONDING = "responding"
"""Device is speaking the response."""


class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""

ANNOUNCE = 1
"""Device supports remotely triggered announcements."""
16 changes: 16 additions & 0 deletions homeassistant/components/assist_satellite/services.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
play_media:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
fields:
text:
required: false
example: "Time to wake up!"
selector:
text:
media_id:
required: false
selector:
text:
7 changes: 7 additions & 0 deletions tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
Expand All @@ -30,15 +31,21 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""

_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE

def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []

def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)

async def async_announce(self, text: str, media_id: str) -> None:
"""Announce media on a device."""
self.announcements.append((text, media_id))


@pytest.fixture
def entity() -> MockAssistSatellite:
Expand Down
37 changes: 37 additions & 0 deletions tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from unittest.mock import patch

import pytest

from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
AudioSettings,
Expand Down Expand Up @@ -84,3 +86,38 @@ async def test_entity_state(
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD


@pytest.mark.parametrize(
["service_data", "expected_params"],

Check failure on line 92 in tests/components/assist_satellite/test_entity.py

View workflow job for this annotation

GitHub Actions / Check ruff

Ruff (PT006)

tests/components/assist_satellite/test_entity.py:92:5: PT006 Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple`
[
(
{"text": "Hello"},
("Hello", "media-source://bla"),
),
(
{
"text": "Hello",
"media_id": "http://example.com/bla.mp3",
},
("Hello", "http://example.com/bla.mp3"),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
),
],
)
async def test_announce(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
expected_params: tuple[str, str],
) -> None:
"""Test announcing on a device."""
await hass.services.async_call(
"assist_satellite", "announce", service_data, blocking=True
)

assert entity.announcements[0] == expected_params

0 comments on commit b806b26

Please sign in to comment.