diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 9ea367bb425..b4065b492bd 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -375,10 +375,13 @@ def connect(self, req): # This connection will now count towards the limit. waiters = self._waiters[key] waiters.append(fut) - yield from fut - waiters.remove(fut) - if not waiters: - del self._waiters[key] + try: + yield from fut + finally: + # remove a waiter even if it was cancelled + waiters.remove(fut) + if not waiters: + del self._waiters[key] proto = self._get(key) if proto is None: @@ -390,6 +393,13 @@ def connect(self, req): if self._closed: proto.close() raise ClientConnectionError("Connector is closed.") + except: + # signal to waiter + for waiter in self._waiters[key]: + if not waiter.done(): + waiter.set_result(None) + break + raise finally: if not self._closed: self._acquired.remove(placeholder) diff --git a/tests/test_connector.py b/tests/test_connector.py index 2cae8b8d768..68167174e68 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1173,6 +1173,100 @@ def test_force_close_and_explicit_keep_alive(loop): assert conn +@asyncio.coroutine +def test_error_on_connection(loop): + conn = aiohttp.BaseConnector(limit=1, loop=loop) + + req = mock.Mock() + req.connection_key = 'key' + proto = mock.Mock() + i = 0 + + fut = helpers.create_future(loop=loop) + exc = OSError() + + @asyncio.coroutine + def create_connection(req): + nonlocal i + i += 1 + if i == 1: + yield from fut + raise exc + elif i == 2: + return proto + + conn._create_connection = create_connection + + t1 = helpers.ensure_future(conn.connect(req), loop=loop) + t2 = helpers.ensure_future(conn.connect(req), loop=loop) + yield from asyncio.sleep(0, loop=loop) + assert not t1.done() + assert not t2.done() + assert len(conn._acquired_per_host['key']) == 1 + + fut.set_result(None) + with pytest.raises(OSError): + yield from t1 + + ret = yield from t2 + assert len(conn._acquired_per_host['key']) == 1 + + assert ret._key == 'key' + assert ret.protocol == proto + assert proto in conn._acquired + + +@asyncio.coroutine +def test_error_on_connection_with_cancelled_waiter(loop): + conn = aiohttp.BaseConnector(limit=1, loop=loop) + + req = mock.Mock() + req.connection_key = 'key' + proto = mock.Mock() + i = 0 + + fut1 = helpers.create_future(loop=loop) + fut2 = helpers.create_future(loop=loop) + exc = OSError() + + @asyncio.coroutine + def create_connection(req): + nonlocal i + i += 1 + if i == 1: + yield from fut1 + raise exc + if i == 2: + yield from fut2 + elif i == 3: + return proto + + conn._create_connection = create_connection + + t1 = helpers.ensure_future(conn.connect(req), loop=loop) + t2 = helpers.ensure_future(conn.connect(req), loop=loop) + t3 = helpers.ensure_future(conn.connect(req), loop=loop) + yield from asyncio.sleep(0, loop=loop) + assert not t1.done() + assert not t2.done() + assert len(conn._acquired_per_host['key']) == 1 + + fut1.set_result(None) + fut2.cancel() + with pytest.raises(OSError): + yield from t1 + + with pytest.raises(asyncio.CancelledError): + yield from t2 + + ret = yield from t3 + assert len(conn._acquired_per_host['key']) == 1 + + assert ret._key == 'key' + assert ret.protocol == proto + assert proto in conn._acquired + + @asyncio.coroutine def test_tcp_connector(test_client, loop): @asyncio.coroutine