diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index ffe32729..4e369f6d 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Optional, Union +from typing import Optional import trio @@ -22,7 +22,7 @@ def none_as_inf(value: Optional[float]) -> float: class SocketStream(AsyncSocketStream): - def __init__(self, stream: Union[trio.SocketStream, trio.SSLStream]) -> None: + def __init__(self, stream: trio.abc.Stream) -> None: self.stream = stream self.read_lock = trio.Lock() self.write_lock = trio.Lock() @@ -43,7 +43,9 @@ async def start_tls( trio.BrokenResourceError: ConnectError, } ssl_stream = trio.SSLStream( - self.stream, ssl_context=ssl_context, server_hostname=hostname + self.stream, + ssl_context=ssl_context, + server_hostname=hostname.decode("ascii"), ) with map_exceptions(exc_map): @@ -85,7 +87,7 @@ def is_connection_dropped(self) -> bool: stream = self.stream # Peek through any SSLStream wrappers to get the underlying SocketStream. - while hasattr(stream, "transport_stream"): + while isinstance(stream, trio.SSLStream): stream = stream.transport_stream assert isinstance(stream, trio.SocketStream) @@ -147,11 +149,11 @@ async def open_tcp_stream( with map_exceptions(exc_map): with trio.fail_after(connect_timeout): - stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port) + stream: trio.abc.Stream = await trio.open_tcp_stream(hostname, port) if ssl_context is not None: stream = trio.SSLStream( - stream, ssl_context, server_hostname=hostname + stream, ssl_context, server_hostname=hostname.decode("ascii") ) await stream.do_handshake() diff --git a/requirements.txt b/requirements.txt index 8ae23a71..cd235471 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ # Optionals trio +trio-typing # Docs mkdocs diff --git a/setup.cfg b/setup.cfg index 939e3cab..5cff3486 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ max-line-length = 88 [mypy] disallow_untyped_defs = True ignore_missing_imports = True +plugins = trio_typing.plugin [tool:isort] combine_as_imports = True