Skip to content

Commit

Permalink
Request tracing refactorized ownership of the signals
Browse files Browse the repository at this point in the history
Signals ownership is spread to each object that is in charge to
send that specific signal, however to allow the user use only the
`ClientSession` this class proxies these signals that are really not
owned by it.
  • Loading branch information
pfreixes committed Oct 24, 2017
1 parent a8da61d commit 3af66fc
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 46 deletions.
38 changes: 20 additions & 18 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
self._on_request_createconn_start = FuncSignal()
self._on_request_createconn_end = FuncSignal()
self._on_request_redirect = FuncSignal()

self._on_request_headers_sent = FuncSignal()
self._on_request_content_sent = FuncSignal()

# TODO: not implemented yet
self._on_request_content_chunk_sent = FuncSignal()
self._on_request_headers_received = FuncSignal()
self._on_request_content_chunk_received = FuncSignal()
Expand Down Expand Up @@ -296,14 +295,17 @@ def _request(self, method, url, *,
session=self, auto_decompress=self._auto_decompress,
verify_ssl=verify_ssl, fingerprint=fingerprint,
ssl_context=ssl_context, proxy_headers=proxy_headers,
on_headers_sent=self.on_request_headers_sent,
on_content_sent=self.on_request_content_sent,
on_content_chunk_sent=self.on_request_content_chunk_sent, # noqa
trace_context=trace_context)

# connection timeout
try:
with CeilTimeout(self._conn_timeout, loop=self._loop):
conn = yield from self._connector.connect(
req,
session_tracing=(self, trace_context)
trace_context=trace_context
)
except asyncio.TimeoutError as exc:
raise ServerTimeoutError(
Expand Down Expand Up @@ -702,36 +704,36 @@ def on_request_start(self):
return self._on_request_start

@property
def on_request_redirect_start(self):
return self._on_request_start
def on_request_redirect(self):
return self._on_request_redirect

@property
def on_request_redirect_end(self):
return self._on_request_start
def on_request_end(self):
return self._on_request_end

@property
def on_request_exception(self):
return self._on_request_exception

# connector signals

@property
def on_request_queued_start(self):
return self._on_request_queued_start
return self._connector.on_queued_start

@property
def on_request_queued_end(self):
return self._on_request_queued_end
return self._connector.on_queued_end

@property
def on_request_createconn_start(self):
return self._on_request_createconn_start
return self._connector.on_createconn_start

@property
def on_request_createconn_end(self):
return self._on_request_createconn_end
return self._connector.on_createconn_end

@property
def on_request_end(self):
return self._on_request_end

@property
def on_request_exception(self):
return self._on_request_exception
# req resp signals

@property
def on_request_headers_sent(self):
Expand Down
15 changes: 12 additions & 3 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(self, method, url, *,
proxy=None, proxy_auth=None,
timer=None, session=None, auto_decompress=True,
verify_ssl=None, fingerprint=None, ssl_context=None,
proxy_headers=None, trace_context=None):
proxy_headers=None, on_headers_sent=None,
on_content_sent=None, on_content_chunk_sent=None,
trace_context=None):

