From 624ec138b2064f0d2eebcb2191c12f1898e43ea9 Mon Sep 17 00:00:00 2001 From: Erik Date: Sun, 28 Jan 2024 11:17:51 +0100 Subject: [PATCH] Add type annotations to socket_client.py --- mypy.ini | 2 +- pychromecast/socket_client.py | 285 +++++++++++++++++++--------------- 2 files changed, 162 insertions(+), 125 deletions(-) diff --git a/mypy.ini b/mypy.ini index 8d64b0f07..cd3d64138 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,4 +19,4 @@ disallow_untyped_decorators = true disallow_untyped_defs = true warn_return_any = true warn_unreachable = true -files = pychromecast/config.py, pychromecast/const.py, pychromecast/dial.py, pychromecast/discovery.py, pychromecast/error.py, pychromecast/models.py, pychromecast/response_handler.py, pychromecast/controllers/__init__.py, pychromecast/controllers/media.py, pychromecast/controllers/multizone.py, pychromecast/controllers/receiver.py +files = pychromecast/config.py, pychromecast/const.py, pychromecast/dial.py, pychromecast/discovery.py, pychromecast/error.py, pychromecast/models.py, pychromecast/response_handler.py, pychromecast/socket_client.py, pychromecast/controllers/__init__.py, pychromecast/controllers/media.py, pychromecast/controllers/multizone.py, pychromecast/controllers/receiver.py diff --git a/pychromecast/socket_client.py b/pychromecast/socket_client.py index 7d742bc9b..7f8f54b20 100644 --- a/pychromecast/socket_client.py +++ b/pychromecast/socket_client.py @@ -6,6 +6,7 @@ Without him this would not have been possible. """ # pylint: disable=too-many-lines +from __future__ import annotations import abc from dataclasses import dataclass @@ -20,18 +21,24 @@ from collections import defaultdict from struct import pack, unpack -from .controllers import BaseController +import zeroconf + +from .controllers import CallbackType, BaseController from .controllers.media import MediaController -from .controllers.receiver import ReceiverController +from .controllers.receiver import CastStatus, CastStatusListener, ReceiverController from .const import MESSAGE_TYPE, REQUEST_ID, SESSION_ID from .dial import get_host_from_service from .error import ( ChromecastConnectionError, + ControllerNotRegistered, UnsupportedNamespace, NotConnected, PyChromecastStopped, ) -from .generated import cast_channel_pb2 + +# pylint: disable-next=no-name-in-module +from .generated.cast_channel_pb2 import CastMessage +from .models import HostServiceInfo, MDNSServiceInfo NS_CONNECTION = "urn:x-cast:com.google.cast.tp.connection" NS_HEARTBEAT = "urn:x-cast:com.google.cast.tp.heartbeat" @@ -63,15 +70,15 @@ HB_PONG_TIME = 10 POLL_TIME_BLOCKING = 5.0 POLL_TIME_NON_BLOCKING = 0.01 -TIMEOUT_TIME = 30 -RETRY_TIME = 5 +TIMEOUT_TIME = 30.0 +RETRY_TIME = 5.0 class InterruptLoop(Exception): """The chromecast has been manually stopped.""" -def _dict_from_message_payload(message): +def _dict_from_message_payload(message: CastMessage) -> dict: """Parses a PB2 message as a JSON dict.""" try: data = json.loads(message.payload_utf8) @@ -94,7 +101,10 @@ def _dict_from_message_payload(message): return {} -def _message_to_string(message, data=None): +def _message_to_string( + message: CastMessage, + data: dict | None = None, +) -> str: """Gives a string representation of a PB2 message.""" if data is None: data = _dict_from_message_payload(message) @@ -123,19 +133,20 @@ class ConnectionStatus: """Connection status container.""" status: str - address: NetworkAddress + address: NetworkAddress | None + service: HostServiceInfo | MDNSServiceInfo | None class ConnectionStatusListener(abc.ABC): """Listener for receiving connection status events.""" @abc.abstractmethod - def new_connection_status(self, status: ConnectionStatus): + def new_connection_status(self, status: ConnectionStatus) -> None: """Updated connection status.""" # pylint: disable-next=too-many-instance-attributes -class SocketClient(threading.Thread): +class SocketClient(threading.Thread, CastStatusListener): """ Class to interact with a Chromecast through a socket. @@ -167,13 +178,13 @@ class SocketClient(threading.Thread): def __init__( self, *, - cast_type, - tries, - timeout, - retry_wait, - services, - zconf, - ): + cast_type: str, + tries: int | None, + timeout: float | None, + retry_wait: float | None, + services: set[HostServiceInfo | MDNSServiceInfo], + zconf: zeroconf.Zeroconf | None, + ) -> None: super().__init__() self.daemon = True @@ -183,7 +194,7 @@ def __init__( self._force_recon = False self.cast_type = cast_type - self.fn = None # pylint:disable=invalid-name + self.fn: str | None = None # pylint:disable=invalid-name self.tries = tries self.timeout = timeout or TIMEOUT_TIME self.retry_wait = retry_wait or RETRY_TIME @@ -198,21 +209,20 @@ def __init__( # socketpair used to interrupt the worker thread self.socketpair = socket.socketpair() - self.app_namespaces = [] - self.destination_id = None - self.session_id = None + self.app_namespaces: list[str] = [] + self.destination_id: str | None = None + self.session_id: str | None = None self._request_id = 0 - # dict mapping requestId to callback functions - self._request_callbacks = {} - self._open_channels = [] + self._request_callbacks: dict[int, CallbackType] = {} + self._open_channels: list[str] = [] self.connecting = True self.first_connection = True - self.socket = None + self.socket: socket.socket | ssl.SSLSocket | None = None # dict mapping namespace on Controller objects - self._handlers = defaultdict(set) - self._connection_listeners = [] + self._handlers: dict[str, set[BaseController]] = defaultdict(set) + self._connection_listeners: list[ConnectionStatusListener] = [] self.receiver_controller = ReceiverController(cast_type) self.media_controller = MediaController() @@ -225,9 +235,9 @@ def __init__( self.receiver_controller.register_status_listener(self) - def initialize_connection( + def initialize_connection( # pylint:disable=too-many-statements, too-many-branches self, - ): # pylint:disable=too-many-statements, too-many-branches + ) -> None: """Initialize a socket to a Chromecast, retrying as necessary.""" tries = self.tries @@ -250,9 +260,12 @@ def initialize_connection( retry_log_fun = self.logger.error # Dict keeping track of individual retry delay for each named service - retries = {} + retries: dict[HostServiceInfo | MDNSServiceInfo, dict[str, float]] = {} - def mdns_backoff(service, retry): + def mdns_backoff( + service: HostServiceInfo | MDNSServiceInfo, + retry: dict[str, float], + ) -> None: """Exponentional backoff for service name mdns lookups.""" now = time.time() retry["next_retry"] = now + retry["delay"] @@ -278,7 +291,8 @@ def mdns_backoff(service, retry): continue try: if self.socket is not None: - self.socket.close() + # If we retry connecting, we need to clean up the socket again + self.socket.close() # type: ignore[unreachable] self.socket = None self.socket = new_socket() @@ -287,6 +301,7 @@ def mdns_backoff(service, retry): ConnectionStatus( CONNECTION_STATUS_CONNECTING, NetworkAddress(self.host, self.port), + None, ) ) # Resolve the service name. @@ -298,7 +313,8 @@ def mdns_backoff(service, retry): if host and port: if service_info: try: - self.fn = service_info.properties[b"fn"].decode("utf-8") + # Mypy does not understand that we catch errors, ignore it + self.fn = service_info.properties[b"fn"].decode("utf-8") # type: ignore[union-attr] except (AttributeError, KeyError, UnicodeError): pass self.logger.debug( @@ -323,7 +339,8 @@ def mdns_backoff(service, retry): self._report_connection_status( ConnectionStatus( CONNECTION_STATUS_FAILED_RESOLVE, - NetworkAddress(service, None), + None, + service, ) ) mdns_backoff(service, retry) @@ -348,6 +365,7 @@ def mdns_backoff(service, retry): ConnectionStatus( CONNECTION_STATUS_CONNECTED, NetworkAddress(self.host, self.port), + None, ) ) self.receiver_controller.update_status() @@ -390,26 +408,18 @@ def mdns_backoff(service, retry): ConnectionStatus( CONNECTION_STATUS_FAILED, NetworkAddress(self.host, self.port), + None, ) ) - if service is not None: - retry_log_fun( - "[%s(%s):%s] Failed to connect to service %s, retrying in %.1fs", - self.fn or "", - self.host, - self.port, - service, - retry["delay"], - ) - mdns_backoff(service, retry) - else: - retry_log_fun( - "[%s(%s):%s] Failed to connect, retrying in %.1fs", - self.fn or "", - self.host, - self.port, - self.retry_wait, - ) + retry_log_fun( + "[%s(%s):%s] Failed to connect to service %s, retrying in %.1fs", + self.fn or "", + self.host, + self.port, + service, + retry["delay"], + ) + mdns_backoff(service, retry) retry_log_fun = self.logger.debug # Only sleep if we have another retry remaining @@ -436,7 +446,7 @@ def mdns_backoff(service, retry): ) raise ChromecastConnectionError("Failed to connect") - def connect(self): + def connect(self) -> None: """Connect socket connection to Chromecast device. Must only be called if the worker thread will not be started. @@ -446,12 +456,14 @@ def connect(self): except ChromecastConnectionError: self._report_connection_status( ConnectionStatus( - CONNECTION_STATUS_DISCONNECTED, NetworkAddress(self.host, self.port) + CONNECTION_STATUS_DISCONNECTED, + NetworkAddress(self.host, self.port), + None, ) ) return - def disconnect(self): + def disconnect(self) -> None: """Disconnect socket connection to Chromecast device""" self.stop.set() try: @@ -461,13 +473,13 @@ def disconnect(self): # The socketpair may already be closed during shutdown, ignore it pass - def register_handler(self, handler: BaseController): + def register_handler(self, handler: BaseController) -> None: """Register a new namespace handler.""" self._handlers[handler.namespace].add(handler) handler.registered(self) - def unregister_handler(self, handler: BaseController): + def unregister_handler(self, handler: BaseController) -> None: """Register a new namespace handler.""" if ( handler.namespace in self._handlers @@ -477,18 +489,18 @@ def unregister_handler(self, handler: BaseController): handler.unregistered() - def new_cast_status(self, cast_status): + def new_cast_status(self, status: CastStatus) -> None: """Called when a new cast status has been received.""" - new_channel = self.destination_id != cast_status.transport_id + new_channel = self.destination_id != status.transport_id - if new_channel: + if new_channel and self.destination_id is not None: self.disconnect_channel(self.destination_id) - self.app_namespaces = cast_status.namespaces - self.destination_id = cast_status.transport_id - self.session_id = cast_status.session_id + self.app_namespaces = status.namespaces + self.destination_id = status.transport_id + self.session_id = status.session_id - if new_channel: + if new_channel and self.destination_id is not None: # If any of the namespaces of the new app are supported # we will automatically connect to it to receive updates for namespace in self.app_namespaces: @@ -497,14 +509,14 @@ def new_cast_status(self, cast_status): for handler in set(self._handlers[namespace]): handler.channel_connected() - def _gen_request_id(self): + def _gen_request_id(self) -> int: """Generates a unique request id.""" self._request_id += 1 return self._request_id @property - def is_connected(self): + def is_connected(self) -> bool: """ Returns True if the client is connected, False if it is stopped (or trying to connect). @@ -512,21 +524,23 @@ def is_connected(self): return not self.connecting @property - def is_stopped(self): + def is_stopped(self) -> bool: """ Returns True if the connection has been stopped, False if it is running. """ return self.stop.is_set() - def run(self): + def run(self) -> None: """Connect to the cast and start polling the socket.""" try: self.initialize_connection() except ChromecastConnectionError: self._report_connection_status( ConnectionStatus( - CONNECTION_STATUS_DISCONNECTED, NetworkAddress(self.host, self.port) + CONNECTION_STATUS_DISCONNECTED, + NetworkAddress(self.host, self.port), + None, ) ) return @@ -550,7 +564,7 @@ def run(self): # Clean up self._cleanup() - def run_once(self, timeout=POLL_TIME_NON_BLOCKING): + def run_once(self, timeout: float = POLL_TIME_NON_BLOCKING) -> int: """ Use run_once() in your own main loop after you receive something on the socket (get_socket()). @@ -563,6 +577,9 @@ def run_once(self, timeout=POLL_TIME_NON_BLOCKING): except ChromecastConnectionError: return 1 + # A connection has been established at this point by self._check_connection + assert self.socket is not None + # poll the socket, as well as the socketpair to allow us to be interrupted rlist = [self.socket, self.socketpair[0]] try: @@ -589,8 +606,8 @@ def run_once(self, timeout=POLL_TIME_NON_BLOCKING): self._force_recon = True return 0 - # read messages from chromecast - message = data = None + # read message from chromecast + message = None if self.socket in can_read and not self._force_recon: try: message = self._read_message() @@ -643,18 +660,18 @@ def run_once(self, timeout=POLL_TIME_NON_BLOCKING): self._route_message(message, data) if REQUEST_ID in data and data[REQUEST_ID] in self._request_callbacks: - self._request_callbacks.pop(data[REQUEST_ID], None)(True, data) + self._request_callbacks.pop(data[REQUEST_ID])(True, data) return 0 - def get_socket(self): + def get_socket(self) -> socket.socket | ssl.SSLSocket | None: """ Returns the socket of the connection to use it in you own main loop. """ return self.socket - def _check_connection(self): + def _check_connection(self) -> bool: """ Checks if the connection is active, and if not reconnect @@ -687,7 +704,7 @@ def _check_connection(self): self.disconnect_channel(channel) self._report_connection_status( ConnectionStatus( - CONNECTION_STATUS_LOST, NetworkAddress(self.host, self.port) + CONNECTION_STATUS_LOST, NetworkAddress(self.host, self.port), None ) ) try: @@ -697,7 +714,7 @@ def _check_connection(self): return False return True - def _route_message(self, message, data: dict): + def _route_message(self, message: CastMessage, data: dict) -> None: """Route message to any handlers on the message namespace""" # route message to handlers if message.namespace in self._handlers: @@ -747,7 +764,7 @@ def _route_message(self, message, data: dict): _message_to_string(message, data), ) - def _cleanup(self): + def _cleanup(self) -> None: """Cleanup open channels and handlers""" for channel in self._open_channels: try: @@ -771,7 +788,9 @@ def _cleanup(self): ) self._report_connection_status( ConnectionStatus( - CONNECTION_STATUS_DISCONNECTED, NetworkAddress(self.host, self.port) + CONNECTION_STATUS_DISCONNECTED, + NetworkAddress(self.host, self.port), + None, ) ) @@ -780,7 +799,7 @@ def _cleanup(self): self.connecting = True - def _report_connection_status(self, status): + def _report_connection_status(self, status: ConnectionStatus) -> None: """Report a change in the connection status to any listeners""" for listener in self._connection_listeners: try: @@ -802,8 +821,11 @@ def _report_connection_status(self, status): self.port, ) - def _read_bytes_from_socket(self, msglen): + def _read_bytes_from_socket(self, msglen: int) -> bytes: """Read bytes from the socket.""" + # It is a programming error if this is called when we don't have a socket + assert self.socket is not None + chunks = [] bytes_recd = 0 while bytes_recd < msglen: @@ -815,7 +837,7 @@ def _read_bytes_from_socket(self, msglen): raise socket.error("socket connection broken") chunks.append(chunk) bytes_recd += len(chunk) - except socket.timeout: + except TimeoutError: self.logger.debug( "[%s(%s):%s] timeout in : _read_bytes_from_socket", self.fn or "", @@ -825,7 +847,7 @@ def _read_bytes_from_socket(self, msglen): continue return b"".join(chunks) - def _read_message(self): + def _read_message(self) -> CastMessage: """Reads a message from the socket and converts it to a message.""" # first 4 bytes is Big-Endian payload length payload_info = self._read_bytes_from_socket(4) @@ -834,7 +856,7 @@ def _read_message(self): # now read the payload payload = self._read_bytes_from_socket(read_len) - message = cast_channel_pb2.CastMessage() # pylint: disable=no-member + message = CastMessage() message.ParseFromString(payload) return message @@ -842,15 +864,15 @@ def _read_message(self): # pylint: disable=too-many-arguments, too-many-branches def send_message( self, - destination_id, - namespace, - data, + destination_id: str, + namespace: str, + data: dict, *, - inc_session_id=False, - callback_function=None, - no_add_request_id=False, - force=False, - ): + inc_session_id: bool = False, + callback_function: CallbackType | None = None, + no_add_request_id: bool = False, + force: bool = False, + ) -> None: """Send a message to the Chromecast.""" # namespace is a string containing namespace @@ -860,7 +882,6 @@ def send_message( # If channel is not open yet, connect to it. self._ensure_channel_connected(destination_id) - request_id = None if not no_add_request_id: request_id = self._gen_request_id() data[REQUEST_ID] = request_id @@ -868,16 +889,14 @@ def send_message( if inc_session_id: data[SESSION_ID] = self.session_id - msg = cast_channel_pb2.CastMessage() # pylint: disable=no-member + msg = CastMessage() msg.protocol_version = msg.CASTV2_1_0 msg.source_id = self.source_id msg.destination_id = destination_id - msg.payload_type = ( - cast_channel_pb2.CastMessage.STRING # pylint: disable=no-member - ) + msg.payload_type = CastMessage.STRING msg.namespace = namespace - msg.payload_utf8 = _json_to_payload(data) + msg.payload_utf8 = _json_to_payload(data) # type: ignore[assignment] # prepend message with Big-Endian 4 byte payload size be_size = pack(">I", msg.ByteSize()) @@ -897,6 +916,9 @@ def send_message( callback_function(False, None) raise PyChromecastStopped("Socket client's thread is stopped.") if not self.connecting and not self._force_recon: + # We have a socket + assert self.socket is not None + try: if callback_function: if not no_add_request_id: @@ -907,7 +929,8 @@ def send_message( except socket.error: if callback_function: callback_function(False, None) - self._request_callbacks.pop(request_id, None) + if not no_add_request_id: + self._request_callbacks.pop(request_id, None) self._force_recon = True self.logger.info( "[%s(%s):%s] Error writing to socket.", @@ -922,13 +945,13 @@ def send_message( def send_platform_message( self, - namespace, - message, + namespace: str, + message: dict, *, - inc_session_id=False, - callback_function=None, - no_add_request_id=False, - ): + inc_session_id: bool = False, + callback_function: CallbackType | None = None, + no_add_request_id: bool = False, + ) -> None: """Helper method to send a message to the platform.""" return self.send_message( PLATFORM_DESTINATION_ID, @@ -941,13 +964,13 @@ def send_platform_message( def send_app_message( self, - namespace, - message, + namespace: str, + message: dict, *, - inc_session_id=False, - callback_function=None, - no_add_request_id=False, - ): + inc_session_id: bool = False, + callback_function: CallbackType | None = None, + no_add_request_id: bool = False, + ) -> None: """Helper method to send a message to current running app.""" if namespace not in self.app_namespaces: raise UnsupportedNamespace( @@ -955,6 +978,11 @@ def send_app_message( f"Supported are {', '.join(self.app_namespaces)}" ) + if self.destination_id is None: + raise NotConnected( + "Attempting send a message when destination_id is not set" + ) + return self.send_message( self.destination_id, namespace, @@ -964,13 +992,13 @@ def send_app_message( no_add_request_id=no_add_request_id, ) - def register_connection_listener(self, listener: ConnectionStatusListener): + def register_connection_listener(self, listener: ConnectionStatusListener) -> None: """Register a connection listener for when the socket connection changes. Listeners will be called with listener.new_connection_status(status)""" self._connection_listeners.append(listener) - def _ensure_channel_connected(self, destination_id): + def _ensure_channel_connected(self, destination_id: str) -> None: """Ensure we opened a channel to destination_id.""" if destination_id not in self._open_channels: self._open_channels.append(destination_id) @@ -994,7 +1022,7 @@ def _ensure_channel_connected(self, destination_id): no_add_request_id=True, ) - def disconnect_channel(self, destination_id): + def disconnect_channel(self, destination_id: str) -> None: """Disconnect a channel with destination_id.""" if destination_id in self._open_channels: try: @@ -1016,7 +1044,7 @@ def disconnect_channel(self, destination_id): self.handle_channel_disconnected() - def handle_channel_disconnected(self): + def handle_channel_disconnected(self) -> None: """Handles a channel being disconnected.""" for namespace in self.app_namespaces: if namespace in self._handlers: @@ -1031,15 +1059,18 @@ def handle_channel_disconnected(self): class ConnectionController(BaseController): """Controller to respond to connection messages.""" - def __init__(self): + def __init__(self) -> None: super().__init__(NS_CONNECTION) - def receive_message(self, message, data: dict): + def receive_message(self, message: CastMessage, data: dict) -> bool: """ Called when a message is received. data is message.payload_utf8 interpreted as a JSON dict. """ + if self._socket_client is None: + raise ControllerNotRegistered + if self._socket_client.is_stopped: return True @@ -1058,17 +1089,20 @@ def receive_message(self, message, data: dict): class HeartbeatController(BaseController): """Controller to respond to heartbeat messages.""" - def __init__(self): + def __init__(self) -> None: super().__init__(NS_HEARTBEAT, target_platform=True) - self.last_ping = 0 + self.last_ping = 0.0 self.last_pong = time.time() - def receive_message(self, _message, data: dict): + def receive_message(self, _message: CastMessage, data: dict) -> bool: """ Called when a heartbeat message is received. data is message.payload_utf8 interpreted as a JSON dict. """ + if self._socket_client is None: + raise ControllerNotRegistered + if self._socket_client.is_stopped: return True @@ -1094,8 +1128,11 @@ def receive_message(self, _message, data: dict): return False - def ping(self): + def ping(self) -> None: """Send a ping message.""" + if self._socket_client is None: + raise ControllerNotRegistered + self.last_ping = time.time() try: self.send_message({MESSAGE_TYPE: TYPE_PING}) @@ -1104,11 +1141,11 @@ def ping(self): "Chromecast is disconnected. Cannot ping until reconnected." ) - def reset(self): + def reset(self) -> None: """Reset expired counter.""" self.last_pong = time.time() - def is_expired(self): + def is_expired(self) -> bool: """Indicates if connection has expired.""" if time.time() - self.last_ping > HB_PING_TIME: self.ping() @@ -1116,7 +1153,7 @@ def is_expired(self): return (time.time() - self.last_pong) > HB_PING_TIME + HB_PONG_TIME -def new_socket(): +def new_socket() -> socket.socket: """ Create a new socket with OS-specific parameters