Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close all and counts #13

Merged
merged 5 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def create_fake_ssl_context(
return _FakeSSLContext(iface)


_global_socketpool = {}
_global_connection_managers = {}
dhalbert marked this conversation as resolved.
Show resolved Hide resolved
_global_socketpools = {}
_global_ssl_contexts = {}


Expand All @@ -113,7 +114,7 @@ def get_radio_socketpool(radio):
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
if class_name not in _global_socketpool:
if class_name not in _global_socketpools:
if class_name == "Radio":
import ssl # pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -151,10 +152,10 @@ def get_radio_socketpool(radio):
else:
raise AttributeError(f"Unsupported radio class: {class_name}")

_global_socketpool[class_name] = pool
_global_socketpools[class_name] = pool
_global_ssl_contexts[class_name] = ssl_context

return _global_socketpool[class_name]
return _global_socketpools[class_name]


def get_radio_ssl_context(radio):
Expand Down Expand Up @@ -186,10 +187,10 @@ def __init__(
self._available_socket = {}
self._open_sockets = {}

def _free_sockets(self) -> None:
def _free_sockets(self, force: bool = False) -> None:
available_sockets = []
for socket, free in self._available_socket.items():
if free:
if free or force:
available_sockets.append(socket)

for socket in available_sockets:
Expand All @@ -203,6 +204,18 @@ def _get_key_for_socket(self, socket):
except StopIteration:
return None

@property
def open_sockets(self) -> int:
"""Get the count of open sockets"""
return len(self._open_sockets)
dhalbert marked this conversation as resolved.
Show resolved Hide resolved

@property
def freeable_open_sockets(self) -> int:
"""Get the count of freeable open sockets"""
return len(
[socket for socket, free in self._available_socket.items() if free is True]
)

dhalbert marked this conversation as resolved.
Show resolved Hide resolved
def close_socket(self, socket: SocketType) -> None:
"""Close a previously opened socket."""
if socket not in self._open_sockets.values():
Expand Down Expand Up @@ -306,11 +319,25 @@ def get_socket(
# global helpers


_global_connection_manager = {}
def connection_manager_close_all(
socket_pool: Optional[SocketpoolModuleType] = None,
) -> None:
"""Close all open sockets for pool"""
if socket_pool:
keys = [socket_pool]
else:
keys = _global_connection_managers.keys()

for key in keys:
connection_manager = _global_connection_managers.get(key, None)
if connection_manager is None:
raise RuntimeError("SocketPool not managed")

connection_manager._free_sockets(force=True) # pylint: disable=protected-access


def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""Get the ConnectionManager singleton for the given pool"""
if socket_pool not in _global_connection_manager:
_global_connection_manager[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_manager[socket_pool]
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
return _global_connection_managers[socket_pool]
32 changes: 28 additions & 4 deletions examples/connectionmanager_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,38 @@

# get request session
requests = adafruit_requests.Session(pool, ssl_context)
connection_manager = adafruit_connection_manager.get_connection_manager(pool)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of this is so:

----------------------------------------
Nothing yet opened
Open Sockets: 0
Freeable Open Sockets: 0
----------------------------------------
Fetching from http://wifitest.adafruit.com/testwifi/index.html in a context handler
Text Response This is a test of Adafruit WiFi!
If you can read this, its working :)
----------------------------------------
1 request, opened and freed
Open Sockets: 1
Freeable Open Sockets: 1
----------------------------------------
Fetching from http://wifitest.adafruit.com/testwifi/index.html not in a context handler
----------------------------------------
1 request, opened but not freed
Open Sockets: 1
Freeable Open Sockets: 0
----------------------------------------
Closing everything in the pool
----------------------------------------
Everything closed
Open Sockets: 0
Freeable Open Sockets: 0

print("-" * 40)
print("Nothing yet opened")
print(f"Open Sockets: {connection_manager.open_sockets}")
print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}")

# make request
print("-" * 40)
print(f"Fetching from {TEXT_URL}")
print(f"Fetching from {TEXT_URL} in a context handler")
with requests.get(TEXT_URL) as response:
response_text = response.text
print(f"Text Response {response_text}")

print("-" * 40)
print("1 request, opened and freed")
print(f"Open Sockets: {connection_manager.open_sockets}")
print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}")

print("-" * 40)
print(f"Fetching from {TEXT_URL} not in a context handler")
response = requests.get(TEXT_URL)
response_text = response.text
response.close()

print(f"Text Response {response_text}")
print("-" * 40)
print("1 request, opened but not freed")
print(f"Open Sockets: {connection_manager.open_sockets}")
print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}")

print("-" * 40)
print("Closing everything in the pool")
adafruit_connection_manager.connection_manager_close_all(pool)

print("-" * 40)
print("Everything closed")
print(f"Open Sockets: {connection_manager.open_sockets}")
print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}")
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module():
@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpool",
"adafruit_connection_manager._global_connection_managers",
{},
)
monkeypatch.setattr(
"adafruit_connection_manager._global_socketpools",
{},
)
monkeypatch.setattr(
Expand Down
87 changes: 87 additions & 0 deletions tests/connection_manager_close_all_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries
#
# SPDX-License-Identifier: Unlicense

""" Get Connection Manager Tests """

import mocket
import pytest

import adafruit_connection_manager


def test_connection_manager_close_all_all():
mock_pool_1 = mocket.MocketPool()
mock_pool_2 = mocket.MocketPool()
assert mock_pool_1 != mock_pool_2

connection_manager_1 = adafruit_connection_manager.get_connection_manager(
mock_pool_1
)
assert connection_manager_1.open_sockets == 0
assert connection_manager_1.freeable_open_sockets == 0
connection_manager_2 = adafruit_connection_manager.get_connection_manager(
mock_pool_2
)
assert connection_manager_2.open_sockets == 0
assert connection_manager_2.freeable_open_sockets == 0
assert len(adafruit_connection_manager._global_connection_managers) == 2

socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert connection_manager_1.open_sockets == 1
assert connection_manager_1.freeable_open_sockets == 0
assert connection_manager_2.open_sockets == 0
assert connection_manager_2.freeable_open_sockets == 0
socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert connection_manager_2.open_sockets == 1
assert connection_manager_2.freeable_open_sockets == 0

adafruit_connection_manager.connection_manager_close_all()
assert connection_manager_1.open_sockets == 0
assert connection_manager_1.freeable_open_sockets == 0
assert connection_manager_2.open_sockets == 0
assert connection_manager_2.freeable_open_sockets == 0
socket_1.close.assert_called_once()
socket_2.close.assert_called_once()


def test_connection_manager_close_all_single():
mock_pool_1 = mocket.MocketPool()
mock_pool_2 = mocket.MocketPool()
assert mock_pool_1 != mock_pool_2

connection_manager_1 = adafruit_connection_manager.get_connection_manager(
mock_pool_1
)
assert connection_manager_1.open_sockets == 0
assert connection_manager_1.freeable_open_sockets == 0
connection_manager_2 = adafruit_connection_manager.get_connection_manager(
mock_pool_2
)
assert connection_manager_2.open_sockets == 0
assert connection_manager_2.freeable_open_sockets == 0
assert len(adafruit_connection_manager._global_connection_managers) == 2

socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert connection_manager_1.open_sockets == 1
assert connection_manager_1.freeable_open_sockets == 0
assert connection_manager_2.open_sockets == 0
assert connection_manager_2.freeable_open_sockets == 0
socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert connection_manager_2.open_sockets == 1
assert connection_manager_2.freeable_open_sockets == 0

adafruit_connection_manager.connection_manager_close_all(mock_pool_1)
assert connection_manager_1.open_sockets == 0
assert connection_manager_1.freeable_open_sockets == 0
assert connection_manager_2.open_sockets == 1
assert connection_manager_2.freeable_open_sockets == 0
socket_1.close.assert_called_once()
socket_2.close.assert_not_called()


def test_connection_manager_close_all_untracked():
mock_pool_1 = mocket.MocketPool()
with pytest.raises(RuntimeError) as context:
adafruit_connection_manager.connection_manager_close_all(mock_pool_1)
assert "SocketPool not managed" in str(context)
16 changes: 16 additions & 0 deletions tests/free_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_free_socket():
mock_pool.socket.return_value = mock_socket_1

connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
assert connection_manager.open_sockets == 0
assert connection_manager.freeable_open_sockets == 0

# validate socket is tracked and not available
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
Expand All @@ -24,12 +26,16 @@ def test_free_socket():
assert socket in connection_manager._available_socket
assert connection_manager._available_socket[socket] is False
assert key in connection_manager._open_sockets
assert connection_manager.open_sockets == 1
assert connection_manager.freeable_open_sockets == 0

# validate socket is tracked and is available
connection_manager.free_socket(socket)
assert socket in connection_manager._available_socket
assert connection_manager._available_socket[socket] is True
assert key in connection_manager._open_sockets
assert connection_manager.open_sockets == 1
assert connection_manager.freeable_open_sockets == 1


def test_free_socket_not_managed():
Expand All @@ -54,26 +60,36 @@ def test_free_sockets():
]

connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
assert connection_manager.open_sockets == 0
assert connection_manager.freeable_open_sockets == 0

# validate socket is tracked and not available
socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert socket_1 == mock_socket_1
assert socket_1 in connection_manager._available_socket
assert connection_manager._available_socket[socket_1] is False
assert connection_manager.open_sockets == 1
assert connection_manager.freeable_open_sockets == 0

socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:")
assert socket_2 == mock_socket_2
assert connection_manager.open_sockets == 2
assert connection_manager.freeable_open_sockets == 0

# validate socket is tracked and is available
connection_manager.free_socket(socket_1)
assert socket_1 in connection_manager._available_socket
assert connection_manager._available_socket[socket_1] is True
assert connection_manager.open_sockets == 2
assert connection_manager.freeable_open_sockets == 1

# validate socket is no longer tracked
connection_manager._free_sockets()
assert socket_1 not in connection_manager._available_socket
assert socket_2 in connection_manager._available_socket
mock_socket_1.close.assert_called_once()
assert connection_manager.open_sockets == 1
assert connection_manager.freeable_open_sockets == 0


def test_get_key_for_socket():
Expand Down
Loading