Skip to content

Commit

Permalink
Modernize test ASGI applications
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Nov 18, 2020
1 parent 634aec9 commit bcf26b1
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 123 deletions.
28 changes: 8 additions & 20 deletions tests/middleware/test_trace_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,23 @@
}


async def app(scope, receive, send):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})


@pytest.mark.skipif(
sys.platform.startswith("win") or platform.python_implementation() == "PyPy",
reason="Skipping test on Windows and PyPy",
)
def test_trace_logging(capsys):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
log_config=test_logging_config,
Expand All @@ -92,21 +89,12 @@ def install_signal_handlers(self):
)
@pytest.mark.parametrize("http_protocol", [("h11"), ("httptools")])
def test_access_logging(capsys, http_protocol):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(
app=App,
app=app,
loop="asyncio",
http=http_protocol,
limit_max_requests=1,
Expand Down
32 changes: 10 additions & 22 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,13 @@ async def get_data(url):

@pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS)
def test_missing_handshake(protocol_cls):
class App:
def __init__(self, scope):
pass

async def __call__(self, receive, send):
pass
async def app(app, receive, send):
pass

async def connect(url):
await websockets.connect(url)

with run_server(App, protocol_cls=protocol_cls) as url:
with run_server(app, protocol_cls=protocol_cls) as url:
loop = asyncio.new_event_loop()
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
loop.run_until_complete(connect(url))
Expand All @@ -336,17 +332,13 @@ async def connect(url):

@pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS)
def test_send_before_handshake(protocol_cls):
class App:
def __init__(self, scope):
pass

async def __call__(self, receive, send):
await send({"type": "websocket.send", "text": "123"})
async def app(scope, receive, send):
await send({"type": "websocket.send", "text": "123"})

async def connect(url):
await websockets.connect(url)

with run_server(App, protocol_cls=protocol_cls) as url:
with run_server(app, protocol_cls=protocol_cls) as url:
loop = asyncio.new_event_loop()
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
loop.run_until_complete(connect(url))
Expand All @@ -356,19 +348,15 @@ async def connect(url):

@pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS)
def test_duplicate_handshake(protocol_cls):
class App:
def __init__(self, scope):
pass

async def __call__(self, receive, send):
await send({"type": "websocket.accept"})
await send({"type": "websocket.accept"})
async def app(scope, receive, send):
await send({"type": "websocket.accept"})
await send({"type": "websocket.accept"})

async def connect(url):
async with websockets.connect(url) as websocket:
_ = await websocket.recv()

with run_server(App, protocol_cls=protocol_cls) as url:
with run_server(app, protocol_cls=protocol_cls) as url:
loop = asyncio.new_event_loop()
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
loop.run_until_complete(connect(url))
Expand Down
20 changes: 8 additions & 12 deletions tests/test_default_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
from uvicorn import Config, Server


class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def app(scope, receive, send):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})


class CustomServer(Server):
Expand All @@ -22,7 +18,7 @@ def install_signal_handlers(self):


def test_default_default_headers():
config = Config(app=App, loop="asyncio", limit_max_requests=1)
config = Config(app=app, loop="asyncio", limit_max_requests=1)
server = CustomServer(config=config)
thread = threading.Thread(target=server.run)
thread.start()
Expand All @@ -37,7 +33,7 @@ def test_default_default_headers():

def test_override_server_header():
config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("Server", "over-ridden")],
Expand All @@ -56,7 +52,7 @@ def test_override_server_header():

def test_override_server_header_multiple_times():
config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("Server", "over-ridden"), ("Server", "another-value")],
Expand All @@ -78,7 +74,7 @@ def test_override_server_header_multiple_times():

def test_add_additional_header():
config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("X-Additional", "new-value")],
Expand Down
53 changes: 14 additions & 39 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from uvicorn.main import Server


async def app(scope, receive, send):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})


@pytest.mark.parametrize(
"host, url",
[
Expand All @@ -18,20 +24,11 @@
],
)
def test_run(host, url):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(app=App, host=host, loop="asyncio", limit_max_requests=1)
config = Config(app=app, host=host, loop="asyncio", limit_max_requests=1)
server = CustomServer(config=config)
thread = threading.Thread(target=server.run)
thread.start()
Expand All @@ -43,20 +40,11 @@ def install_signal_handlers(self):


def test_run_multiprocess():
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(app=App, loop="asyncio", workers=2, limit_max_requests=1)
config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1)
server = CustomServer(config=config)
thread = threading.Thread(target=server.run)
thread.start()
Expand All @@ -68,20 +56,11 @@ def install_signal_handlers(self):


def test_run_reload():
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(app=App, loop="asyncio", reload=True, limit_max_requests=1)
config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1)
server = CustomServer(config=config)
thread = threading.Thread(target=server.run)
thread.start()
Expand All @@ -93,20 +72,16 @@ def install_signal_handlers(self):


def test_run_with_shutdown():
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
while True:
time.sleep(1)
async def app(scope, receive, send):
assert scope["type"] == "http"
while True:
time.sleep(1)

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(app=App, loop="asyncio", workers=2, limit_max_requests=1)
config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1)
server = CustomServer(config=config)
sock = config.bind_socket()
exc = True
Expand Down
39 changes: 9 additions & 30 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,22 @@ def no_ssl_verification(session=requests.Session):
session.request = old_request


async def app(scope, receive, send):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})


@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
)
def test_run(tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_path,
Expand All @@ -64,21 +61,12 @@ def install_signal_handlers(self):
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
)
def test_run_chain(tls_certificate_pem_path):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_certfile=tls_certificate_pem_path,
Expand All @@ -100,21 +88,12 @@ def install_signal_handlers(self):
def test_run_password(
tls_ca_certificate_pem_path, tls_ca_certificate_private_key_encrypted_path
):
class App:
def __init__(self, scope):
if scope["type"] != "http":
raise Exception()

async def __call__(self, receive, send):
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})

class CustomServer(Server):
def install_signal_handlers(self):
pass

config = Config(
app=App,
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_encrypted_path,
Expand Down

0 comments on commit bcf26b1

Please sign in to comment.