diff --git a/CHANGES/2762.feature b/CHANGES/2762.feature new file mode 100644 index 00000000000..add0440fc7e --- /dev/null +++ b/CHANGES/2762.feature @@ -0,0 +1 @@ +Make ``writer.write_headers`` a coroutine. \ No newline at end of file diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 4c3ea6c997d..4c7b8dca01c 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -505,7 +505,7 @@ async def send(self, conn): # status + headers status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( self.method, path, self.version) - writer.write_headers(status_line, self.headers) + await writer.write_headers(status_line, self.headers) self._writer = self.loop.create_task(self.write_bytes(writer, conn)) diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 9e25ddc81b8..83212ca9a88 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -91,7 +91,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024): return noop() - def write_headers(self, status_line, headers, SEP=': ', END='\r\n'): + async def write_headers(self, status_line, headers, SEP=': ', END='\r\n'): """Write request/response status and headers.""" # status + headers headers = status_line + ''.join( diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 3804faeb211..0253daf80e7 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -463,7 +463,6 @@ def make_mocked_request(method, path, headers=None, *, version=HttpVersion(1, 1), closing=False, app=None, writer=sentinel, - payload_writer=sentinel, protocol=sentinel, transport=sentinel, payload=sentinel, @@ -509,14 +508,12 @@ def make_mocked_request(method, path, headers=None, *, if writer is sentinel: writer = mock.Mock() + writer.write_headers = make_mocked_coro(None) + writer.write = make_mocked_coro(None) + writer.write_eof = make_mocked_coro(None) + writer.drain = make_mocked_coro(None) writer.transport = transport - if payload_writer is sentinel: - payload_writer = mock.Mock() - payload_writer.write = make_mocked_coro(None) - payload_writer.write_eof = make_mocked_coro(None) - payload_writer.drain = make_mocked_coro(None) - protocol.transport = transport protocol.writer = writer @@ -524,7 +521,7 @@ def make_mocked_request(method, path, headers=None, *, payload = mock.Mock() req = Request(message, payload, - protocol, payload_writer, task, loop, + protocol, writer, task, loop, client_max_size=client_max_size) match_info = UrlMappingMatchInfo( diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index e5fd254849e..fb07712965a 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -297,19 +297,19 @@ async def prepare(self, request): return self._payload_writer await request._prepare_hook(self) - return self._start(request) - - def _start(self, request, - HttpVersion10=HttpVersion10, - HttpVersion11=HttpVersion11, - CONNECTION=hdrs.CONNECTION, - DATE=hdrs.DATE, - SERVER=hdrs.SERVER, - CONTENT_TYPE=hdrs.CONTENT_TYPE, - CONTENT_LENGTH=hdrs.CONTENT_LENGTH, - SET_COOKIE=hdrs.SET_COOKIE, - SERVER_SOFTWARE=SERVER_SOFTWARE, - TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING): + return await self._start(request) + + async def _start(self, request, + HttpVersion10=HttpVersion10, + HttpVersion11=HttpVersion11, + CONNECTION=hdrs.CONNECTION, + DATE=hdrs.DATE, + SERVER=hdrs.SERVER, + CONTENT_TYPE=hdrs.CONTENT_TYPE, + CONTENT_LENGTH=hdrs.CONTENT_LENGTH, + SET_COOKIE=hdrs.SET_COOKIE, + SERVER_SOFTWARE=SERVER_SOFTWARE, + TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING): self._req = request keep_alive = self._keep_alive @@ -364,7 +364,7 @@ def _start(self, request, # status line status_line = 'HTTP/{}.{} {} {}\r\n'.format( version[0], version[1], self._status, self._reason) - writer.write_headers(status_line, headers) + await writer.write_headers(status_line, headers) return writer @@ -594,7 +594,7 @@ async def write_eof(self): else: await super().write_eof() - def _start(self, request): + async def _start(self, request): if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: if not self._body_payload: if self._body is not None: @@ -602,7 +602,7 @@ def _start(self, request): else: self._headers[hdrs.CONTENT_LENGTH] = '0' - return super()._start(request) + return await super()._start(request) def _do_start_compression(self, coding): if self._body_payload or self._chunked: diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 74f9f352dad..bbabf2d463a 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -17,6 +17,7 @@ from aiohttp import BaseConnector, hdrs, payload from aiohttp.client_reqrep import (ClientRequest, ClientResponse, Fingerprint, _merge_ssl_params) +from aiohttp.test_utils import make_mocked_coro @pytest.fixture @@ -668,6 +669,7 @@ async def test_content_encoding(loop, conn): req = ClientRequest('post', URL('http://python.org/'), data='foo', compress='deflate', loop=loop) with mock.patch('aiohttp.client_reqrep.StreamWriter') as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() resp = await req.send(conn) assert req.headers['TRANSFER-ENCODING'] == 'chunked' assert req.headers['CONTENT-ENCODING'] == 'deflate' @@ -693,6 +695,7 @@ async def test_content_encoding_header(loop, conn): 'post', URL('http://python.org/'), data='foo', headers={'Content-Encoding': 'deflate'}, loop=loop) with mock.patch('aiohttp.client_reqrep.StreamWriter') as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() resp = await req.send(conn) assert not m_writer.return_value.enable_compression.called @@ -732,6 +735,7 @@ async def test_chunked_explicit(loop, conn): req = ClientRequest( 'post', URL('http://python.org/'), chunked=True, loop=loop) with mock.patch('aiohttp.client_reqrep.StreamWriter') as m_writer: + m_writer.return_value.write_headers = make_mocked_coro() resp = await req.send(conn) assert 'chunked' == req.headers['TRANSFER-ENCODING'] diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 0c280bd5eef..93c9cd76f60 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -24,7 +24,7 @@ def append(data=b''): buf.extend(data) return helpers.noop() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): headers = status_line + ''.join( [k + ': ' + v + '\r\n' for k, v in headers.items()]) headers = headers.encode('utf-8') + b'\r\n' @@ -39,7 +39,7 @@ def write_headers(status_line, headers): app._debug = False app.on_response_prepare = signals.Signal(app) app.on_response_prepare.freeze() - req = make_mocked_request(method, path, app=app, payload_writer=writer) + req = make_mocked_request(method, path, app=app, writer=writer) return req diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 987cdb20da6..ba2f8d1bd09 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -45,7 +45,7 @@ def buffer_data(chunk): def write(chunk): buf.extend(chunk) - def write_headers(status_line, headers): + async def write_headers(status_line, headers): headers = status_line + ''.join( [k + ': ' + v + '\r\n' for k, v in headers.items()]) headers = headers.encode('utf-8') + b'\r\n' @@ -235,7 +235,7 @@ def test_last_modified_reset(): async def test_start(): - req = make_request('GET', '/', payload_writer=mock.Mock()) + req = make_request('GET', '/') resp = StreamResponse() assert resp.keep_alive is None @@ -274,7 +274,7 @@ def test_enable_chunked_encoding_with_content_length(): async def test_chunk_size(): - req = make_request('GET', '/', payload_writer=mock.Mock()) + req = make_request('GET', '/') resp = StreamResponse() assert not resp.chunked @@ -300,7 +300,7 @@ async def test_chunked_encoding_forbidden_for_http_10(): async def test_compression_no_accept(): - req = make_request('GET', '/', payload_writer=mock.Mock()) + req = make_request('GET', '/') resp = StreamResponse() assert not resp.chunked @@ -313,7 +313,7 @@ async def test_compression_no_accept(): async def test_force_compression_no_accept_backwards_compat(): - req = make_request('GET', '/', payload_writer=mock.Mock()) + req = make_request('GET', '/') resp = StreamResponse() assert not resp.chunked @@ -327,7 +327,7 @@ async def test_force_compression_no_accept_backwards_compat(): async def test_force_compression_false_backwards_compat(): - req = make_request('GET', '/', payload_writer=mock.Mock()) + req = make_request('GET', '/') resp = StreamResponse() assert not resp.compression @@ -421,13 +421,13 @@ async def test_change_content_length_if_compression_enabled(): async def test_set_content_length_if_compression_enabled(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH in headers assert headers[hdrs.CONTENT_LENGTH] == '26' assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) resp = Response(body=b'answer') resp.enable_compression(ContentCoding.gzip) @@ -440,12 +440,12 @@ def write_headers(status_line, headers): async def test_remove_content_length_if_compression_enabled_http11(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH not in headers assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked' writer.write_headers.side_effect = write_headers - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) resp = StreamResponse() resp.content_length = 123 resp.enable_compression(ContentCoding.gzip) @@ -456,13 +456,13 @@ def write_headers(status_line, headers): async def test_remove_content_length_if_compression_enabled_http10(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH not in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request('GET', '/', version=HttpVersion10, - payload_writer=writer) + writer=writer) resp = StreamResponse() resp.content_length = 123 resp.enable_compression(ContentCoding.gzip) @@ -473,13 +473,13 @@ def write_headers(status_line, headers): async def test_force_compression_identity(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request('GET', '/', - payload_writer=writer) + writer=writer) resp = StreamResponse() resp.content_length = 123 resp.enable_compression(ContentCoding.identity) @@ -490,28 +490,28 @@ def write_headers(status_line, headers): async def test_force_compression_identity_response(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert headers[hdrs.CONTENT_LENGTH] == "6" assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request('GET', '/', - payload_writer=writer) + writer=writer) resp = Response(body=b'answer') resp.enable_compression(ContentCoding.identity) await resp.prepare(req) assert resp.content_length == 6 -async def test_remove_content_length_if_compression_enabled_on_payload_http11(): # noqa +async def test_rm_content_length_if_compression_enabled_on_payload_http11(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH not in headers assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked' writer.write_headers.side_effect = write_headers - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) payload = BytesPayload(b'answer', headers={"X-Test-Header": "test"}) resp = Response(body=payload) assert resp.content_length == 6 @@ -521,16 +521,16 @@ def write_headers(status_line, headers): assert resp.content_length is None -async def test_remove_content_length_if_compression_enabled_on_payload_http10(): # noqa +async def test_rm_content_length_if_compression_enabled_on_payload_http10(): writer = mock.Mock() - def write_headers(status_line, headers): + async def write_headers(status_line, headers): assert hdrs.CONTENT_LENGTH not in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request('GET', '/', version=HttpVersion10, - payload_writer=writer) + writer=writer) resp = Response(body=BytesPayload(b'answer')) resp.enable_compression(ContentCoding.gzip) await resp.prepare(req) @@ -788,7 +788,7 @@ def test_response_ctor(): assert 'CONTENT-LENGTH' not in resp.headers -def test_ctor_with_headers_and_status(): +async def test_ctor_with_headers_and_status(): resp = Response(body=b'body', status=201, headers={'Age': '12', 'DATE': 'date'}) @@ -796,7 +796,8 @@ def test_ctor_with_headers_and_status(): assert b'body' == resp.body assert resp.headers['AGE'] == '12' - resp._start(mock.Mock(version=HttpVersion11)) + req = make_mocked_request('GET', '/') + await resp._start(req) assert 4 == resp.content_length assert resp.headers['CONTENT-LENGTH'] == '4' @@ -816,7 +817,7 @@ def test_ctor_text_body_combined(): Response(body=b'123', text='test text') -def test_ctor_text(): +async def test_ctor_text(): resp = Response(text='test text') assert 200 == resp.status @@ -829,7 +830,8 @@ def test_ctor_text(): assert resp.text == 'test text' resp.headers['DATE'] = 'date' - resp._start(mock.Mock(version=HttpVersion11)) + req = make_mocked_request('GET', '/', version=HttpVersion11) + await resp._start(req) assert resp.headers['CONTENT-LENGTH'] == '9' @@ -889,7 +891,7 @@ def test_ctor_both_charset_param_and_header(): charset='koi8-r') -def test_assign_nonbyteish_body(): +async def test_assign_nonbyteish_body(): resp = Response(body=b'data') with pytest.raises(ValueError): @@ -898,7 +900,8 @@ def test_assign_nonbyteish_body(): assert 4 == resp.content_length resp.headers['DATE'] = 'date' - resp._start(mock.Mock(version=HttpVersion11)) + req = make_mocked_request('GET', '/', version=HttpVersion11) + await resp._start(req) assert resp.headers['CONTENT-LENGTH'] == '4' assert 4 == resp.content_length @@ -919,7 +922,7 @@ def test_response_set_content_length(): async def test_send_headers_for_empty_body(buf, writer): - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) resp = Response() await resp.prepare(req) @@ -933,7 +936,7 @@ async def test_send_headers_for_empty_body(buf, writer): async def test_render_with_body(buf, writer): - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) resp = Response(body=b'data') await resp.prepare(req) @@ -951,7 +954,7 @@ async def test_render_with_body(buf, writer): async def test_send_set_cookie_header(buf, writer): resp = Response() resp.cookies['name'] = 'value' - req = make_request('GET', '/', payload_writer=writer) + req = make_request('GET', '/', writer=writer) await resp.prepare(req) await resp.write_eof() @@ -966,16 +969,17 @@ async def test_send_set_cookie_header(buf, writer): async def test_consecutive_write_eof(): - payload_writer = mock.Mock() - payload_writer.write_eof = make_mocked_coro() - req = make_request('GET', '/', payload_writer=payload_writer) + writer = mock.Mock() + writer.write_eof = make_mocked_coro() + writer.write_headers = make_mocked_coro() + req = make_request('GET', '/', writer=writer) data = b'data' resp = Response(body=data) await resp.prepare(req) await resp.write_eof() await resp.write_eof() - payload_writer.write_eof.assert_called_once_with(data) + writer.write_eof.assert_called_once_with(data) def test_set_text_with_content_type(): diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index ef91ef7f09a..649ac66e4b4 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -21,16 +21,6 @@ def app(loop): return ret -@pytest.fixture -def writer(loop): - writer = mock.Mock() - writer.drain.return_value = loop.create_future() - writer.drain.return_value.set_result(None) - writer.write_eof.return_value = loop.create_future() - writer.write_eof.return_value.set_result(None) - return writer - - @pytest.fixture def protocol(): ret = mock.Mock() @@ -39,7 +29,7 @@ def protocol(): @pytest.fixture -def make_request(app, protocol, writer): +def make_request(app, protocol): def maker(method, path, headers=None, protocols=False): if headers is None: headers = CIMultiDict( @@ -54,7 +44,7 @@ def maker(method, path, headers=None, protocols=False): return make_mocked_request( method, path, headers, - app=app, protocol=protocol, payload_writer=writer, + app=app, protocol=protocol, loop=app.loop) return maker @@ -301,7 +291,7 @@ async def test_pong_closed(make_request, mocker): assert ws_logger.warning.called -async def test_close_idempotent(make_request, writer): +async def test_close_idempotent(make_request): req = make_request('GET', '/') ws = WebSocketResponse() await ws.prepare(req)