Skip to content

Commit

Permalink
Fix support for Python 3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreF committed Jan 1, 2024
1 parent 4102863 commit 148d7c0
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 129 deletions.
193 changes: 100 additions & 93 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,34 @@
else:
EAGAIN = errno.EAGAIN

if typing.TYPE_CHECKING:
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict

from typing_extensions import Literal

class _InPacket(TypedDict):
command: int
have_remaining: int
remaining_count: typing.List[int]
remaining_mult: int
remaining_length: int
packet: bytearray
to_process: int
pos: int

class _OutPacket(TypedDict):
command: int
mid: int
qos: int
pos: int
to_process: int
packet: bytes
info: typing.Optional["MQTTMessageInfo"]


MQTTv31 = 3
MQTTv311 = 4
MQTTv5 = 5
Expand Down Expand Up @@ -191,7 +219,7 @@ class MQTTErrorCode(enum.IntEnum):
# * None is converted to a zero-length payload (i.e. b"")
PayloadType = typing.Union[str, bytes, bytearray, int, float, None]

HTTPHeader = dict[str, str]
HTTPHeader = typing.Dict[str, str]
WebSocketHeaders = typing.Union[typing.Callable[[HTTPHeader], HTTPHeader], HTTPHeader]

SocketLike = typing.Union[socket.socket, "ssl.SSLSocket", "WebsocketWrapper"]
Expand All @@ -204,9 +232,12 @@ class MQTTErrorCode(enum.IntEnum):
CallbackOnConnectFail = typing.Callable[["Client", typing.Any], None]
CallbackOnDisconnect = typing.Union[
typing.Callable[
["Client", typing.Any, dict[str, typing.Any], ReasonCodes, Properties], None
["Client", typing.Any, typing.Dict[str, typing.Any], ReasonCodes, Properties],
None,
],
typing.Callable[
["Client", typing.Any, typing.Dict[str, typing.Any], MQTTErrorCode], None
],
typing.Callable[["Client", typing.Any, dict[str, typing.Any], MQTTErrorCode], None],
]
CallbackOnLog = typing.Callable[["Client", typing.Any, int, str], None]
CallbackOnMessage = typing.Callable[["Client", typing.Any, "MQTTMessage"], None]
Expand All @@ -215,9 +246,9 @@ class MQTTErrorCode(enum.IntEnum):
CallbackOnSocket = typing.Callable[["Client", typing.Any, SocketLike], None]
CallbackOnSubscribe = typing.Union[
typing.Callable[
["Client", typing.Any, Properties, list[ReasonCodes], Properties], None
["Client", typing.Any, Properties, typing.List[ReasonCodes], Properties], None
],
typing.Callable[["Client", typing.Any, int, tuple[int, ...]], None],
typing.Callable[["Client", typing.Any, int, typing.Tuple[int, ...]], None],
]
CallbackOnUnsubscribe = typing.Union[
typing.Callable[["Client", typing.Any, Properties, ReasonCodes], None],
Expand All @@ -228,27 +259,6 @@ class MQTTErrorCode(enum.IntEnum):
_socket = socket


class _InPacket(typing.TypedDict):
command: int
have_remaining: int
remaining_count: list[int]
remaining_mult: int
remaining_length: int
packet: bytearray
to_process: int
pos: int


class _OutPacket(typing.TypedDict):
command: int
mid: int
qos: int
pos: int
to_process: int
packet: bytes
info: typing.Optional["MQTTMessageInfo"]


class WebsocketConnectionError(ValueError):
pass

Expand Down Expand Up @@ -342,7 +352,7 @@ def topic_matches_sub(sub: str, topic: str) -> bool:
return False


def _socketpair_compat() -> tuple[socket.socket, socket.socket]:
def _socketpair_compat() -> typing.Tuple[socket.socket, socket.socket]:
"""TCP/IP socketpair including Windows support"""
listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP)
listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
Expand Down Expand Up @@ -665,19 +675,18 @@ def __init__(

self._username: typing.Optional[bytes] = None
self._password: typing.Optional[bytes] = None
self._in_packet = _InPacket(
{
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}
)
self._out_packet: collections.deque[_OutPacket] = collections.deque()
self._in_packet: "_InPacket" = {
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}

self._out_packet: typing.Deque["_OutPacket"] = collections.deque()
self._last_msg_in = time_func()
self._last_msg_out = time_func()
self._reconnect_min_delay = 1
Expand Down Expand Up @@ -1029,7 +1038,7 @@ def connect(
bind_address: str = "",
bind_port: int = 0,
clean_start: typing.Union[
bool, typing.Literal[3]
bool, "Literal[3]"
] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore
properties: typing.Optional[Properties] = None,
) -> MQTTErrorCode:
Expand Down Expand Up @@ -1071,7 +1080,7 @@ def connect_srv(
bind_address: str = "",
bind_port: int = 0,
clean_start: typing.Union[
bool, typing.Literal[3]
bool, "Literal[3]"
] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore
properties: typing.Optional[Properties] = None,
) -> MQTTErrorCode:
Expand Down Expand Up @@ -1134,7 +1143,7 @@ def connect_async(
bind_address: str = "",
bind_port: int = 0,
clean_start: typing.Union[
bool, typing.Literal[3]
bool, "Literal[3]"
] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore
properties: typing.Optional[Properties] = None,
) -> None:
Expand Down Expand Up @@ -1195,18 +1204,16 @@ def reconnect(self) -> MQTTErrorCode:
if self._port <= 0:
raise ValueError("Invalid port number.")

self._in_packet = _InPacket(
{
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}
)
self._in_packet = {
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}

self._out_packet = collections.deque()

Expand Down Expand Up @@ -1613,15 +1620,15 @@ def subscribe(
self,
topic: typing.Union[
str,
tuple[str, int],
tuple[str, SubscribeOptions],
list[tuple[str, int]],
list[tuple[str, SubscribeOptions]],
typing.Tuple[str, int],
typing.Tuple[str, SubscribeOptions],
typing.List[typing.Tuple[str, int]],
typing.List[typing.Tuple[str, SubscribeOptions]],
],
qos: int = 0,
options: typing.Optional[SubscribeOptions] = None,
properties: typing.Optional[Properties] = None,
) -> tuple[MQTTErrorCode, typing.Optional[int]]:
) -> typing.Tuple[MQTTErrorCode, typing.Optional[int]]:
"""Subscribe the client to one or more topics.
This function may be called in three different ways (and a further three for MQTT v5.0):
Expand Down Expand Up @@ -1767,7 +1774,7 @@ def subscribe(

def unsubscribe(
self, topic: str, properties: typing.Optional[Properties] = None
) -> tuple[MQTTErrorCode, typing.Optional[int]]:
) -> typing.Tuple[MQTTErrorCode, typing.Optional[int]]:
"""Unsubscribe the client from one or more topics.
topic: A single string, or list of strings that are the subscription
Expand Down Expand Up @@ -2836,18 +2843,16 @@ def _packet_read(self) -> MQTTErrorCode:
rc = self._packet_handle()