if verify_ssl is False and ssl_context is not None:
raise ValueError(
Expand Down Expand Up @@ -117,6 +119,10 @@ def __init__(self, method, url, *,
self._auto_decompress = auto_decompress
self._verify_ssl = verify_ssl
self._ssl_context = ssl_context

self._on_headers_sent = on_headers_sent
self._on_content_sent = on_content_sent
self._on_content_chunk_sent = on_content_chunk_sent
self._trace_context = trace_context

if loop.get_debug():
Expand Down Expand Up @@ -401,7 +407,9 @@ def write_bytes(self, writer, conn):
for chunk in self.body:
writer.write(chunk)

self._session.on_request_content_sent.send(self._trace_context)
if self._on_content_sent is not None:
self._on_content_sent.send(self._trace_context)

yield from writer.write_eof()
except OSError as exc:
new_exc = ClientOSError(
Expand Down Expand Up @@ -464,7 +472,8 @@ def send(self, conn):
self.method, path, self.version)
writer.write_headers(status_line, self.headers)

self._session.on_request_headers_sent.send(self._trace_context)
if self._on_headers_sent is not None:
self._on_headers_sent.send(self._trace_context)

self._writer = helpers.ensure_future(
self.write_bytes(writer, conn), loop=self.loop)
Expand Down
45 changes: 31 additions & 14 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hashlib import md5, sha1, sha256
from itertools import cycle, islice
from time import monotonic
from types import MappingProxyType
from types import MappingProxyType, SimpleNamespace

from . import hdrs, helpers
from .client_exceptions import (ClientConnectionError,
Expand All @@ -23,6 +23,7 @@
from .locks import EventResultOrError
from .log import client_logger
from .resolver import DefaultResolver
from .signals import FuncSignal


try:
Expand Down Expand Up @@ -200,6 +201,11 @@ def __init__(self, *, keepalive_timeout=sentinel,
self._cleanup_closed_transports = []
self._cleanup_closed()

self._on_queued_start = FuncSignal()
self._on_queued_end = FuncSignal()
self._on_createconn_start = FuncSignal()
self._on_createconn_end = FuncSignal()

def __del__(self, _warnings=warnings):
if self._closed:
return
Expand Down Expand Up @@ -347,11 +353,11 @@ def closed(self):
return self._closed

@asyncio.coroutine
def connect(self, req, session_tracing=None):
def connect(self, req, trace_context=None):
"""Get from pool or create new connection."""

if session_tracing:
session, trace_context = session_tracing
if trace_context is None:
trace_context = SimpleNamespace()

key = req.connection_key

Expand Down Expand Up @@ -379,8 +385,8 @@ def connect(self, req, session_tracing=None):
# This connection will now count towards the limit.
waiters = self._waiters[key]
waiters.append(fut)
if session_tracing:
session.on_request_queued_start.send(trace_context)

self.on_queued_start.send(trace_context)

try:
yield from fut
Expand All @@ -390,18 +396,15 @@ def connect(self, req, session_tracing=None):
if not waiters:
del self._waiters[key]

if session_tracing:
session.on_request_queued_end.send(trace_context)
self.on_queued_end.send(trace_context)

proto = self._get(key)
if proto is None:
placeholder = _TransportPlaceholder()
self._acquired.add(placeholder)
self._acquired_per_host[key].add(placeholder)

if session_tracing:
session.on_request_createconn_start.send(
trace_context)
self.on_createconn_start.send(trace_context)

try:
proto = yield from self._create_connection(req)
Expand All @@ -420,9 +423,7 @@ def connect(self, req, session_tracing=None):
self._acquired.remove(placeholder)
self._acquired_per_host[key].remove(placeholder)

if session_tracing:
session.on_request_createconn_end.send(
trace_context)
self.on_createconn_end.send(trace_context)

self._acquired.add(proto)
self._acquired_per_host[key].add(proto)
Expand Down Expand Up @@ -520,6 +521,22 @@ def _release(self, key, protocol, *, should_close=False):
def _create_connection(self, req):
raise NotImplementedError()

@property
def on_queued_start(self):
return self._on_queued_start

@property
def on_queued_end(self):
return self._on_queued_end

@property
def on_createconn_start(self):
return self._on_createconn_start

@property
def on_createconn_end(self):
return self._on_createconn_end


_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,15 @@ def new_headers(trace_context, method, host, port, headers):

yield from session.get('http://example.com')
assert MyClientRequest.headers['foo'] == 'bar'


@asyncio.coroutine
def test_request_tracing_proxies_connector_signals(loop):
connector = TCPConnector(loop=loop)
session = aiohttp.ClientSession(connector=connector, loop=loop)
assert id(session.on_request_queued_start) == id(connector.on_queued_start)
assert id(session.on_request_queued_end) == id(connector.on_queued_end)
assert id(session.on_request_createconn_start) ==\
id(connector.on_createconn_start)
assert id(session.on_request_createconn_end) ==\
id(connector.on_createconn_end)
29 changes: 18 additions & 11 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,22 +591,25 @@ def test_connect(loop):


@asyncio.coroutine
def test_connect_request_tracing(loop):
def test_connect_tracing(loop):
proto = mock.Mock()
proto.is_connected.return_value = True
session = mock.Mock()
trace_context = mock.Mock()
on_createconn_start = mock.Mock()
on_createconn_end = mock.Mock()

req = ClientRequest('GET', URL('http://host:80'), loop=loop)

conn = aiohttp.BaseConnector(loop=loop)
conn.on_createconn_start.append(on_createconn_start)
conn.on_createconn_end.append(on_createconn_end)
conn._create_connection = mock.Mock()
conn._create_connection.return_value = helpers.create_future(loop)
conn._create_connection.return_value.set_result(proto)

yield from conn.connect(req, session_tracing=(session, trace_context))
session.on_request_createconn_start.send.assert_called_with(trace_context)
session.on_request_createconn_end.send.assert_called_with(trace_context)
yield from conn.connect(req, trace_context=trace_context)
on_createconn_start.assert_called_with(trace_context)
on_createconn_end.assert_called_with(trace_context)


@asyncio.coroutine
Expand Down Expand Up @@ -904,32 +907,36 @@ def f():


@asyncio.coroutine
def test_connect_queued_operation_request_tracing(loop, key):
def test_connect_queued_operation_tracing(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True
trace_context = mock.Mock()
on_queued_start = mock.Mock()
on_queued_end = mock.Mock()

req = ClientRequest('GET', URL('http://localhost1:80'),
loop=loop,
response_class=mock.Mock())

conn = aiohttp.BaseConnector(loop=loop, limit=1)
conn.on_queued_start.append(on_queued_start)
conn.on_queued_end.append(on_queued_end)
conn._conns[key] = [(proto, loop.time())]
conn._create_connection = mock.Mock()
conn._create_connection.return_value = helpers.create_future(loop)
conn._create_connection.return_value.set_result(proto)

connection1 = yield from conn.connect(req)
connection1 = yield from conn.connect(req, trace_context=trace_context)

@asyncio.coroutine
def f():
session = mock.Mock()
trace_context = mock.Mock()
connection2 = yield from conn.connect(
req,
session_tracing=(session, trace_context)
trace_context=trace_context
)
session.on_request_queued_start.send.assert_called_with(trace_context)
session.on_request_queued_end.send.assert_called_with(trace_context)
on_queued_start.assert_called_with(trace_context)
on_queued_end.assert_called_with(trace_context)
connection2.release()

task = helpers.ensure_future(f(), loop=loop)
Expand Down

0 comments on commit 3af66fc

Please sign in to comment.