diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f92ca68b..69051d28 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -68,6 +68,13 @@ Improvements Previously, :exc:`RuntimeError` was raised. For backwards compatibility, :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. +Bug fixes +......... + +* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't + start the connection handler anymore when ``process_request`` or + ``process_response`` returns a HTTP response. + 13.0.1 ------ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 3985bfb6..b1beb3e0 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -98,9 +98,7 @@ async def handshake( # before receiving a response, when the response cannot be parsed, or # when the response fails the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection: raise uri_or_exc from exc else: + self.connection.start_keepalive() return self.connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 228b2001..78ee760d 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -201,12 +201,11 @@ async def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return + assert connection.protocol.state is OPEN try: + connection.start_keepalive() await self.handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index ac62800d..b2671f40 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - self.logger.info("connection open") return Response(101, "Switching Protocols", headers) def process_request( @@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response: ("Content-Type", "text/plain; charset=utf-8"), ] ) - response = Response(status.value, status.phrase, headers, body) - # When reject() is called from accept(), handshake_exc is already set. - # If a user calls reject(), set handshake_exc to guarantee invariant: - # "handshake_exc is None if and only if opening handshake succeeded." - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) - self.logger.info("connection rejected (%d %s)", status.value, status.phrase) - return response + return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: """ @@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: assert self.state is CONNECTING self.state = OPEN + self.logger.info("connection open") + else: + # handshake_exc may be already set if accept() encountered an error. + # If the connection isn't open, set handshake_exc to guarantee that + # handshake_exc is None if and only if opening handshake succeeded. + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) + self.logger.info( + "connection rejected (%d %s)", + response.status_code, + response.reason_phrase, + ) + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index eb053601..0b19201a 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -23,7 +23,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection @@ -166,8 +166,9 @@ def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -569,6 +570,7 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return + assert connection.protocol.state is OPEN try: handler(connection) except Exception: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fdcbf978..47e0148a 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -145,7 +145,10 @@ async def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +163,10 @@ async def test_async_process_request_returns_response(self): async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d0d2c095..3bc6f76c 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -133,7 +133,10 @@ def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - with run_server(process_request=process_request) as server: + def handler(ws): + self.fail("handler must not run") + + with run_server(handler, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/test_server.py b/tests/test_server.py index d34c8e83..52c8a2b9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -106,10 +106,11 @@ def make_request(self): ), ) - def test_send_accept(self): + def test_send_response_after_successful_accept(self): server = ServerProtocol() + request = self.make_request() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + response = server.accept(request) self.assertIsInstance(response, Response) server.send_response(response) self.assertEqual( @@ -126,7 +127,32 @@ def test_send_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_reject(self): + def test_send_response_after_failed_accept(self): + server = ServerProtocol() + request = self.make_request() + del request.headers["Sec-WebSocket-Key"] + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + response = server.accept(request) + self.assertIsInstance(response, Response) + server.send_response(response) + self.assertEqual( + server.data_to_send(), + [ + f"HTTP/1.1 400 Bad Request\r\n" + f"Date: {DATE}\r\n" + f"Connection: close\r\n" + f"Content-Length: 94\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"\r\n" + f"Failed to open a WebSocket connection: " + f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + + def test_send_response_after_reject(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") @@ -148,6 +174,19 @@ def test_send_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) + def test_send_response_without_accept_or_reject(self): + server = ServerProtocol() + server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + self.assertEqual( + server.data_to_send(), + [ + "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + def test_accept_response(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE):