Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement base protocol class #2986

Merged
merged 5 commits into from
May 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/2986.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Implement base protocol class to avoid a dependency from internal
``asyncio.streams.FlowControlMixin``
64 changes: 64 additions & 0 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio

from .log import internal_logger


class BaseProtocol(asyncio.Protocol):
def __init__(self, loop=None):
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self._paused = False
self._drain_waiter = None
self._connection_lost = False
self.transport = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want transport as public attribute or instead public read-only property? I think transport is only valid to set via connection_made method, not somehow adhoc'y.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying don't touch transport attr, many aiohttp codes use it internally.
Tests modify .transport wildly as well.
Better to deprecate .protocol at all in our public API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def pause_writing(self):
assert not self._paused
self._paused = True
if self._loop.get_debug():
internal_logger.debug("%r pauses writing", self)

def resume_writing(self):
assert self._paused
self._paused = False
if self._loop.get_debug():
internal_logger.debug("%r resumes writing", self)

waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)

def connection_made(self, transport):
self.transport = transport

def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter
12 changes: 3 additions & 9 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import asyncio
import asyncio.streams
from contextlib import suppress

from .base_protocol import BaseProtocol
from .client_exceptions import (ClientOSError, ClientPayloadError,
ServerDisconnectedError)
from .http import HttpResponseParser
from .streams import EMPTY_PAYLOAD, DataQueue


class ResponseHandler(DataQueue, asyncio.streams.FlowControlMixin):
class ResponseHandler(BaseProtocol, DataQueue):
"""Helper class to adapt between Protocol and StreamReader."""

def __init__(self, *, loop=None):
asyncio.streams.FlowControlMixin.__init__(self, loop=loop)
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop=loop)

self.transport = None
self._should_close = False

self._message = None
Expand Down Expand Up @@ -56,9 +54,6 @@ def close(self):
def is_connected(self):
return self.transport is not None

def connection_made(self, transport):
self.transport = transport

def connection_lost(self, exc):
if self._payload_parser is not None:
with suppress(Exception):
Expand All @@ -81,7 +76,6 @@ def connection_lost(self, exc):
# we do it anyway below
self.set_exception(exc)

self.transport = None
self._should_close = True
self._parser = None
self._message = None
Expand Down
9 changes: 2 additions & 7 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import yarl

from . import helpers, http
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout
from .http import HttpProcessingError, HttpRequestParser, StreamWriter
from .log import access_logger, server_logger
Expand All @@ -35,7 +36,7 @@ class PayloadAccessError(Exception):
"""Payload was accesed after responce was sent."""


class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol):
class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.

RequestHandler handles incoming HTTP request. It reads request line,
Expand Down Expand Up @@ -93,8 +94,6 @@ def __init__(self, manager, *, loop=None,

super().__init__(loop=loop)

self._loop = loop if loop is not None else asyncio.get_event_loop()

self._manager = manager
self._request_handler = manager.request_handler
self._request_factory = manager.request_factory
Expand All @@ -121,7 +120,6 @@ def __init__(self, manager, *, loop=None,
max_headers=max_headers,
payload_exception=RequestPayloadError)

self.transport = None
self._reading_paused = False

self.logger = logger
Expand Down Expand Up @@ -177,8 +175,6 @@ async def shutdown(self, timeout=15.0):
def connection_made(self, transport):
super().connection_made(transport)

self.transport = transport

if self._tcp_keepalive:
tcp_keepalive(transport)

Expand All @@ -196,7 +192,6 @@ def connection_lost(self, exc):
self._request_factory = None
self._request_handler = None
self._request_parser = None
self.transport = None

if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
Expand Down
175 changes: 175 additions & 0 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import asyncio
from contextlib import suppress
from unittest import mock

import pytest

from aiohttp.base_protocol import BaseProtocol


def test_loop(loop):
asyncio.set_event_loop(None)
pr = BaseProtocol(loop=loop)
assert pr._loop is loop


def test_default_loop(loop):
asyncio.set_event_loop(loop)
pr = BaseProtocol()
assert pr._loop is loop


def test_pause_writing(loop):
pr = BaseProtocol(loop=loop)
assert not pr._paused
pr.pause_writing()
assert pr._paused


def test_resume_writing_no_waiters(loop):
pr = BaseProtocol(loop=loop)
pr.pause_writing()
assert pr._paused
pr.resume_writing()
assert not pr._paused


def test_connection_made(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
assert pr.transport is None
pr.connection_made(tr)
assert pr.transport is not None


def test_connection_lost_not_paused(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost


def test_connection_lost_paused_without_waiter(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost


async def test_drain_lost(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.connection_lost(None)
with pytest.raises(ConnectionResetError):
await pr._drain_helper()


async def test_drain_not_paused(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert pr._drain_waiter is None
await pr._drain_helper()
assert pr._drain_waiter is None


async def test_resume_drain_waited(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()

t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)

assert pr._drain_waiter is not None
pr.resume_writing()
assert (await t) is None
assert pr._drain_waiter is None


async def test_lost_drain_waited_ok(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()

t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)

assert pr._drain_waiter is not None
pr.connection_lost(None)
assert (await t) is None
assert pr._drain_waiter is None


async def test_lost_drain_waited_exception(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()

t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)

assert pr._drain_waiter is not None
exc = RuntimeError()
pr.connection_lost(exc)
with pytest.raises(RuntimeError) as cm:
await t
assert cm.value is exc
assert pr._drain_waiter is None


async def test_lost_drain_cancelled(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()

fut = loop.create_future()

async def wait():
fut.set_result(None)
await pr._drain_helper()

t = loop.create_task(wait())
await fut
t.cancel()

assert pr._drain_waiter is not None
pr.connection_lost(None)
with suppress(asyncio.CancelledError):
await t
assert pr._drain_waiter is None


async def test_resume_drain_cancelled(loop):
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()

fut = loop.create_future()

async def wait():
fut.set_result(None)
await pr._drain_helper()

t = loop.create_task(wait())
await fut
t.cancel()

assert pr._drain_waiter is not None
pr.resume_writing()
with suppress(asyncio.CancelledError):
await t
assert pr._drain_waiter is None