Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional HTTP/2 #121

Merged
merged 4 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._types import URL, Headers, Origin, TimeoutDict
Expand All @@ -10,8 +10,7 @@
ConnectionState,
NewConnectionRequired,
)
from .http2 import AsyncHTTP2Connection
from .http11 import AsyncHTTP11Connection
from .http import AsyncBaseHTTPConnection

logger = get_logger(__name__)

Expand All @@ -32,7 +31,7 @@ def __init__(
if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])

self.connection: Union[None, AsyncHTTP11Connection, AsyncHTTP2Connection] = None
self.connection: Optional[AsyncBaseHTTPConnection] = None
self.is_http11 = False
self.is_http2 = False
self.connect_failed = False
Expand Down Expand Up @@ -110,11 +109,15 @@ def _create_connection(self, socket: AsyncSocketStream) -> None:
"create_connection socket=%r http_version=%r", socket, http_version
)
if http_version == "HTTP/2":
from .http2 import AsyncHTTP2Connection

self.is_http2 = True
self.connection = AsyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
from .http11 import AsyncHTTP11Connection

self.is_http11 = True
self.connection = AsyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
Expand All @@ -126,7 +129,7 @@ def state(self) -> ConnectionState:
return ConnectionState.CLOSED
elif self.connection is None:
return ConnectionState.PENDING
return self.connection.state
return self.connection.get_state()

def is_connection_dropped(self) -> bool:
return self.connection is not None and self.connection.is_connection_dropped()
Expand All @@ -138,9 +141,8 @@ def mark_as_ready(self) -> None:
async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
await self.connection.start_tls(hostname, timeout)
self.socket = await self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
self.socket = self.connection.socket

async def aclose(self) -> None:
async with self.request_lock:
Expand Down
35 changes: 35 additions & 0 deletions httpcore/_async/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from .._backends.auto import AsyncSocketStream
from .._types import TimeoutDict
from .base import AsyncHTTPTransport, ConnectionState


class AsyncBaseHTTPConnection(AsyncHTTPTransport):
def info(self) -> str:
raise NotImplementedError() # pragma: nocover

def get_state(self) -> ConnectionState:
"""
Return the current state.
"""
raise NotImplementedError() # pragma: nocover

def mark_as_ready(self) -> None:
"""
The connection has been acquired from the pool, and the state
should reflect that.
"""
raise NotImplementedError() # pragma: nocover

def is_connection_dropped(self) -> bool:
"""
Return 'True' if the connection has been dropped by the remote end.
"""
raise NotImplementedError() # pragma: nocover

async def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> AsyncSocketStream:
"""
Upgrade the underlying socket to TLS.
"""
raise NotImplementedError() # pragma: nocover
13 changes: 10 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .._exceptions import ProtocolError, map_exceptions
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import AsyncByteStream, AsyncHTTPTransport, ConnectionState
from .base import AsyncByteStream, ConnectionState
from .http import AsyncBaseHTTPConnection

H11Event = Union[
h11.Request,
Expand All @@ -21,7 +22,7 @@
logger = get_logger(__name__)


class AsyncHTTP11Connection(AsyncHTTPTransport):
class AsyncHTTP11Connection(AsyncBaseHTTPConnection):
READ_NUM_BYTES = 4096

def __init__(
Expand All @@ -40,6 +41,9 @@ def __repr__(self) -> str:
def info(self) -> str:
return f"HTTP/1.1, {self.state.name}"

def get_state(self) -> ConnectionState:
return self.state

def mark_as_ready(self) -> None:
if self.state == ConnectionState.IDLE:
self.state = ConnectionState.READY
Expand Down Expand Up @@ -72,9 +76,12 @@ async def request(
)
return (http_version, status_code, reason_phrase, headers, stream)

async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
async def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> AsyncSocketStream:
timeout = {} if timeout is None else timeout
self.socket = await self.socket.start_tls(hostname, self.ssl_context, timeout)
return self.socket

async def _send_request(
self, method: bytes, url: URL, headers: Headers, timeout: TimeoutDict,
Expand Down
19 changes: 10 additions & 9 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
AsyncByteStream,
AsyncHTTPTransport,
ConnectionState,
NewConnectionRequired,
)
from .base import AsyncByteStream, ConnectionState, NewConnectionRequired
from .http import AsyncBaseHTTPConnection

logger = get_logger(__name__)

Expand All @@ -29,7 +25,7 @@ def get_reason_phrase(status_code: int) -> bytes:
return b""


class AsyncHTTP2Connection(AsyncHTTPTransport):
class AsyncHTTP2Connection(AsyncBaseHTTPConnection):
READ_NUM_BYTES = 4096
CONFIG = H2Configuration(validate_inbound_headers=False)

Expand Down Expand Up @@ -84,8 +80,13 @@ def max_streams_semaphore(self) -> AsyncSemaphore:
)
return self._max_streams_semaphore

async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass
async def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> AsyncSocketStream:
raise NotImplementedError("TLS upgrade not supported on HTTP/2 connections.")

def get_state(self) -> ConnectionState:
return self.state

def mark_as_ready(self) -> None:
if self.state == ConnectionState.IDLE:
Expand Down
16 changes: 9 additions & 7 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._types import URL, Headers, Origin, TimeoutDict
Expand All @@ -10,8 +10,7 @@
ConnectionState,
NewConnectionRequired,
)
from .http2 import SyncHTTP2Connection
from .http11 import SyncHTTP11Connection
from .http import SyncBaseHTTPConnection

logger = get_logger(__name__)

Expand All @@ -32,7 +31,7 @@ def __init__(
if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])

self.connection: Union[None, SyncHTTP11Connection, SyncHTTP2Connection] = None
self.connection: Optional[SyncBaseHTTPConnection] = None
self.is_http11 = False
self.is_http2 = False
self.connect_failed = False
Expand Down Expand Up @@ -110,11 +109,15 @@ def _create_connection(self, socket: SyncSocketStream) -> None:
"create_connection socket=%r http_version=%r", socket, http_version
)
if http_version == "HTTP/2":
from .http2 import SyncHTTP2Connection

self.is_http2 = True
self.connection = SyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
from .http11 import SyncHTTP11Connection

self.is_http11 = True
self.connection = SyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
Expand All @@ -126,7 +129,7 @@ def state(self) -> ConnectionState:
return ConnectionState.CLOSED
elif self.connection is None:
return ConnectionState.PENDING
return self.connection.state
return self.connection.get_state()

def is_connection_dropped(self) -> bool:
return self.connection is not None and self.connection.is_connection_dropped()
Expand All @@ -138,9 +141,8 @@ def mark_as_ready(self) -> None:
def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
self.connection.start_tls(hostname, timeout)
self.socket = self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
self.socket = self.connection.socket

def close(self) -> None:
with self.request_lock:
Expand Down
35 changes: 35 additions & 0 deletions httpcore/_sync/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from .._backends.auto import SyncSocketStream
from .._types import TimeoutDict
from .base import SyncHTTPTransport, ConnectionState


class SyncBaseHTTPConnection(SyncHTTPTransport):
def info(self) -> str:
raise NotImplementedError() # pragma: nocover

def get_state(self) -> ConnectionState:
"""
Return the current state.
"""
raise NotImplementedError() # pragma: nocover

def mark_as_ready(self) -> None:
"""
The connection has been acquired from the pool, and the state
should reflect that.
"""
raise NotImplementedError() # pragma: nocover

def is_connection_dropped(self) -> bool:
"""
Return 'True' if the connection has been dropped by the remote end.
"""
raise NotImplementedError() # pragma: nocover

def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> SyncSocketStream:
"""
Upgrade the underlying socket to TLS.
"""
raise NotImplementedError() # pragma: nocover
13 changes: 10 additions & 3 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .._exceptions import ProtocolError, map_exceptions
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import SyncByteStream, SyncHTTPTransport, ConnectionState
from .base import SyncByteStream, ConnectionState
from .http import SyncBaseHTTPConnection

H11Event = Union[
h11.Request,
Expand All @@ -21,7 +22,7 @@
logger = get_logger(__name__)


class SyncHTTP11Connection(SyncHTTPTransport):
class SyncHTTP11Connection(SyncBaseHTTPConnection):
READ_NUM_BYTES = 4096

def __init__(
Expand All @@ -40,6 +41,9 @@ def __repr__(self) -> str:
def info(self) -> str:
return f"HTTP/1.1, {self.state.name}"

def get_state(self) -> ConnectionState:
return self.state

def mark_as_ready(self) -> None:
if self.state == ConnectionState.IDLE:
self.state = ConnectionState.READY
Expand Down Expand Up @@ -72,9 +76,12 @@ def request(
)
return (http_version, status_code, reason_phrase, headers, stream)

def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> SyncSocketStream:
timeout = {} if timeout is None else timeout
self.socket = self.socket.start_tls(hostname, self.ssl_context, timeout)
return self.socket

def _send_request(
self, method: bytes, url: URL, headers: Headers, timeout: TimeoutDict,
Expand Down
19 changes: 10 additions & 9 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from .._exceptions import PoolTimeout, ProtocolError
from .._types import URL, Headers, TimeoutDict
from .._utils import get_logger
from .base import (
SyncByteStream,
SyncHTTPTransport,
ConnectionState,
NewConnectionRequired,
)
from .base import SyncByteStream, ConnectionState, NewConnectionRequired
from .http import SyncBaseHTTPConnection

logger = get_logger(__name__)

Expand All @@ -29,7 +25,7 @@ def get_reason_phrase(status_code: int) -> bytes:
return b""


class SyncHTTP2Connection(SyncHTTPTransport):
class SyncHTTP2Connection(SyncBaseHTTPConnection):
READ_NUM_BYTES = 4096
CONFIG = H2Configuration(validate_inbound_headers=False)

Expand Down Expand Up @@ -84,8 +80,13 @@ def max_streams_semaphore(self) -> SyncSemaphore:
)
return self._max_streams_semaphore

def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
pass
def start_tls(
self, hostname: bytes, timeout: TimeoutDict = None
) -> SyncSocketStream:
raise NotImplementedError("TLS upgrade not supported on HTTP/2 connections.")

def get_state(self) -> ConnectionState:
return self.state

def mark_as_ready(self) -> None:
if self.state == ConnectionState.IDLE:
Expand Down