Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle cat sensor interrupt on LR4 and better support around token init/refresh #172

Merged
merged 2 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pylitterbot/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import logging
from collections.abc import Callable
from typing import TypeVar, cast

from aiohttp import (
Expand All @@ -12,6 +13,7 @@
ClientWebSocketResponse,
)

from .event import EVENT_UPDATE
from .exceptions import LitterRobotException, LitterRobotLoginException
from .robot import Robot
from .robot.feederrobot import FEEDER_ENDPOINT, FEEDER_ROBOT_MODEL, FeederRobot
Expand All @@ -29,7 +31,10 @@ class Account:
"""Class with data and methods for interacting with a user's Litter-Robots."""

def __init__(
self, token: dict | None = None, websession: ClientSession | None = None
self,
token: dict | None = None,
websession: ClientSession | None = None,
token_update_callback: Callable[[dict | None], None] | None = None,
) -> None:
"""Initialize the account data."""
self._session = LitterRobotSession(token=token, websession=websession)
Expand All @@ -41,6 +46,12 @@ def __init__(
self._robots: list[Robot] = []
self._monitors: dict[type[Robot], WebSocketMonitor] = {}

if token_update_callback:
self._session.on(
EVENT_UPDATE,
lambda session=self._session: token_update_callback(session.tokens), # type: ignore
)

@property
def user_id(self) -> str | None:
"""Return the logged in user's id."""
Expand Down Expand Up @@ -79,7 +90,9 @@ async def connect(
"""Connect to the Litter-Robot API."""
try:
if not self.session.is_token_valid():
if username and password:
if self.session.has_refresh_token():
await self.session.refresh_token()
elif username and password:
await self.session.login(username=username, password=password)
else:
raise LitterRobotLoginException(
Expand Down
36 changes: 36 additions & 0 deletions pylitterbot/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Event handling class for pylitterbot."""
from __future__ import annotations

from abc import ABC
from collections.abc import Callable
from typing import Any

EVENT_UPDATE = "update"


class Event(ABC):
"""Abstract event class properties and methods."""

_listeners: dict[str, list[Callable]] = {}

def emit(self, event_name: str, *args: Any, **kwargs: Any) -> None:
"""Run all callbacks for an event."""
for listener in self._listeners.get(event_name, []):
try:
listener(*args, **kwargs)
except: # pragma: no cover # pylint: disable=bare-except # noqa: E722
pass

def on( # pylint: disable=invalid-name
self, event_name: str, callback: Callable
) -> Callable:
"""Register an event callback."""
listeners: list = self._listeners.setdefault(event_name, [])
listeners.append(callback)

def unsubscribe() -> None:
"""Unsubscribe listeners."""
if callback in listeners:
listeners.remove(callback)

return unsubscribe
28 changes: 2 additions & 26 deletions pylitterbot/robot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
from aiohttp import ClientWebSocketResponse
from deepdiff import DeepDiff

from ..event import EVENT_UPDATE, Event
from ..utils import to_timestamp, urljoin

if TYPE_CHECKING:
from ..account import Account

_LOGGER = logging.getLogger(__name__)

EVENT_UPDATE = "update"


class Robot:
class Robot(Event):
"""Robot base class."""

_attr_model: str
Expand All @@ -38,7 +37,6 @@ def __init__(self, data: dict, account: Account) -> None:
self._account = account

self._is_loaded = False
self._listeners: dict[str, list[Callable]] = {}

self._ws: ClientWebSocketResponse | None = None
self._ws_subscription_id: str | None = None
Expand Down Expand Up @@ -102,28 +100,6 @@ def setup_date(self) -> datetime | None:
"""Return the datetime the robot was onboarded, if any."""
return to_timestamp(self._data.get(self._data_setup_date))

def emit(self, event_name: str, *args: Any, **kwargs: Any) -> None:
"""Run all callbacks for an event."""
for listener in self._listeners.get(event_name, []):
try:
listener(*args, **kwargs)
except: # pragma: no cover # pylint: disable=bare-except # noqa: E722
pass

def on( # pylint: disable=invalid-name
self, event_name: str, callback: Callable
) -> Callable:
"""Register an event callback."""
listeners: list = self._listeners.setdefault(event_name, [])
listeners.append(callback)

def unsubscribe() -> None:
"""Unsubscribe listeners."""
if callback in listeners:
listeners.remove(callback)

return unsubscribe

@abstractmethod
async def refresh(self) -> None:
"""Refresh the robot data from the API."""
Expand Down
5 changes: 4 additions & 1 deletion pylitterbot/robot/litterrobot4.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
"robotCycleStatusIdle": LitterBoxStatus.CLEAN_CYCLE_COMPLETE,
"robotStatusCatDetect": LitterBoxStatus.CAT_DETECTED,
}
CYCLE_STATE_STATUS_MAP = {"CYCLE_STATE_PAUSE": LitterBoxStatus.PAUSED}
CYCLE_STATE_STATUS_MAP = {
"CYCLE_STATE_CAT_DETECT": LitterBoxStatus.CAT_SENSOR_INTERRUPTED,
"CYCLE_STATE_PAUSE": LitterBoxStatus.PAUSED,
}

LITTER_LEVEL_EMPTY = 500

Expand Down
56 changes: 43 additions & 13 deletions pylitterbot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,28 @@

import logging
from abc import ABC, abstractmethod
from threading import Lock
from types import TracebackType
from typing import Any, TypeVar, cast

import jwt
from aiohttp import ClientSession

from .event import EVENT_UPDATE, Event
from .exceptions import InvalidCommandException
from .utils import decode, redact, utcnow
from .utils import decode, first_value, redact, utcnow

T = TypeVar("T", bound="Session")

_LOGGER = logging.getLogger(__name__)


class Session(ABC):
class Session(Event, ABC):
"""Abstract session class."""

_token: dict | None = None
_lock = Lock()

def __init__(self, websession: ClientSession | None = None) -> None:
"""Initialize the session."""
self._websession_provided = websession is not None
Expand All @@ -32,6 +37,18 @@ def websession(self) -> ClientSession:
self._websession = ClientSession()
return self._websession

@property
def tokens(self) -> dict[str, str | None] | None:
"""Return the tokens."""
if not self._token:
return None
return {
"id_token": first_value(self._token, ("id_token", "idToken")),
"refresh_token": first_value(
self._token, ("refresh_token", "refreshToken")
),
}

async def close(self) -> None:
"""Close the session."""
if not self._websession_provided and self.websession is not None:
Expand All @@ -57,9 +74,19 @@ async def async_get_access_token(self, **kwargs: Any) -> str | None:
def is_token_valid(self) -> bool:
"""Return `True` if the token is stills valid."""

@abstractmethod
async def refresh_token(self) -> None:
async def refresh_token(self, ignore_unexpired: bool = False) -> None:
"""Refresh the access token."""
if self._token is None:
return None
with self._lock:
if not ignore_unexpired and self.is_token_valid():
return
self._token = await self._refresh_token()
self.emit(EVENT_UPDATE)

@abstractmethod
async def _refresh_token(self) -> dict:
"""Actual implementation to refresh the access token."""

async def get_bearer_authorization(self) -> str | None:
"""Get the bearer authorization."""
Expand Down Expand Up @@ -139,6 +166,7 @@ def __init__(

self._token = token
self._custom_args: dict = {}
self._lock = Lock()

def generate_args(self, url: str, **kwargs: Any) -> dict[str, Any]:
"""Generate args."""
Expand All @@ -157,7 +185,7 @@ def is_token_valid(self) -> bool:
return False
try:
jwt.decode(
self._token.get("access_token", self._token.get("idToken")),
first_value(self._token, ("id_token", "idToken")),
options={"verify_signature": False, "verify_exp": True},
leeway=-30,
)
Expand All @@ -169,7 +197,7 @@ async def async_get_access_token(self, **kwargs: Any) -> str | None:
"""Return a valid access token."""
if self._token is None or not self.is_token_valid():
return None
return self._token.get("access_token", self._token.get("idToken"))
return first_value(self._token, ("id_token", "idToken"))

async def login(self, username: str, password: str) -> None:
"""Login to the Litter-Robot api and generate a new token."""
Expand All @@ -189,23 +217,21 @@ async def login(self, username: str, password: str) -> None:
)
self._token = cast(dict, data)

async def refresh_token(self) -> None:
async def _refresh_token(self) -> dict:
"""Refresh the access token."""
if self._token is None:
return None
data = await self.post(
self.TOKEN_REFRESH_ENDPOINT,
skip_auth=True,
headers={"x-ios-bundle-identifier": "com.whisker.ios"},
params={"key": decode(self.TOKEN_KEY)},
json={
"grantType": "refresh_token",
"refreshToken": self._token.get(
"refresh_token", self._token.get("refreshToken")
"refreshToken": first_value(
self._token, ("refresh_token", "refreshToken")
),
},
)
self._token = cast(dict, data)
return cast(dict, data)

async def request(
self, method: str, url: str, **kwargs: Any
Expand All @@ -221,7 +247,11 @@ def get_user_id(self) -> str | None:
if self._token is None:
return None
user_id = jwt.decode(
self._token.get("idToken"),
first_value(self._token, ("id_token", "idToken")),
options={"verify_signature": False, "verify_exp": False},
)["mid"]
return cast(str, user_id)

def has_refresh_token(self) -> bool:
"""Return `True` if the session has a refresh token."""
return first_value(self._token, ("refresh_token", "refreshToken")) is not None
22 changes: 20 additions & 2 deletions pylitterbot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging
import re
from base64 import b64decode, b64encode
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from datetime import datetime, time, timezone
from typing import TypeVar, cast, overload
from typing import Any, TypeVar, cast, overload
from urllib.parse import urljoin as _urljoin
from warnings import warn

Expand All @@ -18,7 +18,10 @@
REDACTED = "**REDACTED**"
REDACT_FIELDS = [
"token",
"access_token",
"id_token",
"idToken",
"refresh_token",
"refreshToken",
"userId",
"userEmail",
Expand Down Expand Up @@ -133,3 +136,18 @@ def redact(data: _T) -> _T:
redacted[key] = [redact(item) for item in value]

return cast(_T, redacted)


def first_value(
data: dict | None,
keys: Iterable,
default: Any | None = None,
return_none: bool = False,
) -> Any | None:
"""Return the first valid key's value."""
if not data:
return default
for key in keys:
if key in data and ((value := data[key]) is not None or return_none):
return value
return default
2 changes: 1 addition & 1 deletion tests/test_account.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test account module."""
from __future__ import annotations
import logging

import logging
from unittest.mock import patch

import pytest
Expand Down
4 changes: 4 additions & 0 deletions tests/test_litterrobot4.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ async def test_litter_robot_4_cleaning(mock_account: Account) -> None:
{"robotCycleState": "CYCLE_STATE_PAUSE", "robotStatus": "ROBOT_CLEAN"},
LitterBoxStatus.PAUSED,
),
(
{"robotCycleState": "CYCLE_STATE_CAT_DETECT"},
LitterBoxStatus.CAT_SENSOR_INTERRUPTED,
),
({"robotStatus": "ROBOT_BONNET"}, LitterBoxStatus.BONNET_REMOVED),
({"robotStatus": "ROBOT_CAT_DETECT"}, LitterBoxStatus.CAT_DETECTED),
({"robotStatus": "ROBOT_CAT_DETECT_DELAY"}, LitterBoxStatus.CAT_SENSOR_TIMING),
Expand Down
9 changes: 6 additions & 3 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
pytestmark = pytest.mark.asyncio

EXPIRED_ACCESS_TOKEN = {
"access_token": jwt.encode(
{"exp": datetime.now(tz=timezone.utc) - timedelta(hours=1)},
"secret",
"access_token": (
token := jwt.encode(
{"exp": datetime.now(tz=timezone.utc) - timedelta(hours=1)},
"secret",
)
),
"id_token": token,
"refresh_token": "some_refresh_token",
}

Expand Down
13 changes: 12 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test utils module."""
from pylitterbot.utils import decode, encode, round_time, to_timestamp

from pylitterbot.utils import REDACTED, decode, encode, redact, round_time, to_timestamp


def test_round_time_default() -> None:
Expand All @@ -20,3 +21,13 @@ def test_encode_decode() -> None:
assert (encoded := encode(value)) == "dGVzdA=="
assert decode(encoded) == value
assert encode({value: value}) == "eyJ0ZXN0IjogInRlc3QifQ=="


def test_redact() -> None:
"""Test redacting values from a dictionary."""
assert redact({"litterRobotId": None}) == {"litterRobotId": None}
assert redact({"litterRobotId": "someId"}) == {"litterRobotId": REDACTED}

data = {"key": "value"}
assert redact(data) == data
assert redact([data, data]) == [data, data]