Skip to content

Commit

Permalink
Proxy HTTP/2 support (#468)
Browse files Browse the repository at this point in the history
* Proxy HTTP/2 support

* Add tests for proxy HTTP/2 support

* Add tests for proxy HTTP/2 support

* Add tests for proxy HTTP/2 support

* Add tests for proxy HTTP/2 support
  • Loading branch information
tomchristie authored Jan 5, 2022
1 parent 2209b58 commit 5bcea8b
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 21 deletions.
64 changes: 56 additions & 8 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .._exceptions import ProxyError
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._trace import Trace
from ..backends.base import AsyncNetworkBackend
from .connection import AsyncHTTPConnection
from .connection_pool import AsyncConnectionPool
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(
max_connections: Optional[int] = 10,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: str = None,
uds: str = None,
Expand All @@ -69,6 +73,10 @@ def __init__(
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
Expand All @@ -84,6 +92,8 @@ def __init__(
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
Expand All @@ -107,6 +117,8 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)

Expand Down Expand Up @@ -177,6 +189,8 @@ def __init__(
ssl_context: ssl.SSLContext = None,
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
network_backend: AsyncNetworkBackend = None,
) -> None:
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
Expand All @@ -189,6 +203,8 @@ def __init__(
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = AsyncLock()
self._connected = False

Expand Down Expand Up @@ -224,16 +240,48 @@ async def handle_async_request(self, request: Request) -> Response:
raise ProxyError(msg)

stream = connect_response.extensions["network_stream"]
stream = await stream.start_tls(
ssl_context=self._ssl_context,
server_hostname=self._remote_origin.host.decode("ascii"),
timeout=timeout,

# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("connection.start_tls", request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)

# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection

self._connection = AsyncHTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)

self._connected = True
return await self._connection.handle_async_request(request)

Expand Down
64 changes: 56 additions & 8 deletions httpcore/_sync/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .._exceptions import ProxyError
from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from ..backends.base import NetworkBackend
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(
max_connections: Optional[int] = 10,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: str = None,
uds: str = None,
Expand All @@ -69,6 +73,10 @@ def __init__(
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
Expand All @@ -84,6 +92,8 @@ def __init__(
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
Expand All @@ -107,6 +117,8 @@ def create_connection(self, origin: Origin) -> ConnectionInterface:
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)

Expand Down Expand Up @@ -177,6 +189,8 @@ def __init__(
ssl_context: ssl.SSLContext = None,
proxy_headers: Sequence[Tuple[bytes, bytes]] = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
network_backend: NetworkBackend = None,
) -> None:
self._connection: ConnectionInterface = HTTPConnection(
Expand All @@ -189,6 +203,8 @@ def __init__(
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = Lock()
self._connected = False

Expand Down Expand Up @@ -224,16 +240,48 @@ def handle_request(self, request: Request) -> Response:
raise ProxyError(msg)

stream = connect_response.extensions["network_stream"]
stream = stream.start_tls(
ssl_context=self._ssl_context,
server_hostname=self._remote_origin.host.decode("ascii"),
timeout=timeout,

# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("connection.start_tls", request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)

# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import HTTP2Connection

self._connection = HTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)

self._connected = True
return self._connection.handle_request(request)

Expand Down
2 changes: 1 addition & 1 deletion scripts/coverage
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ fi

set -x

${PREFIX}coverage report --show-missing --skip-covered --fail-under=93
${PREFIX}coverage report --show-missing --skip-covered --fail-under=100
99 changes: 97 additions & 2 deletions tests/_async/test_http_proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import ssl

import hpack
import hyperframe.frame
import pytest

from httpcore import AsyncHTTPProxy, Origin, ProxyError
from httpcore.backends.mock import AsyncMockBackend
from httpcore.backends.base import AsyncNetworkStream
from httpcore.backends.mock import AsyncMockBackend, AsyncMockStream


@pytest.mark.anyio
Expand Down Expand Up @@ -64,7 +69,9 @@ async def test_proxy_tunneling():
"""
network_backend = AsyncMockBackend(
[
b"HTTP/1.1 200 OK\r\n" b"\r\n",
# The initial response to the proxy CONNECT
b"HTTP/1.1 200 OK\r\n\r\n",
# The actual response from the remote server
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
Expand Down Expand Up @@ -111,6 +118,94 @@ async def test_proxy_tunneling():
)


# We need to adapt the mock backend here slightly in order to deal
# with the proxy case. We do not want the initial connection to the proxy
# to indicate an HTTP/2 connection, but we do want it to indicate HTTP/2
# once the SSL upgrade has taken place.
class HTTP1ThenHTTP2Stream(AsyncMockStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: str = None,
timeout: float = None,
) -> AsyncNetworkStream:
self._http2 = True
return self


class HTTP1ThenHTTP2Backend(AsyncMockBackend):
async def connect_tcp(
self, host: str, port: int, timeout: float = None, local_address: str = None
) -> AsyncNetworkStream:
return HTTP1ThenHTTP2Stream(list(self._buffer))


@pytest.mark.anyio
async def test_proxy_tunneling_http2():
"""
Send an HTTP/2 request via a proxy.
"""
network_backend = HTTP1ThenHTTP2Backend(
[
# The initial response to the proxy CONNECT
b"HTTP/1.1 200 OK\r\n\r\n",
# The actual response from the remote server
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
],
)

async with AsyncHTTPProxy(
proxy_url="http://localhost:8080/",
max_connections=10,
network_backend=network_backend,
http2=True,
) as proxy:
# Sending an intial request, which once complete will return to the pool, IDLE.
async with proxy.stream("GET", "https://example.com/") as response:
info = [repr(c) for c in proxy.connections]
assert info == [
"<AsyncTunnelHTTPConnection ['https://example.com:443', HTTP/2, ACTIVE, Request Count: 1]>"
]
await response.aread()

assert response.status == 200
assert response.content == b"Hello, world!"
info = [repr(c) for c in proxy.connections]
assert info == [
"<AsyncTunnelHTTPConnection ['https://example.com:443', HTTP/2, IDLE, Request Count: 1]>"
]
assert proxy.connections[0].is_idle()
assert proxy.connections[0].is_available()
assert not proxy.connections[0].is_closed()

# A connection on a tunneled proxy can only handle HTTPS requests to the same origin.
assert not proxy.connections[0].can_handle_request(
Origin(b"http", b"example.com", 80)
)
assert not proxy.connections[0].can_handle_request(
Origin(b"http", b"other.com", 80)
)
assert proxy.connections[0].can_handle_request(
Origin(b"https", b"example.com", 443)
)
assert not proxy.connections[0].can_handle_request(
Origin(b"https", b"other.com", 443)
)


@pytest.mark.anyio
async def test_proxy_tunneling_with_403():
"""
Expand Down
Loading

0 comments on commit 5bcea8b

Please sign in to comment.