diff --git a/CHANGELOG.md b/CHANGELOG.md index 593d5f9..de32e28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## 0.7.0 (2020-XX-XX) +- Fixed issue when threads are blocked on close while reading, #11 by @richard78917 + ## 0.6.0 (2020-01-08) - Added Python 3.8.* support, #10 diff --git a/pynats/client.py b/pynats/client.py index d35d023..19951b3 100644 --- a/pynats/client.py +++ b/pynats/client.py @@ -18,7 +18,7 @@ import pkg_resources -from pynats.exceptions import NATSInvalidResponse, NATSUnexpectedResponse +from pynats.exceptions import NATSInvalidResponse, NATSUnexpectedResponse, NATSSocketError from pynats.nuid import NUID __all__ = ("NATSSubscription", "NATSMessage", "NATSClient") @@ -154,6 +154,7 @@ def connect(self) -> None: self._recv(INFO_RE) def close(self) -> None: + self._socket.shutdown(socket.SHUT_RDWR) self._socket_file.close() self._socket.close() @@ -279,7 +280,11 @@ def _readline(self, *, size: int = None) -> bytes: read = io.BytesIO() while True: - line = cast(bytes, self._socket_file.readline()) + raw_bytes = self._socket_file.readline() + if not raw_bytes: + raise NATSSocketError("unable to read from socket") + + line = cast(bytes, raw_bytes) read.write(line) if size is not None: diff --git a/pynats/exceptions.py b/pynats/exceptions.py index e053035..9eb1080 100644 --- a/pynats/exceptions.py +++ b/pynats/exceptions.py @@ -15,3 +15,8 @@ class NATSInvalidResponse(NATSError): def __init__(self, line: bytes, *args, **kwargs) -> None: self.line = line super().__init__() + +class NATSSocketError(NATSError): + def __init__(self, line: bytes, *args, **kwargs) -> None: + self.line = line + super().__init__() diff --git a/tests/test_client.py b/tests/test_client.py index d4fadbe..d796533 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ import pytest from pynats import NATSClient +from pynats.exceptions import NATSSocketError @pytest.fixture @@ -179,3 +180,23 @@ def test_request_timeout(nats_url): with NATSClient(nats_url, socket_timeout=2) as client: with pytest.raises(socket.timeout): client.request("test-subject") + +def test_graceful_shutdown(nats_url): + def worker(client, connected_event): + client.connect() + connected_event.set() + try: + client.wait() + except NATSSocketError: + assert True + except Exception: + assert False, "unexpected Exception raised" + + client = NATSClient(nats_url) + connected_event = threading.Event() + thread = threading.Thread(target=worker, args=[client, connected_event]) + thread.start() + assert connected_event.wait(5), "unable to connect" + client.close() + thread.join(5) + assert not thread.is_alive(), "thread did not finish"