Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new SocketPool for ESP32SPI and WIZNET5K #11

Merged
merged 10 commits into from
Apr 30, 2024
40 changes: 28 additions & 12 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self.recv = socket.recv
self.close = socket.close
self.recv_into = socket.recv_into
# For sockets that come from software socketpools (like the esp32api), they track
# the interface and socket pool. We need to make sure the clones do as well
self._interface = getattr(socket, "_interface", None)
self._socket_pool = getattr(socket, "_socket_pool", None)

def connect(self, address: Tuple[str, int]) -> None:
"""Connect wrapper to add non-standard mode parameter"""
Expand Down Expand Up @@ -94,7 +98,10 @@ def create_fake_ssl_context(
* `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor
<https://www.adafruit.com/product/4264>`_
"""
socket_pool.set_interface(iface)
if hasattr(socket_pool, "set_interface"):
# this is to manually support legacy hardware like the fona
socket_pool.set_interface(iface)

return _FakeSSLContext(iface)


Expand All @@ -104,6 +111,13 @@ def create_fake_ssl_context(
_global_ssl_contexts = {}


def _get_radio_hash_key(radio):
try:
return hash(radio)
except TypeError:
return radio.__class__.__name__
dhalbert marked this conversation as resolved.
Show resolved Hide resolved


def get_radio_socketpool(radio):
"""Helper to get a socket pool for common boards.

Expand All @@ -113,8 +127,9 @@ def get_radio_socketpool(radio):
* Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift)
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
if class_name not in _global_socketpools:
key = _get_radio_hash_key(radio)
if key not in _global_socketpools:
class_name = radio.__class__.__name__
if class_name == "Radio":
import ssl # pylint: disable=import-outside-toplevel

Expand All @@ -124,12 +139,15 @@ def get_radio_socketpool(radio):
ssl_context = ssl.create_default_context()

elif class_name == "ESP_SPIcontrol":
import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel
import adafruit_esp32spi.adafruit_esp32spi_socketpool as socketpool # pylint: disable=import-outside-toplevel

pool = socketpool.SocketPool(radio)
ssl_context = create_fake_ssl_context(pool, radio)

elif class_name == "WIZNET5K":
import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel
import adafruit_wiznet5k.adafruit_wiznet5k_socketpool as socketpool # pylint: disable=import-outside-toplevel

pool = socketpool.SocketPool(radio)

# Note: At this time, SSL/TLS connections are not supported by older
# versions of the Wiznet5k library or on boards withouut the ssl module
Expand All @@ -141,7 +159,6 @@ def get_radio_socketpool(radio):
import ssl # pylint: disable=import-outside-toplevel

ssl_context = ssl.create_default_context()
pool.set_interface(radio)
except ImportError:
# if SSL not on board, default to fake_ssl_context
pass
Expand All @@ -152,11 +169,11 @@ def get_radio_socketpool(radio):
else:
raise AttributeError(f"Unsupported radio class: {class_name}")

_global_key_by_socketpool[pool] = class_name
_global_socketpools[class_name] = pool
_global_ssl_contexts[class_name] = ssl_context
_global_key_by_socketpool[pool] = key
_global_socketpools[key] = pool
_global_ssl_contexts[key] = ssl_context

return _global_socketpools[class_name]
return _global_socketpools[key]


def get_radio_ssl_context(radio):
Expand All @@ -168,9 +185,8 @@ def get_radio_ssl_context(radio):
* Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift)
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
get_radio_socketpool(radio)
return _global_ssl_contexts[class_name]
return _global_ssl_contexts[_get_radio_hash_key(radio)]


# main class
Expand Down
66 changes: 46 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,31 @@
import pytest


# pylint: disable=unused-argument
def set_interface(iface):
"""Helper to set the global internet interface"""
class SocketPool:
name = None

def __init__(self, *args, **kwargs):
pass

@property
def __name__(self):
return self.name


class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods
name = "adafruit_esp32spi_socketpool"


class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods
name = "adafruit_wiznet5k_socketpool"
SOCK_STREAM = 0x21


class WIZNET5K_With_SSL_SocketPool(
SocketPool
): # pylint: disable=too-few-public-methods
name = "adafruit_wiznet5k_socketpool"
SOCK_STREAM = 0x1


@pytest.fixture
Expand All @@ -25,41 +47,45 @@ def circuitpython_socketpool_module():


