Skip to content

Commit

Permalink
feat(switch_controller): allow input_boolean and binary_sensor to be …
Browse files Browse the repository at this point in the history
…used in switch controllers
  • Loading branch information
xaviml committed Aug 8, 2020
1 parent ccd400c commit 9807211
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 29 deletions.
25 changes: 19 additions & 6 deletions apps/controllerx/cx_core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ async def handle_action(self, action_key: str) -> None:
ascii_encode=False,
)
await self.call_action(action_key)
else:
self.log(
f"🎮 Button event triggered, but not registered: `{action_key}`",
level="INFO",
ascii_encode=False,
)

async def call_action(self, action_key: str):
delay = self.action_delay[action_key]
Expand Down Expand Up @@ -283,21 +289,28 @@ def get_type_actions_mapping(self) -> TypeActionsMapping:

class TypeController(Controller, abc.ABC):
@abc.abstractmethod
def get_domain(self) -> str:
def get_domain(self) -> List[str]:
raise NotImplementedError

async def check_domain(self, entity: str) -> None:
domain = self.get_domain()
domains = self.get_domain()
if entity.startswith("group."):
entities = await self.get_state(entity, attribute="entity_id")
same_domain = all([elem.startswith(domain + ".") for elem in entities])
same_domain = all(
(
any(elem.startswith(domain + ".") for domain in domains)
for elem in entities
)
)
if not same_domain:
raise ValueError(
f"All entities from '{entity}' must be from {domain} domain (e.g. {domain}.bedroom)"
f"All entities from '{entity}' must be from one "
f"of the following domains {domains} (e.g. {domains[0]}.bedroom)"
)
elif not entity.startswith(domain + "."):
elif not any(entity.startswith(domain + ".") for domain in domains):
raise ValueError(
f"'{entity}' must be from {domain} domain (e.g. {domain}.bedroom)"
f"'{entity}' must be from one of the following domains "
f"{domains} (e.g. {domains[0]}.bedroom)"
)

async def get_entity_state(self, entity: str, attribute: str = None) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions apps/controllerx/cx_core/type/cover_controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, List
from cx_const import Cover, TypeActionsMapping
from cx_core.controller import TypeController, action
from cx_core.feature_support.cover import CoverSupport
Expand Down Expand Up @@ -32,8 +32,8 @@ async def initialize(self) -> None:

await super().initialize()

def get_domain(self) -> str:
return "cover"
def get_domain(self) -> List[str]:
return ["cover"]

def get_type_actions_mapping(self) -> TypeActionsMapping:
return {
Expand Down
8 changes: 4 additions & 4 deletions apps/controllerx/cx_core/type/light_controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

from cx_const import Light, TypeActionsMapping
from cx_core.color_helper import get_color_wheel
from cx_core.controller import ReleaseHoldController, TypeController, action
from cx_core.feature_support.light import LightSupport
from cx_core.stepper import Stepper
from cx_core.stepper.circular_stepper import CircularStepper
from cx_core.stepper.minmax_stepper import MinMaxStepper
from cx_core.color_helper import get_color_wheel

DEFAULT_MANUAL_STEPS = 10
DEFAULT_AUTOMATIC_STEPS = 10
Expand Down Expand Up @@ -106,8 +106,8 @@ async def initialize(self) -> None:
)
await super().initialize()

def get_domain(self) -> str:
return "light"
def get_domain(self) -> List[str]:
return ["light"]

def get_type_actions_mapping(self,) -> TypeActionsMapping:
return {
Expand Down
6 changes: 4 additions & 2 deletions apps/controllerx/cx_core/type/media_player_controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from cx_const import MediaPlayer, TypeActionsMapping
from cx_core.controller import ReleaseHoldController, TypeController, action
from cx_core.feature_support.media_player import MediaPlayerSupport
Expand All @@ -22,8 +24,8 @@ async def initialize(self) -> None:
)
await super().initialize()

def get_domain(self) -> str:
return "media_player"
def get_domain(self) -> List[str]:
return ["media_player"]

def get_type_actions_mapping(self) -> TypeActionsMapping:
return {
Expand Down
6 changes: 4 additions & 2 deletions apps/controllerx/cx_core/type/switch_controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from cx_const import Switch, TypeActionsMapping
from cx_core.controller import TypeController, action

Expand All @@ -17,8 +19,8 @@ async def initialize(self) -> None:
await self.check_domain(self.switch)
await super().initialize()

def get_domain(self) -> str:
return "switch"
def get_domain(self) -> List[str]:
return ["switch", "input_boolean", "binary_sensor"]

def get_type_actions_mapping(self) -> TypeActionsMapping:
return {
Expand Down
43 changes: 31 additions & 12 deletions tests/cx_core/type_controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,54 @@ def sut(hass_mock):


@pytest.mark.parametrize(
"entity, domain, entities, error_expected",
"entity, domains, entities, error_expected",
[
("light.kitchen", "light", [], False),
("light1.kitchen", "light", [], True,),
("media_player.kitchen", "light", [], True,),
("media_player.bedroom", "media_player", [], False),
("group.all_lights", "light", ["light.light1", "light.light2"], False),
("group.all_lights", "light", ["light1.light1", "light2.light2"], True),
("group.all", "media_player", ["media_player.test", "light.test"], True),
("light.kitchen", ["light"], [], False),
("light1.kitchen", ["light"], [], True,),
("media_player.kitchen", ["light"], [], True,),
("media_player.bedroom", ["media_player"], [], False),
("group.all_lights", ["light"], ["light.light1", "light.light2"], False),
("group.all_lights", ["light"], ["light1.light1", "light2.light2"], True),
("group.all", ["media_player"], ["media_player.test", "light.test"], True),
(
"group.all",
["switch", "input_boolean"],
["switch.switch1", "input_boolean.input_boolean1"],
False,
),
("switch.switch1", ["switch", "input_boolean"], [], False),
("switch.switch1", ["binary_sensor", "input_boolean"], [], True),
(
"group.all",
["switch", "input_boolean"],
["light.light1", "input_boolean.input_boolean1"],
True,
),
],
)
@pytest.mark.asyncio
async def test_check_domain(
sut, mocker, monkeypatch, entity, domain, entities, error_expected
sut, monkeypatch, entity, domains, entities, error_expected
):
expected_error_message = ""
if error_expected:
if entities == []:
expected_error_message = (
f"'{entity}' must be from {domain} domain (e.g. {domain}.bedroom)"
f"'{entity}' must be from one of the following domains "
f"{domains} (e.g. {domains[0]}.bedroom)"
)

else:
expected_error_message = f"All entities from '{entity}' must be from {domain} domain (e.g. {domain}.bedroom)"
expected_error_message = (
f"All entities from '{entity}' must be from one of the "
f"following domains {domains} (e.g. {domains[0]}.bedroom)"
)

async def fake_get_state(*args, **kwargs):
return entities

monkeypatch.setattr(sut, "get_state", fake_get_state)
monkeypatch.setattr(sut, "get_domain", lambda *args: domain)
monkeypatch.setattr(sut, "get_domain", lambda *args: domains)

if error_expected:
with pytest.raises(ValueError) as e:
Expand Down

0 comments on commit 9807211

Please sign in to comment.