diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 128954dc52..0e27dd2635 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -140,6 +140,10 @@ abstraction. .. autofunction:: socket_stream_pair +.. autofunction:: open_tcp_stream + +.. autofunction:: open_ssl_over_tcp_stream + SSL / TLS support ~~~~~~~~~~~~~~~~~ diff --git a/trio/__init__.py b/trio/__init__.py index be94d9a873..5a4da0da8b 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -56,6 +56,12 @@ from ._path import * __all__ += _path.__all__ +from ._open_tcp_stream import * +__all__ += _open_tcp_stream.__all__ + +from ._ssl_stream_helpers import * +__all__ += _ssl_stream_helpers.__all__ + # Imported by default from . import socket from . import abc diff --git a/trio/_open_tcp_stream.py b/trio/_open_tcp_stream.py new file mode 100644 index 0000000000..3df5c45ad0 --- /dev/null +++ b/trio/_open_tcp_stream.py @@ -0,0 +1,294 @@ +from contextlib import contextmanager + +import trio +from trio.socket import getaddrinfo, SOCK_STREAM, socket + +__all__ = ["open_tcp_stream"] + +# Implementation of RFC 6555 "Happy eyeballs" +# https://tools.ietf.org/html/rfc6555 +# +# Basically, the problem here is that if we want to connect to some host, and +# DNS returns multiple IP addresses, then we don't know which of them will +# actually work -- it can happen that some of them are reachable, and some of +# them are not. One particularly common situation where this happens is on a +# host that thinks it has ipv6 connectivity, but really doesn't. But in +# principle this could happen for any kind of multi-home situation (e.g. the +# route to one mirror is down but another is up). +# +# The naive algorithm (e.g. the stdlib's socket.create_connection) would be to +# pick one of the IP addresses and try to connect; if that fails, try the +# next; etc. The problem with this is that TCP is stubborn, and if the first +# address is a blackhole then it might take a very long time (tens of seconds) +# before that connection attempt fails. +# +# That's where RFC 6555 comes in. It tells us that what we do is: +# - get the list of IPs from getaddrinfo, trusting the order it gives us (with +# one exception noted in section 5.4) +# - start a connection attempt to the first IP +# - when this fails OR if it's still going after DELAY seconds, then start a +# connection attempt to the second IP +# - when this fails OR if it's still going after another DELAY seconds, then +# start a connection attempt to the third IP +# - ... repeat until we run out of IPs. +# +# Our implementation is similarly straightforward: we spawn a chain of tasks, +# where each one (a) waits until the previous connection has failed or DELAY +# seconds have passed, (b) spawns the next task, (c) attempts to connect. As +# soon as any task crashes or succeeds, we cancel all the tasks and return. +# +# Note: this currently doesn't attempt to cache any results, so if you make +# multiple connections to the same host it'll re-run the happy-eyeballs +# algorithm each time. RFC 6555 is pretty confusing about whether this is +# allowed. Section 4 describes an algorithm that attempts ipv4 and ipv6 +# simultaneously, and then says "The client MUST cache information regarding +# the outcome of each connection attempt, and it uses that information to +# avoid thrashing the network with subsequent attempts." Then section 4.2 says +# "implementations MUST prefer the first IP address family returned by the +# host's address preference policy, unless implementing a stateful +# algorithm". Here "stateful" means "one that caches information about +# previous attempts". So my reading of this is that IF you're starting ipv4 +# and ipv6 at the same time then you MUST cache the result for ~ten minutes, +# but IF you're "preferring" one protocol by trying it first (like we are), +# then you don't need to cache. +# +# Caching is quite tricky: to get it right you need to do things like detect +# when the network interfaces are reconfigured, and if you get it wrong then +# connection attempts basically just don't work. So we don't even try. + +# "Firefox and Chrome use 300 ms" +# https://tools.ietf.org/html/rfc6555#section-6 +# Though +# https://www.researchgate.net/profile/Vaibhav_Bajpai3/publication/304568993_Measuring_the_Effects_of_Happy_Eyeballs/links/5773848e08ae6f328f6c284c/Measuring-the-Effects-of-Happy-Eyeballs.pdf +# claims that Firefox actually uses 0 ms, unless an about:config option is +# toggled and then it uses 250 ms. +DEFAULT_DELAY = 0.300 + +# How should we call getaddrinfo? In particular, should we use AI_ADDRCONFIG? +# +# The idea of AI_ADDRCONFIG is that it only returns addresses that might +# work. E.g., if getaddrinfo knows that you don't have any IPv6 connectivity, +# then it doesn't return any IPv6 addresses. And this is kinda nice, because +# it means maybe you can skip sending AAAA requests entirely. But in practice, +# it doesn't really work right. +# +# - on Linux/glibc, empirically, the default is to return all addresses, and +# with AI_ADDRCONFIG then it only returns IPv6 addresses if there is at least +# one non-loopback IPv6 address configured... but this can be a link-local +# address, so in practice I guess this is basically always configured if IPv6 +# is enabled at all. OTOH if you pass in "::1" as the target address with +# AI_ADDRCONFIG and there's no *external* IPv6 address configured, you get an +# error. So AI_ADDRCONFIG mostly doesn't do anything, even when you would want +# it to, and when it does do something it might break things that would have +# worked. +# +# - on Windows 10, empirically, if no IPv6 address is configured then by +# default they are also suppressed from getaddrinfo (flags=0 and +# flags=AI_ADDRCONFIG seem to do the same thing). If you pass AI_ALL, then you +# get the full list. +# ...except for localhost! getaddrinfo("localhost", "80") gives me ::1, even +# though there's no ipv6 and other queries only return ipv4. +# If you pass in and IPv6 IP address as the target address, then that's always +# returned OK, even with AI_ADDRCONFIG set and no IPv6 configured. +# +# But I guess other versions of windows messed this up, judging from these bug +# reports: +# https://bugs.chromium.org/p/chromium/issues/detail?id=5234 +# https://bugs.chromium.org/p/chromium/issues/detail?id=32522#c50 +# +# So basically the options are either to use AI_ADDRCONFIG and then add some +# complicated special cases to work around its brokenness, or else don't use +# AI_ADDRCONFIG and accept that sometimes on legacy/misconfigured networks +# we'll waste 300 ms trying to connect to a blackholed destination. +# +# Twisted and Tornado always uses default flags. I think we'll do the same. + +@contextmanager +def close_on_error(obj): + try: + yield obj + except: + obj.close() + raise + + +def reorder_for_rfc_6555_section_5_4(targets): + # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address + # families (e.g. IPv4 and IPv6), then you should make sure that your first + # and second attempts use different families: + # + # https://tools.ietf.org/html/rfc6555#section-5.4 + # + # This function post-processes the results from getaddrinfo, in-place, to + # satisfy this requirement. + for i in range(1, len(targets)): + if targets[i][0] != targets[0][0]: + # Found the first entry with a different address family; move it + # so that it becomes the second item on the list. + if i != 1: + targets.insert(1, targets.pop(i)) + break + + +def format_host_port(host, port): + if ":" in host: + return "[{}]:{}".format(host, port) + else: + return "{}:{}".format(host, port) + + +# Twisted's HostnameEndpoint has a good set of configurables: +# https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.HostnameEndpoint.html +# +# - per-connection timeout +# this doesn't seem useful -- we let you set a timeout on the whole thing +# using trio's normal mechanisms, and that seems like enough +# - delay between attempts +# - bind address (but not port!) +# they *don't* support multiple address bindings, like giving the ipv4 and +# ipv6 addresses of the host. +# I think maybe our semantics should be: we accept a list of bind addresses, +# and we bind to the first one that is compatible with the +# connection attempt we want to make, and if none are compatible then we +# don't try to connect to that target. +# +# XX TODO: implement bind address support +# +# Actually, the best option is probably to be explicit: {AF_INET: "...", +# AF_INET6: "..."} +# this might be simpler after +async def open_tcp_stream( + host, port, + *, + # No trailing comma b/c bpo-9232 (fixed in py36) + happy_eyeballs_delay=DEFAULT_DELAY + ): + """Connect to the given host and port over TCP. + + If the given ``host`` has multiple IP addresses associated with it, then + we have a problem: which one do we use? One approach would be to attempt + to connect to the first one, and then if that fails, attempt to connect to + the second one ... until we've tried all of them. The problem with this is + that if the first IP address is unreachable (for example, because it's an + IPv6 address and our network discards IPv6 packets), then we might end up + waiting tens of seconds for the first connection attempt to timeout before + we try the second address. Another approach would be to attempt to connect + to all of the addresses at the same time, and then use whichever address + succeeds first. This will be fast, but it creates a lot of unnecessary + network load. + + This function strikes a balance between these two extremes: it works its + way through the available addresses in sequence, like the first approach; + but, if an attempt hasn't succeeded or failed after + ``happy_eyeballs_delay`` seconds, then it gets impatient and starts the + next connection attempt in parallel. As soon as any one connection attempt + succeeds, all the other attempts are cancelled. This way most connections + involve minimal network load, but if one of the addresses is unreachable + then it doesn't slow us down too much. + + This is a "happy eyeballs" algorithm, and roughly matches what Chrome + does; see `RFC 6555 `__. + + Args: + host (bytes or str): The host to connect to. Can be an IPv4 address, + IPv6 address, or a hostname. + port (int): The port to connect to. + happy_eyeballs_delay (float): How many seconds to wait for each + connection attempt to succeed or fail before getting impatient and + starting another one in parallel. Set to :obj:`math.inf` if you want + to limit to only one connection attempt at a time (like + :func:`socket.create_connection`). Default: 0.3 (300 ms). + + Returns: + SocketStream: a :class:`~trio.abc.Stream` connected to the given server. + + Raises: + OSError: if the connection fails. + + See also: + open_ssl_over_tcp_stream + + """ + + if happy_eyeballs_delay is None: + happy_eyeballs_delay = DEFAULT_DELAY + + targets = await getaddrinfo(host, port, type=SOCK_STREAM) + + # I don't think this can actually happen -- if there are no results, + # getaddrinfo should have raised OSError instead of returning an empty + # list. But let's be paranoid and handle it anyway: + if not targets: + msg = "no results found for hostname lookup: {}".format( + format_host_port(host, port)) + raise OSError(msg) + + reorder_for_rfc_6555_section_5_4(targets) + + targets_iter = iter(targets) + + # This list records all the connection failures that we ignored. + oserrors = [] + + # It's possible for multiple connection attempts to succeed at the ~same + # time; this list records all successful connections. + winning_sockets = [] + + # Sleep for the given amount of time, then kick off the next task and + # start a connection attempt. On failure, expedite the next task; on + # success, kill everything. Possible outcomes: + # - records a failure in oserrors, returns None + # - records a connected socket in winning_sockets, returns None + # - crash (raises an unexpected exception) + async def attempt_connect(nursery, previous_attempt_failed): + # Wait until either the previous attempt failed, or the timeout + # expires (unless this is the first invocation, in which case we just + # go ahead). + if previous_attempt_failed is not None: + with trio.move_on_after(happy_eyeballs_delay): + await previous_attempt_failed.wait() + + # Claim our target. + try: + *socket_args, _, target_sockaddr = next(targets_iter) + except StopIteration: + return + + # Then kick off the next attempt. + this_attempt_failed = trio.Event() + nursery.spawn(attempt_connect, nursery, this_attempt_failed) + + # Then make this invocation's attempt + try: + with close_on_error(socket(*socket_args)) as sock: + await sock.connect(target_sockaddr) + except OSError as exc: + # This connection attempt failed, but the next one might + # succeed. Save the error for later so we can report it if + # everything fails, and tell the next attempt that it should go + # ahead (if it hasn't already). + oserrors.append(exc) + this_attempt_failed.set() + else: + # Success! Save the winning socket and cancel all outstanding + # connection attempts. + winning_sockets.append(sock) + nursery.cancel_scope.cancel() + + # Kick off the chain of connection attempts. + async with trio.open_nursery() as nursery: + nursery.spawn(attempt_connect, nursery, None) + + # All connection attempts complete, and no unexpected errors escaped. So + # at this point the oserrors and winning_sockets lists are filled in. + + if winning_sockets: + first_prize = winning_sockets.pop(0) + for sock in winning_sockets: + sock.close() + return trio.SocketStream(first_prize) + else: + assert len(oserrors) == len(targets) + msg = "all attempts to connect to {} failed".format( + format_host_port(host, port)) + raise OSError(msg) from trio.MultiError(oserrors) diff --git a/trio/_ssl_stream_helpers.py b/trio/_ssl_stream_helpers.py new file mode 100644 index 0000000000..56b35ad9d5 --- /dev/null +++ b/trio/_ssl_stream_helpers.py @@ -0,0 +1,65 @@ +import trio + +from ._open_tcp_stream import DEFAULT_DELAY + +__all__ = ["open_ssl_over_tcp_stream"] + +# It might have been nice to take a ssl_protocols= argument here to set up +# NPN/ALPN, but to do this we have to mutate the context object, which is OK +# if it's one we created, but not OK if it's one that was passed in... and +# the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use +# a specially configured SSLContext anyway! I also thought maybe we could copy +# the given SSLContext and then mutate the copy, but it's no good: +# copy.copy(SSLContext) seems to succeed, but the state is not transferred! +# For example, with CPython 3.5, we have: +# ctx = ssl.create_default_context() +# assert ctx.check_hostname == True +# assert copy.copy(ctx).check_hostname == False +# So... let's punt on that for now. Hopefully we'll be getting a new Python +# TLS API soon and can revisit this then. +async def open_ssl_over_tcp_stream( + host, + port, + *, + https_compatible=False, + ssl_context=None, + # No trailing comma b/c bpo-9232 (fixed in py36) + happy_eyeballs_delay=DEFAULT_DELAY + ): + """Make a TLS-encrypted Connection to the given host and port over TCP. + + This is a convenience wrapper that calls :func:`open_tcp_stream` and + wraps the result in an :class:`~trio.ssl.SSLStream`. + + This function does not perform the TLS handshake; you can do it + manually by calling :meth:`~trio.ssl.SSLStream.do_handshake`, or else + it will be performed automatically the first time you send or receive + data. + + Args: + host (bytes or str): The host to connect to. We require the server + to have a TLS certificate valid for this hostname. + port (int): The port to connect to. + https_compatible (bool): Set this to True if you're connecting to a web + server. See :class:`~trio.ssl.SSLStream` for details. Default: + False. + ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to + use. If None (the default), :func:`ssl.create_default_context` + will be called to create a context. + happy_eyeballs_delay (float): See :func:`open_tcp_stream`. + + Returns: + trio.ssl.SSLStream: the encrypted connection to the server. + + """ + tcp_stream = await trio.open_tcp_stream( + host, port, happy_eyeballs_delay=happy_eyeballs_delay, + ) + if ssl_context is None: + ssl_context = trio.ssl.create_default_context() + return trio.ssl.SSLStream( + tcp_stream, + ssl_context, + server_hostname=host, + https_compatible=https_compatible, + ) diff --git a/trio/ssl.py b/trio/ssl.py index 0c719a302f..322abe57bf 100644 --- a/trio/ssl.py +++ b/trio/ssl.py @@ -306,7 +306,7 @@ class SSLStream(_Stream): :class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and attributes are re-exported as methods and attributes on this class. - This also means that you register a SNI callback using + This also means that if you register a SNI callback using :meth:`~ssl.SSLContext.set_servername_callback`, then the first argument your callback receives will be a :class:`ssl.SSLObject`. diff --git a/trio/tests/conftest.py b/trio/tests/conftest.py index 2a79b3bec8..b8169c0c3a 100644 --- a/trio/tests/conftest.py +++ b/trio/tests/conftest.py @@ -17,6 +17,10 @@ def pytest_addoption(parser): def mock_clock(): return MockClock() +@pytest.fixture +def autojump_clock(): + return MockClock(autojump_threshold=0) + # FIXME: split off into a package (or just make part of trio's public # interface?), with config file to enable? and I guess a mark option too; I # guess it's useful with the class- and file-level marking machinery (where diff --git a/trio/tests/test_open_tcp_stream.py b/trio/tests/test_open_tcp_stream.py new file mode 100644 index 0000000000..f0eb464504 --- /dev/null +++ b/trio/tests/test_open_tcp_stream.py @@ -0,0 +1,439 @@ +import pytest + +import attr + +import trio +from trio._open_tcp_stream import ( + reorder_for_rfc_6555_section_5_4, close_on_error, open_tcp_stream, + format_host_port, +) + +def test_close_on_error(): + class CloseMe: + closed = False + + def close(self): + self.closed = True + + with close_on_error(CloseMe()) as c: + pass + assert not c.closed + + with pytest.raises(RuntimeError): + with close_on_error(CloseMe()) as c: + raise RuntimeError + assert c.closed + + +def test_reorder_for_rfc_6555_section_5_4(): + def fake4(i): + return (trio.socket.AF_INET, + trio.socket.SOCK_STREAM, + trio.socket.IPPROTO_TCP, + "", + ("10.0.0.{}".format(i), 80)) + + def fake6(i): + return (trio.socket.AF_INET6, + trio.socket.SOCK_STREAM, + trio.socket.IPPROTO_TCP, + "", + ("::{}".format(i), 80)) + + for fake in fake4, fake6: + # No effect on homogenous lists + targets = [fake(0), fake(1), fake(2)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake(0), fake(1), fake(2)] + + # Single item lists also OK + targets = [fake(0)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake(0)] + + # If the list starts out with different families in positions 0 and 1, + # then it's left alone + orig = [fake4(0), fake6(0), fake4(1), fake6(1)] + targets = list(orig) + reorder_for_rfc_6555_section_5_4(targets) + assert targets == orig + + # If not, it's reordered + targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] + + +def test_format_host_port(): + assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" + assert format_host_port("example.com", 443) == "example.com:443" + assert format_host_port("::1", "http") == "[::1]:http" + + +# Make sure we can connect to localhost using real kernel sockets +async def test_open_tcp_stream_real_socket_smoketest(): + listen_sock = trio.socket.socket() + listen_sock.bind(("127.0.0.1", 0)) + _, listen_port = listen_sock.getsockname() + listen_sock.listen(1) + client_stream = await open_tcp_stream("127.0.0.1", listen_port) + server_sock, _ = await listen_sock.accept() + await client_stream.send_all(b"x") + assert await server_sock.recv(1) == b"x" + client_stream.forceful_close() + server_sock.close() + + +# Now, thorough tests using fake sockets + +@attr.s +class FakeSocket: + scenario = attr.ib() + family = attr.ib() + type = attr.ib() + proto = attr.ib() + + ip = attr.ib(default=None) + port = attr.ib(default=None) + succeeded = attr.ib(default=False) + closed = attr.ib(default=False) + + async def connect(self, sockaddr): + self.ip = sockaddr[0] + self.port = sockaddr[1] + assert self.ip not in self.scenario.sockets + self.scenario.sockets[self.ip] = self + self.scenario.connect_times[self.ip] = trio.current_time() + delay, result = self.scenario.ip_dict[self.ip] + await trio.sleep(delay) + if result == "error": + raise OSError("sorry") + self.succeeded = True + + def close(self): + self.closed = True + + # Some stubs to stop SocketStream from complaining: + def setsockopt(self, *args, **kwargs): + pass + + def getpeername(self): + return (self.ip, self.port) + + +class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): + def __init__(self, port, ip_list, ipv6_supported): + # ip_list have to be unique + ip_order = [ip for (ip, _, _) in ip_list] + assert len(set(ip_order)) == len(ip_list) + ip_dict = {} + for ip, delay, result in ip_list: + assert 0 <= delay + assert result in ["error", "success"] + ip_dict[ip] = (delay, result) + + self.port = port + self.ip_order = ip_order + self.ip_dict = ip_dict + self.ipv6_supported = ipv6_supported + self.socket_count = 0 + self.sockets = {} + self.connect_times = {} + + def socket(self, family, type, proto): + if not self.ipv6_supported and family == trio.socket.AF_INET6: + raise OSError("pretending not to support ipv6") + self.socket_count += 1 + return FakeSocket(self, family, type, proto) + + def is_trio_socket(self, obj): + return isinstance(obj, FakeSocket) + + def _ip_to_gai_entry(self, ip): + if ":" in ip: + family = trio.socket.AF_INET6 + sockaddr = (ip, self.port, 0, 0) + else: + family = trio.socket.AF_INET + sockaddr = (ip, self.port) + return (family, + trio.socket.SOCK_STREAM, + trio.socket.IPPROTO_TCP, + "", + sockaddr) + + async def getaddrinfo(self, host, port, family, type, proto, flags): + assert host == b"test.example.com" + assert port == self.port + assert family == trio.socket.AF_UNSPEC + assert type == trio.socket.SOCK_STREAM + assert proto == 0 + assert flags == 0 + return [self._ip_to_gai_entry(ip) for ip in self.ip_order] + + async def getnameinfo(self, sockaddr, flags): # pragma: no cover + raise NotImplementedError + + def check(self, succeeded): + # sockets only go into self.sockets when connect is called; make sure + # all the sockets that were created did in fact go in there. + assert self.socket_count == len(self.sockets) + + for ip, socket in self.sockets.items(): + assert ip in self.ip_dict + if socket is not succeeded: + assert socket.closed + assert socket.port == self.port + + +async def run_scenario( + # The port to connect to + port, + # A list of + # (ip, delay, result) + # tuples, where delay is in seconds and result is "success" or "error" + # The ip's will be returned from getaddrinfo in this order, and then + # connect() calls to them will have the given result. + ip_list, + *, + # If False, AF_INET6 sockets error out on creation, before connect is + # even called. + ipv6_supported=True, + # Normally, we return (winning_sock, scenario object) + # If this is True, we require there to be an exception, and return + # (exception, scenario object) + expect_error=(), + **kwargs +): + scenario = Scenario(port, ip_list, ipv6_supported) + trio.socket.set_custom_hostname_resolver(scenario) + trio.socket.set_custom_socket_factory(scenario) + + try: + stream = await open_tcp_stream("test.example.com", port, **kwargs) + assert expect_error == () + scenario.check(stream.socket) + return (stream.socket, scenario) + except AssertionError: # pragma: no cover + raise + except expect_error as exc: + scenario.check(None) + return (exc, scenario) + +async def test_one_host_quick_success(autojump_clock): + sock, scenario = await run_scenario( + 80, [("1.2.3.4", 0.123, "success")]) + assert sock.ip == "1.2.3.4" + assert trio.current_time() == 0.123 + + +async def test_one_host_slow_success(autojump_clock): + sock, scenario = await run_scenario( + 81, [("1.2.3.4", 100, "success")]) + assert sock.ip == "1.2.3.4" + assert trio.current_time() == 100 + + +async def test_one_host_quick_fail(autojump_clock): + exc, scenario = await run_scenario( + 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError) + assert isinstance(exc, OSError) + assert trio.current_time() == 0.123 + + +async def test_one_host_slow_fail(autojump_clock): + exc, scenario = await run_scenario( + 83, [("1.2.3.4", 100, "error")], expect_error=OSError) + assert isinstance(exc, OSError) + assert trio.current_time() == 100 + + +# With the default 0.300 second delay, the third attempt will win +async def test_basic_fallthrough(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 1, "success"), + ("2.2.2.2", 1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + ) + assert sock.ip == "3.3.3.3" + assert trio.current_time() == (0.300 + 0.300 + 0.2) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.300, + "3.3.3.3": 0.600, + } + + +async def test_early_success(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 1, "success"), + ("2.2.2.2", 0.1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + ) + assert sock.ip == "2.2.2.2" + assert trio.current_time() == (0.300 + 0.1) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.300, + # 3.3.3.3 was never even started + } + + +# With a 0.450 second delay, the first attempt will win +async def test_custom_delay(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 1, "success"), + ("2.2.2.2", 1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + happy_eyeballs_delay=0.450, + ) + assert sock.ip == "1.1.1.1" + assert trio.current_time() == 1 + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.450, + "3.3.3.3": 0.900, + } + + +async def test_custom_errors_expedite(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 0.1, "error"), + ("2.2.2.2", 0.2, "error"), + ("3.3.3.3", 10, "success"), + ("4.4.4.4", 0.3, "success"), + ], + ) + assert sock.ip == "4.4.4.4" + assert trio.current_time() == (0.1 + 0.2 + 0.3 + 0.3) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.1, + "3.3.3.3": 0.1 + 0.2, + "4.4.4.4": 0.1 + 0.2 + 0.3, + } + + +async def test_all_fail(autojump_clock): + exc, scenario = await run_scenario( + 80, + [("1.1.1.1", 0.1, "error"), + ("2.2.2.2", 0.2, "error"), + ("3.3.3.3", 10, "error"), + ("4.4.4.4", 0.3, "error"), + ], + expect_error=OSError, + ) + assert isinstance(exc, OSError) + assert isinstance(exc.__cause__, trio.MultiError) + assert len(exc.__cause__.exceptions) == 4 + assert trio.current_time() == (0.1 + 0.2 + 10) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.1, + "3.3.3.3": 0.1 + 0.2, + "4.4.4.4": 0.1 + 0.2 + 0.3, + } + + +async def test_multi_success(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 0.5, "error"), + ("2.2.2.2", 10, "success"), + ("3.3.3.3", 10 - 1, "success"), + ("4.4.4.4", 10 - 2, "success"), + ("5.5.5.5", 0.5, "error"), + ], + happy_eyeballs_delay=1, + ) + assert not scenario.sockets["1.1.1.1"].succeeded + assert scenario.sockets["2.2.2.2"].succeeded + assert scenario.sockets["3.3.3.3"].succeeded + assert scenario.sockets["4.4.4.4"].succeeded + assert not scenario.sockets["5.5.5.5"].succeeded + assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"] + assert trio.current_time() == (0.5 + 10) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.5, + "3.3.3.3": 1.5, + "4.4.4.4": 2.5, + "5.5.5.5": 3.5, + } + + +async def test_does_reorder(autojump_clock): + sock, scenario = await run_scenario( + 80, + [("1.1.1.1", 10, "error"), + # This would win if we tried it first... + ("2.2.2.2", 1, "success"), + # But in fact we try this first, because of section 5.4 + ("::3", 0.5, "success"), + ], + happy_eyeballs_delay=1, + ) + assert sock.ip == "::3" + assert trio.current_time() == 1 + 0.5 + assert scenario.connect_times == { + "1.1.1.1": 0, + "::3": 1, + } + + +async def test_handles_no_ipv6(autojump_clock): + sock, scenario = await run_scenario( + 80, + # Here the ipv6 addresses fail at socket creation time, so the connect + # configuration doesn't matter + [("::1", 0, "success"), + ("2.2.2.2", 10, "success"), + ("::3", 0, "success"), + ("4.4.4.4", 0.1, "success"), + ], + happy_eyeballs_delay=1, + ipv6_supported=False, + ) + assert sock.ip == "4.4.4.4" + assert trio.current_time() == 1 + 0.1 + assert scenario.connect_times == { + "2.2.2.2": 0, + "4.4.4.4": 1.0, + } + + +async def test_no_hosts(autojump_clock): + exc, scenario = await run_scenario(80, [], expect_error=OSError) + assert "no results found" in str(exc) + + +async def test_cancel(autojump_clock): + with trio.move_on_after(5) as cancel_scope: + exc, scenario = await run_scenario( + 80, + [("1.1.1.1", 10, "success"), + ("2.2.2.2", 10, "success"), + ("3.3.3.3", 10, "success"), + ("4.4.4.4", 10, "success"), + ], + expect_error=trio.MultiError, + ) + # What comes out should be 1 or more Cancelled errors that all belong + # to this cancel_scope; this is the easiest way to check that + raise exc + assert cancel_scope.cancelled_caught + + assert trio.current_time() == 5 + + # This should have been called already, but just to make sure, since the + # exception-handling logic in run_scenario is a bit complicated and the + # main thing we care about here is that all the sockets were cleaned up. + scenario.check(succeeded=False) diff --git a/trio/tests/test_ssl_helpers.py b/trio/tests/test_ssl_helpers.py new file mode 100644 index 0000000000..d062869499 --- /dev/null +++ b/trio/tests/test_ssl_helpers.py @@ -0,0 +1,128 @@ +import pytest + +import attr + +import trio +import trio.testing +from .._util import acontextmanager +from .test_ssl import CLIENT_CTX, SERVER_CTX + +from .._ssl_stream_helpers import open_ssl_over_tcp_stream + +# this would be much simpler with a real fake network +# or just having trustme support for IP addresses so I could try connecting to +# 127.0.0.1 + +# Need to at least check making a successful connection, and making +# connections that fail CA and hostname validation. +# +# Also custom context and https_compatible I guess, though there isn't a whole +# lot that could go wrong here. Probably don't need to test +# happy_eyeballs_delay separately. + +@attr.s +class FakeSocket: + stream = attr.ib() + + async def connect(self, sockaddr): + pass + + async def sendall(self, data): + await self.stream.send_all(data) + + async def recv(self, max_bytes): + return await self.stream.receive_some(max_bytes) + + def close(self): + self.stream.forceful_close() + + # Stubs to make SocketStream happy: + def setsockopt(self, *args, **kwargs): + pass + + def getpeername(self, *args): + pass + + type = trio.socket.SOCK_STREAM + did_shutdown_SHUT_WR = False + + +# No matter who you connect to, you end up talking to an echo server with a +# cert for trio-test-1.example.com. +@attr.s +class FakeNetwork(trio.abc.HostnameResolver, trio.abc.SocketFactory): + nursery = attr.ib() + + async def getaddrinfo(self, *args): + return [(trio.socket.AF_INET, + trio.socket.SOCK_STREAM, + trio.socket.IPPROTO_TCP, + "", + ("1.1.1.1", 443))] + + async def getnameinfo(self, *args): # pragma: no cover + raise NotImplementedError + + def is_trio_socket(self, obj): + return isinstance(obj, FakeSocket) + + def socket(self, family, type, proto): + client_stream, server_stream = trio.testing.memory_stream_pair() + self.nursery.spawn(self.echo_server, server_stream) + return FakeSocket(client_stream) + + async def echo_server(self, raw_server_stream): + ssl_server_stream = trio.ssl.SSLStream( + raw_server_stream, + SERVER_CTX, + server_side=True, + ) + while True: + data = await ssl_server_stream.receive_some(10000) + if not data: + break + await ssl_server_stream.send_all(data) + + +async def test_open_ssl_over_tcp_stream(): + async with trio.open_nursery() as nursery: + network = FakeNetwork(nursery) + trio.socket.set_custom_hostname_resolver(network) + trio.socket.set_custom_socket_factory(network) + + # We don't have the right trust set up + # (checks that ssl_context=None is doing some validation) + stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80) + with pytest.raises(trio.BrokenStreamError): + await stream.do_handshake() + + # We have the trust but not the hostname + # (checks custom ssl_context + hostname checking) + stream = await open_ssl_over_tcp_stream( + "xyzzy.example.org", 80, ssl_context=CLIENT_CTX, + ) + with pytest.raises(trio.BrokenStreamError): + await stream.do_handshake() + + # This one should work! + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", 80, + ssl_context=CLIENT_CTX, + ) + await stream.send_all(b"x") + assert await stream.receive_some(1) == b"x" + await stream.graceful_close() + + # Check https_compatible settings are being passed through + assert not stream._https_compatible + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", 80, + ssl_context=CLIENT_CTX, + https_compatible=True, + # also, smoke test happy_eyeballs_delay + happy_eyeballs_delay=1, + ) + assert stream._https_compatible + + # We've left abandoned server tasks behind; clean them up. + nursery.cancel_scope.cancel()