diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index b591372..9dc8f51 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -35,7 +35,7 @@ if not sys.implementation.name == "circuitpython": - from typing import Optional, Tuple + from typing import List, Optional, Tuple from circuitpython_typing.socket import ( CircuitPythonSocketType, @@ -64,15 +64,14 @@ def connect(self, address: Tuple[str, int]) -> None: try: return self._socket.connect(address, self._mode) except RuntimeError as error: - raise OSError(errno.ENOMEM) from error + raise OSError(errno.ENOMEM, str(error)) from error class _FakeSSLContext: def __init__(self, iface: InterfaceType) -> None: self._iface = iface - # pylint: disable=unused-argument - def wrap_socket( + def wrap_socket( # pylint: disable=unused-argument self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None ) -> _FakeSSLSocket: """Return the same socket""" @@ -99,7 +98,8 @@ def create_fake_ssl_context( return _FakeSSLContext(iface) -_global_socketpool = {} +_global_connection_managers = {} +_global_socketpools = {} _global_ssl_contexts = {} @@ -113,7 +113,7 @@ def get_radio_socketpool(radio): * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ class_name = radio.__class__.__name__ - if class_name not in _global_socketpool: + if class_name not in _global_socketpools: if class_name == "Radio": import ssl # pylint: disable=import-outside-toplevel @@ -151,10 +151,10 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") - _global_socketpool[class_name] = pool + _global_socketpools[class_name] = pool _global_ssl_contexts[class_name] = ssl_context - return _global_socketpool[class_name] + return _global_socketpools[class_name] def get_radio_ssl_context(radio): @@ -183,42 +183,75 @@ def __init__( ) -> None: self._socket_pool = socket_pool # Hang onto open sockets so that we can reuse them. - self._available_socket = {} - self._open_sockets = {} - - def _free_sockets(self) -> None: - available_sockets = [] - for socket, free in self._available_socket.items(): - if free: - available_sockets.append(socket) + self._available_sockets = set() + self._key_by_managed_socket = {} + self._managed_socket_by_key = {} + def _free_sockets(self, force: bool = False) -> None: + # cloning lists since items are being removed + available_sockets = list(self._available_sockets) for socket in available_sockets: self.close_socket(socket) + if force: + open_sockets = list(self._managed_socket_by_key.values()) + for socket in open_sockets: + self.close_socket(socket) - def _get_key_for_socket(self, socket): + def _get_connected_socket( # pylint: disable=too-many-arguments + self, + addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], + host: str, + port: int, + timeout: float, + is_ssl: bool, + ssl_context: Optional[SSLContextType] = None, + ): try: - return next( - key for key, value in self._open_sockets.items() if value == socket - ) - except StopIteration: - return None + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) + except (OSError, RuntimeError) as exc: + return exc + + if is_ssl: + socket = ssl_context.wrap_socket(socket, server_hostname=host) + connect_host = host + else: + connect_host = addr_info[-1][0] + socket.settimeout(timeout) # socket read timeout + + try: + socket.connect((connect_host, port)) + except (MemoryError, OSError) as exc: + socket.close() + return exc + + return socket + + @property + def available_socket_count(self) -> int: + """Get the count of freeable open sockets""" + return len(self._available_sockets) + + @property + def managed_socket_count(self) -> int: + """Get the count of open sockets""" + return len(self._managed_socket_by_key) def close_socket(self, socket: SocketType) -> None: """Close a previously opened socket.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - key = self._get_key_for_socket(socket) socket.close() - del self._available_socket[socket] - del self._open_sockets[key] + key = self._key_by_managed_socket.pop(socket) + del self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: """Mark a previously opened socket as available so it can be reused if needed.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - self._available_socket[socket] = True + self._available_sockets.add(socket) - # pylint: disable=too-many-branches,too-many-locals,too-many-statements def get_socket( self, host: str, @@ -234,10 +267,10 @@ def get_socket( if session_id: session_id = str(session_id) key = (host, port, proto, session_id) - if key in self._open_sockets: - socket = self._open_sockets[key] - if self._available_socket[socket]: - self._available_socket[socket] = False + if key in self._managed_socket_by_key: + socket = self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) return socket raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") @@ -253,64 +286,68 @@ def get_socket( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - try_count = 0 - socket = None - last_exc = None - while try_count < 2 and socket is None: - try_count += 1 - if try_count > 1: - if any( - socket - for socket, free in self._available_socket.items() - if free is True - ): - self._free_sockets() - else: - break - - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except OSError as exc: - last_exc = exc - continue - except RuntimeError as exc: - last_exc = exc - continue - - if is_ssl: - socket = ssl_context.wrap_socket(socket, server_hostname=host) - connect_host = host - else: - connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout - - try: - socket.connect((connect_host, port)) - except MemoryError as exc: - last_exc = exc - socket.close() - socket = None - except OSError as exc: - last_exc = exc - socket.close() - socket = None - - if socket is None: - raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc - - self._available_socket[socket] = False - self._open_sockets[key] = socket - return socket + first_exception = None + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + # Got an error, if there are any available sockets, free them and try again + if self.available_socket_count: + first_exception = result + self._free_sockets() + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + last_result = f", first error: {first_exception}" if first_exception else "" + raise RuntimeError( + f"Error connecting socket: {result}{last_result}" + ) from result + + self._key_by_managed_socket[result] = key + self._managed_socket_by_key[key] = result + return result # global helpers -_global_connection_manager = {} +def connection_manager_close_all( + socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False +) -> None: + """Close all open sockets for pool""" + if socket_pool: + socket_pools = [socket_pool] + else: + socket_pools = _global_connection_managers.keys() + + for pool in socket_pools: + connection_manager = _global_connection_managers.get(pool, None) + if connection_manager is None: + raise RuntimeError("SocketPool not managed") + + connection_manager._free_sockets(force=True) # pylint: disable=protected-access + + if release_references: + radio_key = None + for radio_check, pool_check in _global_socketpools.items(): + if pool == pool_check: + radio_key = radio_check + break + + if radio_key: + if radio_key in _global_socketpools: + del _global_socketpools[radio_key] + + if radio_key in _global_ssl_contexts: + del _global_ssl_contexts[radio_key] + + if pool in _global_connection_managers: + del _global_connection_managers[pool] def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """Get the ConnectionManager singleton for the given pool""" - if socket_pool not in _global_connection_manager: - _global_connection_manager[socket_pool] = ConnectionManager(socket_pool) - return _global_connection_manager[socket_pool] + if socket_pool not in _global_connection_managers: + _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) + return _global_connection_managers[socket_pool] diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index 36f4af6..e9fb842 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -24,14 +24,38 @@ # get request session requests = adafruit_requests.Session(pool, ssl_context) +connection_manager = adafruit_connection_manager.get_connection_manager(pool) +print("-" * 40) +print("Nothing yet opened") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") # make request print("-" * 40) -print(f"Fetching from {TEXT_URL}") +print(f"Fetching from {TEXT_URL} in a context handler") +with requests.get(TEXT_URL) as response: + response_text = response.text + print(f"Text Response {response_text}") + +print("-" * 40) +print("1 request, opened and freed") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") +print("-" * 40) +print(f"Fetching from {TEXT_URL} not in a context handler") response = requests.get(TEXT_URL) -response_text = response.text -response.close() -print(f"Text Response {response_text}") print("-" * 40) +print("1 request, opened but not freed") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") + +print("-" * 40) +print("Closing everything in the pool") +adafruit_connection_manager.connection_manager_close_all(pool) + +print("-" * 40) +print("Everything closed") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") diff --git a/tests/close_socket_test.py b/tests/close_socket_test.py index 957cb94..3927181 100644 --- a/tests/close_socket_test.py +++ b/tests/close_socket_test.py @@ -21,13 +21,13 @@ def test_close_socket(): socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert key in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key # validate socket is no longer tracked connection_manager.close_socket(socket) - assert socket not in connection_manager._available_socket - assert key not in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key not in connection_manager._managed_socket_by_key def test_close_socket_not_managed(): diff --git a/tests/conftest.py b/tests/conftest.py index 08d3914..ef6c96d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module(): @pytest.fixture(autouse=True) def reset_connection_manager(monkeypatch): monkeypatch.setattr( - "adafruit_connection_manager._global_socketpool", + "adafruit_connection_manager._global_connection_managers", + {}, + ) + monkeypatch.setattr( + "adafruit_connection_manager._global_socketpools", {}, ) monkeypatch.setattr( diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py new file mode 100644 index 0000000..c0fa498 --- /dev/null +++ b/tests/connection_manager_close_all_test.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get Connection Manager Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_connection_manager_close_all_all(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 + + adafruit_connection_manager.connection_manager_close_all() + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_called_once() + + +def test_connection_manager_close_all_single(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 + + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_not_called() + + +def test_connection_manager_close_all_untracked(): + mock_pool_1 = mocket.MocketPool() + with pytest.raises(RuntimeError) as context: + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert "SocketPool not managed" in str(context) + + +def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=False + ) + + assert socket_pool_wifi in adafruit_connection_manager._global_socketpools.values() + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ssl_context_wifi in adafruit_connection_manager._global_ssl_contexts.values() + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) + + +def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=True + ) + + assert ( + socket_pool_wifi not in adafruit_connection_manager._global_socketpools.values() + ) + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ( + ssl_context_wifi + not in adafruit_connection_manager._global_ssl_contexts.values() + ) + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + not in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py index 93f34eb..666a072 100644 --- a/tests/free_socket_test.py +++ b/tests/free_socket_test.py @@ -16,20 +16,24 @@ def test_free_socket(): mock_pool.socket.return_value = mock_socket_1 connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is False - assert key in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket) - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is True - assert key in connection_manager._open_sockets + assert socket in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 1 def test_free_socket_not_managed(): @@ -54,46 +58,31 @@ def test_free_sockets(): ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") assert socket_1 == mock_socket_1 - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is False + assert socket_1 not in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") assert socket_2 == mock_socket_2 + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket_1) - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is True + assert socket_1 in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 1 # validate socket is no longer tracked connection_manager._free_sockets() - assert socket_1 not in connection_manager._available_socket - assert socket_2 in connection_manager._available_socket + assert socket_1 not in connection_manager._available_sockets + assert socket_2 not in connection_manager._available_sockets mock_socket_1.close.assert_called_once() - - -def test_get_key_for_socket(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - mock_pool.socket.return_value = mock_socket_1 - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate tracked socket has correct key - socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - key = (mocket.MOCK_HOST_1, 80, "http:", None) - assert connection_manager._get_key_for_socket(socket) == key - - -def test_get_key_for_socket_not_managed(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate untracked socket has no key - assert connection_manager._get_key_for_socket(mock_socket_1) is None + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 426f785..5c43ad1 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -19,6 +19,7 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument radio = mocket.MockRadio.Radio() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert isinstance(socket_pool, mocket.MocketPool) + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument @@ -27,6 +28,7 @@ def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument @@ -36,6 +38,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument with mock.patch("sys.implementation", return_value=[9, 0, 0]): socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_unsupported(): @@ -52,22 +55,25 @@ def test_get_radio_socketpool_returns_same_one( # pylint: disable=unused-argume socket_pool_1 = adafruit_connection_manager.get_radio_socketpool(radio) socket_pool_2 = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool_1 == socket_pool_2 + assert socket_pool_1 in adafruit_connection_manager._global_socketpools.values() def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, ssl.SSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument adafruit_esp32spi_socket_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument @@ -75,8 +81,9 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_unsupported(): @@ -90,6 +97,7 @@ def test_get_radio_ssl_context_returns_same_one( # pylint: disable=unused-argum circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts_1 = adafruit_connection_manager.get_radio_ssl_context(radio) - ssl_contexts_2 = adafruit_connection_manager.get_radio_ssl_context(radio) - assert ssl_contexts_1 == ssl_contexts_2 + ssl_context_1 = adafruit_connection_manager.get_radio_ssl_context(radio) + ssl_context_2 = adafruit_connection_manager.get_radio_ssl_context(radio) + assert ssl_context_1 == ssl_context_2 + assert ssl_context_1 in adafruit_connection_manager._global_ssl_contexts.values() diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index ea252cc..6be48f0 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -213,7 +213,7 @@ def test_get_socket_runtime_error_ties_again_only_once(): # try to get a socket that returns a RuntimeError twice with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2" in str(context) + assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -242,7 +242,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.return_value = mock_socket_1 - mock_socket_1.connect.side_effect = RuntimeError("RuntimeError ") + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError") radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) @@ -252,4 +252,4 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: 12" in str(context) + assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)