diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index f3736eb6cfe..d28907ac858 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -1,3 +1,4 @@ +import sys import asyncio import warnings @@ -11,6 +12,7 @@ __all__ = ('WebSocketResponse', 'MsgType') +PY_35 = sys.version_info >= (3, 5) THRESHOLD_CONNLOST_ACCESS = 5 @@ -285,3 +287,15 @@ def receive_bytes(self): def write(self, data): raise RuntimeError("Cannot call .write() for websocket") + + if PY_35: + @asyncio.coroutine + def __aiter__(self): + return self + + @asyncio.coroutine + def __anext__(self): + msg = yield from self.receive() + if msg.tp == MsgType.close: + raise StopAsyncIteration # NOQA + return msg diff --git a/docs/web.rst b/docs/web.rst index dec17656e99..c3a003cbd93 100644 --- a/docs/web.rst +++ b/docs/web.rst @@ -415,19 +415,17 @@ using response's methods: ws = web.WebSocketResponse() await ws.prepare(request) - while not ws.closed: - msg = await ws.receive() - + async for msg in ws: if msg.tp == aiohttp.MsgType.text: if msg.data == 'close': await ws.close() else: ws.send_str(msg.data + '/answer') - elif msg.tp == aiohttp.MsgType.close: - print('websocket connection closed') elif msg.tp == aiohttp.MsgType.error: print('ws connection closed with exception %s' % ws.exception()) + + print('websocket connection closed') return ws diff --git a/tests/test_py35/test_web_websocket_35.py b/tests/test_py35/test_web_websocket_35.py new file mode 100644 index 00000000000..55c55a50894 --- /dev/null +++ b/tests/test_py35/test_web_websocket_35.py @@ -0,0 +1,50 @@ +import pytest + +import asyncio + +from aiohttp import web, websocket +from aiohttp.websocket_client import MsgType, ws_connect + + +async def create_server(loop, port, method, path, route_handler): + app = web.Application(loop=loop) + app.router.add_route(method, path, route_handler) + handler = app.make_handler(keep_alive_on=False) + return await loop.create_server(handler, '127.0.0.1', port) + + +@pytest.mark.run_loop +async def test_await(loop, unused_port): + closed = asyncio.Future(loop=loop) + + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + assert msg.tp == MsgType.text + s = msg.data + ws.send_str(s + '/answer') + await ws.close() + closed.set_result(1) + return ws + + port = unused_port() + await create_server(loop, port, 'GET', '/', handler) # returns server + resp = await ws_connect('ws://127.0.0.1:{p}'.format(p=port), loop=loop) + + items = ['q1', 'q2', 'q3'] + for item in items: + resp._writer.send(item) + msg = await resp._reader.read() + assert msg.tp == websocket.MSG_TEXT + assert item + '/answer' == msg.data + + resp._writer.close() + + msg = await resp._reader.read() + assert msg.tp == websocket.MSG_CLOSE + assert msg.data == 1000 + assert msg.extra == '' + + await closed + await resp.close()