Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local_address= kwarg to open_tcp_stream #1644

Merged
merged 2 commits into from
Jun 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions newsfragments/275.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`trio.open_tcp_stream` has a new ``local_address=`` keyword argument
that can be used on machines with multiple IP addresses to control
which IP is used for the outgoing connection.
70 changes: 65 additions & 5 deletions trio/_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,7 @@ def format_host_port(host, port):
# AF_INET6: "..."}
# this might be simpler after
async def open_tcp_stream(
host,
port,
*,
# No trailing comma b/c bpo-9232 (fixed in py36)
happy_eyeballs_delay=DEFAULT_DELAY,
host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None,
):
"""Connect to the given host and port over TCP.

Expand Down Expand Up @@ -205,13 +201,30 @@ async def open_tcp_stream(
Args:
host (str or bytes): The host to connect to. Can be an IPv4 address,
IPv6 address, or a hostname.

port (int): The port to connect to.

happy_eyeballs_delay (float): How many seconds to wait for each
connection attempt to succeed or fail before getting impatient and
starting another one in parallel. Set to `math.inf` if you want
to limit to only one connection attempt at a time (like
:func:`socket.create_connection`). Default: 0.25 (250 ms).

local_address (None or str): The local IP address or hostname to use as
the source for outgoing connections. If ``None``, we let the OS pick
the source IP.

This is useful in some exotic networking configurations where your
host has multiple IP addresses, and you want to force the use of a
specific one.

Note that if you pass an IPv4 ``local_address``, then you won't be
able to connect to IPv6 hosts, and vice-versa. If you want to take
advantage of this to force the use of IPv4 or IPv6 without
specifying an exact source address, you can use the IPv4 wildcard
address ``local_address="0.0.0.0"``, or the IPv6 wildcard address
``local_address="::"``.

Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given server.

Expand Down Expand Up @@ -269,6 +282,53 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed):
sock = socket(*socket_args)
open_sockets.add(sock)

if local_address is not None:
# TCP connections are identified by a 4-tuple:
#
# (local IP, local port, remote IP, remote port)
#
# So if a single local IP wants to make multiple connections
# to the same (remote IP, remote port) pair, then those
# connections have to use different local ports, or else TCP
# won't be able to tell them apart. OTOH, if you have multiple
# connections to different remote IP/ports, then those
# connections can share a local port.
#
# Normally, when you call bind(), the kernel will immediately
# assign a specific local port to your socket. At this point
# the kernel doesn't know which (remote IP, remote port)
# you're going to use, so it has to pick a local port that
# *no* other connection is using. That's the only way to
# guarantee that this local port will be usable later when we
# call connect(). (Alternatively, you can set SO_REUSEADDR to
# allow multiple nascent connections to share the same port,
# but then connect() might fail with EADDRNOTAVAIL if we get
# unlucky and our TCP 4-tuple ends up colliding with another
# unrelated connection.)
#
# So calling bind() before connect() works, but it disables
# sharing of local ports. This is inefficient: it makes you
# more likely to run out of local ports.
#
# But on some versions of Linux, we can re-enable sharing of
# local ports by setting a special flag. This flag tells
# bind() to only bind the IP, and not the port. That way,
# connect() is allowed to pick the the port, and it can do a
# better job of it because it knows the remote IP/port.
try:
sock.setsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1
)
except (OSError, AttributeError):
pass
try:
await sock.bind((local_address, 0))
except OSError:
raise OSError(
f"local_address={local_address!r} is incompatible "
f"with remote address {sockaddr}"
)

await sock.connect(sockaddr)

# Success! Save the winning socket and cancel all outstanding
Expand Down
6 changes: 6 additions & 0 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,9 @@
TCP_NOTSENT_LOWAT = 0x201
elif _sys.platform == "linux":
TCP_NOTSENT_LOWAT = 25

try:
IP_BIND_ADDRESS_NO_PORT
except NameError:
if _sys.platform == "linux":
IP_BIND_ADDRESS_NO_PORT = 24
58 changes: 58 additions & 0 deletions trio/tests/test_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import sys
import socket

import attr

Expand Down Expand Up @@ -112,6 +114,62 @@ async def test_open_tcp_stream_input_validation():
await open_tcp_stream("127.0.0.1", b"80")


def can_bind_127_0_0_2():
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
return s.getsockname()[0] == "127.0.0.2"


async def test_local_address_real():
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()

# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
if can_bind_127_0_0_2():
local_address = "127.0.0.2"
else:
local_address = "127.0.0.1"

async with await open_tcp_stream(
*listener.getsockname(), local_address=local_address
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
)

server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
assert remote_addr[0] == local_address

# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(OSError):
await open_tcp_stream(*listener.getsockname(), local_address="::")

# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(), local_address="0.0.0.0"
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()


# Now, thorough tests using fake sockets


Expand Down