diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 48a4e188..00000000 --- a/.coveragerc +++ /dev/null @@ -1,7 +0,0 @@ -[report] -exclude_lines = - # Have to re-enable the standard pragma - pragma: no cover - - # Don't complain if tests don't hit defensive assertion code: - raise NotImplementedError \ No newline at end of file diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 9c6a8ccc..00000000 --- a/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -ignore = E501, W503 -exclude = controllerx.py diff --git a/.gitignore b/.gitignore index a75742d0..f3f2f168 100644 --- a/.gitignore +++ b/.gitignore @@ -129,7 +129,8 @@ dmypy.json .pyre/ #VSCode -.vscode/ +.vscode/* +!.vscode/settings.json .idea # Ignoring Pipfile.lock since we support python 3.6, 3.7 and 3.8 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a766be32..4cd56955 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,36 +1,33 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.4.0 - hooks: - - id: check-json - - id: pretty-format-json - args: - - --autofix - - --indent - - '4' - repo: local hooks: + - id: isort + name: isort + entry: pipenv run isort + language: python + types: [python] + args: [] + require_serial: false - id: black name: black - entry: pipenv run black apps/controllerx tests - language: system - pass_filenames: false - always_run: true + entry: pipenv run black + language: python + types: [python] - id: flake8 name: flake8 - entry: pipenv run flake8 apps/controllerx tests - language: system - pass_filenames: false - always_run: true + entry: pipenv run flake8 + language: python + types: [python] - id: mypy name: mypy - entry: pipenv run mypy apps/controllerx - language: system + entry: pipenv run mypy apps/controllerx/ tests/ + language: python + types: [python] pass_filenames: false - always_run: true - id: pytest name: pytest entry: pipenv run pytest - language: system + language: python + types: [python] pass_filenames: false always_run: true diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..986f6932 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "files.autoSave": "afterDelay", + "files.autoSaveDelay": 100, + "python.testing.pytestEnabled": true, + "editor.formatOnSave": true, + "python.formatting.provider": "black", + "python.analysis.typeCheckingMode": "basic", + "python.analysis.diagnosticMode": "workspace", + "python.analysis.stubPath": "apps/controllerx", + "python.testing.pytestArgs": [ + "tests" + ], + "python.languageServer": "Pylance", + "python.linting.mypyEnabled": true, + "python.linting.mypyCategorySeverity.note": "Error", + "python.linting.flake8Enabled": true, + "[python]": { + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + }, +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ef6eee7f..0299cb6e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,33 +18,49 @@ New controllers need to be added into the `apps/controllerx/devices/` and you wi Note that this project will only accept the mapping that the original controller would follow with its original hub. +## Imports + +Run the following to fix imports order: + +```shell +pipenv run isort apps/controllerx/ tests/ +``` + +## Format + +Run the following to fix formatting: + +```shell +pipenv run black apps/controllerx/ tests/ +``` + ## Typing Run the following to check consistency in the typings: -``` -pipenv run mypy apps/controllerx +```shell +pipenv run mypy apps/controllerx/ tests/ ``` ## Linting Run the following to check for stylings: -``` -pipenv run flake8 apps/controllerx +```shell +pipenv run flake8 apps/controllerx/ tests/ ``` ## Test Run the following command for the tests: -``` +```shell pipenv run pytest --cov=apps ``` or the following to get a report of the missing lines to be tested: -``` +```shell pytest --cov-report term-missing --cov=apps ``` @@ -52,7 +68,7 @@ pytest --cov-report term-missing --cov=apps Once you have the code ready, pre-commit will run some checks to make sure the code follows the format and the tests did not break. If you want to run the check for all files at any point, run: -``` +```shell pipenv run pre-commit run --all-files ``` @@ -64,7 +80,7 @@ You can use the tool `commitizen` to commit based in a standard. If you are in t [Install Jekyll](https://jekyllrb.com/docs/) and run the documentation locally with: -``` +```shell cd docs bundle install bundle exec jekyll serve @@ -98,13 +114,13 @@ git checkout -b - / Thanks to the Azure Pipelines, we are able to deploy by just creating a new tag on git. So first, we will need to bump version with `commitizen` by running the following line in the `master` branch: -``` +```shell cz bump --no-verify ``` `--prerelease beta` tag can be added to create a pre-release. Note that you can also add `--dry-run` to see which version will bump without commiting anything. Then, we can directly push the tags: -``` +```shell git push origin master --tags ``` diff --git a/Pipfile b/Pipfile index 429e3209..36c7b95a 100644 --- a/Pipfile +++ b/Pipfile @@ -14,6 +14,7 @@ pre-commit = "==2.8.2" commitizen = "==2.8.0" mypy = "==0.790" flake8 = "==3.8.4" +isort = "==5.6.4" controllerx = {path = ".", editable = true} [packages] diff --git a/apps/controllerx/cx_const.py b/apps/controllerx/cx_const.py index 8279fb52..6984f3e4 100644 --- a/apps/controllerx/cx_const.py +++ b/apps/controllerx/cx_const.py @@ -1,6 +1,5 @@ from typing import Any, Awaitable, Callable, Dict, Tuple, Union - ActionFunction = Callable[..., Awaitable[Any]] TypeAction = Union[ActionFunction, Tuple, str] ActionEvent = Union[str, int] diff --git a/apps/controllerx/cx_core/controller.py b/apps/controllerx/cx_core/controller.py index ae32c874..d35d3bcb 100644 --- a/apps/controllerx/cx_core/controller.py +++ b/apps/controllerx/cx_core/controller.py @@ -9,7 +9,6 @@ Awaitable, Callable, Counter, - DefaultDict, Dict, List, Optional, @@ -24,13 +23,13 @@ from appdaemon.plugins.hass.hassapi import Hass # type: ignore from appdaemon.plugins.mqtt.mqttapi import Mqtt # type: ignore from cx_const import ActionEvent, ActionFunction, TypeAction, TypeActionsMapping - from cx_core import integration as integration_module -from cx_core.integration import Integration +from cx_core.integration import EventData, Integration Service = Tuple[str, Dict] Services = List[Service] + DEFAULT_DELAY = 350 # In milliseconds DEFAULT_ACTION_DELTA = 300 # In milliseconds DEFAULT_MULTIPLE_CLICK_DELAY = 500 # In milliseconds @@ -118,15 +117,13 @@ async def initialize(self) -> None: self.multiple_click_delay: int = self.args.get( "multiple_click_delay", DEFAULT_MULTIPLE_CLICK_DELAY ) - self.action_times: DefaultDict[str, float] = defaultdict(lambda: 0.0) - self.multiple_click_action_times: DefaultDict[str, float] = defaultdict( - lambda: 0.0 - ) + self.action_times: Dict[str, float] = defaultdict(lambda: 0.0) + self.multiple_click_action_times: Dict[str, float] = defaultdict(lambda: 0.0) self.click_counter: Counter[ActionEvent] = Counter() - self.action_delay_handles: DefaultDict[ - ActionEvent, Optional[float] - ] = defaultdict(lambda: None) - self.multiple_click_action_delay_tasks: DefaultDict[ + self.action_delay_handles: Dict[ActionEvent, Optional[float]] = defaultdict( + lambda: None + ) + self.multiple_click_action_delay_tasks: Dict[ ActionEvent, Optional[Future] ] = defaultdict(lambda: None) @@ -188,7 +185,7 @@ def get_actions_mapping(self, integration: Integration) -> TypeActionsMapping: raise ValueError(f"This controller does not support {integration.name}.") return actions_mapping - def get_list(self, entities: Union[List[T], str]) -> Union[List[T], List[str]]: + def get_list(self, entities: Union[List[str], str]) -> List[str]: if isinstance(entities, str): return [entities] return entities @@ -387,7 +384,7 @@ def get_zha_actions_mapping(self) -> Optional[TypeActionsMapping]: """ return None - def get_zha_action(self, data: Dict[Any, Any]) -> Optional[str]: + def get_zha_action(self, data: EventData) -> Optional[str]: """ This method can be override for controllers that do not support the standard extraction of the actions on cx_core/integration/zha.py @@ -398,44 +395,6 @@ def get_type_actions_mapping(self) -> TypeActionsMapping: return {} -class TypeController(Controller, abc.ABC): - @abc.abstractmethod - def get_domain(self) -> List[str]: - raise NotImplementedError - - async def check_domain(self, entity: str) -> None: - domains = self.get_domain() - if entity.startswith("group."): - entities = await self.get_state(entity, attribute="entity_id") - same_domain = all( - ( - any(elem.startswith(domain + ".") for domain in domains) - for elem in entities - ) - ) - if not same_domain: - raise ValueError( - f"All entities from '{entity}' must be from one " - f"of the following domains {domains} (e.g. {domains[0]}.bedroom)" - ) - elif not any(entity.startswith(domain + ".") for domain in domains): - raise ValueError( - f"'{entity}' must be from one of the following domains " - f"{domains} (e.g. {domains[0]}.bedroom)" - ) - - async def get_entity_state(self, entity: str, attribute: str = None) -> Any: - if entity.startswith("group."): - entities = await self.get_state(entity, attribute="entity_id") - if len(entities) == 0: - raise ValueError( - f"The group `{entity}` does not have any entities registered." - ) - entity = entities[0] - out = await self.get_state(entity, attribute=attribute) - return out - - class ReleaseHoldController(Controller, abc.ABC): DEFAULT_MAX_LOOPS = 50 diff --git a/apps/controllerx/cx_core/feature_support/__init__.py b/apps/controllerx/cx_core/feature_support/__init__.py index ce62f567..75f024ac 100644 --- a/apps/controllerx/cx_core/feature_support/__init__.py +++ b/apps/controllerx/cx_core/feature_support/__init__.py @@ -1,13 +1,21 @@ -from typing import List, Set, Union +from typing import TYPE_CHECKING, List, Optional, Set, Type, TypeVar -from cx_core.controller import TypeController +if TYPE_CHECKING: + from cx_core.type_controller import TypeController -SupportedFeatureNumber = Union[int, str] Features = List[int] SupportedFeatures = Set[int] +FeatureSupportType = TypeVar("FeatureSupportType", bound="FeatureSupport") class FeatureSupport: + + entity_id: str + controller: "TypeController" + features: Features = [] + update_supported_features: bool + _supported_features: Optional[SupportedFeatures] + @staticmethod def encode(supported_features: SupportedFeatures) -> int: number = 0 @@ -21,21 +29,29 @@ def decode(number: int, features: Features) -> SupportedFeatures: def __init__( self, - entity: str, - controller: TypeController, - features: Features, - update_supported_features: bool, + entity_id: str, + controller: "TypeController", + update_supported_features=False, ) -> None: - self.entity = entity + self.entity_id = entity_id self.controller = controller self._supported_features = None - self.features = features self.update_supported_features = update_supported_features - async def supported_features(self): + @classmethod + def instantiate( + cls: Type[FeatureSupportType], + entity_id: str, + controller: "TypeController", + update_supported_features=False, + ) -> FeatureSupportType: + return cls(entity_id, controller, update_supported_features) + + @property + async def supported_features(self) -> SupportedFeatures: if self._supported_features is None or self.update_supported_features: bitfield: str = await self.controller.get_entity_state( - self.entity, attribute="supported_features" + self.entity_id, attribute="supported_features" ) if bitfield is not None: self._supported_features = FeatureSupport.decode( @@ -43,12 +59,12 @@ async def supported_features(self): ) else: raise ValueError( - f"`supported_features` could not be read from `{self.entity}`. Entity might not be available." + f"`supported_features` could not be read from `{self.entity_id}`. Entity might not be available." ) return self._supported_features async def is_supported(self, feature: int) -> bool: - return feature in await self.supported_features() + return feature in await self.supported_features async def not_supported(self, feature: int) -> bool: - return feature not in await self.supported_features() + return feature not in await self.supported_features diff --git a/apps/controllerx/cx_core/feature_support/cover.py b/apps/controllerx/cx_core/feature_support/cover.py index 5c0b9504..11d818e2 100644 --- a/apps/controllerx/cx_core/feature_support/cover.py +++ b/apps/controllerx/cx_core/feature_support/cover.py @@ -1,15 +1,5 @@ -from cx_core.controller import TypeController from cx_core.feature_support import FeatureSupport -SUPPORT_OPEN = 1 -SUPPORT_CLOSE = 2 -SUPPORT_SET_POSITION = 4 -SUPPORT_STOP = 8 -SUPPORT_OPEN_TILT = 16 -SUPPORT_CLOSE_TILT = 32 -SUPPORT_STOP_TILT = 64 -SUPPORT_SET_TILT_POSITION = 128 - class CoverSupport(FeatureSupport): @@ -22,21 +12,13 @@ class CoverSupport(FeatureSupport): STOP_TILT = 64 SET_TILT_POSITION = 128 - def __init__( - self, entity: str, controller: TypeController, update_supported_features: bool - ) -> None: - super().__init__( - entity, - controller, - [ - CoverSupport.OPEN, - CoverSupport.CLOSE, - CoverSupport.SET_COVER_POSITION, - CoverSupport.STOP, - CoverSupport.OPEN_TILT, - CoverSupport.CLOSE_TILT, - CoverSupport.STOP_TILT, - CoverSupport.SET_TILT_POSITION, - ], - update_supported_features, - ) + features = [ + OPEN, + CLOSE, + SET_COVER_POSITION, + STOP, + OPEN_TILT, + CLOSE_TILT, + STOP_TILT, + SET_TILT_POSITION, + ] diff --git a/apps/controllerx/cx_core/feature_support/light.py b/apps/controllerx/cx_core/feature_support/light.py index 5bfc0c2c..2b39a5da 100644 --- a/apps/controllerx/cx_core/feature_support/light.py +++ b/apps/controllerx/cx_core/feature_support/light.py @@ -1,4 +1,3 @@ -from cx_core.controller import TypeController from cx_core.feature_support import FeatureSupport @@ -11,20 +10,12 @@ class LightSupport(FeatureSupport): TRANSITION = 32 WHITE_VALUE = 128 - def __init__( - self, entity: str, controller: TypeController, update_supported_features: bool - ) -> None: - super().__init__( - entity, - controller, - [ - LightSupport.BRIGHTNESS, - LightSupport.COLOR_TEMP, - LightSupport.EFFECT, - LightSupport.FLASH, - LightSupport.COLOR, - LightSupport.TRANSITION, - LightSupport.WHITE_VALUE, - ], - update_supported_features, - ) + features = [ + BRIGHTNESS, + COLOR_TEMP, + EFFECT, + FLASH, + COLOR, + TRANSITION, + WHITE_VALUE, + ] diff --git a/apps/controllerx/cx_core/feature_support/media_player.py b/apps/controllerx/cx_core/feature_support/media_player.py index 9e654780..18826226 100644 --- a/apps/controllerx/cx_core/feature_support/media_player.py +++ b/apps/controllerx/cx_core/feature_support/media_player.py @@ -1,4 +1,3 @@ -from cx_core.controller import TypeController from cx_core.feature_support import FeatureSupport @@ -20,29 +19,21 @@ class MediaPlayerSupport(FeatureSupport): SHUFFLE_SET = 32768 SELECT_SOUND_MODE = 65536 - def __init__( - self, entity: str, controller: TypeController, update_supported_features: bool - ) -> None: - super().__init__( - entity, - controller, - [ - MediaPlayerSupport.PAUSE, - MediaPlayerSupport.SEEK, - MediaPlayerSupport.VOLUME_SET, - MediaPlayerSupport.VOLUME_MUTE, - MediaPlayerSupport.PREVIOUS_TRACK, - MediaPlayerSupport.NEXT_TRACK, - MediaPlayerSupport.TURN_ON, - MediaPlayerSupport.TURN_OFF, - MediaPlayerSupport.PLAY_MEDIA, - MediaPlayerSupport.VOLUME_STEP, - MediaPlayerSupport.SELECT_SOURCE, - MediaPlayerSupport.STOP, - MediaPlayerSupport.CLEAR_PLAYLIST, - MediaPlayerSupport.PLAY, - MediaPlayerSupport.SHUFFLE_SET, - MediaPlayerSupport.SELECT_SOUND_MODE, - ], - update_supported_features, - ) + features = [ + PAUSE, + SEEK, + VOLUME_SET, + VOLUME_MUTE, + PREVIOUS_TRACK, + NEXT_TRACK, + TURN_ON, + TURN_OFF, + PLAY_MEDIA, + VOLUME_STEP, + SELECT_SOURCE, + STOP, + CLEAR_PLAYLIST, + PLAY, + SHUFFLE_SET, + SELECT_SOUND_MODE, + ] diff --git a/apps/controllerx/cx_core/integration/__init__.py b/apps/controllerx/cx_core/integration/__init__.py index 6fdab355..9a57cb5e 100644 --- a/apps/controllerx/cx_core/integration/__init__.py +++ b/apps/controllerx/cx_core/integration/__init__.py @@ -9,17 +9,19 @@ if TYPE_CHECKING: from cx_core.controller import Controller +EventData = Dict[str, Any] + class Integration(abc.ABC): + + name: str + controller: "Controller" + kwargs: Dict[str, Any] + def __init__(self, controller: "Controller", kwargs: Dict[str, Any]): - self.name = self.get_name() self.controller = controller self.kwargs = kwargs - @abc.abstractmethod - def get_name(self) -> str: - raise NotImplementedError - @abc.abstractmethod def get_actions_mapping(self) -> Optional[TypeActionsMapping]: raise NotImplementedError diff --git a/apps/controllerx/cx_core/integration/deconz.py b/apps/controllerx/cx_core/integration/deconz.py index 650a192c..cf36b4e9 100644 --- a/apps/controllerx/cx_core/integration/deconz.py +++ b/apps/controllerx/cx_core/integration/deconz.py @@ -1,22 +1,22 @@ from typing import Optional from appdaemon.plugins.hass.hassapi import Hass # type:ignore - -from cx_core.integration import Integration, TypeActionsMapping +from cx_core.integration import EventData, Integration, TypeActionsMapping class DeCONZIntegration(Integration): - def get_name(self) -> str: - return "deconz" + name = "deconz" def get_actions_mapping(self) -> Optional[TypeActionsMapping]: return self.controller.get_deconz_actions_mapping() def listen_changes(self, controller_id: str) -> None: Hass.listen_event( - self.controller, self.callback, "deconz_event", id=controller_id + self.controller, self.event_callback, "deconz_event", id=controller_id ) - async def callback(self, event_name: str, data: dict, kwargs: dict) -> None: + async def event_callback( + self, event_name: str, data: EventData, kwargs: dict + ) -> None: type_ = self.kwargs.get("type", "event") await self.controller.handle_action(data[type_]) diff --git a/apps/controllerx/cx_core/integration/mqtt.py b/apps/controllerx/cx_core/integration/mqtt.py index f441eeb2..cf9d67ca 100644 --- a/apps/controllerx/cx_core/integration/mqtt.py +++ b/apps/controllerx/cx_core/integration/mqtt.py @@ -1,14 +1,12 @@ from typing import Optional from appdaemon.plugins.mqtt.mqttapi import Mqtt # type: ignore - from cx_const import TypeActionsMapping -from cx_core.integration import Integration +from cx_core.integration import EventData, Integration class MQTTIntegration(Integration): - def get_name(self) -> str: - return "mqtt" + name = "mqtt" def get_actions_mapping(self) -> Optional[TypeActionsMapping]: return self.controller.get_z2m_actions_mapping() @@ -18,7 +16,9 @@ def listen_changes(self, controller_id: str) -> None: self.controller, self.event_callback, topic=controller_id, namespace="mqtt" ) - async def event_callback(self, event_name: str, data: dict, kwargs: dict) -> None: + async def event_callback( + self, event_name: str, data: EventData, kwargs: dict + ) -> None: self.controller.log(f"MQTT data event: {data}", level="DEBUG") if "payload" in data: await self.controller.handle_action(data["payload"]) diff --git a/apps/controllerx/cx_core/integration/state.py b/apps/controllerx/cx_core/integration/state.py index be7d3048..520aa125 100644 --- a/apps/controllerx/cx_core/integration/state.py +++ b/apps/controllerx/cx_core/integration/state.py @@ -1,14 +1,12 @@ from typing import Optional from appdaemon.plugins.hass.hassapi import Hass # type: ignore - from cx_const import TypeActionsMapping from cx_core.integration import Integration class StateIntegration(Integration): - def get_name(self) -> str: - return "state" + name = "state" def get_actions_mapping(self) -> Optional[TypeActionsMapping]: return self.controller.get_z2m_actions_mapping() @@ -16,10 +14,10 @@ def get_actions_mapping(self) -> Optional[TypeActionsMapping]: def listen_changes(self, controller_id: str) -> None: attribute = self.kwargs.get("attribute", None) Hass.listen_state( - self.controller, self.callback, controller_id, attribute=attribute + self.controller, self.state_callback, controller_id, attribute=attribute ) - async def callback( + async def state_callback( self, entity: Optional[str], attribute: Optional[str], old, new, kwargs ) -> None: await self.controller.handle_action(new) diff --git a/apps/controllerx/cx_core/integration/z2m.py b/apps/controllerx/cx_core/integration/z2m.py index ee8f3631..95c161bb 100644 --- a/apps/controllerx/cx_core/integration/z2m.py +++ b/apps/controllerx/cx_core/integration/z2m.py @@ -3,17 +3,15 @@ from appdaemon.plugins.hass.hassapi import Hass # type: ignore from appdaemon.plugins.mqtt.mqttapi import Mqtt # type: ignore - from cx_const import TypeActionsMapping -from cx_core.integration import Integration +from cx_core.integration import EventData, Integration LISTENS_TO_HA = "ha" LISTENS_TO_MQTT = "mqtt" class Z2MIntegration(Integration): - def get_name(self) -> str: - return "z2m" + name = "z2m" def get_actions_mapping(self) -> Optional[TypeActionsMapping]: return self.controller.get_z2m_actions_mapping() @@ -35,7 +33,9 @@ def listen_changes(self, controller_id: str) -> None: "`listen_to` has to be either `ha` or `mqtt`. Default is `ha`." ) - async def event_callback(self, event_name: str, data: dict, kwargs: dict) -> None: + async def event_callback( + self, event_name: str, data: EventData, kwargs: dict + ) -> None: self.controller.log(f"MQTT data event: {data}", level="DEBUG") action_key = self.kwargs.get("action_key", "action") action_group_key = self.kwargs.get("action_group_key", "action_group") diff --git a/apps/controllerx/cx_core/integration/zha.py b/apps/controllerx/cx_core/integration/zha.py index 86305350..02db43de 100644 --- a/apps/controllerx/cx_core/integration/zha.py +++ b/apps/controllerx/cx_core/integration/zha.py @@ -1,14 +1,12 @@ from typing import Optional from appdaemon.plugins.hass.hassapi import Hass # type: ignore - from cx_const import TypeActionsMapping -from cx_core.integration import Integration +from cx_core.integration import EventData, Integration class ZHAIntegration(Integration): - def get_name(self): - return "zha" + name = "zha" def get_actions_mapping(self) -> Optional[TypeActionsMapping]: return self.controller.get_zha_actions_mapping() @@ -18,7 +16,7 @@ def listen_changes(self, controller_id: str) -> None: self.controller, self.callback, "zha_event", device_ieee=controller_id ) - def get_action(self, data: dict): + def get_action(self, data: EventData) -> str: command = data["command"] args = data["args"] if isinstance(args, dict): @@ -30,7 +28,7 @@ def get_action(self, data: dict): action += "_" + "_".join(args) return action - async def callback(self, event_name: str, data: dict, kwargs: dict) -> None: + async def callback(self, event_name: str, data: EventData, kwargs: dict) -> None: action = self.controller.get_zha_action(data) if action is None: # If there is no action extracted from the controller then diff --git a/apps/controllerx/cx_core/type/cover_controller.py b/apps/controllerx/cx_core/type/cover_controller.py index 6798d701..73f21de3 100644 --- a/apps/controllerx/cx_core/type/cover_controller.py +++ b/apps/controllerx/cx_core/type/cover_controller.py @@ -1,10 +1,12 @@ -from typing import Callable, List +from typing import Callable, Type + from cx_const import Cover, TypeActionsMapping -from cx_core.controller import TypeController, action +from cx_core.controller import action from cx_core.feature_support.cover import CoverSupport +from cx_core.type_controller import Entity, TypeController -class CoverController(TypeController): +class CoverController(TypeController[Entity, CoverSupport]): """ This is the main class that controls the coveres for different devices. Type of actions: @@ -17,23 +19,24 @@ class CoverController(TypeController): - close_position (optional): The close position. Default is 0 """ + domains = ["cover"] + entity_arg = "cover" + + open_position: int + close_position: int + async def initialize(self) -> None: - self.cover = self.args["cover"] self.open_position = self.args.get("open_position", 100) self.close_position = self.args.get("close_position", 0) - update_supported_features = self.args.get("update_supported_features", False) if self.open_position < self.close_position: raise ValueError("`open_position` must be higher than `close_position`") - await self.check_domain(self.cover) - - self.supported_features = CoverSupport( - self.cover, self, update_supported_features - ) - await super().initialize() - def get_domain(self) -> List[str]: - return ["cover"] + def _get_entity_type(self) -> Type[Entity]: + return Entity + + def _get_feature_support_type(self) -> Type[CoverSupport]: + return CoverSupport def get_type_actions_mapping(self) -> TypeActionsMapping: return { @@ -46,45 +49,45 @@ def get_type_actions_mapping(self) -> TypeActionsMapping: @action async def open(self) -> None: - if await self.supported_features.is_supported(CoverSupport.SET_COVER_POSITION): + if await self.feature_support.is_supported(CoverSupport.SET_COVER_POSITION): await self.call_service( "cover/set_cover_position", - entity_id=self.cover, + entity_id=self.entity.name, position=self.open_position, ) - elif await self.supported_features.is_supported(CoverSupport.OPEN): - await self.call_service("cover/open_cover", entity_id=self.cover) + elif await self.feature_support.is_supported(CoverSupport.OPEN): + await self.call_service("cover/open_cover", entity_id=self.entity.name) else: self.log( - f"⚠️ `{self.cover}` does not support SET_COVER_POSITION or OPEN", + f"⚠️ `{self.entity.name}` does not support SET_COVER_POSITION or OPEN", level="WARNING", ascii_encode=False, ) @action async def close(self) -> None: - if await self.supported_features.is_supported(CoverSupport.SET_COVER_POSITION): + if await self.feature_support.is_supported(CoverSupport.SET_COVER_POSITION): await self.call_service( "cover/set_cover_position", - entity_id=self.cover, + entity_id=self.entity.name, position=self.close_position, ) - elif await self.supported_features.is_supported(CoverSupport.CLOSE): - await self.call_service("cover/close_cover", entity_id=self.cover) + elif await self.feature_support.is_supported(CoverSupport.CLOSE): + await self.call_service("cover/close_cover", entity_id=self.entity.name) else: self.log( - f"⚠️ `{self.cover}` does not support SET_COVER_POSITION or CLOSE", + f"⚠️ `{self.entity.name}` does not support SET_COVER_POSITION or CLOSE", level="WARNING", ascii_encode=False, ) @action async def stop(self) -> None: - await self.call_service("cover/stop_cover", entity_id=self.cover) + await self.call_service("cover/stop_cover", entity_id=self.entity.name) @action async def toggle(self, action: Callable) -> None: - cover_state = await self.get_entity_state(self.cover) + cover_state = await self.get_entity_state(self.entity.name) if cover_state == "opening" or cover_state == "closing": await self.stop() else: diff --git a/apps/controllerx/cx_core/type/light_controller.py b/apps/controllerx/cx_core/type/light_controller.py index 9c152bcc..9bb47d63 100644 --- a/apps/controllerx/cx_core/type/light_controller.py +++ b/apps/controllerx/cx_core/type/light_controller.py @@ -1,12 +1,20 @@ -from typing import Any, Dict, List, Union +import sys +from typing import Any, Dict, Optional, Type, Union from cx_const import Light, TypeActionsMapping from cx_core.color_helper import get_color_wheel -from cx_core.controller import ReleaseHoldController, TypeController, action +from cx_core.controller import ReleaseHoldController, action from cx_core.feature_support.light import LightSupport from cx_core.stepper import Stepper from cx_core.stepper.circular_stepper import CircularStepper from cx_core.stepper.minmax_stepper import MinMaxStepper +from cx_core.type_controller import Entity, TypeController + +if sys.version_info[1] < 8: + from typing_extensions import Literal +else: + from typing import Literal + DEFAULT_MANUAL_STEPS = 10 DEFAULT_AUTOMATIC_STEPS = 10 @@ -17,11 +25,21 @@ DEFAULT_MIN_COLOR_TEMP = 153 DEFAULT_MAX_COLOR_TEMP = 500 DEFAULT_TRANSITION = 300 +DEFAULT_ADD_TRANSITION = True +DEFAULT_TRANSITION_TURN_TOGGLE = False + +ColorMode = Literal["auto", "xy_color", "color_temp"] -LightEntity = Dict[str, str] +class LightEntity(Entity): + color_mode: ColorMode -class LightController(TypeController, ReleaseHoldController): + def __init__(self, name: str, color_mode: ColorMode = "auto") -> None: + super().__init__(name) + self.color_mode = color_mode + + +class LightController(TypeController[LightEntity, LightSupport], ReleaseHoldController): """ This is the main class that controls the lights for different devices. Type of actions: @@ -49,9 +67,10 @@ class LightController(TypeController, ReleaseHoldController): index_color = 0 value_attribute = None + domains = ["light"] + entity_arg = "light" + async def initialize(self) -> None: - self.light = self.get_light(self.args["light"]) - await self.check_domain(self.light["name"]) manual_steps = self.args.get("manual_steps", DEFAULT_MANUAL_STEPS) automatic_steps = self.args.get("automatic_steps", DEFAULT_AUTOMATIC_STEPS) self.min_brightness = self.args.get("min_brightness", DEFAULT_MIN_BRIGHTNESS) @@ -95,19 +114,17 @@ async def initialize(self) -> None: self.smooth_power_on = self.args.get( "smooth_power_on", self.supports_smooth_power_on() ) - self.add_transition = self.args.get("add_transition", True) + self.add_transition = self.args.get("add_transition", DEFAULT_ADD_TRANSITION) self.add_transition_turn_toggle = self.args.get( - "add_transition_turn_toggle", False - ) - update_supported_features = self.args.get("update_supported_features", False) - - self.supported_features = LightSupport( - self.light["name"], self, update_supported_features + "add_transition_turn_toggle", DEFAULT_TRANSITION_TURN_TOGGLE ) await super().initialize() - def get_domain(self) -> List[str]: - return ["light"] + def _get_entity_type(self) -> Type[LightEntity]: + return LightEntity + + def _get_feature_support_type(self) -> Type[LightSupport]: + return LightSupport def get_type_actions_mapping(self) -> TypeActionsMapping: return { @@ -306,17 +323,6 @@ def get_type_actions_mapping(self) -> TypeActionsMapping: ), } - def get_light(self, light: Union[str, dict]) -> LightEntity: - if isinstance(light, str): - return {"name": light, "color_mode": "auto"} - elif isinstance(light, dict): - color_mode = light.get("color_mode", "auto") - return {"name": light["name"], "color_mode": color_mode} - else: - raise ValueError( - f"Type {type(light)} is not supported for `light` attribute" - ) - async def call_light_service( self, service: str, turned_toggle: bool, **attributes ) -> None: @@ -326,15 +332,15 @@ async def call_light_service( if ( not self.add_transition or (turned_toggle and not self.add_transition_turn_toggle) - or await self.supported_features.not_supported(LightSupport.TRANSITION) + or await self.feature_support.not_supported(LightSupport.TRANSITION) ): del attributes["transition"] - await self.call_service(service, entity_id=self.light["name"], **attributes) + await self.call_service(service, entity_id=self.entity.name, **attributes) @action - async def on(self, light_on: bool = None, **attributes) -> None: + async def on(self, light_on: Optional[bool] = None, **attributes) -> None: if light_on is None: - light_state = await self.get_entity_state(self.light["name"]) + light_state = await self.get_entity_state(self.entity.name) light_on = light_state == "on" await self.call_light_service( "light/turn_on", turned_toggle=not light_on, **attributes @@ -352,7 +358,7 @@ async def toggle(self, **attributes) -> None: @action async def set_value( - self, attribute: str, fraction: float, light_on: bool = None + self, attribute: str, fraction: float, light_on: Optional[bool] = None ) -> None: fraction = max(0, min(fraction, 1)) stepper = self.automatic_steppers[attribute] @@ -375,11 +381,11 @@ async def toggle_min(self, attribute: str) -> None: await self.toggle(**{attribute: stepper.minmax.min}) @action - async def on_full(self, attribute: str, light_on: bool = None) -> None: + async def on_full(self, attribute: str, light_on: Optional[bool] = None) -> None: await self.set_value(attribute, 1, light_on=light_on) @action - async def on_min(self, attribute: str, light_on: bool = None) -> None: + async def on_min(self, attribute: str, light_on: Optional[bool] = None) -> None: await self.set_value(attribute, 0, light_on=light_on) @action @@ -401,19 +407,17 @@ async def sync(self) -> None: async def get_attribute(self, attribute: str) -> str: if attribute == LightController.ATTRIBUTE_COLOR: - if self.light["color_mode"] == "auto": - if await self.supported_features.is_supported(LightSupport.COLOR): + if self.entity.color_mode == "auto": + if await self.feature_support.is_supported(LightSupport.COLOR): return LightController.ATTRIBUTE_XY_COLOR - elif await self.supported_features.is_supported( - LightSupport.COLOR_TEMP - ): + elif await self.feature_support.is_supported(LightSupport.COLOR_TEMP): return LightController.ATTRIBUTE_COLOR_TEMP else: raise ValueError( "This light does not support xy_color or color_temp" ) else: - return self.light["color_mode"] + return self.entity.color_mode else: return attribute @@ -421,7 +425,7 @@ async def get_value_attribute( self, attribute: str, direction: str ) -> Union[float, int]: if self.check_smooth_power_on( - attribute, direction, await self.get_entity_state(self.light["name"]) + attribute, direction, await self.get_entity_state(self.entity.name) ): return 0 if attribute == LightController.ATTRIBUTE_XY_COLOR: @@ -431,11 +435,11 @@ async def get_value_attribute( or attribute == LightController.ATTRIBUTE_WHITE_VALUE or attribute == LightController.ATTRIBUTE_COLOR_TEMP ): - value = await self.get_entity_state(self.light["name"], attribute) + value = await self.get_entity_state(self.entity.name, attribute) if value is None: raise ValueError( f"Value for `{attribute}` attribute could not be retrieved " - f"from `{self.light['name']}`. " + f"from `{self.entity.name}`. " "Check the FAQ to know more about this error: " "https://xaviml.github.io/controllerx/faq" ) @@ -464,7 +468,7 @@ async def before_action(self, action: str, *args, **kwargs) -> bool: to_return = True if action == "click" or action == "hold": attribute, direction = args - light_state = await self.get_entity_state(self.light["name"]) + light_state = await self.get_entity_state(self.entity.name) to_return = light_state == "on" or self.check_smooth_power_on( attribute, direction, light_state ) @@ -541,7 +545,7 @@ async def change_light_state( # would be to force the loop to stop after 4 or 5 loops as a safety measure. return False if self.check_smooth_power_on( - attribute, direction, await self.get_entity_state(self.light["name"]) + attribute, direction, await self.get_entity_state(self.entity.name) ): await self.on_min(attribute, light_on=False) # # After smooth power on, the light should not brighten up. diff --git a/apps/controllerx/cx_core/type/media_player_controller.py b/apps/controllerx/cx_core/type/media_player_controller.py index 93c82673..b25ac4f5 100644 --- a/apps/controllerx/cx_core/type/media_player_controller.py +++ b/apps/controllerx/cx_core/type/media_player_controller.py @@ -1,31 +1,34 @@ -from typing import List +from typing import Type from cx_const import MediaPlayer, TypeActionsMapping -from cx_core.controller import ReleaseHoldController, TypeController, action +from cx_core.controller import ReleaseHoldController, action from cx_core.feature_support.media_player import MediaPlayerSupport from cx_core.stepper import Stepper from cx_core.stepper.circular_stepper import CircularStepper from cx_core.stepper.minmax_stepper import MinMaxStepper +from cx_core.type_controller import Entity, TypeController DEFAULT_VOLUME_STEPS = 10 -class MediaPlayerController(TypeController, ReleaseHoldController): +class MediaPlayerController( + TypeController[Entity, MediaPlayerSupport], ReleaseHoldController +): + + domains = ["media_player"] + entity_arg = "media_player" + async def initialize(self) -> None: - self.media_player = self.args["media_player"] - await self.check_domain(self.media_player) volume_steps = self.args.get("volume_steps", DEFAULT_VOLUME_STEPS) - update_supported_features = self.args.get("update_supported_features", False) self.volume_stepper = MinMaxStepper(0, 1, volume_steps) self.volume_level = 0.0 - - self.supported_features = MediaPlayerSupport( - self.media_player, self, update_supported_features - ) await super().initialize() - def get_domain(self) -> List[str]: - return ["media_player"] + def _get_entity_type(self) -> Type[Entity]: + return Entity + + def _get_feature_support_type(self) -> Type[MediaPlayerSupport]: + return MediaPlayerSupport def get_type_actions_mapping(self) -> TypeActionsMapping: return { @@ -45,12 +48,12 @@ def get_type_actions_mapping(self) -> TypeActionsMapping: @action async def change_source_list(self, direction: str) -> None: - entity_states = await self.get_entity_state(self.media_player, attribute="all") + entity_states = await self.get_entity_state(self.entity.name, attribute="all") entity_attributes = entity_states["attributes"] source_list = entity_attributes.get("source_list") if len(source_list) == 0 or source_list is None: self.log( - f"⚠️ There is no `source_list` parameter in `{self.media_player}`", + f"⚠️ There is no `source_list` parameter in `{self.entity.name}`", level="WARNING", ascii_encode=False, ) @@ -64,34 +67,34 @@ async def change_source_list(self, direction: str) -> None: new_index_source, _ = source_stepper.step(index_source, direction) await self.call_service( "media_player/select_source", - entity_id=self.media_player, + entity_id=self.entity.name, source=source_list[new_index_source], ) @action async def play(self) -> None: - await self.call_service("media_player/media_play", entity_id=self.media_player) + await self.call_service("media_player/media_play", entity_id=self.entity.name) @action async def pause(self) -> None: - await self.call_service("media_player/media_pause", entity_id=self.media_player) + await self.call_service("media_player/media_pause", entity_id=self.entity.name) @action async def play_pause(self) -> None: await self.call_service( - "media_player/media_play_pause", entity_id=self.media_player + "media_player/media_play_pause", entity_id=self.entity.name ) @action async def previous_track(self) -> None: await self.call_service( - "media_player/media_previous_track", entity_id=self.media_player + "media_player/media_previous_track", entity_id=self.entity.name ) @action async def next_track(self) -> None: await self.call_service( - "media_player/media_next_track", entity_id=self.media_player + "media_player/media_next_track", entity_id=self.entity.name ) @action @@ -111,30 +114,30 @@ async def hold(self, direction: str) -> None: async def prepare_volume_change(self) -> None: volume_level = await self.get_entity_state( - self.media_player, attribute="volume_level" + self.entity.name, attribute="volume_level" ) if volume_level is not None: self.volume_level = volume_level async def volume_change(self, direction: str) -> bool: - if await self.supported_features.is_supported(MediaPlayerSupport.VOLUME_SET): + if await self.feature_support.is_supported(MediaPlayerSupport.VOLUME_SET): self.volume_level, exceeded = self.volume_stepper.step( self.volume_level, direction ) await self.call_service( "media_player/volume_set", - entity_id=self.media_player, + entity_id=self.entity.name, volume_level=self.volume_level, ) return exceeded else: if direction == Stepper.UP: await self.call_service( - "media_player/volume_up", entity_id=self.media_player + "media_player/volume_up", entity_id=self.entity.name ) else: await self.call_service( - "media_player/volume_down", entity_id=self.media_player + "media_player/volume_down", entity_id=self.entity.name ) return False diff --git a/apps/controllerx/cx_core/type/switch_controller.py b/apps/controllerx/cx_core/type/switch_controller.py index 29a67766..4e9122d6 100644 --- a/apps/controllerx/cx_core/type/switch_controller.py +++ b/apps/controllerx/cx_core/type/switch_controller.py @@ -1,10 +1,12 @@ -from typing import List +from typing import Type from cx_const import Switch, TypeActionsMapping -from cx_core.controller import TypeController, action +from cx_core.controller import action +from cx_core.feature_support import FeatureSupport +from cx_core.type_controller import Entity, TypeController -class SwitchController(TypeController): +class SwitchController(TypeController[Entity, FeatureSupport]): """ This is the main class that controls the switches for different devices. Type of actions: @@ -14,22 +16,17 @@ class SwitchController(TypeController): - switch (required): Switch entity name """ - async def initialize(self) -> None: - self.switch = self.args["switch"] - await self.check_domain(self.switch) - await super().initialize() - - def get_domain(self) -> List[str]: - return [ - "alert", - "automation", - "cover", - "input_boolean", - "light", - "media_player", - "script", - "switch", - ] + domains = [ + "alert", + "automation", + "cover", + "input_boolean", + "light", + "media_player", + "script", + "switch", + ] + entity_arg = "switch" def get_type_actions_mapping(self) -> TypeActionsMapping: return { @@ -38,14 +35,20 @@ def get_type_actions_mapping(self) -> TypeActionsMapping: Switch.TOGGLE: self.toggle, } + def _get_entity_type(self) -> Type[Entity]: + return Entity + + def _get_feature_support_type(self) -> Type[FeatureSupport]: + return FeatureSupport + @action async def on(self) -> None: - await self.call_service("homeassistant/turn_on", entity_id=self.switch) + await self.call_service("homeassistant/turn_on", entity_id=self.entity.name) @action async def off(self) -> None: - await self.call_service("homeassistant/turn_off", entity_id=self.switch) + await self.call_service("homeassistant/turn_off", entity_id=self.entity.name) @action async def toggle(self) -> None: - await self.call_service("homeassistant/toggle", entity_id=self.switch) + await self.call_service("homeassistant/toggle", entity_id=self.entity.name) diff --git a/apps/controllerx/cx_core/type_controller.py b/apps/controllerx/cx_core/type_controller.py new file mode 100644 index 00000000..846b48bb --- /dev/null +++ b/apps/controllerx/cx_core/type_controller.py @@ -0,0 +1,86 @@ +import abc +from typing import Any, Generic, List, Optional, Type, TypeVar, Union + +from cx_core.controller import Controller +from cx_core.feature_support import FeatureSupportType + +EntityType = TypeVar("EntityType", bound="Entity") + + +class Entity: + name: str + + def __init__(self, name: str) -> None: + self.name = name + + @classmethod + def instantiate(cls: Type[EntityType], **params) -> EntityType: + return cls(**params) + + +class TypeController(Controller, abc.ABC, Generic[EntityType, FeatureSupportType]): + + domains: List[str] + entity_arg: str + entity: EntityType + feature_support: FeatureSupportType + + async def initialize(self) -> None: + self.entity = self.get_entity(self.args[self.entity_arg]) + await self.check_domain(self.entity.name) + update_supported_features = self.args.get("update_supported_features", False) + self.feature_support = self._get_feature_support_type().instantiate( + self.entity.name, self, update_supported_features + ) + await super().initialize() + + @abc.abstractmethod + def _get_entity_type(self) -> Type[EntityType]: + raise NotImplementedError + + @abc.abstractmethod + def _get_feature_support_type(self) -> Type[FeatureSupportType]: + raise NotImplementedError + + def get_entity(self, entity: Union[str, dict]) -> EntityType: + if isinstance(entity, str): + return self._get_entity_type().instantiate(name=entity) + elif isinstance(entity, dict): + return self._get_entity_type().instantiate(**entity) + else: + raise ValueError( + f"Type {type(entity)} is not supported for `{self.entity_arg}` attribute" + ) + + async def check_domain(self, entity_name: str) -> None: + if entity_name.startswith("group."): + entities = await self.get_state(entity_name, attribute="entity_id") + same_domain = all( + ( + any(elem.startswith(domain + ".") for domain in self.domains) + for elem in entities + ) + ) + if not same_domain: + raise ValueError( + f"All entities from '{entity_name}' must be from one " + f"of the following domains {self.domains} (e.g. {self.domains[0]}.bedroom)" + ) + elif not any(entity_name.startswith(domain + ".") for domain in self.domains): + raise ValueError( + f"'{entity_name}' must be from one of the following domains " + f"{self.domains} (e.g. {self.domains[0]}.bedroom)" + ) + + async def get_entity_state( + self, entity: str, attribute: Optional[str] = None + ) -> Any: + if entity.startswith("group."): + entities = await self.get_state(entity, attribute="entity_id") + if len(entities) == 0: + raise ValueError( + f"The group `{entity}` does not have any entities registered." + ) + entity = entities[0] + out = await self.get_state(entity, attribute=attribute) + return out diff --git a/apps/controllerx/cx_devices/aqara.py b/apps/controllerx/cx_devices/aqara.py index e647204d..35714693 100644 --- a/apps/controllerx/cx_devices/aqara.py +++ b/apps/controllerx/cx_devices/aqara.py @@ -1,5 +1,6 @@ from cx_const import Light, Switch, TypeActionsMapping from cx_core import LightController, SwitchController +from cx_core.integration import EventData class WXKG02LMLightController(LightController): @@ -83,7 +84,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "quadruple": Light.SET_HALF_BRIGHTNESS, } - def get_zha_action(self, data: dict) -> str: + def get_zha_action(self, data: EventData) -> str: return data["args"]["click_type"] @@ -112,7 +113,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "quadruple": Light.SET_HALF_BRIGHTNESS, } - def get_zha_action(self, data: dict) -> str: + def get_zha_action(self, data: EventData) -> str: mapping = { 1: "single", 2: "double", @@ -193,7 +194,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "rotate_right": Light.CLICK_BRIGHTNESS_UP, } - def get_zha_action(self, data: dict) -> str: + def get_zha_action(self, data: EventData) -> str: command = action = data["command"] args = data.get("args", {}) if command == "flip": diff --git a/apps/controllerx/cx_devices/legrand.py b/apps/controllerx/cx_devices/legrand.py index d4b89d6a..4676e5f4 100644 --- a/apps/controllerx/cx_devices/legrand.py +++ b/apps/controllerx/cx_devices/legrand.py @@ -1,10 +1,9 @@ -from typing import Optional - from cx_const import Light, TypeActionsMapping from cx_core import LightController +from cx_core.integration import EventData -def get_zha_action_LegrandWallController(data: dict) -> Optional[str]: +def get_zha_action_LegrandWallController(data: dict) -> str: endpoint_id = data.get("endpoint_id", 1) command = action = data["command"] args = data.get("args", {}) @@ -25,7 +24,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "1_stop": Light.RELEASE, } - def get_zha_action(self, data: dict) -> Optional[str]: + def get_zha_action(self, data: EventData) -> str: return get_zha_action_LegrandWallController(data) @@ -44,5 +43,5 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "2_stop": Light.RELEASE, } - def get_zha_action(self, data: dict) -> Optional[str]: + def get_zha_action(self, data: EventData) -> str: return get_zha_action_LegrandWallController(data) diff --git a/apps/controllerx/cx_devices/osram.py b/apps/controllerx/cx_devices/osram.py index ea96614d..9792d7d6 100644 --- a/apps/controllerx/cx_devices/osram.py +++ b/apps/controllerx/cx_devices/osram.py @@ -1,6 +1,6 @@ -from typing import Optional from cx_const import Light, TypeActionsMapping from cx_core import LightController +from cx_core.integration import EventData class OsramAC025XX00NJLightController(LightController): @@ -20,7 +20,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "2_stop": Light.RELEASE, } - def get_zha_action(self, data: dict) -> Optional[str]: + def get_zha_action(self, data: EventData) -> str: return f"{data['endpoint_id']}_{data['command']}" diff --git a/apps/controllerx/cx_devices/phillips.py b/apps/controllerx/cx_devices/phillips.py index b723525f..025badd7 100644 --- a/apps/controllerx/cx_devices/phillips.py +++ b/apps/controllerx/cx_devices/phillips.py @@ -1,6 +1,6 @@ -from typing import Any, Dict from cx_const import Light, TypeActionsMapping from cx_core import LightController +from cx_core.integration import EventData class HueDimmerController(LightController): @@ -57,7 +57,7 @@ def get_zha_actions_mapping(self) -> TypeActionsMapping: "on_short_release": Light.ON, } - def get_zha_action(self, data: Dict[Any, Any]) -> str: + def get_zha_action(self, data: EventData) -> str: return data["command"] diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 69229c91..ce746d88 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -21,7 +21,6 @@ stages: jobs: - job: Build displayName: Build job - strategy: matrix: Python36: @@ -41,11 +40,13 @@ stages: displayName: Lock dependencies - script: pipenv install --system --deploy --dev displayName: Install dependencies - - script: black apps/controllerx tests --check + - script: isort apps/controllerx/ tests/ --check + displayName: Organize imports (isort) + - script: black apps/controllerx/ tests/ --check displayName: Formatter (black) - - script: flake8 apps/controllerx tests + - script: flake8 apps/controllerx/ tests/ displayName: Styling (flake8) - - script: mypy apps/controllerx + - script: mypy apps/controllerx/ tests/ displayName: Typing (mypy) - script: pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=apps --cov-report=xml --cov-report=html displayName: Tests (pytest) diff --git a/hacs.json b/hacs.json index f67f6bb7..c3454fde 100644 --- a/hacs.json +++ b/hacs.json @@ -1,7 +1,7 @@ { "filename": "controllerx.zip", "hide_default_branch": true, - "name": "ControllerX", + "name": "🎮 ControllerX", "render_readme": true, "zip_release": true } diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 86d57a40..00000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -mock_use_standalone_module = true diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..da7180ac --- /dev/null +++ b/setup.cfg @@ -0,0 +1,30 @@ +[isort] +profile=black + +[black] +line-length = 88 +exclude = .git,.hg,.mypy_cache,.tox,_build,buck-out,build,dist + +[flake8] +ignore = E501, W503 +exclude = controllerx.py +max-line-length = 88 + +[mypy] +python_version = 3.6 +namespace_packages = True +no_implicit_optional = True + +[tool:pytest] +mock_use_standalone_module = true + +[report] +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain if tests don't hit defensive assertion code: + raise NotImplementedError + +[mypy-appdaemon.*] +ignore_missing_imports = true diff --git a/tests/conftest.py b/tests/conftest.py index fe762311..5b95a790 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ import asyncio -import appdaemon.plugins.hass.hassapi as hass -import appdaemon.plugins.mqtt.mqttapi as mqtt -import pytest +import appdaemon.plugins.hass.hassapi as hass # type: ignore +import appdaemon.plugins.mqtt.mqttapi as mqtt # type: ignore +import pytest +from _pytest.monkeypatch import MonkeyPatch +from cx_core import LightController from cx_core.controller import Controller + from tests.test_utils import fake_fn @@ -21,7 +24,7 @@ async def fake_cancel_timer(self, task): @pytest.fixture(autouse=True) -def hass_mock(monkeypatch, mocker): +def hass_mock(monkeypatch: MonkeyPatch): """ Fixture for set up the tests, mocking appdaemon functions """ @@ -37,8 +40,15 @@ def hass_mock(monkeypatch, mocker): monkeypatch.setattr(hass.Hass, "cancel_timer", fake_cancel_timer) -@pytest.fixture(autouse=True) -def fake_controller(hass_mock): +@pytest.fixture +def fake_controller() -> Controller: c = Controller() # type: ignore c.args = {} return c + + +@pytest.fixture +def fake_type_controller() -> LightController: + c = LightController() # type: ignore + c.args = {} + return c diff --git a/tests/integ_tests/integ_test.py b/tests/integ_tests/integ_test.py index ffa9a6e0..3070c6aa 100644 --- a/tests/integ_tests/integ_test.py +++ b/tests/integ_tests/integ_test.py @@ -1,10 +1,13 @@ import asyncio import glob from pathlib import Path -from tests.test_utils import get_controller +from typing import Any, Dict import pytest import yaml +from pytest_mock.plugin import MockerFixture + +from tests.test_utils import get_controller def get_integ_tests(): @@ -38,7 +41,9 @@ async def inner(entity_id, attribute=None): @pytest.mark.asyncio @pytest.mark.parametrize("config_file, data", integration_tests) -async def test_integ_configs(hass_mock, mocker, config_file, data): +async def test_integ_configs( + mocker: MockerFixture, config_file: str, data: Dict[str, Any] +): entity_state_attributes = data.get("entity_state_attributes", {}) entity_state = data.get("entity_state", None) fired_actions = data.get("fired_actions", []) diff --git a/tests/integ_tests/switch_controller/config.yaml b/tests/integ_tests/switch_controller/config.yaml new file mode 100644 index 00000000..04d6827e --- /dev/null +++ b/tests/integ_tests/switch_controller/config.yaml @@ -0,0 +1,9 @@ +example_app: + module: controllerx + class: SwitchController + controller: sensor.my_controller + integration: z2m + switch: + name: switch.my_switch + mapping: + toggle: toggle \ No newline at end of file diff --git a/tests/integ_tests/switch_controller/toggle_called_test.yaml b/tests/integ_tests/switch_controller/toggle_called_test.yaml new file mode 100644 index 00000000..0f997a98 --- /dev/null +++ b/tests/integ_tests/switch_controller/toggle_called_test.yaml @@ -0,0 +1,5 @@ +fired_actions: [toggle] +expected_calls: +- service: homeassistant/toggle + data: + entity_id: switch.my_switch diff --git a/tests/test_utils.py b/tests/test_utils.py index 783ba797..73e3c47b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,18 +1,29 @@ import importlib import os import pkgutil +from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable, Generator, Optional + +import pytest +from _pytest._code.code import ExceptionInfo +from mock import MagicMock +from pytest_mock.plugin import MockerFixture + +if TYPE_CHECKING: + from cx_core.controller import Controller class IntegrationMock: - def __init__(self, name, controller, mocker): + def __init__(self, name: str, controller: "Controller", mocker: MockerFixture): self.name = name self.controller = controller - self.get_actions_mapping = mocker.stub(name="get_actions_mapping") + self.get_actions_mapping = MagicMock( + name="get_actions_mapping", return_value={} + ) self.listen_changes = mocker.stub(name="listen_changes") - super().__init__() -def fake_fn(async_=False, to_return=None): +def fake_fn(to_return=None, async_: bool = False) -> Callable: async def inner_fake_async_fn(*args, **kwargs): return to_return @@ -22,10 +33,10 @@ def inner_fake_fn(*args, **kwargs): return inner_fake_async_fn if async_ else inner_fake_fn -def get_controller(module_name, class_name): +def get_controller(module_name, class_name) -> Optional["Controller"]: module = importlib.import_module(module_name) class_ = getattr(module, class_name, None) - return class_() if class_ is not None else class_ + return class_() if class_ is not None else None def _import_modules(file_dir, package): @@ -51,6 +62,17 @@ def get_classes(file_, package_, class_, instantiate=False): subclasses = [ cls_() if instantiate else cls_ for cls_ in subclasses - if len(cls_.__subclasses__()) == 0 and package_ in cls_.__module__ + if f"{package_}." in cls_.__module__ ] return subclasses + + +@contextmanager +def wrap_exetuction( + *, error_expected: bool, exception=Exception +) -> Generator[Optional[ExceptionInfo], None, None]: + if error_expected: + with pytest.raises(exception) as err_info: + yield err_info + else: + yield None diff --git a/tests/unit_tests/cx_core/color_helper_test.py b/tests/unit_tests/cx_core/color_helper_test.py index 2dd2ed37..be200663 100644 --- a/tests/unit_tests/cx_core/color_helper_test.py +++ b/tests/unit_tests/cx_core/color_helper_test.py @@ -1,22 +1,18 @@ import pytest +from cx_core.color_helper import Colors, get_color_wheel -from cx_core.color_helper import get_color_wheel +from tests.test_utils import wrap_exetuction @pytest.mark.parametrize( "colors, error_expected", [ - ("default_color_wheel", None), - ("non_existing", ValueError), - ([(0.2, 0.3), (0.4, 0.5)], None), - (0, ValueError), + ("default_color_wheel", False), + ("non_existing", True), + ([(0.2, 0.3), (0.4, 0.5)], False), + (0, True), ], ) -def test_get_color_wheel(colors, error_expected): - - # SUT - if error_expected: - with pytest.raises(error_expected): - colors = get_color_wheel(colors) - else: +def test_get_color_wheel(colors: Colors, error_expected: bool): + with wrap_exetuction(error_expected=error_expected, exception=ValueError): colors = get_color_wheel(colors) diff --git a/tests/unit_tests/cx_core/controller_test.py b/tests/unit_tests/cx_core/controller_test.py index 5a23a9b2..a46390aa 100644 --- a/tests/unit_tests/cx_core/controller_test.py +++ b/tests/unit_tests/cx_core/controller_test.py @@ -1,19 +1,39 @@ from collections import defaultdict +from typing import Any, Dict, List, Optional, Union -import appdaemon.plugins.hass.hassapi as hass +import appdaemon.plugins.hass.hassapi as hass # type: ignore import pytest - +from cx_const import ActionEvent, ActionFunction, TypeAction, TypeActionsMapping from cx_core import integration as integration_module -from cx_core.controller import action -from tests.test_utils import IntegrationMock, fake_fn +from cx_core.controller import Controller, action +from pytest_mock.plugin import MockerFixture + +from tests.test_utils import IntegrationMock, fake_fn, wrap_exetuction + +INTEGRATION_TEST_NAME = "test" +CONTROLLER_NAME = "test_controller" @pytest.fixture -def sut(fake_controller): - fake_controller.multiple_click_actions = set() +def sut_before_init(fake_controller: Controller, mocker: MockerFixture) -> Controller: + fake_controller.args = { + "controller": CONTROLLER_NAME, + "integration": INTEGRATION_TEST_NAME, + } + integration_mock = IntegrationMock("test", fake_controller, mocker) + mocker.patch.object( + fake_controller, "get_integration", return_value=integration_mock + ) return fake_controller +@pytest.fixture +@pytest.mark.asyncio +async def sut(sut_before_init: Controller) -> Controller: + await sut_before_init.initialize() + return sut_before_init + + @pytest.mark.asyncio async def test_action_decorator(sut, mocker): stub_action = mocker.stub() @@ -31,7 +51,7 @@ async def fake_action(self): @pytest.mark.parametrize( - "controller_input, actions_input, included_actions, excluded_actions, actions_output, expect_an_error", + "controller_input, actions_input, included_actions, excluded_actions, actions_output, error_expected", [ ( ["controller_id"], @@ -109,49 +129,47 @@ async def fake_action(self): ) @pytest.mark.asyncio async def test_initialize( - sut, - mocker, - monkeypatch, - controller_input, - actions_input, - included_actions, - excluded_actions, - actions_output, - expect_an_error, + sut_before_init: Controller, + mocker: MockerFixture, + controller_input: Union[str, List[str]], + actions_input: List[str], + included_actions: Optional[List[str]], + excluded_actions: Optional[List[str]], + actions_output: List[str], + error_expected: bool, ): actions = {action: action for action in actions_input} type_actions = {action: lambda: None for action in actions_input} - sut.args["controller"] = controller_input - sut.args["integration"] = "test" + sut_before_init.args["controller"] = controller_input + integration_mock = IntegrationMock(INTEGRATION_TEST_NAME, sut_before_init, mocker) + mocker.patch.object( + sut_before_init, "get_integration", return_value=integration_mock + ) if included_actions: - sut.args["actions"] = included_actions + sut_before_init.args["actions"] = included_actions if excluded_actions: - sut.args["excluded_actions"] = excluded_actions - integration_mock = IntegrationMock("test", sut, mocker) - monkeypatch.setattr(sut, "get_integration", lambda integration: integration_mock) - monkeypatch.setattr(sut, "get_actions_mapping", lambda integration: actions) - monkeypatch.setattr(sut, "get_type_actions_mapping", lambda: type_actions) - check_ad_version = mocker.patch.object(sut, "check_ad_version") - get_actions_mapping = mocker.spy(sut, "get_actions_mapping") + sut_before_init.args["excluded_actions"] = excluded_actions + mocker.patch.object(sut_before_init, "get_actions_mapping", return_value=actions) + mocker.patch.object( + sut_before_init, "get_type_actions_mapping", return_value=type_actions + ) + get_actions_mapping = mocker.spy(sut_before_init, "get_actions_mapping") # SUT - if expect_an_error: - with pytest.raises(ValueError): - await sut.initialize() - else: - await sut.initialize() + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + await sut_before_init.initialize() - # Checks - check_ad_version.assert_called_once() + # Checks + if not error_expected: get_actions_mapping.assert_called_once() for controller_id in controller_input: integration_mock.listen_changes.assert_any_call(controller_id) assert integration_mock.listen_changes.call_count == len(controller_input) - assert list(sut.actions_mapping.keys()) == actions_output + assert list(sut_before_init.actions_mapping.keys()) == actions_output @pytest.mark.parametrize( - "mapping, merge_mapping, actions_output, expected_error", + "mapping, merge_mapping, actions_output, error_expected", [ (["action1"], None, ["action1"], False), (["action1", "action2"], None, ["action1", "action2"], False), @@ -163,31 +181,33 @@ async def test_initialize( ) @pytest.mark.asyncio async def test_merge_mapping( - sut, monkeypatch, mocker, mapping, merge_mapping, actions_output, expected_error + sut_before_init: Controller, + mocker: MockerFixture, + mapping: List[str], + merge_mapping: List[str], + actions_output: List[str], + error_expected: bool, ): actions_input = ["action1", "action2", "action3"] actions = {action: action for action in actions_input} type_actions = {action: lambda: None for action in actions_input} - sut.args["controller"] = "my_controller" - sut.args["integration"] = "test" if mapping: - sut.args["mapping"] = {item: item for item in mapping} + sut_before_init.args["mapping"] = {item: item for item in mapping} if merge_mapping: - sut.args["merge_mapping"] = {item: item for item in merge_mapping} - integration_mock = IntegrationMock("test", sut, mocker) - monkeypatch.setattr(sut, "get_integration", lambda integration: integration_mock) - monkeypatch.setattr(sut, "get_actions_mapping", lambda integration: actions) - monkeypatch.setattr(sut, "get_type_actions_mapping", lambda: type_actions) + sut_before_init.args["merge_mapping"] = {item: item for item in merge_mapping} + + mocker.patch.object(sut_before_init, "get_actions_mapping", return_value=actions) + mocker.patch.object( + sut_before_init, "get_type_actions_mapping", return_value=type_actions + ) # SUT - if expected_error: - with pytest.raises(ValueError): - await sut.initialize() - else: - await sut.initialize() + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + await sut_before_init.initialize() - # Checks - assert list(sut.actions_mapping.keys()) == actions_output + # Checks + if not error_expected: + assert list(sut_before_init.actions_mapping.keys()) == actions_output @pytest.mark.parametrize( @@ -201,7 +221,9 @@ async def test_merge_mapping( (["sensor 1", "sensor 2"], ["sensor 1", "sensor 2"]), ], ) -def test_get_list(sut, monkeypatch, test_input, expected): +def test_get_list( + sut: Controller, test_input: Union[List[str], str], expected: List[str] +): output = sut.get_list(test_input) assert output == expected @@ -222,23 +244,24 @@ def test_get_list(sut, monkeypatch, test_input, expected): (["toggle", "toggle$1", "toggle$2", "another$3"], ["toggle", "another"]), ], ) -def test_get_multiple_click_actions(sut, mapping, expected): - output = sut.get_multiple_click_actions({key: None for key in mapping}) +def test_get_multiple_click_actions( + sut: Controller, mapping: List[ActionEvent], expected: List[str] +): + output = sut.get_multiple_click_actions({key: "action" for key in mapping}) assert output == set(expected) @pytest.mark.parametrize( - "option,options,expect_an_error", + "option, options, error_expected", [ ("option1", ["option1", "option2", "option3"], False), ("option4", ["option1", "option2", "option3"], True), ], ) -def test_get_option(sut, option, options, expect_an_error): - if expect_an_error: - with pytest.raises(ValueError): - sut.get_option(option, options) - else: +def test_get_option( + sut: Controller, option: str, options: List[str], error_expected: bool +): + with wrap_exetuction(error_expected=error_expected, exception=ValueError): sut.get_option(option, options) @@ -258,29 +281,25 @@ def test_get_option(sut, option, options, expect_an_error): ], ) def test_get_integration( - sut, - mocker, - integration_input, - integration_name_expected, - args_expected, - error_expected, + fake_controller: Controller, + mocker: MockerFixture, + integration_input: Union[str, Dict[str, Any]], + integration_name_expected: str, + args_expected: Dict[str, Any], + error_expected: bool, ): get_integrations_spy = mocker.spy(integration_module, "get_integrations") - # SUT - if error_expected: - with pytest.raises(ValueError): - integration = sut.get_integration(integration_input) - else: - integration = sut.get_integration(integration_input) + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + integration = fake_controller.get_integration(integration_input) - # Checks - get_integrations_spy.assert_called_once_with(sut, args_expected) + if not error_expected: + get_integrations_spy.assert_called_once_with(fake_controller, args_expected) assert integration.name == integration_name_expected -def test_check_ad_version_throwing_error(sut, monkeypatch): - monkeypatch.setattr(sut, "get_ad_version", lambda: "3.0.0") +def test_check_ad_version_throwing_error(sut: Controller, mocker: MockerFixture): + mocker.patch.object(sut, "get_ad_version", return_value="3.0.0") with pytest.raises(ValueError) as e: sut.check_ad_version() assert str(e.value) == "Please upgrade to AppDaemon 4.x" @@ -289,20 +308,20 @@ def test_check_ad_version_throwing_error(sut, monkeypatch): def test_get_actions_mapping_happyflow(sut, monkeypatch, mocker): integration_mock = IntegrationMock("integration-test", sut, mocker) monkeypatch.setattr( - integration_mock, "get_actions_mapping", lambda: "this_is_mapping" + integration_mock, "get_actions_mapping", lambda: "this_is_a_mapping" ) mapping = sut.get_actions_mapping(integration_mock) - assert mapping == "this_is_mapping" + assert mapping == "this_is_a_mapping" -def test_get_actions_mapping_throwing_error(sut, monkeypatch, mocker): +def test_get_actions_mapping_throwing_error(sut: Controller, mocker: MockerFixture): integration_mock = IntegrationMock("integration-test", sut, mocker) - monkeypatch.setattr(integration_mock, "get_actions_mapping", lambda: None) + mocker.patch.object(integration_mock, "get_actions_mapping", return_value=None) with pytest.raises(ValueError) as e: - sut.get_actions_mapping(integration_mock) + sut.get_actions_mapping(integration_mock) # type: ignore assert str(e.value) == "This controller does not support integration-test." @@ -318,17 +337,18 @@ def test_get_actions_mapping_throwing_error(sut, monkeypatch, mocker): ) @pytest.mark.asyncio async def test_handle_action( - sut, - mocker, - actions_input, - action_called, - action_called_times, - action_delta, - expected_calls, + sut: Controller, + mocker: MockerFixture, + actions_input: List[ActionEvent], + action_called: str, + action_called_times: int, + action_delta: int, + expected_calls: int, ): sut.action_delta = action_delta sut.action_times = defaultdict(lambda: 0) - sut.actions_mapping = {action: None for action in actions_input} + actions_mapping: TypeActionsMapping = {action: "test" for action in actions_input} + sut.actions_mapping = actions_mapping call_action_patch = mocker.patch.object(sut, "call_action") # SUT @@ -349,19 +369,20 @@ async def test_handle_action( ) @pytest.mark.asyncio async def test_call_action( - sut, + sut: Controller, monkeypatch, - mocker, - delay, - handle, - cancel_timer_called, - run_in_called, - action_timer_callback_called, + mocker: MockerFixture, + delay: int, + handle: Optional[int], + cancel_timer_called: bool, + run_in_called: bool, + action_timer_callback_called: bool, ): action_key = "test" sut.actions_key_mapping = {"test": "test_action"} sut.action_delay = {action_key: delay} - sut.action_delay_handles = {action_key: handle} + action_delay_handles: Dict[ActionEvent, Optional[float]] = {action_key: handle} + sut.action_delay_handles = action_delay_handles monkeypatch.setattr(sut, "cancel_timer", fake_fn(async_=True)) monkeypatch.setattr(sut, "run_in", fake_fn(async_=True)) @@ -384,29 +405,32 @@ async def test_call_action( action_timer_callback_patch.assert_called_once_with({"action_key": action_key}) -def fake_action(): - pass - - @pytest.mark.parametrize( "test_input, expected, error_expected", [ - (fake_action, (fake_action,), False), - ((fake_action,), (fake_action,), False), - ((fake_action, "test"), (fake_action, "test"), False), + (fake_fn, (fake_fn,), False), + ((fake_fn,), (fake_fn,), False), + ((fake_fn, "test"), (fake_fn, "test"), False), ("not-list-or-function", (), True), ], ) -def test_get_action(sut, test_input, expected, error_expected): - if error_expected: - with pytest.raises(ValueError) as e: - output = sut.get_action(test_input) +def test_get_action( + sut: Controller, + test_input: TypeAction, + expected: ActionFunction, + error_expected: bool, +): + with wrap_exetuction( + error_expected=error_expected, exception=ValueError + ) as err_info: + output = sut.get_action(test_input) + + if err_info is not None: assert ( - str(e.value) + str(err_info.value) == "The action value from the action mapping should be a list or a function" ) else: - output = sut.get_action(test_input) assert output == expected @@ -415,12 +439,9 @@ def test_get_action(sut, test_input, expected, error_expected): [("test_service", {"attr1": 0.0, "attr2": "test"}), ("test_service", {})], ) @pytest.mark.asyncio -async def test_call_service(sut, mocker, service, attributes): - +async def test_call_service( + sut: Controller, mocker: MockerFixture, service: str, attributes: Dict[str, Any] +): call_service_stub = mocker.patch.object(hass.Hass, "call_service") - - # SUT await sut.call_service(service, **attributes) - - # Checker call_service_stub.assert_called_once_with(sut, service, **attributes) diff --git a/tests/unit_tests/cx_core/custom_controller_test.py b/tests/unit_tests/cx_core/custom_controller_test.py index eaf3d09c..f59dea90 100644 --- a/tests/unit_tests/cx_core/custom_controller_test.py +++ b/tests/unit_tests/cx_core/custom_controller_test.py @@ -1,73 +1,79 @@ -import pytest +from typing import Any, Dict, List, Tuple, Type +import pytest +from _pytest.monkeypatch import MonkeyPatch +from cx_const import TypeActionsMapping from cx_core import ( CallServiceController, Controller, - CustomLightController, - CustomMediaPlayerController, + CoverController, + LightController, + MediaPlayerController, + SwitchController, ) -from cx_core.custom_controller import CustomCoverController, CustomSwitchController +from cx_core.type_controller import TypeController +from pytest_mock.plugin import MockerFixture + from tests.test_utils import fake_fn @pytest.mark.parametrize( "custom_cls, mapping, action_input, mock_function, expected_calls", [ - (CustomLightController, {"action1": "on"}, "action1", "on", 1), - (CustomLightController, {"action1": "toggle"}, "action1", "toggle", 1), - (CustomLightController, {"action1": "off"}, "action1", "off", 1), + (LightController, {"action1": "on"}, "action1", "on", 1), + (LightController, {"action1": "toggle"}, "action1", "toggle", 1), + (LightController, {"action1": "off"}, "action1", "off", 1), ( - CustomLightController, + LightController, {"action1": "on_min_brightness"}, "action1", "on_min", 1, ), ( - CustomLightController, + LightController, {"action1": "hold_brightness_up"}, "action1", "hold", 1, ), ( - CustomLightController, + LightController, {"action1": "hold_brightness_up"}, "action2", "hold", 0, ), ( - CustomMediaPlayerController, + MediaPlayerController, {"action1": "play_pause"}, "action1", "play_pause", 1, ), ( - CustomMediaPlayerController, + MediaPlayerController, {"action1": "hold_volume_up"}, "action1", "hold", 1, ), - (CustomMediaPlayerController, {"action1": "release"}, "action1", "release", 1), - (CustomSwitchController, {"action1": "toggle"}, "action1", "toggle", 1), - (CustomCoverController, {"action1": "open"}, "action2", "open", 0), + (MediaPlayerController, {"action1": "release"}, "action1", "release", 1), + (SwitchController, {"action1": "toggle"}, "action1", "toggle", 1), + (CoverController, {"action1": "open"}, "action2", "open", 0), ], ) @pytest.mark.asyncio async def test_custom_controllers( - hass_mock, - monkeypatch, - mocker, - custom_cls, - mapping, - action_input, - mock_function, - expected_calls, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + custom_cls: Type[TypeController], + mapping: TypeActionsMapping, + action_input: str, + mock_function: str, + expected_calls: int, ): - sut = custom_cls() + sut = custom_cls() # type: ignore sut.args = { "controller": "test_controller", "integration": "z2m", @@ -78,13 +84,14 @@ async def test_custom_controllers( "mapping": mapping, } mocked = mocker.patch.object(sut, mock_function) - monkeypatch.setattr(sut, "get_entity_state", fake_fn(async_=True, to_return="0")) + # SUT await sut.initialize() sut.action_delta = 0 await sut.handle_action(action_input) + # Check assert mocked.call_count == expected_calls @@ -132,7 +139,11 @@ async def test_custom_controllers( ) @pytest.mark.asyncio async def test_call_service_controller( - hass_mock, monkeypatch, mocker, integration, services, expected_calls + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + integration: str, + services: List[Dict[str, Any]], + expected_calls: List[Tuple[str, Dict[str, Any]]], ): sut = CallServiceController() # type: ignore sut.args = { @@ -147,10 +158,12 @@ async def fake_call_service(self, service, **data): monkeypatch.setattr(Controller, "call_service", fake_call_service) + # SUT await sut.initialize() sut.action_delta = 0 await sut.handle_action("action") + # Checks assert call_service_stub.call_count == len(expected_calls) for expected_service, expected_data in expected_calls: call_service_stub.assert_any_call(expected_service, **expected_data) diff --git a/tests/unit_tests/cx_core/feature_support/cover_support_test.py b/tests/unit_tests/cx_core/feature_support/cover_support_test.py index 6fddb0dd..fcc8f256 100644 --- a/tests/unit_tests/cx_core/feature_support/cover_support_test.py +++ b/tests/unit_tests/cx_core/feature_support/cover_support_test.py @@ -1,6 +1,6 @@ -from cx_core.feature_support.cover import CoverSupport import pytest -from cx_core.feature_support import FeatureSupport +from cx_core.feature_support import FeatureSupport, SupportedFeatures +from cx_core.feature_support.cover import CoverSupport @pytest.mark.parametrize( @@ -28,9 +28,6 @@ (0, set()), ], ) -def test_init(number, expected_supported_features): - cover_support = CoverSupport("fake_entity", None, False) # type: ignore - cover_support._supported_features = FeatureSupport.decode( - number, cover_support.features - ) - assert cover_support._supported_features == expected_supported_features +def test_decode(number: int, expected_supported_features: SupportedFeatures): + supported_features = FeatureSupport.decode(number, CoverSupport.features) + assert supported_features == expected_supported_features diff --git a/tests/unit_tests/cx_core/feature_support/feature_support_test.py b/tests/unit_tests/cx_core/feature_support/feature_support_test.py index 6419e6ed..faccfe64 100644 --- a/tests/unit_tests/cx_core/feature_support/feature_support_test.py +++ b/tests/unit_tests/cx_core/feature_support/feature_support_test.py @@ -1,9 +1,10 @@ -from cx_core.feature_support import FeatureSupport import pytest +from cx_core.feature_support import Features, FeatureSupport, SupportedFeatures +from cx_core.type_controller import TypeController @pytest.mark.parametrize( - "number, features, expected_features", + "number, features, expected_supported_features", [ (15, [1, 2, 4, 8, 16, 32, 64], {1, 2, 4, 8}), (16, [1, 2, 4, 8, 16, 32, 64], {16}), @@ -11,9 +12,11 @@ (70, [1, 2, 4, 8, 16, 64], {2, 4, 64}), ], ) -def test_decode(number, features, expected_features): +def test_decode( + number: int, features: Features, expected_supported_features: SupportedFeatures +): supported_features = FeatureSupport.decode(number, features) - assert supported_features == expected_features + assert supported_features == expected_supported_features @pytest.mark.parametrize( @@ -26,13 +29,13 @@ def test_decode(number, features, expected_features): ({1, 2, 4, 8, 16, 64}, 95), ], ) -def test_encode(supported_features, expected_number): +def test_encode(supported_features: SupportedFeatures, expected_number: int): number = FeatureSupport.encode(supported_features) assert expected_number == number @pytest.mark.parametrize( - "number, features, feature, is_supported", + "number, features, feature, expected_is_supported", [ (15, [1, 2, 4, 8, 16, 32, 64], 16, False), (16, [1, 2, 4, 8, 16, 32, 64], 2, False), @@ -42,15 +45,22 @@ def test_encode(supported_features, expected_number): ], ) @pytest.mark.asyncio -async def test_is_supported(number, features, feature, is_supported): - feature_support = FeatureSupport("fake_entity", None, features, False) # type: ignore +async def test_is_supported( + fake_type_controller: TypeController, + number: int, + features: Features, + feature: int, + expected_is_supported: bool, +): + feature_support = FeatureSupport("fake_entity", fake_type_controller, False) + feature_support.features = features feature_support._supported_features = FeatureSupport.decode(number, features) is_supported = await feature_support.is_supported(feature) - assert is_supported == is_supported + assert is_supported == expected_is_supported @pytest.mark.parametrize( - "number, features, feature, is_supported", + "number, features, feature, expected_is_supported", [ (15, [1, 2, 4, 8, 16, 32, 64], 16, True), (16, [1, 2, 4, 8, 16, 32, 64], 2, True), @@ -60,8 +70,15 @@ async def test_is_supported(number, features, feature, is_supported): ], ) @pytest.mark.asyncio -async def test_not_supported(number, features, feature, is_supported): - feature_support = FeatureSupport("fake_entity", None, features, False) # type: ignore +async def test_not_supported( + fake_type_controller: TypeController, + number: int, + features: Features, + feature: int, + expected_is_supported: bool, +): + feature_support = FeatureSupport("fake_entity", fake_type_controller, False) + feature_support.features = features feature_support._supported_features = FeatureSupport.decode(number, features) is_supported = await feature_support.not_supported(feature) - assert is_supported == is_supported + assert is_supported == expected_is_supported diff --git a/tests/unit_tests/cx_core/feature_support/light_support_test.py b/tests/unit_tests/cx_core/feature_support/light_support_test.py index 81737e9e..77bf03d1 100644 --- a/tests/unit_tests/cx_core/feature_support/light_support_test.py +++ b/tests/unit_tests/cx_core/feature_support/light_support_test.py @@ -1,6 +1,5 @@ import pytest - -from cx_core.feature_support import FeatureSupport +from cx_core.feature_support import FeatureSupport, SupportedFeatures from cx_core.feature_support.light import LightSupport @@ -29,9 +28,6 @@ (0, set()), ], ) -def test_init(number, expected_supported_features): - light_support = LightSupport("fake_entity", None, False) # type: ignore - light_support._supported_features = FeatureSupport.decode( - number, light_support.features - ) - assert light_support._supported_features == expected_supported_features +def test_decode(number: int, expected_supported_features: SupportedFeatures): + supported_features = FeatureSupport.decode(number, LightSupport.features) + assert supported_features == expected_supported_features diff --git a/tests/unit_tests/cx_core/feature_support/media_player_support_test.py b/tests/unit_tests/cx_core/feature_support/media_player_support_test.py index 29f29d3e..ed490ecc 100644 --- a/tests/unit_tests/cx_core/feature_support/media_player_support_test.py +++ b/tests/unit_tests/cx_core/feature_support/media_player_support_test.py @@ -1,6 +1,5 @@ import pytest - -from cx_core.feature_support import FeatureSupport +from cx_core.feature_support import FeatureSupport, SupportedFeatures from cx_core.feature_support.media_player import MediaPlayerSupport @@ -30,9 +29,6 @@ (0, set()), ], ) -def test_init(number, expected_supported_features): - media_player_support = MediaPlayerSupport("fake_entity", None, False) # type: ignore - media_player_support._supported_features = FeatureSupport.decode( - number, media_player_support.features - ) - assert media_player_support._supported_features == expected_supported_features +def test_decode(number: int, expected_supported_features: SupportedFeatures): + supported_features = FeatureSupport.decode(number, MediaPlayerSupport.features) + assert supported_features == expected_supported_features diff --git a/tests/unit_tests/cx_core/integration/integration_test.py b/tests/unit_tests/cx_core/integration/integration_test.py index f5ebe10b..8f3f01c7 100644 --- a/tests/unit_tests/cx_core/integration/integration_test.py +++ b/tests/unit_tests/cx_core/integration/integration_test.py @@ -1,7 +1,8 @@ from cx_core import integration as integration_module +from cx_core.controller import Controller -def test_get_integrations(fake_controller): +def test_get_integrations(fake_controller: Controller): integrations = integration_module.get_integrations(fake_controller, {}) inteagration_names = {i.name for i in integrations} assert inteagration_names == {"z2m", "zha", "deconz", "state", "mqtt"} diff --git a/tests/unit_tests/cx_core/integration/z2m_test.py b/tests/unit_tests/cx_core/integration/z2m_test.py index 373545d6..1c41088e 100644 --- a/tests/unit_tests/cx_core/integration/z2m_test.py +++ b/tests/unit_tests/cx_core/integration/z2m_test.py @@ -1,7 +1,9 @@ -from typing import Any -import pytest +from typing import Any, Dict +import pytest +from cx_core.controller import Controller from cx_core.integration.z2m import Z2MIntegration +from pytest_mock import MockerFixture @pytest.mark.parametrize( @@ -10,18 +12,18 @@ ({"payload": '{"event_1": "action_1"}'}, "event_1", True, "action_1"), ({}, None, False, Any), ({"payload": '{"action": "action_1"}'}, None, True, "action_1"), - ({"payload": '{"event_1": "action_1"}'}, "event_2", False, Any), - ({"payload": '{"action_rate": 195}'}, "action", False, Any), + ({"payload": '{"event_1": "action_1"}'}, "event_2", False, "Any"), + ({"payload": '{"action_rate": 195}'}, "action", False, "Any"), ], ) @pytest.mark.asyncio async def test_event_callback( - fake_controller, - mocker, - data, - action_key, - handle_action_called, - expected_called_with, + fake_controller: Controller, + mocker: MockerFixture, + data: Dict, + action_key: str, + handle_action_called: bool, + expected_called_with: str, ): handle_action_patch = mocker.patch.object(fake_controller, "handle_action") z2m_integration = Z2MIntegration(fake_controller, {}) diff --git a/tests/unit_tests/cx_core/integration/zha_test.py b/tests/unit_tests/cx_core/integration/zha_test.py index 5a78e87c..a7c004f4 100644 --- a/tests/unit_tests/cx_core/integration/zha_test.py +++ b/tests/unit_tests/cx_core/integration/zha_test.py @@ -1,6 +1,9 @@ -import pytest +from typing import Dict +import pytest +from cx_core.controller import Controller from cx_core.integration.zha import ZHAIntegration +from pytest_mock.plugin import MockerFixture @pytest.mark.parametrize( @@ -31,7 +34,11 @@ ) @pytest.mark.asyncio async def test_get_integrations( - fake_controller, mocker, command, args, expected_called_with + fake_controller: Controller, + mocker: MockerFixture, + command: str, + args: Dict, + expected_called_with: str, ): data = {"command": command, "args": args} handle_action_patch = mocker.patch.object(fake_controller, "handle_action") diff --git a/tests/unit_tests/cx_core/release_hold_controller_test.py b/tests/unit_tests/cx_core/release_hold_controller_test.py index 4c464558..4f2a512a 100644 --- a/tests/unit_tests/cx_core/release_hold_controller_test.py +++ b/tests/unit_tests/cx_core/release_hold_controller_test.py @@ -1,6 +1,8 @@ import pytest - +from _pytest.monkeypatch import MonkeyPatch from cx_core.controller import Controller, ReleaseHoldController +from pytest_mock import MockerFixture + from tests.test_utils import fake_fn @@ -8,35 +10,38 @@ class FakeReleaseHoldController(ReleaseHoldController): def hold_loop(self): pass + def default_delay(self) -> int: + return 500 + @pytest.fixture -def sut(hass_mock): - c = FakeReleaseHoldController() # type: ignore - c.args = {} - c.delay = 0 - c.hold_release_toggle = False - return c +def sut_before_init(mocker: MockerFixture) -> FakeReleaseHoldController: + controller = FakeReleaseHoldController() # type: ignore + controller.args = {} + mocker.patch.object(Controller, "initialize") + mocker.patch.object(controller, "sleep") + return controller +@pytest.fixture @pytest.mark.asyncio -async def test_initialize(sut, monkeypatch): - monkeypatch.setattr(Controller, "initialize", fake_fn(async_=True)) - monkeypatch.setattr(sut, "default_delay", lambda: 500) - monkeypatch.setattr(sut, "sleep", lambda time: None) - # SUT - await sut.initialize() +async def sut(sut_before_init: FakeReleaseHoldController) -> FakeReleaseHoldController: + await sut_before_init.initialize() + return sut_before_init - assert sut.delay == 500 + +@pytest.mark.asyncio +async def test_initialize( + sut_before_init: FakeReleaseHoldController, mocker: MockerFixture +): + await sut_before_init.initialize() + assert sut_before_init.delay == 500 @pytest.mark.asyncio -async def test_release(sut): +async def test_release(sut: FakeReleaseHoldController): sut.on_hold = True - - # SUT await sut.release() - - # Checks assert not sut.on_hold @@ -46,19 +51,18 @@ async def test_release(sut): ) @pytest.mark.asyncio async def test_hold( - sut, monkeypatch, mocker, on_hold_input, hold_release_toogle, expected_calls + sut: FakeReleaseHoldController, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + on_hold_input: bool, + hold_release_toogle: bool, + expected_calls: int, ): sut.on_hold = on_hold_input sut.hold_release_toggle = hold_release_toogle - - async def fake_hold_loop(): - return True - - monkeypatch.setattr(sut, "hold_loop", fake_hold_loop) + monkeypatch.setattr(sut, "hold_loop", fake_fn(to_return=True, async_=True)) hold_loop_patch = mocker.patch.object(sut, "hold_loop") - # SUT await sut.hold() - # Checks assert hold_loop_patch.call_count == expected_calls diff --git a/tests/unit_tests/cx_core/stepper/circular_stepper_test.py b/tests/unit_tests/cx_core/stepper/circular_stepper_test.py index 94eb179f..bf6159e4 100644 --- a/tests/unit_tests/cx_core/stepper/circular_stepper_test.py +++ b/tests/unit_tests/cx_core/stepper/circular_stepper_test.py @@ -1,7 +1,9 @@ -import pytest +from typing import Tuple -from cx_core.stepper.circular_stepper import CircularStepper +import pytest from cx_core.stepper import Stepper +from cx_core.stepper.circular_stepper import CircularStepper +from typing_extensions import Literal @pytest.mark.parametrize( @@ -19,11 +21,13 @@ ((0, 10), 4, 5, Stepper.UP, 6), ], ) -def test_minmax_stepper(minmax, value, steps, direction, expected_value): +def test_minmax_stepper( + minmax: Tuple[int, int], + value: int, + steps: int, + direction: Literal["up", "down"], + expected_value: int, +): stepper = CircularStepper(*minmax, steps) - - # SUT new_value, _ = stepper.step(value, direction) - - # Checks assert new_value == expected_value diff --git a/tests/unit_tests/cx_core/stepper/minmax_stepper_test.py b/tests/unit_tests/cx_core/stepper/minmax_stepper_test.py index 27be1f32..a8d6709b 100644 --- a/tests/unit_tests/cx_core/stepper/minmax_stepper_test.py +++ b/tests/unit_tests/cx_core/stepper/minmax_stepper_test.py @@ -1,7 +1,9 @@ -import pytest +from typing import Tuple -from cx_core.stepper.minmax_stepper import MinMaxStepper +import pytest from cx_core.stepper import Stepper +from cx_core.stepper.minmax_stepper import MinMaxStepper +from typing_extensions import Literal @pytest.mark.parametrize( @@ -103,12 +105,12 @@ ], ) def test_minmax_stepper_get_direction( - minmax, - value, - direction, - previous_direction, - expected_direction, - expected_new_previous_direction, + minmax: Tuple[int, int], + value: int, + direction: str, + previous_direction: str, + expected_direction: str, + expected_new_previous_direction: str, ): stepper = MinMaxStepper(*minmax, 10) stepper.previous_direction = previous_direction @@ -137,13 +139,16 @@ def test_minmax_stepper_get_direction( ], ) def test_minmax_stepper_step( - minmax, value, steps, direction, expected_value, expected_exceeded + minmax: Tuple[int, int], + value: int, + steps: int, + direction: Literal["up", "down"], + expected_value: int, + expected_exceeded: bool, ): stepper = MinMaxStepper(*minmax, steps) - # SUT new_value, exceeded = stepper.step(value, direction) - # Checks assert new_value == expected_value assert exceeded == expected_exceeded diff --git a/tests/unit_tests/cx_core/stepper/stepper_test.py b/tests/unit_tests/cx_core/stepper/stepper_test.py index 5b042d2c..d0a44a57 100644 --- a/tests/unit_tests/cx_core/stepper/stepper_test.py +++ b/tests/unit_tests/cx_core/stepper/stepper_test.py @@ -1,11 +1,12 @@ -import pytest +from typing import Tuple, Union +import pytest from cx_core.stepper import Stepper class FakeStepper(Stepper): - def step(self, value, direction): - pass + def step(self, value: float, direction: str) -> Tuple[Union[int, float], bool]: + return 0, True @pytest.mark.parametrize( @@ -19,14 +20,14 @@ def step(self, value, direction): (Stepper.TOGGLE, Stepper.TOGGLE_DOWN, Stepper.TOGGLE_UP), ], ) -def test_get_direction(direction_input, previous_direction, expected_direction): +def test_get_direction( + direction_input: str, previous_direction: str, expected_direction: str +): stepper = FakeStepper() stepper.previous_direction = previous_direction - # SUT direction_output = stepper.get_direction(0, direction_input) - # Checks assert direction_output == expected_direction @@ -39,11 +40,7 @@ def test_get_direction(direction_input, previous_direction, expected_direction): (Stepper.TOGGLE_DOWN, -1), ], ) -def test_sign(direction_input, expected_sign): +def test_sign(direction_input: str, expected_sign: int): stepper = FakeStepper() - - # SUT sign_output = stepper.sign(direction_input) - - # Checks assert sign_output == expected_sign diff --git a/tests/unit_tests/cx_core/type/cover_controller_test.py b/tests/unit_tests/cx_core/type/cover_controller_test.py index 4693197f..2e19ef97 100644 --- a/tests/unit_tests/cx_core/type/cover_controller_test.py +++ b/tests/unit_tests/cx_core/type/cover_controller_test.py @@ -1,20 +1,34 @@ -from cx_core.feature_support.cover import CoverSupport -import pytest +from typing import Any, Dict, Set -from cx_core.controller import TypeController +import pytest +from _pytest.monkeypatch import MonkeyPatch from cx_core import CoverController -from tests.test_utils import fake_fn +from cx_core.controller import Controller +from cx_core.feature_support.cover import CoverSupport +from cx_core.type_controller import TypeController +from pytest_mock.plugin import MockerFixture + +from tests.test_utils import fake_fn, wrap_exetuction + +ENTITY_NAME = "cover.test" @pytest.fixture @pytest.mark.asyncio -async def sut(hass_mock, mocker): - c = CoverController() # type: ignore +async def sut_before_init(mocker: MockerFixture) -> CoverController: + controller = CoverController() # type: ignore mocker.patch.object(TypeController, "initialize") - c.cover = "cover.test" - c.open_position = 100 - c.close_position = 0 - return c + return controller + + +@pytest.fixture +@pytest.mark.asyncio +async def sut(mocker: MockerFixture) -> CoverController: + controller = CoverController() # type: ignore + mocker.patch.object(Controller, "initialize") + controller.args = {"cover": ENTITY_NAME} + await controller.initialize() + return controller @pytest.mark.parametrize( @@ -29,20 +43,22 @@ async def sut(hass_mock, mocker): ) @pytest.mark.asyncio async def test_initialize( - sut, monkeypatch, open_position, close_position, error_expected + sut_before_init: CoverController, + open_position: int, + close_position: int, + error_expected: bool, ): - sut.args = { - "cover": "cover.test2", + sut_before_init.args = { "open_position": open_position, "close_position": close_position, } - monkeypatch.setattr(sut, "get_entity_state", fake_fn(async_=True, to_return="0")) - if error_expected: - with pytest.raises(ValueError): - await sut.initialize() - else: - await sut.initialize() - assert sut.cover == "cover.test2" + + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + await sut_before_init.initialize() + + if not error_expected: + assert sut_before_init.open_position == open_position + assert sut_before_init.close_position == close_position @pytest.mark.parametrize( @@ -59,12 +75,19 @@ async def test_initialize( ], ) @pytest.mark.asyncio -async def test_open(sut, mocker, supported_features, expected_service): - sut.supported_features = CoverSupport(sut.cover, sut, False) - sut.supported_features._supported_features = set(supported_features) +async def test_open( + sut: CoverController, + mocker: MockerFixture, + supported_features: Set[int], + expected_service: str, +): + sut.feature_support._supported_features = set(supported_features) called_service_patch = mocker.patch.object(sut, "call_service") + await sut.open() + if expected_service is not None: + expected_attributes: Dict[str, Any] if expected_service == "cover/open_cover": expected_attributes = {"entity_id": "cover.test"} elif expected_service == "cover/set_cover_position": @@ -92,12 +115,19 @@ async def test_open(sut, mocker, supported_features, expected_service): ], ) @pytest.mark.asyncio -async def test_close(sut, mocker, supported_features, expected_service): - sut.supported_features = CoverSupport(sut.cover, sut, False) - sut.supported_features._supported_features = set(supported_features) +async def test_close( + sut: CoverController, + mocker: MockerFixture, + supported_features: Set[int], + expected_service: str, +): + sut.feature_support._supported_features = set(supported_features) called_service_patch = mocker.patch.object(sut, "call_service") + await sut.close() + if expected_service is not None: + expected_attributes: Dict[str, Any] if expected_service == "cover/close_cover": expected_attributes = {"entity_id": "cover.test"} elif expected_service == "cover/set_cover_position": @@ -112,11 +142,13 @@ async def test_close(sut, mocker, supported_features, expected_service): @pytest.mark.asyncio -async def test_stop(sut, mocker): +async def test_stop(sut: CoverController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.stop() + called_service_patch.assert_called_once_with( - "cover/stop_cover", entity_id=sut.cover + "cover/stop_cover", entity_id=ENTITY_NAME ) @@ -125,16 +157,25 @@ async def test_stop(sut, mocker): [("opening", True), ("closing", True), ("open", False), ("close", False)], ) @pytest.mark.asyncio -async def test_toggle(sut, monkeypatch, mocker, cover_state, stop_expected): +async def test_toggle( + sut: CoverController, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + cover_state: str, + stop_expected: bool, +): called_service_patch = mocker.patch.object(sut, "call_service") open_patch = mocker.patch.object(sut, "open") monkeypatch.setattr( sut, "get_entity_state", fake_fn(async_=True, to_return=cover_state) ) + await sut.toggle(open_patch) + if stop_expected: called_service_patch.assert_called_once_with( - "cover/stop_cover", entity_id=sut.cover + "cover/stop_cover", entity_id=ENTITY_NAME ) + open_patch.assert_not_called() else: open_patch.assert_called_once() diff --git a/tests/unit_tests/cx_core/type/light_controller_test.py b/tests/unit_tests/cx_core/type/light_controller_test.py index 6f8f4851..77a50a80 100644 --- a/tests/unit_tests/cx_core/type/light_controller_test.py +++ b/tests/unit_tests/cx_core/type/light_controller_test.py @@ -1,75 +1,88 @@ -from cx_core.color_helper import get_color_wheel -import pytest +from typing import Any, Dict, Set, Tuple, Type, Union +import pytest +from _pytest.monkeypatch import MonkeyPatch from cx_core import LightController, ReleaseHoldController +from cx_core.controller import Controller from cx_core.feature_support.light import LightSupport from cx_core.stepper import Stepper from cx_core.stepper.circular_stepper import CircularStepper from cx_core.stepper.minmax_stepper import MinMaxStepper -from tests.test_utils import fake_fn +from cx_core.type.light_controller import ColorMode, LightEntity +from pytest_mock.plugin import MockerFixture +from typing_extensions import Literal + +from tests.test_utils import fake_fn, wrap_exetuction + +ENTITY_NAME = "light.test" @pytest.fixture -def sut(hass_mock, monkeypatch): - c = LightController() # type: ignore - c.args = {} - c.delay = 0 - c.light = {"name": "light"} - c.on_hold = False +@pytest.mark.asyncio +async def sut_before_init(mocker: MockerFixture) -> LightController: + controller = LightController() # type: ignore + controller.args = {} + mocker.patch.object(Controller, "initialize") + return controller - monkeypatch.setattr(c, "get_entity_state", fake_fn(async_=True, to_return="0")) - return c + +@pytest.fixture +@pytest.mark.asyncio +async def sut(mocker: MockerFixture) -> LightController: + controller = LightController() # type: ignore + mocker.patch.object(Controller, "initialize") + controller.args = {"light": ENTITY_NAME} + await controller.initialize() + return controller @pytest.mark.parametrize( - "light_input, light_output, error_expected", + "light_input, expected_name, expected_color_mode, error_expected", [ - ("light.kitchen", {"name": "light.kitchen", "color_mode": "auto"}, False), + ("light.kitchen", "light.kitchen", "auto", False), ( {"name": "light.kitchen", "color_mode": "auto"}, - {"name": "light.kitchen", "color_mode": "auto"}, + "light.kitchen", + "auto", False, ), ( {"name": "light.kitchen"}, - {"name": "light.kitchen", "color_mode": "auto"}, + "light.kitchen", + "auto", False, ), ( {"name": "light.kitchen", "color_mode": "color_temp"}, - {"name": "light.kitchen", "color_mode": "color_temp"}, + "light.kitchen", + "color_temp", False, ), - (0.0, None, True), + (0.0, None, None, True), ], ) @pytest.mark.asyncio -async def test_initialize_and_get_light( - sut, monkeypatch, mocker, light_input, light_output, error_expected +async def test_initialize( + sut_before_init: LightController, + light_input: Union[str, Dict[str, str]], + expected_name: str, + expected_color_mode: str, + error_expected: bool, ): - super_initialize_stub = mocker.stub() - - async def fake_super_initialize(self): - super_initialize_stub() - - monkeypatch.setattr(ReleaseHoldController, "initialize", fake_super_initialize) - - sut.args["light"] = light_input + sut_before_init.args["light"] = light_input # SUT - if error_expected: - with pytest.raises(ValueError): - await sut.initialize() - else: - await sut.initialize() + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + await sut_before_init.initialize() - # Checks - super_initialize_stub.assert_called_once() - assert sut.light == light_output + # Checks + if not error_expected: + assert sut_before_init.entity.name == expected_name + assert sut_before_init.entity.color_mode == expected_color_mode @pytest.mark.parametrize( - "attribute_input, color_mode, supported_features, attribute_expected, throws_error", + "attribute_input, color_mode, supported_features, expected_attribute, error_expected", [ ("color", "auto", {LightSupport.COLOR}, "xy_color", False), ("color", "auto", {LightSupport.COLOR_TEMP}, "color_temp", False), @@ -102,53 +115,47 @@ async def fake_super_initialize(self): @pytest.mark.asyncio async def test_get_attribute( sut: LightController, - monkeypatch, - attribute_input, - color_mode, - supported_features, - attribute_expected, - throws_error, + attribute_input: str, + color_mode: ColorMode, + supported_features: Set[int], + expected_attribute: str, + error_expected: bool, ): - sut.supported_features = LightSupport("fake_entity", sut, False) - sut.supported_features._supported_features = supported_features - sut.light = {"name": "light", "color_mode": color_mode} + sut.feature_support._supported_features = supported_features + sut.entity = LightEntity(name=ENTITY_NAME, color_mode=color_mode) - # SUT - if throws_error: - with pytest.raises(ValueError): - await sut.get_attribute(attribute_input) - else: + with wrap_exetuction(error_expected=error_expected, exception=ValueError): output = await sut.get_attribute(attribute_input) - # Checks - assert output == attribute_expected + if not error_expected: + assert output == expected_attribute @pytest.mark.parametrize( "attribute_input, direction_input, light_state, expected_output, error_expected", [ - ("xy_color", None, None, 0, False), - ("brightness", None, None, 3.0, False), - ("brightness", None, None, "3.0", False), - ("brightness", None, None, "3", False), - ("brightness", None, None, "error", True), - ("color_temp", None, None, 1, False), - ("xy_color", None, None, 0, False), - ("brightness", None, None, None, True), - ("color_temp", None, None, None, True), - ("not_a_valid_attribute", None, None, None, True), + ("xy_color", Stepper.DOWN, "any", 0, False), + ("brightness", Stepper.DOWN, "any", 3.0, False), + ("brightness", Stepper.DOWN, "any", "3.0", False), + ("brightness", Stepper.DOWN, "any", "3", False), + ("color_temp", Stepper.DOWN, "any", 1, False), + ("xy_color", Stepper.DOWN, "any", 0, False), ("brightness", Stepper.UP, "off", 0, False), + ("brightness", Stepper.DOWN, "any", "error", True), + ("brightness", Stepper.DOWN, "any", None, True), + ("color_temp", Stepper.DOWN, "any", None, True), + ("not_a_valid_attribute", Stepper.DOWN, "any", None, True), ], ) @pytest.mark.asyncio async def test_get_value_attribute( - sut, - monkeypatch, - attribute_input, - direction_input, - light_state, - expected_output, - error_expected, + sut: LightController, + monkeypatch: MonkeyPatch, + attribute_input: str, + direction_input: Literal["up", "down"], + light_state: str, + expected_output: Union[int, float, str], + error_expected: bool, ): sut.smooth_power_on = True @@ -159,19 +166,15 @@ async def fake_get_entity_state(entity, attribute=None): monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) - # SUT - if error_expected: - with pytest.raises(ValueError): - await sut.get_value_attribute(attribute_input, direction_input) - else: + with wrap_exetuction(error_expected=error_expected, exception=ValueError): output = await sut.get_value_attribute(attribute_input, direction_input) - # Checks + if not error_expected: assert output == float(expected_output) @pytest.mark.parametrize( - "old, attribute, direction, stepper, light_state, smooth_power_on, expected_stop, expected_value_attribute", + "old, attribute, direction, stepper, light_state, smooth_power_on, stop_expected, expected_value_attribute", [ ( 50, @@ -209,39 +212,30 @@ async def fake_get_entity_state(entity, attribute=None): @pytest.mark.asyncio async def test_change_light_state( sut: LightController, - mocker, - monkeypatch, - old, - attribute, - direction, - stepper, - light_state, - smooth_power_on, - expected_stop, - expected_value_attribute, + mocker: MockerFixture, + monkeypatch: MonkeyPatch, + old: int, + attribute: str, + direction: Literal["up", "down"], + stepper: MinMaxStepper, + light_state: str, + smooth_power_on: bool, + stop_expected: bool, + expected_value_attribute: int, ): - async def fake_get_entity_state(*args, **kwargs): - return light_state - called_service_patch = mocker.patch.object(sut, "call_service") sut.smooth_power_on = smooth_power_on sut.value_attribute = old sut.manual_steppers = {attribute: stepper} sut.automatic_steppers = {attribute: stepper} - sut.transition = 300 - sut.add_transition = True - sut.add_transition_turn_toggle = False - sut.supported_features = LightSupport("fake_entity", sut, False) - sut.supported_features._supported_features = set() - sut.color_wheel = get_color_wheel("default_color_wheel") - - monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) + sut.feature_support._supported_features = set() + monkeypatch.setattr( + sut, "get_entity_state", fake_fn(to_return=light_state, async_=True) + ) - # SUT stop = await sut.change_light_state(old, attribute, direction, stepper, "hold") - # Checks - assert stop == expected_stop + assert stop == stop_expected assert sut.value_attribute == expected_value_attribute called_service_patch.assert_called() @@ -288,26 +282,25 @@ async def fake_get_entity_state(*args, **kwargs): @pytest.mark.asyncio async def test_call_light_service( sut: LightController, - mocker, - attributes_input, - transition_support, - turned_toggle, - add_transition, - add_transition_turn_toggle, - attributes_expected, + mocker: MockerFixture, + attributes_input: Dict[str, str], + transition_support: bool, + turned_toggle: bool, + add_transition: bool, + add_transition_turn_toggle: bool, + attributes_expected: Dict[str, str], ): called_service_patch = mocker.patch.object(sut, "call_service") sut.transition = 300 sut.add_transition = add_transition sut.add_transition_turn_toggle = add_transition_turn_toggle supported_features = {LightSupport.TRANSITION} if transition_support else set() - sut.supported_features = LightSupport("fake_entity", sut, False) - sut.supported_features._supported_features = supported_features + sut.feature_support._supported_features = supported_features await sut.call_light_service( "test_service", turned_toggle=turned_toggle, **attributes_input ) called_service_patch.assert_called_once_with( - "test_service", entity_id=sut.light["name"], **attributes_expected + "test_service", entity_id=ENTITY_NAME, **attributes_expected ) @@ -317,9 +310,12 @@ async def test_call_light_service( ) @pytest.mark.asyncio async def test_on( - sut, mocker, monkeypatch, light_on, light_state, expected_turned_toggle + sut: LightController, + mocker: MockerFixture, + light_on: bool, + light_state: str, + expected_turned_toggle: bool, ): - monkeypatch.setattr(sut, "call_light_service", fake_fn(async_=True)) mocker.patch.object( sut, "get_entity_state", fake_fn(async_=True, to_return=light_state) ) @@ -327,30 +323,31 @@ async def test_on( attributes = {"test": 0} await sut.on(light_on=light_on, **attributes) + call_light_service_patch.assert_called_once_with( "light/turn_on", turned_toggle=expected_turned_toggle, **attributes ) @pytest.mark.asyncio -async def test_off(sut, mocker, monkeypatch): - monkeypatch.setattr(sut, "call_light_service", fake_fn(async_=True)) +async def test_off(sut: LightController, mocker: MockerFixture): call_light_service_patch = mocker.patch.object(sut, "call_light_service") attributes = {"test": 0} await sut.off(**attributes) + call_light_service_patch.assert_called_once_with( "light/turn_off", turned_toggle=True, **attributes ) @pytest.mark.asyncio -async def test_toggle(sut, mocker, monkeypatch): - monkeypatch.setattr(sut, "call_light_service", fake_fn(async_=True)) +async def test_toggle(sut: LightController, mocker: MockerFixture): call_light_service_patch = mocker.patch.object(sut, "call_light_service") attributes = {"test": 0} await sut.toggle(**attributes) + call_light_service_patch.assert_called_once_with( "light/toggle", turned_toggle=True, **attributes ) @@ -365,17 +362,21 @@ async def test_toggle(sut, mocker, monkeypatch): ], ) @pytest.mark.asyncio -async def test_toggle_full(sut, mocker, attribute, stepper, expected_attribute_value): - sut.light = {"name": "test_light"} - sut.transition = 300 - sut.add_transition = False +async def test_toggle_full( + sut: LightController, + mocker: MockerFixture, + attribute: str, + stepper: MinMaxStepper, + expected_attribute_value: int, +): call_service_patch = mocker.patch.object(sut, "call_service") sut.automatic_steppers = {attribute: stepper} await sut.toggle_full(attribute) + call_service_patch.assert_called_once_with( "light/toggle", - **{"entity_id": "test_light", attribute: expected_attribute_value} + **{"entity_id": ENTITY_NAME, attribute: expected_attribute_value} ) @@ -388,17 +389,21 @@ async def test_toggle_full(sut, mocker, attribute, stepper, expected_attribute_v ], ) @pytest.mark.asyncio -async def test_toggle_min(sut, mocker, attribute, stepper, expected_attribute_value): - sut.light = {"name": "test_light"} - sut.transition = 300 - sut.add_transition = False +async def test_toggle_min( + sut: LightController, + mocker: MockerFixture, + attribute: str, + stepper: MinMaxStepper, + expected_attribute_value: int, +): call_service_patch = mocker.patch.object(sut, "call_service") sut.automatic_steppers = {attribute: stepper} await sut.toggle_min(attribute) + call_service_patch.assert_called_once_with( "light/toggle", - **{"entity_id": "test_light", attribute: expected_attribute_value} + **{"entity_id": ENTITY_NAME, attribute: expected_attribute_value} ) @@ -416,49 +421,49 @@ async def test_toggle_min(sut, mocker, attribute, stepper, expected_attribute_va ) @pytest.mark.asyncio async def test_set_value( - sut, mocker, stepper_cls, min_max, fraction, expected_calls, expected_value + sut: LightController, + mocker: MockerFixture, + stepper_cls: Type[Union[MinMaxStepper, CircularStepper]], + min_max: Tuple[int, int], + fraction: float, + expected_calls: int, + expected_value: int, ): attribute = "test_attribute" on_patch = mocker.patch.object(sut, "on") stepper = stepper_cls(min_max[0], min_max[1], 1) sut.automatic_steppers = {attribute: stepper} - # SUT await sut.set_value(attribute, fraction, light_on=False) - # Checks assert on_patch.call_count == expected_calls if expected_calls > 0: on_patch.assert_called_with(light_on=False, **{attribute: expected_value}) @pytest.mark.asyncio -async def test_on_full(sut, mocker): +async def test_on_full(sut: LightController, mocker: MockerFixture): attribute = "test_attribute" max_ = 10 on_patch = mocker.patch.object(sut, "on") stepper = MinMaxStepper(1, max_, 10) sut.automatic_steppers = {attribute: stepper} - # SUT await sut.on_full(attribute, light_on=False) - # Checks on_patch.assert_called_once_with(light_on=False, **{attribute: max_}) @pytest.mark.asyncio -async def test_on_min(sut, mocker): +async def test_on_min(sut: LightController, mocker: MockerFixture): attribute = "test_attribute" min_ = 1 on_patch = mocker.patch.object(sut, "on") stepper = MinMaxStepper(min_, 10, 10) sut.automatic_steppers = {attribute: stepper} - # SUT await sut.on_min(attribute, light_on=False) - # Checks on_patch.assert_called_once_with(light_on=False, **{attribute: min_}) @@ -473,19 +478,15 @@ async def test_on_min(sut, mocker): @pytest.mark.asyncio async def test_sync( sut: LightController, - monkeypatch, - mocker, - max_brightness, - color_attribute, - expected_attributes, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + max_brightness: int, + color_attribute: str, + expected_attributes: Dict[str, Any], ): sut.max_brightness = max_brightness - sut.light = {"name": "test_light"} - sut.transition = 300 - sut.add_transition = True sut.add_transition_turn_toggle = True - sut.supported_features = LightSupport("fake_entity", sut, False) - sut.supported_features._supported_features = {LightSupport.TRANSITION} + sut.feature_support._supported_features = {LightSupport.TRANSITION} async def fake_get_attribute(*args, **kwargs): if color_attribute == "error": @@ -493,13 +494,14 @@ async def fake_get_attribute(*args, **kwargs): return color_attribute monkeypatch.setattr(sut, "get_attribute", fake_get_attribute) + monkeypatch.setattr(sut, "get_entity_state", fake_fn(async_=True, to_return="on")) called_service_patch = mocker.patch.object(sut, "call_service") await sut.sync() called_service_patch.assert_called_once_with( "light/turn_on", - entity_id="test_light", + entity_id=ENTITY_NAME, **{"transition": 0.3, **expected_attributes} ) @@ -508,44 +510,38 @@ async def fake_get_attribute(*args, **kwargs): "attribute_input, direction_input, light_state, smooth_power_on, expected_calls", [ (LightController.ATTRIBUTE_BRIGHTNESS, Stepper.UP, "off", True, 1), - ("color_temp", Stepper.UP, "off", True, 0), - ("color_temp", Stepper.UP, "on", True, 1), + (LightController.ATTRIBUTE_COLOR_TEMP, Stepper.UP, "off", True, 0), + (LightController.ATTRIBUTE_COLOR_TEMP, Stepper.UP, "on", True, 1), ], ) @pytest.mark.asyncio async def test_click( - sut, - monkeypatch, - mocker, - attribute_input, - direction_input, - light_state, - smooth_power_on, - expected_calls, + sut: LightController, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + attribute_input: str, + direction_input: Literal["up", "down"], + light_state: Literal["on", "off"], + smooth_power_on: bool, + expected_calls: int, ): value_attribute = 10 - - async def fake_get_entity_state(*args, **kwargs): - return light_state - - async def fake_get_value_attribute(*args, **kwargs): - return value_attribute - - async def fake_get_attribute(*args, **kwargs): - return attribute_input - - monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) - monkeypatch.setattr(sut, "get_value_attribute", fake_get_value_attribute) - monkeypatch.setattr(sut, "get_attribute", fake_get_attribute) + monkeypatch.setattr( + sut, "get_entity_state", fake_fn(to_return=light_state, async_=True) + ) + monkeypatch.setattr( + sut, "get_value_attribute", fake_fn(to_return=value_attribute, async_=True) + ) + monkeypatch.setattr( + sut, "get_attribute", fake_fn(to_return=attribute_input, async_=True) + ) change_light_state_patch = mocker.patch.object(sut, "change_light_state") sut.smooth_power_on = smooth_power_on stepper = MinMaxStepper(1, 10, 10) sut.manual_steppers = {attribute_input: stepper} - # SUT await sut.click(attribute_input, direction_input) - # Checks assert change_light_state_patch.call_count == expected_calls @@ -576,41 +572,35 @@ async def fake_get_attribute(*args, **kwargs): ) @pytest.mark.asyncio async def test_hold( - sut, - monkeypatch, - mocker, - attribute_input, - direction_input, - previous_direction, - light_state, - smooth_power_on, - expected_calls, - expected_direction, + sut: LightController, + monkeypatch: MonkeyPatch, + mocker: MockerFixture, + attribute_input: str, + direction_input: str, + previous_direction: str, + light_state: Literal["on", "off"], + smooth_power_on: bool, + expected_calls: int, + expected_direction: str, ): value_attribute = 10 - - async def fake_get_entity_state(*args, **kwargs): - return light_state - - async def fake_get_value_attribute(*args, **kwargs): - return value_attribute - - async def fake_get_attribute(*args, **kwargs): - return attribute_input - - monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) - monkeypatch.setattr(sut, "get_value_attribute", fake_get_value_attribute) - monkeypatch.setattr(sut, "get_attribute", fake_get_attribute) + monkeypatch.setattr( + sut, "get_entity_state", fake_fn(to_return=light_state, async_=True) + ) + monkeypatch.setattr( + sut, "get_value_attribute", fake_fn(to_return=value_attribute, async_=True) + ) + monkeypatch.setattr( + sut, "get_attribute", fake_fn(to_return=attribute_input, async_=True) + ) sut.smooth_power_on = smooth_power_on stepper = MinMaxStepper(1, 10, 10) stepper.previous_direction = previous_direction sut.automatic_steppers = {attribute_input: stepper} super_hold_patch = mocker.patch.object(ReleaseHoldController, "hold") - # SUT await sut.hold(attribute_input, direction_input) - # Checks assert super_hold_patch.call_count == expected_calls if expected_calls > 0: super_hold_patch.assert_called_with(attribute_input, expected_direction) @@ -618,7 +608,9 @@ async def fake_get_attribute(*args, **kwargs): @pytest.mark.parametrize("value_attribute", [10, None]) @pytest.mark.asyncio -async def test_hold_loop(sut, mocker, value_attribute): +async def test_hold_loop( + sut: LightController, mocker: MockerFixture, value_attribute: int +): attribute = "test_attribute" direction = Stepper.UP sut.value_attribute = value_attribute @@ -626,7 +618,6 @@ async def test_hold_loop(sut, mocker, value_attribute): stepper = MinMaxStepper(1, 10, 10) sut.automatic_steppers = {attribute: stepper} - # SUT exceeded = await sut.hold_loop(attribute, direction) if value_attribute is None: diff --git a/tests/unit_tests/cx_core/type/media_player_controller_test.py b/tests/unit_tests/cx_core/type/media_player_controller_test.py index 8ee13ca0..3010e1bc 100644 --- a/tests/unit_tests/cx_core/type/media_player_controller_test.py +++ b/tests/unit_tests/cx_core/type/media_player_controller_test.py @@ -1,113 +1,123 @@ -import pytest +from typing import List +import pytest +from _pytest.monkeypatch import MonkeyPatch from cx_core import MediaPlayerController, ReleaseHoldController +from cx_core.controller import Controller from cx_core.feature_support.media_player import MediaPlayerSupport from cx_core.stepper import Stepper +from pytest_mock.plugin import MockerFixture +from typing_extensions import Literal + from tests.test_utils import fake_fn - -@pytest.fixture -@pytest.mark.asyncio -async def sut(monkeypatch, hass_mock, mocker): - c = MediaPlayerController() # type: ignore - c.args = {} - c.delay = 0 - c.media_player = "test" - c.on_hold = False - mocker.patch.object(ReleaseHoldController, "initialize") - c.args["media_player"] = "media_player.test" - monkeypatch.setattr(c, "get_entity_state", fake_fn(async_=True, to_return="0")) - await c.initialize() - return c +ENTITY_NAME = "media_player.test" +@pytest.fixture @pytest.mark.asyncio -async def test_initialize(sut): - await sut.initialize() - assert sut.media_player == "media_player.test" +async def sut(mocker: MockerFixture) -> MediaPlayerController: + controller = MediaPlayerController() # type: ignore + mocker.patch.object(Controller, "initialize") + controller.args = {"media_player": ENTITY_NAME} + await controller.initialize() + return controller @pytest.mark.asyncio -async def test_play_pause(sut, mocker): +async def test_play_pause(sut: MediaPlayerController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.play_pause() + called_service_patch.assert_called_once_with( - "media_player/media_play_pause", entity_id=sut.media_player + "media_player/media_play_pause", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_play(sut, mocker): +async def test_play(sut: MediaPlayerController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.play() + called_service_patch.assert_called_once_with( - "media_player/media_play", entity_id=sut.media_player + "media_player/media_play", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_pause(sut, mocker): +async def test_pause(sut: MediaPlayerController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.pause() + called_service_patch.assert_called_once_with( - "media_player/media_pause", entity_id=sut.media_player + "media_player/media_pause", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_previous_track(sut, mocker): +async def test_previous_track(sut: MediaPlayerController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.previous_track() + called_service_patch.assert_called_once_with( - "media_player/media_previous_track", entity_id=sut.media_player + "media_player/media_previous_track", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_next_track(sut, mocker): +async def test_next_track(sut: MediaPlayerController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") + await sut.next_track() + called_service_patch.assert_called_once_with( - "media_player/media_next_track", entity_id=sut.media_player + "media_player/media_next_track", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_volume_up(sut, mocker, monkeypatch): - async def fake_get_entity_state(entity, attribute=None): - return 0.5 - - monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) - sut.supported_features._supported_features = [MediaPlayerSupport.VOLUME_SET] - +async def test_volume_up( + sut: MediaPlayerController, mocker: MockerFixture, monkeypatch: MonkeyPatch +): + monkeypatch.setattr(sut, "get_entity_state", fake_fn(async_=True, to_return=0.5)) + sut.feature_support._supported_features = {MediaPlayerSupport.VOLUME_SET} called_service_patch = mocker.patch.object(sut, "call_service") + await sut.volume_up() + called_service_patch.assert_called_once_with( - "media_player/volume_set", entity_id=sut.media_player, volume_level=0.6 + "media_player/volume_set", entity_id=ENTITY_NAME, volume_level=0.6 ) @pytest.mark.asyncio -async def test_volume_down(sut, mocker, monkeypatch): - async def fake_get_entity_state(entity, attribute=None): - return 0.5 - - monkeypatch.setattr(sut, "get_entity_state", fake_get_entity_state) - sut.supported_features._supported_features = [MediaPlayerSupport.VOLUME_SET] - +async def test_volume_down( + sut: MediaPlayerController, mocker: MockerFixture, monkeypatch: MonkeyPatch +): + monkeypatch.setattr(sut, "get_entity_state", fake_fn(async_=True, to_return=0.5)) + sut.feature_support._supported_features = {MediaPlayerSupport.VOLUME_SET} called_service_patch = mocker.patch.object(sut, "call_service") + await sut.volume_down() + called_service_patch.assert_called_once_with( - "media_player/volume_set", entity_id=sut.media_player, volume_level=0.4 + "media_player/volume_set", entity_id=ENTITY_NAME, volume_level=0.4 ) @pytest.mark.asyncio -async def test_hold(sut, mocker): +async def test_hold(sut: MediaPlayerController, mocker: MockerFixture): direction = "test_direction" - mocker.patch.object(sut, "prepare_volume_change") + prepare_volume_change_patch = mocker.patch.object(sut, "prepare_volume_change") super_hold_patch = mocker.patch.object(ReleaseHoldController, "hold") + await sut.hold(direction) + + prepare_volume_change_patch.assert_called_once() super_hold_patch.assert_called_once_with(direction) @@ -122,29 +132,30 @@ async def test_hold(sut, mocker): ) @pytest.mark.asyncio async def test_hold_loop( - sut, - mocker, - monkeypatch, - direction_input, + sut: MediaPlayerController, + mocker: MockerFixture, + direction_input: Literal["up", "down"], volume_set_support, volume_level, expected_volume_level, ): called_service_patch = mocker.patch.object(sut, "call_service") - sut.supported_features._supported_features = ( - [MediaPlayerSupport.VOLUME_SET] if volume_set_support else [] + sut.feature_support._supported_features = ( + {MediaPlayerSupport.VOLUME_SET} if volume_set_support else set() ) sut.volume_level = volume_level + await sut.hold_loop(direction_input) + if volume_set_support: called_service_patch.assert_called_once_with( "media_player/volume_set", - entity_id=sut.media_player, + entity_id=ENTITY_NAME, volume_level=expected_volume_level, ) else: called_service_patch.assert_called_once_with( - f"media_player/volume_{direction_input}", entity_id=sut.media_player + f"media_player/volume_{direction_input}", entity_id=ENTITY_NAME ) @@ -164,14 +175,14 @@ async def test_hold_loop( ) @pytest.mark.asyncio async def test_change_source_list( - sut, - mocker, - monkeypatch, - direction_input, - source_list, - active_source, - expected_calls, - expected_source, + sut: MediaPlayerController, + mocker: MockerFixture, + monkeypatch: MonkeyPatch, + direction_input: Literal["up", "down"], + source_list: List[str], + active_source: str, + expected_calls: int, + expected_source: str, ): called_service_patch = mocker.patch.object(sut, "call_service") @@ -191,6 +202,6 @@ async def fake_get_entity_state(entity, attribute=None): if expected_calls > 0: called_service_patch.assert_called_once_with( "media_player/select_source", - entity_id=sut.media_player, + entity_id=ENTITY_NAME, source=expected_source, ) diff --git a/tests/unit_tests/cx_core/type/switch_controller_test.py b/tests/unit_tests/cx_core/type/switch_controller_test.py index f0fa5a99..5ae7a1e4 100644 --- a/tests/unit_tests/cx_core/type/switch_controller_test.py +++ b/tests/unit_tests/cx_core/type/switch_controller_test.py @@ -1,47 +1,41 @@ import pytest - from cx_core import SwitchController -from cx_core.controller import TypeController +from cx_core.type_controller import Entity +from pytest_mock.plugin import MockerFixture + +ENTITY_NAME = "switch.test" @pytest.fixture @pytest.mark.asyncio -async def sut(hass_mock, mocker): +async def sut(): c = SwitchController() # type: ignore - mocker.patch.object(TypeController, "initialize") - c.args = {"switch": "switch.test"} - await c.initialize() + c.entity = Entity(ENTITY_NAME) return c @pytest.mark.asyncio -async def test_initialize(sut): - await sut.initialize() - assert sut.switch == "switch.test" - - -@pytest.mark.asyncio -async def test_turn_on(sut, mocker): +async def test_turn_on(sut: SwitchController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") await sut.on() called_service_patch.assert_called_once_with( - "homeassistant/turn_on", entity_id=sut.switch + "homeassistant/turn_on", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_turn_off(sut, mocker): +async def test_turn_off(sut: SwitchController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") await sut.off() called_service_patch.assert_called_once_with( - "homeassistant/turn_off", entity_id=sut.switch + "homeassistant/turn_off", entity_id=ENTITY_NAME ) @pytest.mark.asyncio -async def test_toggle(sut, mocker): +async def test_toggle(sut: SwitchController, mocker: MockerFixture): called_service_patch = mocker.patch.object(sut, "call_service") await sut.toggle() called_service_patch.assert_called_once_with( - "homeassistant/toggle", entity_id=sut.switch + "homeassistant/toggle", entity_id=ENTITY_NAME ) diff --git a/tests/unit_tests/cx_core/type/type_test.py b/tests/unit_tests/cx_core/type/type_test.py index eaaf7802..9c162072 100644 --- a/tests/unit_tests/cx_core/type/type_test.py +++ b/tests/unit_tests/cx_core/type/type_test.py @@ -1,28 +1,40 @@ -from cx_core import Controller +from typing import Type + +import pytest +from cx_const import TypeActionsMapping from cx_core import type as type_module +from cx_core.type_controller import TypeController +from pytest_mock.plugin import MockerFixture + from tests.test_utils import get_classes -def check_mapping(mapping): +def check_mapping(mapping: TypeActionsMapping) -> None: if mapping is None: return - for k, v in mapping.items(): - if not (callable(v) or type(v) == tuple): + for v in mapping.values(): + if not (callable(v) or isinstance(v, tuple)): raise ValueError("The value mapping should be a callable or a tuple") - if type(v) == "tuple": + if isinstance(v, tuple): if len(v) == 0: raise ValueError( "The tuple should contain at least 1 element, the function" ) - fn, *args = v + fn, *_ = v if not callable(fn): raise ValueError("The first element of the tuple should be a callable") -def test_devices(hass_mock): - controller_types = get_classes( - type_module.__file__, type_module.__package__, Controller, instantiate=True - ) - for controller_type in controller_types: - mappings = controller_type.get_type_actions_mapping() - check_mapping(mappings) +controller_types = get_classes( + type_module.__file__, type_module.__package__, TypeController +) + + +@pytest.mark.parametrize("controller_type", controller_types) +def test_type_actions_mapping( + mocker: MockerFixture, controller_type: Type[TypeController] +): + controller = controller_type() # type: ignore + # mocker.patch.object(TypeController, "initialize") + mappings = controller.get_type_actions_mapping() + check_mapping(mappings) diff --git a/tests/unit_tests/cx_core/type_controller_test.py b/tests/unit_tests/cx_core/type_controller_test.py index 9d5bfd1d..9d66b9e7 100644 --- a/tests/unit_tests/cx_core/type_controller_test.py +++ b/tests/unit_tests/cx_core/type_controller_test.py @@ -1,22 +1,82 @@ +from typing import Any, Dict, List, Type + import pytest +from _pytest.monkeypatch import MonkeyPatch +from cx_core.controller import Controller +from cx_core.feature_support import FeatureSupport +from cx_core.type_controller import Entity, TypeController +from pytest_mock.plugin import MockerFixture + +from tests.test_utils import fake_fn, wrap_exetuction + +ENTITY_ARG = "my_entity" +ENTITY_NAME = "domain_1.test" +DEFAULT_ATTR_TEST = "my_default" + + +class MyEntity(Entity): + attr_test: str + + def __init__(self, name: str, attr_test: str = DEFAULT_ATTR_TEST) -> None: + super().__init__(name) + self.attr_test = attr_test + -from cx_core.controller import TypeController +class MyFeatureSupport(FeatureSupport): + features = [1, 2, 3, 4] -class FakeTypeController(TypeController): - def get_domain(self): - return "domain" +class MyTypeController(TypeController[MyEntity, MyFeatureSupport]): + + domains = ["domain_1", "domain_2"] + entity_arg = ENTITY_ARG + + def _get_entity_type(self) -> Type[MyEntity]: + return MyEntity + + def _get_feature_support_type(self) -> Type[MyFeatureSupport]: + return MyFeatureSupport @pytest.fixture -def sut(hass_mock): - c = FakeTypeController() # type: ignore - c.args = {} - return c +def sut_before_init(mocker: MockerFixture) -> MyTypeController: + controller = MyTypeController() # type: ignore + controller.args = {ENTITY_ARG: ENTITY_NAME} + mocker.patch.object(Controller, "initialize") + return controller + + +@pytest.fixture +@pytest.mark.asyncio +async def sut(sut_before_init: MyTypeController) -> MyTypeController: + await sut_before_init.initialize() + return sut_before_init -# All entities from '{entity}' must be from {domain} domain (e.g. {domain}.bedroom) -# '{entity}' must be from {domain} domain (e.g. {domain}.bedroom) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "args, error_expected", + [ + ({ENTITY_ARG: ENTITY_NAME}, False), + ({ENTITY_ARG: {"name": ENTITY_NAME, "attr_test": "my_attr"}}, False), + ({ENTITY_ARG: {"name": ENTITY_NAME}}, False), + ({ENTITY_ARG: "non_existing_domain.my_entity"}, True), + ], +) +async def test_initialize( + sut_before_init: MyTypeController, args: Dict[str, Any], error_expected: bool +): + sut_before_init.args = args + + with wrap_exetuction(error_expected=error_expected, exception=ValueError): + await sut_before_init.initialize() + + if not error_expected: + assert sut_before_init.entity.name == ENTITY_NAME + if isinstance(args[ENTITY_ARG], dict): + assert sut_before_init.entity.attr_test == args[ENTITY_ARG].get( + "attr_test", DEFAULT_ATTR_TEST + ) @pytest.mark.parametrize( @@ -47,8 +107,14 @@ def sut(hass_mock): ) @pytest.mark.asyncio async def test_check_domain( - sut, monkeypatch, entity, domains, entities, error_expected + sut: MyTypeController, + monkeypatch: MonkeyPatch, + entity: str, + domains: List[str], + entities: List[str], + error_expected: bool, ): + sut.domains = domains expected_error_message = "" if error_expected: if entities == []: @@ -63,19 +129,16 @@ async def test_check_domain( f"following domains {domains} (e.g. {domains[0]}.bedroom)" ) - async def fake_get_state(*args, **kwargs): - return entities + monkeypatch.setattr(sut, "get_state", fake_fn(to_return=entities, async_=True)) - monkeypatch.setattr(sut, "get_state", fake_get_state) - monkeypatch.setattr(sut, "get_domain", lambda *args: domains) - - if error_expected: - with pytest.raises(ValueError) as e: - await sut.check_domain(entity) - assert str(e.value) == expected_error_message - else: + with wrap_exetuction( + error_expected=error_expected, exception=ValueError + ) as err_info: await sut.check_domain(entity) + if err_info is not None: + assert str(err_info.value) == expected_error_message + @pytest.mark.parametrize( "entity_input, entities, expected_calls", @@ -87,7 +150,12 @@ async def fake_get_state(*args, **kwargs): ) @pytest.mark.asyncio async def test_get_entity_state( - sut, mocker, monkeypatch, entity_input, entities, expected_calls + sut: MyTypeController, + mocker: MockerFixture, + monkeypatch: MonkeyPatch, + entity_input: str, + entities: List[str], + expected_calls: int, ): stub_get_state = mocker.stub() @@ -97,14 +165,10 @@ async def fake_get_state(entity, attribute=None): monkeypatch.setattr(sut, "get_state", fake_get_state) - # SUT - if expected_calls is None: - with pytest.raises(ValueError): - await sut.get_entity_state(entity_input, "attribute_test") - else: + with wrap_exetuction(error_expected=expected_calls is None, exception=ValueError): await sut.get_entity_state(entity_input, "attribute_test") - # Checks + if expected_calls is not None: if expected_calls == 1: stub_get_state.assert_called_once_with( entity_input, attribute="attribute_test" diff --git a/tests/unit_tests/cx_devices/aqara_test.py b/tests/unit_tests/cx_devices/aqara_test.py index 3b8df270..3cb959e1 100644 --- a/tests/unit_tests/cx_devices/aqara_test.py +++ b/tests/unit_tests/cx_devices/aqara_test.py @@ -1,4 +1,5 @@ import pytest +from cx_core.integration import EventData from cx_devices.aqara import ( MFKZQ01LMLightController, WXKG01LMLightController, @@ -18,7 +19,7 @@ ({"command": "rotate_right"}, "rotate_right"), ], ) -def test_zha_action_MFKZQ01LMLightController(data, expected_action): +def test_zha_action_MFKZQ01LMLightController(data: EventData, expected_action: str): sut = MFKZQ01LMLightController() # type: ignore action = sut.get_zha_action(data) assert action == expected_action @@ -34,7 +35,7 @@ def test_zha_action_MFKZQ01LMLightController(data, expected_action): ({"command": "click", "args": {"click_type": "furious"}}, "furious"), ], ) -def test_zha_action_WXKG01LMLightController(data, expected_action): +def test_zha_action_WXKG01LMLightController(data: EventData, expected_action: str): sut = WXKG01LMLightController() # type: ignore action = sut.get_zha_action(data) assert action == expected_action @@ -50,7 +51,7 @@ def test_zha_action_WXKG01LMLightController(data, expected_action): ({"args": {"value": 4}}, "quadruple"), ], ) -def test_zha_action_WXKG11LMLightController(data, expected_action): - sut = WXKG11LMLightController() +def test_zha_action_WXKG11LMLightController(data: EventData, expected_action: str): + sut = WXKG11LMLightController() # type: ignore action = sut.get_zha_action(data) assert action == expected_action diff --git a/tests/unit_tests/cx_devices/devices_test.py b/tests/unit_tests/cx_devices/devices_test.py index d0bbb0f3..864772a2 100644 --- a/tests/unit_tests/cx_devices/devices_test.py +++ b/tests/unit_tests/cx_devices/devices_test.py @@ -1,23 +1,30 @@ -import pytest -from tests.test_utils import get_classes, get_controller +from typing import Callable, KeysView, List, Optional, Type import cx_devices as devices_module +import pytest +from cx_const import ActionEvent, TypeActionsMapping from cx_core import Controller from cx_core.controller import ReleaseHoldController +from tests.test_utils import get_classes, get_controller + -def check_mapping(mapping, all_possible_actions, device): +def check_mapping( + mapping: Optional[TypeActionsMapping], + all_possible_actions: KeysView[ActionEvent], + device: Controller, +) -> None: device_name = device.__class__.__name__ if mapping is None: return - if issubclass(device.__class__, ReleaseHoldController): + if isinstance(device, ReleaseHoldController): delay = device.default_delay() if delay < 0: raise ValueError( f"`default_delay` should be a positive integer and the value is `{delay}`. " f"Device class: {device_name}" ) - for k, v in mapping.items(): + for v in mapping.values(): if not isinstance(v, str): raise ValueError( "The value from the mapping should be a string, matching " @@ -29,7 +36,7 @@ def check_mapping(mapping, all_possible_actions, device): if v not in all_possible_actions: raise ValueError( f"{device_name}: `{v}` not found in the list of possible action from the controller. " - + f"The possible actions are: {all_possible_actions}" + + f"The possible actions are: {list(all_possible_actions)}" ) @@ -39,8 +46,8 @@ def check_mapping(mapping, all_possible_actions, device): @pytest.mark.parametrize("device_class", devices_classes) -def test_devices(hass_mock, device_class): - device = device_class() +def test_devices(device_class: Type[Controller]): + device = device_class() # type: ignore # We first check that all devices are importable from controllerx module device_from_controllerx = get_controller("controllerx", device_class.__name__) @@ -49,10 +56,8 @@ def test_devices(hass_mock, device_class): ), f"'{device_class.__name__}' not importable from controllerx.py" type_actions_mapping = device.get_type_actions_mapping() - if type_actions_mapping is None: - return - possible_actions = list(type_actions_mapping.keys()) - integration_mappings_funcs = [ + possible_actions = type_actions_mapping.keys() + integration_mappings_funcs: List[Callable[[], Optional[TypeActionsMapping]]] = [ device.get_z2m_actions_mapping, device.get_deconz_actions_mapping, device.get_zha_actions_mapping, diff --git a/tests/unit_tests/cx_devices/legrand_test.py b/tests/unit_tests/cx_devices/legrand_test.py index 24fad88d..d99737d4 100644 --- a/tests/unit_tests/cx_devices/legrand_test.py +++ b/tests/unit_tests/cx_devices/legrand_test.py @@ -1,5 +1,5 @@ import pytest - +from cx_core.integration import EventData from cx_devices.legrand import get_zha_action_LegrandWallController @@ -16,6 +16,6 @@ ({"endpoint_id": 2, "command": "stop"}, "2_stop"), ], ) -def test_get_zha_action_LegrandWallController(data, expected_action): +def test_get_zha_action_LegrandWallController(data: EventData, expected_action: str): action = get_zha_action_LegrandWallController(data) assert action == expected_action diff --git a/tests/unit_tests/cx_devices/phillips_test.py b/tests/unit_tests/cx_devices/phillips_test.py index de8b84a6..f94bdd98 100644 --- a/tests/unit_tests/cx_devices/phillips_test.py +++ b/tests/unit_tests/cx_devices/phillips_test.py @@ -1,4 +1,5 @@ import pytest +from cx_core.integration import EventData from cx_devices.phillips import HueDimmerController @@ -10,7 +11,7 @@ ({"command": "off_hold"}, "off_hold"), ], ) -def test_zha_action_HueDimmerController(data, expected_action): +def test_zha_action_HueDimmerController(data: EventData, expected_action: str): sut = HueDimmerController() # type: ignore action = sut.get_zha_action(data) assert action == expected_action