Skip to content

Commit

Permalink
Fix #1985: Keep a reference to ClientSession in response object
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jun 19, 2017
1 parent 22c8937 commit 53141f3
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 62 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Changes

- Fix BadStatusLine caused by extra `CRLF` after `POST` data #1792

- Keep a reference to ClientSession in response object #1985


2.1.0 (2017-05-26)
------------------
Expand Down
10 changes: 2 additions & 8 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def _request(self, method, url, *,
compress=compress, chunked=chunked,
expect100=expect100, loop=self._loop,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer)
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self)

# connection timeout
try:
Expand Down Expand Up @@ -688,13 +689,6 @@ def __await__(self):
self._session.close()
raise

def __del__(self):
# in case of "resp = aiohttp.request(...)"
# _SessionRequestContextManager get destroyed before resp get processed
# and connection has to stay alive during this time
# ClientSession.detach just cleans up connector attribute
self._session.detach()


def request(method, url, *,
params=None,
Expand Down
10 changes: 6 additions & 4 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def __init__(self, method, url, *,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None, proxy_from_env=False,
timer=None):
timer=None, session=None):

if loop is None:
loop = asyncio.get_event_loop()

assert isinstance(url, URL), url
assert isinstance(proxy, (URL, type(None))), proxy

self._session = session
if params:
q = MultiDict(url.query)
url2 = url.with_query(params)
Expand Down Expand Up @@ -409,7 +409,7 @@ def send(self, conn):
request_info=self.request_info
)

self.response._post_init(self.loop)
self.response._post_init(self.loop, self._session)
return self.response

@asyncio.coroutine
Expand Down Expand Up @@ -446,6 +446,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_loop = None
_closed = True # to allow __del__ for non-initialized properly response
_session = None

def __init__(self, method, url, *,
writer=None, continue100=None, timer=None,
Expand Down Expand Up @@ -487,8 +488,9 @@ def _headers(self):
def request_info(self):
return self._request_info

def _post_init(self, loop):
def _post_init(self, loop, session):
self._loop = loop
self._session = session # store a reference to session #1985
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_client_protocol_readuntil_eof(loop):
proto.data_received(b'HTTP/1.1 200 Ok\r\n\r\n')

response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, mock.Mock())
yield from response.start(conn, read_until_eof=True)

assert not response.content.is_eof()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def send(self, conn):
self.url,
writer=self._writer,
continue100=self._continue)
resp._post_init(self.loop)
resp._post_init(self.loop, mock.Mock())
self.response = resp
nonlocal called
called = True
Expand Down
101 changes: 53 additions & 48 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
from aiohttp.client_reqrep import ClientResponse, RequestInfo


@pytest.fixture
def session():
return mock.Mock()


@asyncio.coroutine
def test_http_processing_error():
def test_http_processing_error(session):
loop = mock.Mock()
request_info = mock.Mock()
response = ClientResponse(
'get', URL('http://del-cl-resp.org'), request_info=request_info)
response._post_init(loop)
response._post_init(loop, session)
loop.get_debug = mock.Mock()
loop.get_debug.return_value = True

Expand All @@ -34,10 +39,10 @@ def test_http_processing_error():
assert info.value.request_info is request_info


