Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add announce service to Assist Satellite #124927

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import logging

import voluptuous as vol

from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
Expand All @@ -10,13 +12,14 @@

from .const import DOMAIN
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .models import AssistSatelliteState
from .models import AssistSatelliteEntityFeature, AssistSatelliteState

__all__ = [
"DOMAIN",
"AssistSatelliteState",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature",
"AssistSatelliteState",
]

_LOGGER = logging.getLogger(__name__)
Expand All @@ -30,6 +33,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await component.async_setup(config)

component.async_register_entity_service(
"announce",
vol.All(
cv.make_entity_service_schema(
{
vol.Optional("message"): str,
vol.Optional("media_id"): str,
}
),
cv.has_at_least_one_key("message", "media_id"),
),
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)

return True


Expand Down
72 changes: 70 additions & 2 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@
from abc import abstractmethod
from collections.abc import AsyncIterable
import time
from typing import Final
from typing import Any, Final

from homeassistant.components import stt
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid

from .errors import SatelliteBusyError
from .models import AssistSatelliteState

_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
Expand All @@ -43,6 +49,7 @@
_conversation_id_time: float | None = None

_run_has_tts: bool = False
_is_announcing = False

@property
def pipeline_entity_id(self) -> str | None:
Expand All @@ -54,6 +61,67 @@
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id

async def async_internal_announce(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check the supported features here and raise an exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. That is done by the service call handler. The service is registered in __init__.py with the correct feature flag.

self,
message: str | None = None,
media_id: str | None = None,
) -> None:
"""Play an announcement on the satellite.

If media_id is not provided, message is synthesized to
audio with the selected pipeline.

Calls async_announce with media id.
"""
if message is None:
message = ""

if not media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)

tts_options: dict[str, Any] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note for future us: we will need to put the correct format here once we have a supported media format coming from the satellite. Otherwise it will always be MP3.

if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice

Check warning on line 86 in homeassistant/components/assist_satellite/entity.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/assist_satellite/entity.py#L86

Added line #L86 was not covered by tests

media_id = tts_generate_media_source_id(
self.hass,
message,
engine=pipeline.tts_engine,
language=pipeline.tts_language,
options=tts_options,
)

if media_source.is_media_source_id(media_id):
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url

# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)

if self._is_announcing:
raise SatelliteBusyError

Check warning on line 108 in homeassistant/components/assist_satellite/entity.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/assist_satellite/entity.py#L108

Added line #L108 was not covered by tests

self._is_announcing = True

try:
# Block until announcement is finished
await self.async_announce(message, media_id)
finally:
self._is_announcing = False

async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on the satellite.

Should block until the announcement is done playing.
"""
raise NotImplementedError

async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
Expand Down
11 changes: 11 additions & 0 deletions homeassistant/components/assist_satellite/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Errors for assist satellite."""

from homeassistant.exceptions import HomeAssistantError


class AssistSatelliteError(HomeAssistantError):
"""Base class for assist satellite errors."""


class SatelliteBusyError(AssistSatelliteError):
"""Satellite is busy and cannot handle the request."""
12 changes: 12 additions & 0 deletions homeassistant/components/assist_satellite/icons.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"entity_component": {
"_": {
"default": "mdi:account-voice"
}
},
"services": {
"announce": {
"service": "mdi:bullhorn"
}
}
}
2 changes: 1 addition & 1 deletion homeassistant/components/assist_satellite/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}
9 changes: 8 additions & 1 deletion homeassistant/components/assist_satellite/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models for assist satellite."""

from enum import StrEnum
from enum import IntFlag, StrEnum


class AssistSatelliteState(StrEnum):
Expand All @@ -17,3 +17,10 @@ class AssistSatelliteState(StrEnum):

RESPONDING = "responding"
"""Device is speaking the response."""


class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""

ANNOUNCE = 1
"""Device supports remotely triggered announcements."""
16 changes: 16 additions & 0 deletions homeassistant/components/assist_satellite/services.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
announce:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
fields:
message:
required: false
example: "Time to wake up!"
selector:
text:
media_id:
required: false
selector:
text:
17 changes: 17 additions & 0 deletions homeassistant/components/assist_satellite/strings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"title": "Assist satellite",
"entity_component": {
"_": {
"name": "Assist satellite",
Expand All @@ -9,5 +10,21 @@
"processing": "Processing"
}
}
},
"services": {
"announce": {
"name": "Announce",
"description": "Let the satellite announce a message.",
"fields": {
"message": {
"name": "Message",
"description": "The message to announce."
},
"media_id": {
"name": "Media ID",
"description": "The media ID to announce instead of using text-to-speech."
}
}
}
}
}
13 changes: 11 additions & 2 deletions tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
Expand All @@ -30,15 +31,21 @@
"""Mock Assist Satellite Entity."""

_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE

def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []

def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)

async def async_announce(self, text: str, media_id: str) -> None:

Check warning on line 45 in tests/components/assist_satellite/conftest.py

View workflow job for this annotation

GitHub Actions / Check pylint on tests

W0237: Parameter 'message' has been renamed to 'text' in overriding 'MockAssistSatellite.async_announce' method (arguments-renamed)
"""Announce media on a device."""
self.announcements.append((text, media_id))


@pytest.fixture
def entity() -> MockAssistSatellite:
Expand All @@ -56,7 +63,9 @@

@pytest.fixture
async def init_components(
hass: HomeAssistant, config_entry: ConfigEntry, entity: MockAssistSatellite
hass: HomeAssistant,
config_entry: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
Expand Down Expand Up @@ -91,7 +100,7 @@
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
"""Set up test satellite platform via config entry."""
async_add_entities([entity])

loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
Expand Down
64 changes: 64 additions & 0 deletions tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

from unittest.mock import patch

import pytest

from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import Context, HomeAssistant
Expand Down Expand Up @@ -84,3 +89,62 @@ async def test_entity_state(
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD


@pytest.mark.parametrize(
("service_data", "expected_params"),
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
},
("Hello", "http://example.com/bla.mp3"),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
),
],
)
async def test_announce(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
expected_params: tuple[str, str],
) -> None:
"""Test announcing on a device."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
tts_engine="tts.mock_entity",
tts_language="en",
)

with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
),
):
await hass.services.async_call(
"assist_satellite",
"announce",
service_data,
target={"entity_id": "assist_satellite.test_entity"},
blocking=True,
)

assert entity.announcements[0] == expected_params
Loading