@pytest.fixture
def adafruit_esp32spi_socket_module():
def adafruit_esp32spi_socketpool_module():
esp32spi_module = type(sys)("adafruit_esp32spi")
esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket")
esp32spi_socket_module.set_interface = set_interface
esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool")
esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool
sys.modules["adafruit_esp32spi"] = esp32spi_module
sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module
sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = (
esp32spi_socket_module
)
yield
del sys.modules["adafruit_esp32spi"]
del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"]
del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"]


@pytest.fixture
def adafruit_wiznet5k_socket_module():
def adafruit_wiznet5k_socketpool_module():
wiznet5k_module = type(sys)("adafruit_wiznet5k")
wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket")
wiznet5k_socket_module.set_interface = set_interface
wiznet5k_socket_module.SOCK_STREAM = 0x21
wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool")
wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool
sys.modules["adafruit_wiznet5k"] = wiznet5k_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = (
wiznet5k_socketpool_module
)
yield
del sys.modules["adafruit_wiznet5k"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"]


@pytest.fixture
def adafruit_wiznet5k_with_ssl_socket_module():
def adafruit_wiznet5k_with_ssl_socketpool_module():
wiznet5k_module = type(sys)("adafruit_wiznet5k")
wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket")
wiznet5k_socket_module.set_interface = set_interface
wiznet5k_socket_module.SOCK_STREAM = 1
wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool")
wiznet5k_socketpool_module.SocketPool = WIZNET5K_With_SSL_SocketPool
sys.modules["adafruit_wiznet5k"] = wiznet5k_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = (
wiznet5k_socketpool_module
)
yield
del sys.modules["adafruit_wiznet5k"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"]


@pytest.fixture(autouse=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/connection_manager_close_all_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_connection_manager_close_all_untracked():


def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_connection_manager_close_all_single_release_references_false( # pylint


def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down
2 changes: 1 addition & 1 deletion tests/get_connection_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_get_connection_manager():


def test_different_connection_manager_different_pool( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down
24 changes: 18 additions & 6 deletions tests/get_radio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
import adafruit_connection_manager


def test__get_radio_hash_key():
radio = mocket.MockRadio.Radio()
assert adafruit_connection_manager._get_radio_hash_key(radio) == hash(radio)


def test__get_radio_hash_key_not_hashable():
radio = mocket.MockRadio.Radio()

with mock.patch("builtins.hash", side_effect=TypeError()):
assert adafruit_connection_manager._get_radio_hash_key(radio) == "Radio"


def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument
circuitpython_socketpool_module,
):
Expand All @@ -23,21 +35,21 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument


def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
radio = mocket.MockRadio.ESP_SPIcontrol()
socket_pool = adafruit_connection_manager.get_radio_socketpool(radio)
assert socket_pool.__name__ == "adafruit_esp32spi_socket"
assert socket_pool.__name__ == "adafruit_esp32spi_socketpool"
assert socket_pool in adafruit_connection_manager._global_socketpools.values()


def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", return_value=[9, 0, 0]):
socket_pool = adafruit_connection_manager.get_radio_socketpool(radio)
assert socket_pool.__name__ == "adafruit_wiznet5k_socket"
assert socket_pool.__name__ == "adafruit_wiznet5k_socketpool"
assert socket_pool in adafruit_connection_manager._global_socketpools.values()


Expand Down Expand Up @@ -68,7 +80,7 @@ def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument


def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
radio = mocket.MockRadio.ESP_SPIcontrol()
ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio)
Expand All @@ -77,7 +89,7 @@ def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument


def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", return_value=[9, 0, 0]):
Expand Down
4 changes: 2 additions & 2 deletions tests/get_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_get_socket_runtime_error_ties_again_only_once():


def test_fake_ssl_context_connect( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand All @@ -237,7 +237,7 @@ def test_fake_ssl_context_connect( # pylint: disable=unused-argument


def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand Down
6 changes: 3 additions & 3 deletions tests/ssl_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def test_connect_esp32spi_https( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_connect_wifi_https( # pylint: disable=unused-argument


def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
mock_pool = mocket.MocketPool()
radio = mocket.MockRadio.WIZNET5K()
Expand All @@ -66,7 +66,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen


def test_connect_wiznet5k_https_supported( # pylint: disable=unused-argument
adafruit_wiznet5k_with_ssl_socket_module,
adafruit_wiznet5k_with_ssl_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", (None, WIZNET5K_SSL_SUPPORT_VERSION)):
Expand Down
Loading