def test_del():
def test_del(session):
loop = mock.Mock()
response = ClientResponse('get', URL('http://del-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
loop.get_debug = mock.Mock()
loop.get_debug.return_value = True

Expand All @@ -53,9 +58,9 @@ def test_del():
connection.release.assert_called_with()


def test_close(loop):
def test_close(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response._closed = False
response._connection = mock.Mock()
response.close()
Expand All @@ -64,25 +69,25 @@ def test_close(loop):
response.close()


def test_wait_for_100_1(loop):
def test_wait_for_100_1(loop, session):
response = ClientResponse(
'get', URL('http://python.org'), continue100=object())
response._post_init(loop)
response._post_init(loop, session)
assert response._continue is not None
response.close()


def test_wait_for_100_2(loop):
def test_wait_for_100_2(loop, session):
response = ClientResponse(
'get', URL('http://python.org'))
response._post_init(loop)
response._post_init(loop, session)
assert response._continue is None
response.close()


def test_repr(loop):
def test_repr(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response.status = 200
response.reason = 'Ok'
assert '<ClientResponse(http://def-cl-resp.org) [200 Ok]>'\
Expand All @@ -109,9 +114,9 @@ def test_url_obj_deprecated():


@asyncio.coroutine
def test_read_and_release_connection(loop):
def test_read_and_release_connection(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -126,9 +131,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_read_and_release_connection_with_error(loop):
def test_read_and_release_connection_with_error(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
content = response.content = mock.Mock()
content.read.return_value = helpers.create_future(loop)
content.read.return_value.set_exception(ValueError)
Expand All @@ -139,9 +144,9 @@ def test_read_and_release_connection_with_error(loop):


@asyncio.coroutine
def test_release(loop):
def test_release(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
fut = helpers.create_future(loop)
fut.set_result(b'')
content = response.content = mock.Mock()
Expand All @@ -152,13 +157,13 @@ def test_release(loop):


@asyncio.coroutine
def test_release_on_del(loop):
def test_release_on_del(loop, session):
connection = mock.Mock()
connection.protocol.upgraded = False

def run(conn):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response._closed = False
response._connection = conn

Expand All @@ -168,9 +173,9 @@ def run(conn):


@asyncio.coroutine
def test_response_eof(loop):
def test_response_eof(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response._closed = False
conn = response._connection = mock.Mock()
conn.protocol.upgraded = False
Expand All @@ -181,9 +186,9 @@ def test_response_eof(loop):


@asyncio.coroutine
def test_response_eof_upgraded(loop):
def test_response_eof_upgraded(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

conn = response._connection = mock.Mock()
conn.protocol.upgraded = True
Expand All @@ -194,9 +199,9 @@ def test_response_eof_upgraded(loop):


@asyncio.coroutine
def test_response_eof_after_connection_detach(loop):
def test_response_eof_after_connection_detach(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response._closed = False
conn = response._connection = mock.Mock()
conn.protocol = None
Expand All @@ -207,9 +212,9 @@ def test_response_eof_after_connection_detach(loop):


@asyncio.coroutine
def test_text(loop):
def test_text(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -227,9 +232,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_text_bad_encoding(loop):
def test_text_bad_encoding(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -250,9 +255,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_text_custom_encoding(loop):
def test_text_custom_encoding(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -272,9 +277,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_text_detect_encoding(loop):
def test_text_detect_encoding(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -292,9 +297,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_text_after_read(loop):
def test_text_after_read(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -312,9 +317,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_json(loop):
def test_json(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -332,9 +337,9 @@ def side_effect(*args, **kwargs):


@asyncio.coroutine
def test_json_custom_loader(loop):
def test_json_custom_loader(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response.headers = {
'Content-Type': 'application/json;charset=cp1251'}
response._content = b'data'
Expand All @@ -347,9 +352,9 @@ def custom(content):


@asyncio.coroutine
def test_json_no_content(loop):
def test_json_no_content(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response.headers = {
'Content-Type': 'data/octet-stream'}
response._content = b''
Expand All @@ -364,9 +369,9 @@ def test_json_no_content(loop):


@asyncio.coroutine
def test_json_override_encoding(loop):
def test_json_override_encoding(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

def side_effect(*args, **kwargs):
fut = helpers.create_future(loop)
Expand All @@ -386,19 +391,19 @@ def side_effect(*args, **kwargs):


@pytest.mark.xfail
def test_override_flow_control(loop):
def test_override_flow_control(loop, session):
class MyResponse(ClientResponse):
flow_control_class = aiohttp.StreamReader
response = MyResponse('get', URL('http://my-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)
response._connection = mock.Mock()
assert isinstance(response.content, aiohttp.StreamReader)
response.close()


def test_get_encoding_unknown(loop):
def test_get_encoding_unknown(loop, session):
response = ClientResponse('get', URL('http://def-cl-resp.org'))
response._post_init(loop)
response._post_init(loop, session)

response.headers = {'Content-Type': 'application/json'}
with mock.patch('aiohttp.client_reqrep.chardet') as m_chardet:
Expand Down

0 comments on commit 53141f3

Please sign in to comment.