diff --git a/CHANGES/2767.feature b/CHANGES/2767.feature new file mode 100644 index 00000000000..99a8b4e5383 --- /dev/null +++ b/CHANGES/2767.feature @@ -0,0 +1 @@ +Add tracking signals for getting request/response bodies. \ No newline at end of file diff --git a/aiohttp/client.py b/aiohttp/client.py index baa7f0a0564..841a4fc7c96 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -309,7 +309,7 @@ async def _request(self, method, url, *, response_class=self._response_class, proxy=proxy, proxy_auth=proxy_auth, timer=timer, session=self, auto_decompress=self._auto_decompress, - ssl=ssl, proxy_headers=proxy_headers) + ssl=ssl, proxy_headers=proxy_headers, traces=traces) # connection timeout try: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index e9e45d46e17..33bd25ca7ce 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -168,7 +168,8 @@ def __init__(self, method, url, *, proxy=None, proxy_auth=None, timer=None, session=None, auto_decompress=True, ssl=None, - proxy_headers=None): + proxy_headers=None, + traces=None): if loop is None: loop = asyncio.get_event_loop() @@ -209,6 +210,9 @@ def __init__(self, method, url, *, if data or self.method not in self.GET_METHODS: self.update_transfer_encoding() self.update_expect_continue(expect100) + if traces is None: + traces = [] + self._traces = traces def is_ssl(self): return self.url.scheme in ('https', 'wss') @@ -475,7 +479,10 @@ async def send(self, conn): if self.url.raw_query_string: path += '?' + self.url.raw_query_string - writer = StreamWriter(conn.protocol, conn.transport, self.loop) + writer = StreamWriter( + conn.protocol, conn.transport, self.loop, + on_chunk_sent=self._on_chunk_request_sent + ) if self.compress: writer.enable_compression(self.compress) @@ -513,8 +520,9 @@ async def send(self, conn): self.method, self.original_url, writer=self._writer, continue100=self._continue, timer=self._timer, request_info=self.request_info, - auto_decompress=self._auto_decompress) - + auto_decompress=self._auto_decompress, + traces=self._traces, + ) self.response._post_init(self.loop, self._session) return self.response @@ -531,6 +539,10 @@ def terminate(self): self._writer.cancel() self._writer = None + async def _on_chunk_request_sent(self, chunk): + for trace in self._traces: + await trace.send_request_chunk_sent(chunk) + class ClientResponse(HeadersMixin): @@ -555,7 +567,8 @@ class ClientResponse(HeadersMixin): def __init__(self, method, url, *, writer=None, continue100=None, timer=None, - request_info=None, auto_decompress=True): + request_info=None, auto_decompress=True, + traces=None): assert isinstance(url, URL) self.method = method @@ -572,6 +585,9 @@ def __init__(self, method, url, *, self._timer = timer if timer is not None else TimerNoop() self._auto_decompress = auto_decompress self._cache = {} # reqired for @reify method decorator + if traces is None: + traces = [] + self._traces = traces @property def url(self): @@ -796,6 +812,8 @@ async def read(self): if self._content is None: try: self._content = await self.content.read() + for trace in self._traces: + await trace.send_response_chunk_received(self._content) except BaseException: self.close() raise diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index f0552a3840d..bc7201da1aa 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -16,7 +16,7 @@ class StreamWriter(AbstractStreamWriter): - def __init__(self, protocol, transport, loop): + def __init__(self, protocol, transport, loop, on_chunk_sent=None): self._protocol = protocol self._transport = transport @@ -30,6 +30,8 @@ def __init__(self, protocol, transport, loop): self._compress = None self._drain_waiter = None + self._on_chunk_sent = on_chunk_sent + @property def transport(self): return self._transport @@ -55,13 +57,16 @@ def _write(self, chunk): raise asyncio.CancelledError('Cannot write to closing transport') self._transport.write(chunk) - async def write(self, chunk, *, drain=True, LIMIT=64*1024): + async def write(self, chunk, *, drain=True, LIMIT=0x10000): """Writes chunk of data to a stream. write_eof() indicates end of stream. writer can't be used after write_eof() method being called. write() return drain future. """ + if self._on_chunk_sent is not None: + await self._on_chunk_sent(chunk) + if self._compress is not None: chunk = self._compress.compress(chunk) if not chunk: diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index b813347637d..165e68cbf9d 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -15,7 +15,8 @@ 'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams', 'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams', 'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams', - 'TraceRequestRedirectParams' + 'TraceRequestRedirectParams', + 'TraceRequestChunkSentParams', 'TraceResponseChunkReceivedParams', ) @@ -25,6 +26,8 @@ class TraceConfig: def __init__(self, trace_config_ctx_factory=SimpleNamespace): self._on_request_start = Signal(self) + self._on_request_chunk_sent = Signal(self) + self._on_response_chunk_received = Signal(self) self._on_request_end = Signal(self) self._on_request_exception = Signal(self) self._on_request_redirect = Signal(self) @@ -47,6 +50,8 @@ def trace_config_ctx(self, trace_request_ctx=None): def freeze(self): self._on_request_start.freeze() + self._on_request_chunk_sent.freeze() + self._on_response_chunk_received.freeze() self._on_request_end.freeze() self._on_request_exception.freeze() self._on_request_redirect.freeze() @@ -64,6 +69,14 @@ def freeze(self): def on_request_start(self): return self._on_request_start + @property + def on_request_chunk_sent(self): + return self._on_request_chunk_sent + + @property + def on_response_chunk_received(self): + return self._on_response_chunk_received + @property def on_request_end(self): return self._on_request_end @@ -121,6 +134,18 @@ class TraceRequestStartParams: headers = attr.ib(type=CIMultiDict) +@attr.s(frozen=True, slots=True) +class TraceRequestChunkSentParams: + """ Parameters sent by the `on_request_chunk_sent` signal""" + chunk = attr.ib(type=bytes) + + +@attr.s(frozen=True, slots=True) +class TraceResponseChunkReceivedParams: + """ Parameters sent by the `on_response_chunk_received` signal""" + chunk = attr.ib(type=bytes) + + @attr.s(frozen=True, slots=True) class TraceRequestEndParams: """ Parameters sent by the `on_request_end` signal""" @@ -213,6 +238,20 @@ async def send_request_start(self, method, url, headers): TraceRequestStartParams(method, url, headers) ) + async def send_request_chunk_sent(self, chunk): + return await self._trace_config.on_request_chunk_sent.send( + self._session, + self._trace_config_ctx, + TraceRequestChunkSentParams(chunk) + ) + + async def send_response_chunk_received(self, chunk): + return await self._trace_config.on_response_chunk_received.send( + self._session, + self._trace_config_ctx, + TraceResponseChunkReceivedParams(chunk) + ) + async def send_request_end(self, method, url, headers, response): return await self._trace_config.on_request_end.send( self._session, diff --git a/docs/tracing_reference.rst b/docs/tracing_reference.rst index 5cfaf09c527..e1f1ab9f6da 100644 --- a/docs/tracing_reference.rst +++ b/docs/tracing_reference.rst @@ -34,16 +34,26 @@ Overview exception[shape=flowchart.terminator, description="on_request_exception"]; acquire_connection[description="Connection acquiring"]; - got_response; - send_request; + headers_received; + headers_sent; + chunk_sent[description="on_request_chunk_sent"]; + chunk_received[description="on_response_chunk_received"]; start -> acquire_connection; - acquire_connection -> send_request; - send_request -> got_response; - got_response -> redirect; - got_response -> end; - redirect -> send_request; - send_request -> exception; + acquire_connection -> headers_sent; + headers_sent -> headers_received; + headers_sent -> chunk_sent; + chunk_sent -> chunk_sent; + chunk_sent -> headers_received; + headers_received -> chunk_received; + chunk_received -> chunk_received; + chunk_received -> end; + headers_received -> redirect; + headers_received -> end; + redirect -> headers_sent; + chunk_received -> exception; + chunk_sent -> exception; + headers_sent -> exception; } @@ -147,6 +157,26 @@ TraceConfig ``params`` is :class:`aiohttp.TraceRequestStartParams` instance. + .. attribute:: on_request_chunk_sent + + + Property that gives access to the signals that will be executed + when a chunk of request body is sent. + + ``params`` is :class:`aiohttp.TraceRequestChunkSentParams` instance. + + .. versionadded:: 3.1 + + .. attribute:: on_response_chunk_received + + + Property that gives access to the signals that will be executed + when a chunk of response body is received. + + ``params`` is :class:`aiohttp.TraceResponseChunkReceivedParams` instance. + + .. versionadded:: 3.1 + .. attribute:: on_request_redirect Property that gives access to the signals that will be executed when a @@ -259,6 +289,35 @@ TraceRequestStartParams Headers that will be used for the request, can be mutated. + +TraceRequestChunkSentParams +--------------------------- + +.. class:: TraceRequestChunkSentParams + + .. versionadded:: 3.1 + + See :attr:`TraceConfig.on_request_chunk_sent` for details. + + .. attribute:: chunk + + Bytes of chunk sent + + +TraceResponseChunkSentParams +---------------------------- + +.. class:: TraceResponseChunkSentParams + + .. versionadded:: 3.1 + + See :attr:`TraceConfig.on_response_chunk_received` for details. + + .. attribute:: chunk + + Bytes of chunk received + + TraceRequestEndParams --------------------- diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 1b40a3909e3..d8dfe7b54ce 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -10,6 +10,7 @@ import aiohttp from aiohttp import http from aiohttp.client_reqrep import ClientResponse, RequestInfo +from aiohttp.test_utils import make_mocked_coro @pytest.fixture @@ -613,3 +614,35 @@ def test_redirect_history_in_exception(): with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert [hist_response] == cm.value.history + + +async def test_response_read_triggers_callback(loop, session): + trace = mock.Mock() + trace.send_response_chunk_received = make_mocked_coro() + response_body = b'This is response' + + response = ClientResponse( + 'get', URL('http://def-cl-resp.org'), + traces=[trace] + ) + response._post_init(loop, session) + + def side_effect(*args, **kwargs): + fut = loop.create_future() + fut.set_result(response_body) + return fut + + response.headers = { + 'Content-Type': 'application/json;charset=cp1251'} + content = response.content = mock.Mock() + content.read.side_effect = side_effect + + res = await response.read() + assert res == response_body + assert response._connection is None + + assert trace.send_response_chunk_received.called + assert ( + trace.send_response_chunk_received.call_args == + mock.call(response_body) + ) diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 60e5fcce33f..f31e4def17b 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -1,8 +1,10 @@ import asyncio import contextlib import gc +import json import re from http.cookies import SimpleCookie +from io import BytesIO from unittest import mock import pytest @@ -457,33 +459,47 @@ def test_client_session_implicit_loop_warn(): async def test_request_tracing(loop, aiohttp_client): async def handler(request): - return web.Response() + return web.json_response({'ok': True}) app = web.Application() - app.router.add_get('/', handler) + app.router.add_post('/', handler) trace_config_ctx = mock.Mock() trace_request_ctx = {} + body = 'This is request body' + gathered_req_body = BytesIO() + gathered_res_body = BytesIO() on_request_start = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) on_request_redirect = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + async def on_request_chunk_sent(session, context, params): + gathered_req_body.write(params.chunk) + + async def on_response_chunk_received(session, context, params): + gathered_res_body.write(params.chunk) + trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_request_start.append(on_request_start) trace_config.on_request_end.append(on_request_end) + trace_config.on_request_chunk_sent.append(on_request_chunk_sent) + trace_config.on_response_chunk_received.append(on_response_chunk_received) trace_config.on_request_redirect.append(on_request_redirect) session = await aiohttp_client(app, trace_configs=[trace_config]) - async with session.get('/', trace_request_ctx=trace_request_ctx) as resp: + async with session.post( + '/', data=body, trace_request_ctx=trace_request_ctx) as resp: + + await resp.json() on_request_start.assert_called_once_with( session.session, trace_config_ctx, aiohttp.TraceRequestStartParams( - hdrs.METH_GET, + hdrs.METH_POST, session.make_url('/'), CIMultiDict() ) @@ -493,13 +509,16 @@ async def handler(request): session.session, trace_config_ctx, aiohttp.TraceRequestEndParams( - hdrs.METH_GET, + hdrs.METH_POST, session.make_url('/'), CIMultiDict(), resp ) ) assert not on_request_redirect.called + assert gathered_req_body.getvalue() == body.encode('utf8') + assert gathered_res_body.getvalue() == json.dumps( + {'ok': True}).encode('utf8') async def test_request_tracing_exception(loop): diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 9b22750d18d..47d41ef3984 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -159,6 +159,18 @@ async def test_write_drain(protocol, transport, loop): assert msg.buffer_size == 0 +async def test_write_calls_callback(protocol, transport, loop): + on_chunk_sent = make_mocked_coro() + msg = http.StreamWriter( + protocol, transport, loop, + on_chunk_sent=on_chunk_sent + ) + chunk = b'1' + await msg.write(chunk) + assert on_chunk_sent.called + assert on_chunk_sent.call_args == mock.call(chunk) + + async def test_write_to_closing_transport(protocol, transport, loop): msg = http.StreamWriter(protocol, transport, loop) diff --git a/tests/test_tracing.py b/tests/test_tracing.py index bc837bf3af1..cdb22b93e9a 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -13,10 +13,12 @@ TraceDnsCacheHitParams, TraceDnsCacheMissParams, TraceDnsResolveHostEndParams, TraceDnsResolveHostStartParams, + TraceRequestChunkSentParams, TraceRequestEndParams, TraceRequestExceptionParams, TraceRequestRedirectParams, - TraceRequestStartParams) + TraceRequestStartParams, + TraceResponseChunkReceivedParams) class TestTraceConfig: @@ -41,6 +43,8 @@ def test_freeze(self): trace_config.freeze() assert trace_config.on_request_start.frozen + assert trace_config.on_request_chunk_sent.frozen + assert trace_config.on_response_chunk_received.frozen assert trace_config.on_request_end.frozen assert trace_config.on_request_exception.frozen assert trace_config.on_request_redirect.frozen @@ -63,6 +67,16 @@ class TestTrace: (Mock(), Mock(), Mock()), TraceRequestStartParams ), + ( + 'request_chunk_sent', + (Mock(), ), + TraceRequestChunkSentParams + ), + ( + 'response_chunk_received', + (Mock(), ), + TraceResponseChunkReceivedParams + ), ( 'request_end', (Mock(), Mock(), Mock(), Mock()),