Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local address support. #100

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple, Union

from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .._types import URL, Headers, Origin, SocketAddress, TimeoutDict
from .._utils import get_logger, url_to_origin
from .base import (
AsyncByteStream,
Expand All @@ -23,11 +23,15 @@ def __init__(
http2: bool = False,
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
family: int = 0,
local_addr: SocketAddress = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.family = family
self.local_addr = local_addr

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -90,7 +94,7 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
hostname, port, ssl_context, timeout, self.family, self.local_addr
)
except Exception:
self.connect_failed = True
Expand Down
15 changes: 13 additions & 2 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .._backends.auto import AsyncLock, AsyncSemaphore, AutoBackend
from .._exceptions import PoolTimeout
from .._threadlock import ThreadLock
from .._types import URL, Headers, Origin, TimeoutDict
from .._types import URL, Headers, Origin, SocketAddress, TimeoutDict
from .._utils import get_logger, origin_to_url_string, url_to_origin
from .base import (
AsyncByteStream,
Expand Down Expand Up @@ -76,6 +76,9 @@ class AsyncConnectionPool(AsyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **family** - `int` - Address family to use, defaults to 0.
* **local_addr** - `Optional[SocketAddress]` - Local address to connect
from; requires family
"""

def __init__(
Expand All @@ -85,12 +88,16 @@ def __init__(
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
family: int = 0,
local_addr: SocketAddress = None,
):
self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._family = family
self._local_addr = local_addr
self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = AutoBackend()
Expand Down Expand Up @@ -141,7 +148,11 @@ async def request(

if connection is None:
connection = AsyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context,
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
family=self._family,
local_addr=self._local_addr,
)
logger.trace("created connection=%r", connection)
await self._add_to_pool(connection, timeout=timeout)
Expand Down
9 changes: 7 additions & 2 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._types import SocketAddress, TimeoutDict
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

SSL_MONKEY_PATCH_APPLIED = False
Expand Down Expand Up @@ -222,13 +222,18 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
family: int,
local_addr: Optional[SocketAddress],
) -> SocketStream:
host = hostname.decode("ascii")
connect_timeout = timeout.get("connect")
exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError}
with map_exceptions(exc_map):
stream_reader, stream_writer = await asyncio.wait_for(
asyncio.open_connection(host, port, ssl=ssl_context), connect_timeout,
asyncio.open_connection(
host, port, ssl=ssl_context, family=family, local_addr=local_addr
),
connect_timeout,
)
return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import sniffio

from .._types import TimeoutDict
from .._types import SocketAddress, TimeoutDict
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

# The following line is imported from the _sync modules
Expand Down Expand Up @@ -34,8 +34,12 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
family: int,
local_addr: Optional[SocketAddress],
) -> AsyncSocketStream:
return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout, family, local_addr
)

def create_lock(self) -> AsyncLock:
return self.backend.create_lock()
Expand Down
4 changes: 3 additions & 1 deletion httpcore/_backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from types import TracebackType
from typing import Optional, Type

from .._types import TimeoutDict
from .._types import SocketAddress, TimeoutDict


class AsyncSocketStream:
Expand Down Expand Up @@ -76,6 +76,8 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
family: int,
local_addr: Optional[SocketAddress],
) -> AsyncSocketStream:
raise NotImplementedError() # pragma: no cover

Expand Down
13 changes: 11 additions & 2 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._types import SocketAddress, TimeoutDict


class SyncSocketStream:
Expand Down Expand Up @@ -125,13 +125,22 @@ def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
family: int,
local_addr: Optional[SocketAddress],
) -> SyncSocketStream:
address = (hostname.decode("ascii"), port)
connect_timeout = timeout.get("connect")
exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError}

with map_exceptions(exc_map):
sock = socket.create_connection(address, connect_timeout)
if family != 0 and local_addr is None:
if family == socket.AF_INET:
local_addr = ("0.0.0.0", 0)
elif family == socket.AF_INET6:
local_addr = ("::", 0)
else:
raise NotImplementedError()
sock = socket.create_connection(address, connect_timeout, local_addr) # type: ignore
if ssl_context is not None:
sock = ssl_context.wrap_socket(
sock, server_hostname=hostname.decode("ascii")
Expand Down
6 changes: 5 additions & 1 deletion httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._types import SocketAddress, TimeoutDict
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream


Expand Down Expand Up @@ -140,7 +140,11 @@ async def open_tcp_stream(
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
family: int,
local_addr: Optional[SocketAddress],
) -> AsyncSocketStream:
if family != 0 or local_addr:
raise NotImplementedError()
connect_timeout = none_as_inf(timeout.get("connect"))
exc_map = {
trio.TooSlowError: ConnectTimeout,
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple, Union

from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .._types import URL, Headers, Origin, SocketAddress, TimeoutDict
from .._utils import get_logger, url_to_origin
from .base import (
SyncByteStream,
Expand All @@ -23,11 +23,15 @@ def __init__(
http2: bool = False,
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
family: int = 0,
local_addr: SocketAddress = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.family = family
self.local_addr = local_addr

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -90,7 +94,7 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
ssl_context = self.ssl_context if scheme == b"https" else None
try:
return self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
hostname, port, ssl_context, timeout, self.family, self.local_addr
)
except Exception:
self.connect_failed = True
Expand Down
15 changes: 13 additions & 2 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .._backends.auto import SyncLock, SyncSemaphore, SyncBackend
from .._exceptions import PoolTimeout
from .._threadlock import ThreadLock
from .._types import URL, Headers, Origin, TimeoutDict
from .._types import URL, Headers, Origin, SocketAddress, TimeoutDict
from .._utils import get_logger, origin_to_url_string, url_to_origin
from .base import (
SyncByteStream,
Expand Down Expand Up @@ -76,6 +76,9 @@ class SyncConnectionPool(SyncHTTPTransport):
* **keepalive_expiry** - `Optional[float]` - The maximum time to allow
before closing a keep-alive connection.
* **http2** - `bool` - Enable HTTP/2 support.
* **family** - `int` - Address family to use, defaults to 0.
* **local_addr** - `Optional[SocketAddress]` - Local address to connect
from; requires family
"""

def __init__(
Expand All @@ -85,12 +88,16 @@ def __init__(
max_keepalive: int = None,
keepalive_expiry: float = None,
http2: bool = False,
family: int = 0,
local_addr: SocketAddress = None,
):
self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive = max_keepalive
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._family = family
self._local_addr = local_addr
self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = SyncBackend()
Expand Down Expand Up @@ -141,7 +148,11 @@ def request(

if connection is None:
connection = SyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context,
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
family=self._family,
local_addr=self._local_addr,
)
logger.trace("created connection=%r", connection)
self._add_to_pool(connection, timeout=timeout)
Expand Down
1 change: 1 addition & 0 deletions httpcore/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
URL = Tuple[bytes, bytes, Optional[int], bytes]
Headers = List[Tuple[bytes, bytes]]
TimeoutDict = Dict[str, Optional[float]]
SocketAddress = Tuple[Union[str, bytes, bytearray], int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be reasonable to keep this as tightly constrained as possible?
Presumably Tuple[bytes, int] is sufficient right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If source-port is removed, this type isn't needed at all, as the address could just be a bytes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, yup.

42 changes: 42 additions & 0 deletions tests/async_tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socket
import ssl
import typing

Expand Down Expand Up @@ -204,6 +205,47 @@ async def test_http_proxy(
assert reason == b"OK"


@pytest.mark.parametrize("family", [socket.AF_INET])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth just setting it below?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, this originally tested AF_INET6 also, but the test machine doesn't support IPv6.

@pytest.mark.asyncio
# This doesn't run with trio, since trio doesn't support family.
async def test_http_request_family(family: int,) -> None:
async with httpcore.AsyncConnectionPool(family=family) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
)
body = await read_body(stream)

assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
assert len(http._connections[url[:3]]) == 1 # type: ignore


@pytest.mark.parametrize("local_addr", ["0.0.0.0"])
@pytest.mark.asyncio
# This doesn't run with trio, since trio doesn't support local_addr.
async def test_http_request_local_addr(local_addr: str) -> None:
family = socket.AF_INET6 if ":" in local_addr else socket.AF_INET
async with httpcore.AsyncConnectionPool(
family=family, local_addr=(local_addr, 0)
) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
)
body = await read_body(stream)

assert http_version == b"HTTP/1.1"
assert status_code == 200
assert reason == b"OK"
assert len(http._connections[url[:3]]) == 1 # type: ignore


# mitmproxy does not support forwarding HTTPS requests
@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "TUNNEL_ONLY"])
@pytest.mark.usefixtures("async_environment")
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pytest
import trustme

from mitmproxy import options, proxy
from mitmproxy.tools.dump import DumpMaster

Expand Down
Loading