Skip to content

Commit

Permalink
feat: segregate connect/auth/refresh/enable device duties (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
chemelli74 committed Jul 24, 2024
1 parent c9e7d9e commit 681bd79
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 97 deletions.
106 changes: 53 additions & 53 deletions midealocal/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import IntEnum, StrEnum
from typing import Any

from .exceptions import SocketException
from .exceptions import CannotConnect, SocketException
from .message import (
MessageApplianceResponse,
MessageQueryAppliance,
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
self._updates: list[Callable[[dict[str, Any]], None]] = []
self._unsupported_protocol: list[str] = []
self._is_run = False
self._available = True
self._available = False
self._appliance_query = True
self._refresh_interval = 30
self._heartbeat_interval = 10
Expand Down Expand Up @@ -190,67 +190,66 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]:
break
return result, msg

def connect(
self,
refresh_status: bool = True,
get_capabilities: bool = True,
) -> bool:
def _authenticate_refresh_capabilities(self) -> None:
if self._protocol == ProtocolVersion.V3:
self.authenticate()
self.refresh_status(wait_response=True)
self.get_capabilities()

def connect(self) -> bool:
"""Connect to device."""
connected = False
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
if self._protocol == ProtocolVersion.V3:
self.authenticate()
_LOGGER.debug("[%s] Authentication success", self._device_id)
if refresh_status:
self.refresh_status(wait_response=True)
if get_capabilities:
self.get_capabilities()
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
for _ in range(3):
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(
"[%s] Connecting to %s:%s",
self._device_id,
self._ip_address,
self._port,
)
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug("[%s] Connected", self._device_id)
connected = True
except TimeoutError:
_LOGGER.debug("[%s] Connection timed out", self._device_id)
except OSError:
_LOGGER.debug("[%s] Connection error", self._device_id)
except AuthException:
_LOGGER.debug("[%s] Authentication failed", self._device_id)
except RefreshFailed:
_LOGGER.debug("[%s] Refresh status is timed out", self._device_id)
except Exception as e:
file = None
lineno = None
if e.__traceback__:
file = e.__traceback__.tb_frame.f_globals["__file__"] # pylint: disable=E1101
lineno = e.__traceback__.tb_lineno
_LOGGER.exception(
"[%s] Unknown error : %s, %s",
self._device_id,
file,
lineno,
)
self.enable_device(connected)
return connected

def authenticate(self) -> None:
"""Authenticate to device. V3 only."""
request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST)
_LOGGER.debug("[%s] Handshaking", self._device_id)
_LOGGER.debug("[%s] Authentication handshaking", self._device_id)
if not self._socket:
self.enable_device(False)
raise SocketException
self._socket.send(request)
response = self._socket.recv(512)
if len(response) < MIN_AUTH_RESPONSE:
self.enable_device(False)
raise AuthException
response = response[8:72]
self._security.tcp_key(response, self._key)
_LOGGER.debug("[%s] Authentication success", self._device_id)

def send_message(self, data: bytes) -> None:
"""Send message."""
Expand Down Expand Up @@ -462,6 +461,7 @@ def update_all(self, status: dict[str, Any]) -> None:

def enable_device(self, available: bool = True) -> None:
"""Enable device."""
_LOGGER.debug("[%s] Enabling device", self._device_id)
self._available = available
status = {"available": available}
self.update_all(status)
Expand Down Expand Up @@ -510,14 +510,14 @@ def _check_heartbeat(self, now: float) -> None:
def run(self) -> None:
"""Run loop."""
while self._is_run:
while self._socket is None:
if self.connect(refresh_status=True) is False:
self.close_socket()
time.sleep(5)
if not self.connect():
raise CannotConnect
if not self._socket:
raise SocketException
self._authenticate_refresh_capabilities()
timeout_counter = 0
start = time.time()
self._previous_refresh = start
self._previous_heartbeat = start
self._previous_refresh = self._previous_heartbeat = start
self._socket.settimeout(1)
while True:
try:
Expand Down
4 changes: 4 additions & 0 deletions midealocal/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class CannotAuthenticate(MideaLocalError):
"""Exception raised when credentials are incorrect."""


class CannotConnect(MideaLocalError):
"""Exception raised when connection fails."""


class DataUnexpectedLength(MideaLocalError):
"""Exception raised when data length is less or more than expected."""

Expand Down
60 changes: 16 additions & 44 deletions tests/device_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Midea Local device test."""

from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -28,7 +27,7 @@ def test_fetch_v2_message() -> None:
)


class MideaDeviceTest(IsolatedAsyncioTestCase):
class MideaDeviceTest:
"""Midea device test case."""

device: MideaDevice
Expand Down Expand Up @@ -59,55 +58,28 @@ def test_initial_attributes(self) -> None:
assert self.device.model == "test_model"
assert self.device.subtype == 1

def test_connect(self) -> None:
@pytest.mark.parametrize(
("exc", "result"),
[
(TimeoutError, False),
(OSError, False),
(AuthException, False),
(RefreshFailed, False),
(None, True),
],
)
def test_connect(self, exc: Exception, result: bool) -> None:
"""Test connect."""
with (
patch("socket.socket.connect") as connect_mock,
patch.object(
self.device,
"authenticate",
side_effect=[AuthException(), None, None],
),
patch.object(
self.device,
"refresh_status",
side_effect=[RefreshFailed(), None],
),
patch.object(
self.device,
"get_capabilities",
side_effect=[None],
),
):
connect_mock.side_effect = [
TimeoutError(),
OSError(),
None,
None,
None,
None,
]
assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is False
assert self.device.available is False

assert self.device.connect(True, True) is True
assert self.device.available is True
with patch("socket.socket.connect", side_effect=exc):
assert self.device.connect() is result
assert self.device.available is result

def test_connect_generic_exception(self) -> None:
"""Test connect with generic exception."""
with patch("socket.socket.connect") as connect_mock:
connect_mock.side_effect = Exception()

assert self.device.connect(True, True) is False
assert self.device.connect() is False
assert self.device.available is False

def test_authenticate(self) -> None:
Expand Down

0 comments on commit 681bd79

Please sign in to comment.