Skip to content

Commit

Permalink
Properly set host header to ascii string in ProxyFixMiddleware.
Browse files Browse the repository at this point in the history
  • Loading branch information
apollo13 authored and pgjones committed Jan 4, 2024
1 parent bc39603 commit 3fbd5f2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/hypercorn/middleware/proxy_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def __init__(
self.trusted_hops = trusted_hops

async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
if scope["type"] in {"http", "websocket"}:
# Keep the `or` instead of `in {'http' …}` to allow type narrowing
if scope["type"] == "http" or scope["type"] == "websocket":
scope = deepcopy(scope)
headers = scope["headers"] # type: ignore
headers = scope["headers"]
client: Optional[str] = None
scheme: Optional[str] = None
host: Optional[str] = None
Expand All @@ -44,19 +45,19 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non
host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops)

if client is not None:
scope["client"] = (client, 0) # type: ignore
scope["client"] = (client, 0)

if scheme is not None:
scope["scheme"] = scheme # type: ignore
scope["scheme"] = scheme

if host is not None:
headers = [
(name, header_value)
for name, header_value in headers
if name.lower() != b"host"
]
headers.append((b"host", host))
scope["headers"] = headers # type: ignore
headers.append((b"host", host.encode()))
scope["headers"] = headers

await self.app(scope, receive, send)

Expand Down
17 changes: 12 additions & 5 deletions tests/middleware/test_proxy_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@ async def test_proxy_fix_legacy() -> None:
(b"x-forwarded-for", b"127.0.0.1"),
(b"x-forwarded-for", b"127.0.0.2"),
(b"x-forwarded-proto", b"http,https"),
(b"x-forwarded-host", b"example.com"),
],
"client": ("127.0.0.3", 80),
"server": None,
"extensions": {},
}
await app(scope, None, None)
mock.assert_called()
assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0)
assert mock.call_args[0][0]["scheme"] == "https"
scope = mock.call_args[0][0]
assert scope["client"] == ("127.0.0.2", 0)
assert scope["scheme"] == "https"
host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"]
assert host_headers == [(b"host", b"example.com")]


@pytest.mark.asyncio
Expand All @@ -52,13 +56,16 @@ async def test_proxy_fix_modern() -> None:
"query_string": b"",
"root_path": "",
"headers": [
(b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https"),
(b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https;host=example.com"),
],
"client": ("127.0.0.3", 80),
"server": None,
"extensions": {},
}
await app(scope, None, None)
mock.assert_called()
assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0)
assert mock.call_args[0][0]["scheme"] == "https"
scope = mock.call_args[0][0]
assert scope["client"] == ("127.0.0.2", 0)
assert scope["scheme"] == "https"
host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"]
assert host_headers == [(b"host", b"example.com")]

0 comments on commit 3fbd5f2

Please sign in to comment.