diff --git a/midealocal/device.py b/midealocal/device.py index 10d92ded..ad15bf18 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -8,7 +8,7 @@ from enum import IntEnum, StrEnum from typing import Any -from .exceptions import SocketException +from .exceptions import CannotConnect, SocketException from .message import ( MessageApplianceResponse, MessageQueryAppliance, @@ -140,7 +140,7 @@ def __init__( self._updates: list[Callable[[dict[str, Any]], None]] = [] self._unsupported_protocol: list[str] = [] self._is_run = False - self._available = True + self._available = False self._appliance_query = True self._refresh_interval = 30 self._heartbeat_interval = 10 @@ -190,67 +190,66 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]: break return result, msg - def connect( - self, - refresh_status: bool = True, - get_capabilities: bool = True, - ) -> bool: + def _authenticate_refresh_capabilities(self) -> None: + if self._protocol == ProtocolVersion.V3: + self.authenticate() + self.refresh_status(wait_response=True) + self.get_capabilities() + + def connect(self) -> bool: """Connect to device.""" connected = False - try: - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.settimeout(10) - _LOGGER.debug( - "[%s] Connecting to %s:%s", - self._device_id, - self._ip_address, - self._port, - ) - self._socket.connect((self._ip_address, self._port)) - _LOGGER.debug("[%s] Connected", self._device_id) - if self._protocol == ProtocolVersion.V3: - self.authenticate() - _LOGGER.debug("[%s] Authentication success", self._device_id) - if refresh_status: - self.refresh_status(wait_response=True) - if get_capabilities: - self.get_capabilities() - connected = True - except TimeoutError: - _LOGGER.debug("[%s] Connection timed out", self._device_id) - except OSError: - _LOGGER.debug("[%s] Connection error", self._device_id) - except AuthException: - _LOGGER.debug("[%s] Authentication failed", self._device_id) - except RefreshFailed: - _LOGGER.debug("[%s] Refresh status is timed out", self._device_id) - except Exception as e: - file = None - lineno = None - if e.__traceback__: - file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 - lineno = e.__traceback__.tb_lineno - _LOGGER.exception( - "[%s] Unknown error : %s, %s", - self._device_id, - file, - lineno, - ) + for _ in range(3): + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.settimeout(10) + _LOGGER.debug( + "[%s] Connecting to %s:%s", + self._device_id, + self._ip_address, + self._port, + ) + self._socket.connect((self._ip_address, self._port)) + _LOGGER.debug("[%s] Connected", self._device_id) + connected = True + except TimeoutError: + _LOGGER.debug("[%s] Connection timed out", self._device_id) + except OSError: + _LOGGER.debug("[%s] Connection error", self._device_id) + except AuthException: + _LOGGER.debug("[%s] Authentication failed", self._device_id) + except RefreshFailed: + _LOGGER.debug("[%s] Refresh status is timed out", self._device_id) + except Exception as e: + file = None + lineno = None + if e.__traceback__: + file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101 + lineno = e.__traceback__.tb_lineno + _LOGGER.exception( + "[%s] Unknown error : %s, %s", + self._device_id, + file, + lineno, + ) self.enable_device(connected) return connected def authenticate(self) -> None: """Authenticate to device. V3 only.""" request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST) - _LOGGER.debug("[%s] Handshaking", self._device_id) + _LOGGER.debug("[%s] Authentication handshaking", self._device_id) if not self._socket: + self.enable_device(False) raise SocketException self._socket.send(request) response = self._socket.recv(512) if len(response) < MIN_AUTH_RESPONSE: + self.enable_device(False) raise AuthException response = response[8:72] self._security.tcp_key(response, self._key) + _LOGGER.debug("[%s] Authentication success", self._device_id) def send_message(self, data: bytes) -> None: """Send message.""" @@ -462,6 +461,7 @@ def update_all(self, status: dict[str, Any]) -> None: def enable_device(self, available: bool = True) -> None: """Enable device.""" + _LOGGER.debug("[%s] Enabling device", self._device_id) self._available = available status = {"available": available} self.update_all(status) @@ -510,14 +510,14 @@ def _check_heartbeat(self, now: float) -> None: def run(self) -> None: """Run loop.""" while self._is_run: - while self._socket is None: - if self.connect(refresh_status=True) is False: - self.close_socket() - time.sleep(5) + if not self.connect(): + raise CannotConnect + if not self._socket: + raise SocketException + self._authenticate_refresh_capabilities() timeout_counter = 0 start = time.time() - self._previous_refresh = start - self._previous_heartbeat = start + self._previous_refresh = self._previous_heartbeat = start self._socket.settimeout(1) while True: try: diff --git a/midealocal/exceptions.py b/midealocal/exceptions.py index f2c7ba95..7cd5c74d 100644 --- a/midealocal/exceptions.py +++ b/midealocal/exceptions.py @@ -11,6 +11,10 @@ class CannotAuthenticate(MideaLocalError): """Exception raised when credentials are incorrect.""" +class CannotConnect(MideaLocalError): + """Exception raised when connection fails.""" + + class DataUnexpectedLength(MideaLocalError): """Exception raised when data length is less or more than expected.""" diff --git a/tests/device_test.py b/tests/device_test.py index 491813be..88b8ba03 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -1,6 +1,5 @@ """Midea Local device test.""" -from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch import pytest @@ -28,7 +27,7 @@ def test_fetch_v2_message() -> None: ) -class MideaDeviceTest(IsolatedAsyncioTestCase): +class MideaDeviceTest: """Midea device test case.""" device: MideaDevice @@ -59,55 +58,28 @@ def test_initial_attributes(self) -> None: assert self.device.model == "test_model" assert self.device.subtype == 1 - def test_connect(self) -> None: + @pytest.mark.parametrize( + ("exc", "result"), + [ + (TimeoutError, False), + (OSError, False), + (AuthException, False), + (RefreshFailed, False), + (None, True), + ], + ) + def test_connect(self, exc: Exception, result: bool) -> None: """Test connect.""" - with ( - patch("socket.socket.connect") as connect_mock, - patch.object( - self.device, - "authenticate", - side_effect=[AuthException(), None, None], - ), - patch.object( - self.device, - "refresh_status", - side_effect=[RefreshFailed(), None], - ), - patch.object( - self.device, - "get_capabilities", - side_effect=[None], - ), - ): - connect_mock.side_effect = [ - TimeoutError(), - OSError(), - None, - None, - None, - None, - ] - assert self.device.connect(True, True) is False - assert self.device.available is False - - assert self.device.connect(True, True) is False - assert self.device.available is False - - assert self.device.connect(True, True) is False - assert self.device.available is False - - assert self.device.connect(True, True) is False - assert self.device.available is False - - assert self.device.connect(True, True) is True - assert self.device.available is True + with patch("socket.socket.connect", side_effect=exc): + assert self.device.connect() is result + assert self.device.available is result def test_connect_generic_exception(self) -> None: """Test connect with generic exception.""" with patch("socket.socket.connect") as connect_mock: connect_mock.side_effect = Exception() - assert self.device.connect(True, True) is False + assert self.device.connect() is False assert self.device.available is False def test_authenticate(self) -> None: