diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 7ffb723..b32f2c2 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -1047,7 +1047,7 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]: rcs = [] while True: - rc = self._wait_for_msg() + rc = self._wait_for_msg(timeout=timeout) if rc is not None: rcs.append(rc) if self.get_monotonic_time() - stamp > timeout: @@ -1056,11 +1056,13 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]: return rcs if rcs else None - def _wait_for_msg(self) -> Optional[int]: + def _wait_for_msg(self, timeout: Optional[float] = None) -> Optional[int]: # pylint: disable = too-many-return-statements """Reads and processes network events. Return the packet type or None if there is nothing to be received. + + :param float timeout: return after this timeout, in seconds. """ # CPython socket module contains a timeout attribute if hasattr(self._socket_pool, "timeout"): @@ -1070,7 +1072,7 @@ def _wait_for_msg(self) -> Optional[int]: return None else: # socketpool, esp32spi try: - res = self._sock_exact_recv(1) + res = self._sock_exact_recv(1, timeout=timeout) except OSError as error: if error.errno in (errno.ETIMEDOUT, errno.EAGAIN): # raised by a socket timeout if 0 bytes were present @@ -1139,7 +1141,9 @@ def _decode_remaining_length(self) -> int: return n sh += 7 - def _sock_exact_recv(self, bufsize: int) -> bytearray: + def _sock_exact_recv( + self, bufsize: int, timeout: Optional[float] = None + ) -> bytearray: """Reads _exact_ number of bytes from the connected socket. Will only return bytearray with the exact number of bytes requested. @@ -1150,6 +1154,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray: bytes is returned or trigger a timeout exception. :param int bufsize: number of bytes to receive + :param float timeout: timeout, in seconds. Defaults to keep_alive :return: byte array """ stamp = self.get_monotonic_time() @@ -1161,7 +1166,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray: to_read = bufsize - recv_len if to_read < 0: raise MMQTTException(f"negative number of bytes to read: {to_read}") - read_timeout = self.keep_alive + read_timeout = timeout if timeout is not None else self.keep_alive mv = mv[recv_len:] while to_read > 0: recv_len = self._sock.recv_into(mv, to_read) diff --git a/tests/test_loop.py b/tests/test_loop.py index 6a762fd..ccca924 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -21,9 +21,9 @@ class Loop(TestCase): INITIAL_RCS_VAL = 42 rcs_val = INITIAL_RCS_VAL - def fake_wait_for_msg(self): + def fake_wait_for_msg(self, timeout=1): """_wait_for_msg() replacement. Sleeps for 1 second and returns an integer.""" - time.sleep(1) + time.sleep(timeout) retval = self.rcs_val self.rcs_val += 1 return retval @@ -62,7 +62,7 @@ def test_loop_basic(self) -> None: # Check the return value. assert rcs is not None - assert len(rcs) > 1 + assert len(rcs) >= 1 expected_rc = self.INITIAL_RCS_VAL for ret_code in rcs: assert ret_code == expected_rc