diff --git a/tests/middleware/test_trace_logging.py b/tests/middleware/test_trace_logging.py index 71ef82398..762b71fd1 100644 --- a/tests/middleware/test_trace_logging.py +++ b/tests/middleware/test_trace_logging.py @@ -48,26 +48,23 @@ } +async def app(scope, receive, send): + assert scope["type"] == "http" + await send({"type": "http.response.start", "status": 204, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + @pytest.mark.skipif( sys.platform.startswith("win") or platform.python_implementation() == "PyPy", reason="Skipping test on Windows and PyPy", ) def test_trace_logging(capsys): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, log_config=test_logging_config, @@ -92,21 +89,12 @@ def install_signal_handlers(self): ) @pytest.mark.parametrize("http_protocol", [("h11"), ("httptools")]) def test_access_logging(capsys, http_protocol): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass config = Config( - app=App, + app=app, loop="asyncio", http=http_protocol, limit_max_requests=1, diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 98445de2a..71b81fcd2 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -316,17 +316,13 @@ async def get_data(url): @pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS) def test_missing_handshake(protocol_cls): - class App: - def __init__(self, scope): - pass - - async def __call__(self, receive, send): - pass + async def app(app, receive, send): + pass async def connect(url): await websockets.connect(url) - with run_server(App, protocol_cls=protocol_cls) as url: + with run_server(app, protocol_cls=protocol_cls) as url: loop = asyncio.new_event_loop() with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: loop.run_until_complete(connect(url)) @@ -336,17 +332,13 @@ async def connect(url): @pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS) def test_send_before_handshake(protocol_cls): - class App: - def __init__(self, scope): - pass - - async def __call__(self, receive, send): - await send({"type": "websocket.send", "text": "123"}) + async def app(scope, receive, send): + await send({"type": "websocket.send", "text": "123"}) async def connect(url): await websockets.connect(url) - with run_server(App, protocol_cls=protocol_cls) as url: + with run_server(app, protocol_cls=protocol_cls) as url: loop = asyncio.new_event_loop() with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: loop.run_until_complete(connect(url)) @@ -356,19 +348,15 @@ async def connect(url): @pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS) def test_duplicate_handshake(protocol_cls): - class App: - def __init__(self, scope): - pass - - async def __call__(self, receive, send): - await send({"type": "websocket.accept"}) - await send({"type": "websocket.accept"}) + async def app(scope, receive, send): + await send({"type": "websocket.accept"}) + await send({"type": "websocket.accept"}) async def connect(url): async with websockets.connect(url) as websocket: _ = await websocket.recv() - with run_server(App, protocol_cls=protocol_cls) as url: + with run_server(app, protocol_cls=protocol_cls) as url: loop = asyncio.new_event_loop() with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info: loop.run_until_complete(connect(url)) diff --git a/tests/test_default_headers.py b/tests/test_default_headers.py index b3c029efe..b9ed98188 100644 --- a/tests/test_default_headers.py +++ b/tests/test_default_headers.py @@ -6,14 +6,10 @@ from uvicorn import Config, Server -class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 200, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) +async def app(scope, receive, send): + assert scope["type"] == "http" + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) class CustomServer(Server): @@ -22,7 +18,7 @@ def install_signal_handlers(self): def test_default_default_headers(): - config = Config(app=App, loop="asyncio", limit_max_requests=1) + config = Config(app=app, loop="asyncio", limit_max_requests=1) server = CustomServer(config=config) thread = threading.Thread(target=server.run) thread.start() @@ -37,7 +33,7 @@ def test_default_default_headers(): def test_override_server_header(): config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, headers=[("Server", "over-ridden")], @@ -56,7 +52,7 @@ def test_override_server_header(): def test_override_server_header_multiple_times(): config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, headers=[("Server", "over-ridden"), ("Server", "another-value")], @@ -78,7 +74,7 @@ def test_override_server_header_multiple_times(): def test_add_additional_header(): config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, headers=[("X-Additional", "new-value")], diff --git a/tests/test_main.py b/tests/test_main.py index a04178f14..629a8de84 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -9,6 +9,12 @@ from uvicorn.main import Server +async def app(scope, receive, send): + assert scope["type"] == "http" + await send({"type": "http.response.start", "status": 204, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + @pytest.mark.parametrize( "host, url", [ @@ -18,20 +24,11 @@ ], ) def test_run(host, url): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass - config = Config(app=App, host=host, loop="asyncio", limit_max_requests=1) + config = Config(app=app, host=host, loop="asyncio", limit_max_requests=1) server = CustomServer(config=config) thread = threading.Thread(target=server.run) thread.start() @@ -43,20 +40,11 @@ def install_signal_handlers(self): def test_run_multiprocess(): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass - config = Config(app=App, loop="asyncio", workers=2, limit_max_requests=1) + config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1) server = CustomServer(config=config) thread = threading.Thread(target=server.run) thread.start() @@ -68,20 +56,11 @@ def install_signal_handlers(self): def test_run_reload(): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass - config = Config(app=App, loop="asyncio", reload=True, limit_max_requests=1) + config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1) server = CustomServer(config=config) thread = threading.Thread(target=server.run) thread.start() @@ -93,20 +72,16 @@ def install_signal_handlers(self): def test_run_with_shutdown(): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - while True: - time.sleep(1) + async def app(scope, receive, send): + assert scope["type"] == "http" + while True: + time.sleep(1) class CustomServer(Server): def install_signal_handlers(self): pass - config = Config(app=App, loop="asyncio", workers=2, limit_max_requests=1) + config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1) server = CustomServer(config=config) sock = config.bind_socket() exc = True diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d64d828ab..5bd5dc73d 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -25,25 +25,22 @@ def no_ssl_verification(session=requests.Session): session.request = old_request +async def app(scope, receive, send): + assert scope["type"] == "http" + await send({"type": "http.response.start", "status": 204, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + @pytest.mark.skipif( sys.platform.startswith("win"), reason="Skipping SSL test on Windows" ) def test_run(tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, ssl_keyfile=tls_ca_certificate_private_key_path, @@ -64,21 +61,12 @@ def install_signal_handlers(self): sys.platform.startswith("win"), reason="Skipping SSL test on Windows" ) def test_run_chain(tls_certificate_pem_path): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, ssl_certfile=tls_certificate_pem_path, @@ -100,21 +88,12 @@ def install_signal_handlers(self): def test_run_password( tls_ca_certificate_pem_path, tls_ca_certificate_private_key_encrypted_path ): - class App: - def __init__(self, scope): - if scope["type"] != "http": - raise Exception() - - async def __call__(self, receive, send): - await send({"type": "http.response.start", "status": 204, "headers": []}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - class CustomServer(Server): def install_signal_handlers(self): pass config = Config( - app=App, + app=app, loop="asyncio", limit_max_requests=1, ssl_keyfile=tls_ca_certificate_private_key_encrypted_path,