Skip to content

Commit

Permalink
Improve isolation of tests of sync implementation.
Browse files Browse the repository at this point in the history
Before this change, threads handling requests could continue running
after the end of the test. This caused spurious failures.

Specifically, a test expecting an error log could get an error log from
a previous tests. This happened sporadically on PyPy.
  • Loading branch information
aaugustin committed Sep 9, 2024
1 parent 14d9d40 commit f9cea9c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 51 deletions.
18 changes: 18 additions & 0 deletions tests/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,30 @@ def run_server(handler=handler, host="localhost", port=0, **kwargs):
with serve(handler, host, port, **kwargs) as server:
thread = threading.Thread(target=server.serve_forever)
thread.start()

# HACK: since the sync server doesn't track connections (yet), we record
# a reference to the thread handling the most recent connection, then we
# can wait for that thread to terminate when exiting the context.
handler_thread = None
original_handler = server.handler

def handler(sock, addr):
nonlocal handler_thread
handler_thread = threading.current_thread()
original_handler(sock, addr)

server.handler = handler

try:
yield server
finally:
server.shutdown()
thread.join()

# HACK: wait for the thread handling the most recent connection.
if handler_thread is not None:
handler_thread.join()


@contextlib.contextmanager
def run_unix_server(path, handler=handler, **kwargs):
Expand Down
54 changes: 3 additions & 51 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import http
import logging
import socket
import threading
import time
import unittest

from websockets.exceptions import (
Expand Down Expand Up @@ -289,50 +289,19 @@ def test_timeout_during_handshake(self):
def test_connection_closed_during_handshake(self):
"""Server reads EOF before receiving handshake request from client."""
with run_server() as server:
# Patch handler to record a reference to the thread running it.
server_thread = None
conn_received = threading.Event()
original_handler = server.handler

def handler(sock, addr):
nonlocal server_thread
server_thread = threading.current_thread()
nonlocal conn_received
conn_received.set()
original_handler(sock, addr)

server.handler = handler

with socket.create_connection(server.socket.getsockname()):
# Wait for the server to receive the connection, then close it.
conn_received.wait()

# Wait for the server thread to terminate.
server_thread.join()
time.sleep(MS)

def test_junk_handshake(self):
"""Server closes the connection when receiving non-HTTP request from client."""
with self.assertLogs("websockets.server", logging.ERROR) as logs:
with run_server() as server:
# Patch handler to record a reference to the thread running it.
server_thread = None
original_handler = server.handler

def handler(sock, addr):
nonlocal server_thread
server_thread = threading.current_thread()
original_handler(sock, addr)

server.handler = handler

with socket.create_connection(server.socket.getsockname()) as sock:
sock.send(b"HELO relay.invalid\r\n")
# Wait for the server to close the connection.
self.assertEqual(sock.recv(4096), b"")

# Wait for the server thread to terminate.
server_thread.join()

self.assertEqual(
[record.getMessage() for record in logs.records],
["opening handshake failed"],
Expand Down Expand Up @@ -360,26 +329,9 @@ def test_timeout_during_tls_handshake(self):
def test_connection_closed_during_tls_handshake(self):
"""Server reads EOF before receiving TLS handshake request from client."""
with run_server(ssl=SERVER_CONTEXT) as server:
# Patch handler to record a reference to the thread running it.
server_thread = None
conn_received = threading.Event()
original_handler = server.handler

def handler(sock, addr):
nonlocal server_thread
server_thread = threading.current_thread()
nonlocal conn_received
conn_received.set()
original_handler(sock, addr)

server.handler = handler

with socket.create_connection(server.socket.getsockname()):
# Wait for the server to receive the connection, then close it.
conn_received.wait()

# Wait for the server thread to terminate.
server_thread.join()
time.sleep(MS)


@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
Expand Down

0 comments on commit f9cea9c

Please sign in to comment.