diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index cb31f1f89b3..a1c08711c16 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -489,7 +489,8 @@ def make_mocked_request(method, path, headers=None, *, transport=sentinel, payload=sentinel, sslcontext=None, - secure_proxy_ssl_header=None): + secure_proxy_ssl_header=None, + client_max_size=1024**2): """Creates mocked web.Request testing purposes. Useful in unit tests, when spinning full web server is overkill or @@ -547,9 +548,11 @@ def timeout(*args, **kw): loop = mock.Mock() loop.create_future.return_value = () - req = Request(message, payload, protocol, - time_service, task, loop=loop, - secure_proxy_ssl_header=secure_proxy_ssl_header) + req = Request(message, payload, + protocol, time_service, task, + loop=loop, + secure_proxy_ssl_header=secure_proxy_ssl_header, + client_max_size=client_max_size) match_info = UrlMappingMatchInfo({}, mock.Mock()) match_info.add_app(app) diff --git a/aiohttp/web.py b/aiohttp/web.py index f392ea02539..45c7d46f139 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -37,7 +37,9 @@ class Application(MutableMapping): def __init__(self, *, logger=web_logger, loop=None, - router=None, middlewares=(), handler_args=None, debug=...): + + router=None, middlewares=(), handler_args=None, debug=..., + client_max_size=1024**2): if loop is None: loop = asyncio.get_event_loop() if router is None: @@ -64,6 +66,7 @@ def __init__(self, *, logger=web_logger, loop=None, self._on_startup = Signal(self) self._on_shutdown = Signal(self) self._on_cleanup = Signal(self) + self._client_max_size = client_max_size # MutableMapping API @@ -238,7 +241,8 @@ def _make_request(self, message, payload, protocol, message, payload, protocol, protocol._time_service, protocol._request_handler, loop=self._loop, - secure_proxy_ssl_header=self._secure_proxy_ssl_header) + secure_proxy_ssl_header=self._secure_proxy_ssl_header, + client_max_size=self._client_max_size) @asyncio.coroutine def _handle(self, request): diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 61aa4b54a7e..3c951dfcbde 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -17,6 +17,7 @@ from yarl import URL from . import hdrs, multipart + from .helpers import HeadersMixin, SimpleCookie, reify, sentinel from .protocol import (RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11, PayloadWriter) @@ -50,7 +51,8 @@ class BaseRequest(collections.MutableMapping, HeadersMixin): hdrs.METH_TRACE, hdrs.METH_DELETE} def __init__(self, message, payload, protocol, time_service, task, *, - loop=None, secure_proxy_ssl_header=None): + loop=None, secure_proxy_ssl_header=None, + client_max_size=1024**2): self._loop = loop self._message = message self._protocol = protocol @@ -70,6 +72,7 @@ def __init__(self, message, payload, protocol, time_service, task, *, self._state = {} self._cache = {} self._task = task + self._client_max_size = client_max_size self.rel_url = message.url @@ -372,6 +375,11 @@ def read(self): while True: chunk = yield from self._payload.readany() body.extend(chunk) + if self._client_max_size \ + and len(body) >= self._client_max_size: + # local import to avoid circular imports + from aiohttp import web_exceptions + raise web_exceptions.HTTPRequestEntityTooLarge if not chunk: break self._read_bytes = bytes(body) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index e91bee18e7d..1706453754f 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1205,3 +1205,72 @@ def handler(request): resp = yield from client.get('/') assert 200 == resp.status + + +@asyncio.coroutine +def test_app_max_client_size(loop, test_client): + + @asyncio.coroutine + def handler(request): + yield from request.post() + return web.Response(body=b'ok') + + max_size = 1024**2 + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + data = {"long_string": max_size * 'x' + 'xxx'} + resp = yield from client.post('/', data=data) + assert 413 == resp.status + resp_text = yield from resp.text() + assert 'Request Entity Too Large' in resp_text + + +@asyncio.coroutine +def test_app_max_client_size_adjusted(loop, test_client): + + @asyncio.coroutine + def handler(request): + yield from request.post() + return web.Response(body=b'ok') + + default_max_size = 1024**2 + custom_max_size = default_max_size * 2 + app = web.Application(loop=loop, client_max_size=custom_max_size) + app.router.add_post('/', handler) + client = yield from test_client(app) + data = {'long_string': default_max_size * 'x' + 'xxx'} + resp = yield from client.post('/', data=data) + assert 200 == resp.status + resp_text = yield from resp.text() + assert 'ok' == resp_text + too_large_data = {'log_string': custom_max_size * 'x' + "xxx"} + resp = yield from client.post('/', data=too_large_data) + assert 413 == resp.status + resp_text = yield from resp.text() + assert 'Request Entity Too Large' in resp_text + + +@asyncio.coroutine +def test_app_max_client_size_none(loop, test_client): + + @asyncio.coroutine + def handler(request): + yield from request.post() + return web.Response(body=b'ok') + + default_max_size = 1024**2 + custom_max_size = None + app = web.Application(loop=loop, client_max_size=custom_max_size) + app.router.add_post('/', handler) + client = yield from test_client(app) + data = {'long_string': default_max_size * 'x' + 'xxx'} + resp = yield from client.post('/', data=data) + assert 200 == resp.status + resp_text = yield from resp.text() + assert 'ok' == resp_text + too_large_data = {'log_string': default_max_size * 2 * 'x'} + resp = yield from client.post('/', data=too_large_data) + assert 200 == resp.status + resp_text = yield from resp.text() + assert resp_text == 'ok' diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 644d8694836..0aca2acd531 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -9,6 +9,7 @@ from aiohttp.protocol import HttpVersion from aiohttp.streams import StreamReader from aiohttp.test_utils import make_mocked_request +from aiohttp.web_exceptions import HTTPRequestEntityTooLarge @pytest.fixture @@ -314,3 +315,45 @@ def test_cannot_clone_after_read(loop): yield from req.read() with pytest.raises(RuntimeError): req.clone() + + +@asyncio.coroutine +def test_make_too_big_request(loop): + payload = StreamReader(loop=loop) + large_file = 1024 ** 2 * b'x' + too_large_file = large_file + b'x' + payload.feed_data(too_large_file) + payload.feed_eof() + req = make_mocked_request('POST', '/', payload=payload) + with pytest.raises(HTTPRequestEntityTooLarge) as err: + yield from req.read() + + assert err.value.status_code == 413 + + +@asyncio.coroutine +def test_make_too_big_request_adjust_limit(loop): + payload = StreamReader(loop=loop) + large_file = 1024 ** 2 * b'x' + too_large_file = large_file + b'x' + payload.feed_data(too_large_file) + payload.feed_eof() + max_size = 1024**2 + 2 + req = make_mocked_request('POST', '/', payload=payload, + client_max_size=max_size) + txt = yield from req.read() + assert len(txt) == 1024**2 + 1 + + +@asyncio.coroutine +def test_make_too_big_request_limit_None(loop): + payload = StreamReader(loop=loop) + large_file = 1024 ** 2 * b'x' + too_large_file = large_file + b'x' + payload.feed_data(too_large_file) + payload.feed_eof() + max_size = None + req = make_mocked_request('POST', '/', payload=payload, + client_max_size=max_size) + txt = yield from req.read() + assert len(txt) == 1024**2 + 1