# Free data and reset values
self._in_packet = _InPacket(
{
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}
)
self._in_packet = {
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": bytearray(b""),
"to_process": 0,
"pos": 0,
}

with self._msgtime_mutex:
self._last_msg_in = time_func()
Expand Down Expand Up @@ -3317,9 +3322,11 @@ def _send_disconnect(
def _send_subscribe(
self,
dup: int,
topics: typing.Sequence[tuple[bytes, typing.Union[SubscribeOptions, int]]],
topics: typing.Sequence[
typing.Tuple[bytes, typing.Union[SubscribeOptions, int]]
],
properties: typing.Optional[Properties] = None,
) -> tuple[MQTTErrorCode, int]:
) -> typing.Tuple[MQTTErrorCode, int]:
remaining_length = 2
if self._protocol == MQTTv5:
if properties is None:
Expand Down Expand Up @@ -3359,9 +3366,9 @@ def _send_subscribe(
def _send_unsubscribe(
self,
dup: int,
topics: list[bytes],
topics: typing.List[bytes],
properties: typing.Optional[Properties] = None,
) -> tuple[MQTTErrorCode, int]:
) -> typing.Tuple[MQTTErrorCode, int]:
remaining_length = 2
if self._protocol == MQTTv5:
if properties is None:
Expand Down Expand Up @@ -3471,17 +3478,15 @@ def _packet_queue(
qos: int,
info: typing.Optional[MQTTMessageInfo] = None,
) -> MQTTErrorCode:
mpkt = _OutPacket(
{
"command": command,
"mid": mid,
"qos": qos,
"pos": 0,
"to_process": len(packet),
"packet": packet,
"info": info,
}
)
mpkt: "_OutPacket" = {
"command": command,
"mid": mid,
"qos": qos,
"pos": 0,
"to_process": len(packet),
"packet": packet,
"info": info,
}

self._out_packet.append(mpkt)

Expand Down Expand Up @@ -3702,7 +3707,7 @@ def _handle_connack(self) -> MQTTErrorCode:
else:
return MQTTErrorCode.MQTT_ERR_PROTOCOL

def _handle_disconnect(self) -> typing.Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]:
def _handle_disconnect(self) -> "Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]":
packet_type = DISCONNECT >> 4
reasonCode = properties = None
if self._in_packet["remaining_length"] > 2:
Expand All @@ -3719,7 +3724,7 @@ def _handle_disconnect(self) -> typing.Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]:

return MQTTErrorCode.MQTT_ERR_SUCCESS

def _handle_suback(self) -> typing.Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]:
def _handle_suback(self) -> "Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]":
self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK")
pack_format = f"!H{len(self._in_packet['packet']) - 2}s"
(mid, packet) = struct.unpack(pack_format, self._in_packet["packet"])
Expand Down Expand Up @@ -3957,7 +3962,9 @@ def _handle_unsuback(self) -> MQTTErrorCode:
for c in packet[props_len:]:
reasoncodes_list.append(ReasonCodes(UNSUBACK >> 4, identifier=c))

reasoncodes: typing.Union[ReasonCodes, list[ReasonCodes]] = reasoncodes_list
reasoncodes: typing.Union[
ReasonCodes, typing.List[ReasonCodes]
] = reasoncodes_list
if len(reasoncodes_list) == 1:
reasoncodes = reasoncodes_list[0]

Expand Down Expand Up @@ -4031,7 +4038,7 @@ def _do_on_publish(self, mid: int) -> MQTTErrorCode:
return MQTTErrorCode.MQTT_ERR_SUCCESS

def _handle_pubackcomp(
self, cmd: typing.Union[typing.Literal["PUBACK"], typing.Literal["PUBCOMP"]]
self, cmd: typing.Union["Literal['PUBACK']", "Literal['PUBCOMP']"]
) -> MQTTErrorCode:
if self._protocol == MQTTv5:
if self._in_packet["remaining_length"] < 2:
Expand Down Expand Up @@ -4156,7 +4163,7 @@ def check(t, a) -> bool: # type: ignore[no-untyped-def]
else:
return False

def _get_proxy(self) -> typing.Optional[dict[str, typing.Any]]:
def _get_proxy(self) -> typing.Optional[typing.Dict[str, typing.Any]]:
if socks is None:
return None

Expand Down
Loading

0 comments on commit 148d7c0

Please sign in to comment.