-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement base protocol class #2986
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Implement base protocol class to avoid a dependency from internal | ||
``asyncio.streams.FlowControlMixin`` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import asyncio | ||
|
||
from .log import internal_logger | ||
|
||
|
||
class BaseProtocol(asyncio.Protocol): | ||
def __init__(self, loop=None): | ||
if loop is None: | ||
self._loop = asyncio.get_event_loop() | ||
else: | ||
self._loop = loop | ||
self._paused = False | ||
self._drain_waiter = None | ||
self._connection_lost = False | ||
self.transport = None | ||
|
||
def pause_writing(self): | ||
assert not self._paused | ||
self._paused = True | ||
if self._loop.get_debug(): | ||
internal_logger.debug("%r pauses writing", self) | ||
|
||
def resume_writing(self): | ||
assert self._paused | ||
self._paused = False | ||
if self._loop.get_debug(): | ||
internal_logger.debug("%r resumes writing", self) | ||
|
||
waiter = self._drain_waiter | ||
if waiter is not None: | ||
self._drain_waiter = None | ||
if not waiter.done(): | ||
waiter.set_result(None) | ||
|
||
def connection_made(self, transport): | ||
self.transport = transport | ||
|
||
def connection_lost(self, exc): | ||
self._connection_lost = True | ||
# Wake up the writer if currently paused. | ||
self.transport = None | ||
if not self._paused: | ||
return | ||
waiter = self._drain_waiter | ||
if waiter is None: | ||
return | ||
self._drain_waiter = None | ||
if waiter.done(): | ||
return | ||
if exc is None: | ||
waiter.set_result(None) | ||
else: | ||
waiter.set_exception(exc) | ||
|
||
async def _drain_helper(self): | ||
if self._connection_lost: | ||
raise ConnectionResetError('Connection lost') | ||
if not self._paused: | ||
return | ||
waiter = self._drain_waiter | ||
assert waiter is None or waiter.cancelled() | ||
waiter = self._loop.create_future() | ||
self._drain_waiter = waiter | ||
await waiter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import asyncio | ||
from contextlib import suppress | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from aiohttp.base_protocol import BaseProtocol | ||
|
||
|
||
def test_loop(loop): | ||
asyncio.set_event_loop(None) | ||
pr = BaseProtocol(loop=loop) | ||
assert pr._loop is loop | ||
|
||
|
||
def test_default_loop(loop): | ||
asyncio.set_event_loop(loop) | ||
pr = BaseProtocol() | ||
assert pr._loop is loop | ||
|
||
|
||
def test_pause_writing(loop): | ||
pr = BaseProtocol(loop=loop) | ||
assert not pr._paused | ||
pr.pause_writing() | ||
assert pr._paused | ||
|
||
|
||
def test_resume_writing_no_waiters(loop): | ||
pr = BaseProtocol(loop=loop) | ||
pr.pause_writing() | ||
assert pr._paused | ||
pr.resume_writing() | ||
assert not pr._paused | ||
|
||
|
||
def test_connection_made(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
assert pr.transport is None | ||
pr.connection_made(tr) | ||
assert pr.transport is not None | ||
|
||
|
||
def test_connection_lost_not_paused(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
assert not pr._connection_lost | ||
pr.connection_lost(None) | ||
assert pr.transport is None | ||
assert pr._connection_lost | ||
|
||
|
||
def test_connection_lost_paused_without_waiter(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
assert not pr._connection_lost | ||
pr.pause_writing() | ||
pr.connection_lost(None) | ||
assert pr.transport is None | ||
assert pr._connection_lost | ||
|
||
|
||
async def test_drain_lost(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.connection_lost(None) | ||
with pytest.raises(ConnectionResetError): | ||
await pr._drain_helper() | ||
|
||
|
||
async def test_drain_not_paused(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
assert pr._drain_waiter is None | ||
await pr._drain_helper() | ||
assert pr._drain_waiter is None | ||
|
||
|
||
async def test_resume_drain_waited(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.pause_writing() | ||
|
||
t = loop.create_task(pr._drain_helper()) | ||
await asyncio.sleep(0) | ||
|
||
assert pr._drain_waiter is not None | ||
pr.resume_writing() | ||
assert (await t) is None | ||
assert pr._drain_waiter is None | ||
|
||
|
||
async def test_lost_drain_waited_ok(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.pause_writing() | ||
|
||
t = loop.create_task(pr._drain_helper()) | ||
await asyncio.sleep(0) | ||
|
||
assert pr._drain_waiter is not None | ||
pr.connection_lost(None) | ||
assert (await t) is None | ||
assert pr._drain_waiter is None | ||
|
||
|
||
async def test_lost_drain_waited_exception(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.pause_writing() | ||
|
||
t = loop.create_task(pr._drain_helper()) | ||
await asyncio.sleep(0) | ||
|
||
assert pr._drain_waiter is not None | ||
exc = RuntimeError() | ||
pr.connection_lost(exc) | ||
with pytest.raises(RuntimeError) as cm: | ||
await t | ||
assert cm.value is exc | ||
assert pr._drain_waiter is None | ||
|
||
|
||
async def test_lost_drain_cancelled(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.pause_writing() | ||
|
||
fut = loop.create_future() | ||
|
||
async def wait(): | ||
fut.set_result(None) | ||
await pr._drain_helper() | ||
|
||
t = loop.create_task(wait()) | ||
await fut | ||
t.cancel() | ||
|
||
assert pr._drain_waiter is not None | ||
pr.connection_lost(None) | ||
with suppress(asyncio.CancelledError): | ||
await t | ||
assert pr._drain_waiter is None | ||
|
||
|
||
async def test_resume_drain_cancelled(loop): | ||
pr = BaseProtocol(loop=loop) | ||
tr = mock.Mock() | ||
pr.connection_made(tr) | ||
pr.pause_writing() | ||
|
||
fut = loop.create_future() | ||
|
||
async def wait(): | ||
fut.set_result(None) | ||
await pr._drain_helper() | ||
|
||
t = loop.create_task(wait()) | ||
await fut | ||
t.cancel() | ||
|
||
assert pr._drain_waiter is not None | ||
pr.resume_writing() | ||
with suppress(asyncio.CancelledError): | ||
await t | ||
assert pr._drain_waiter is None |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want transport as public attribute or instead public read-only property? I think transport is only valid to set via
connection_made
method, not somehow adhoc'y.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying don't touch transport attr, many aiohttp codes use it internally.
Tests modify
.transport
wildly as well.Better to deprecate
.protocol
at all in our public API.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