diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 10fbf6ebf52..6fe5bfc183c 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -174,6 +174,80 @@ def url(self, *, filename, query=None): url = self._prefix + filename return self._append_query(url, query) + def _sendfile_cb(self, fut, out_fd, in_fd, offset, count, loop): + loop.remove_writer(out_fd) + try: + n = os.sendfile(out_fd, in_fd, offset, count) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n < count: + loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd, + offset + n, count - n, loop) + else: + fut.set_result(None) + + @asyncio.coroutine + def _sendfile_system(self, req, resp, fobj, offset, count): + """ + Write `count` bytes of `fobj` to `resp` starting from `offset` using + the ``sendfile`` system call. + + `req` should be a :obj:`aiohttp.web.Request` instance. + + `resp` should be a :obj:`aiohttp.web.StreamResponse` instance. + + `fobj` should be an open file object. + + `offset` should be an integer >= 0. + + `count` should be an integer > 0. + """ + yield from resp.drain() + + loop = req.app.loop + out_fd = req.transport.get_extra_info("socket").fileno() + in_fd = fobj.fileno() + fut = asyncio.Future(loop=loop) + + loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd, offset, + count, loop) + + yield from fut + + @asyncio.coroutine + def _sendfile_fallback(self, req, resp, fobj, offset, count): + """ + Mimic the :meth:`_sendfile_system` method, but without using the + ``sendfile`` system call. This should be used on systems that don't + support the ``sendfile`` system call. + + To avoid blocking the event loop & to keep memory usage low, `fobj` is + transferred in chunks controlled by the `chunk_size` argument to + :class:`StaticRoute`. + """ + fobj.seek(offset) + chunk_size = self._chunk_size + + chunk = fobj.read(chunk_size) + while chunk and count > chunk_size: + resp.write(chunk) + yield from resp.drain() + count = count - chunk_size + chunk = fobj.read(chunk_size) + + if chunk: + resp.write(chunk[:count]) + yield from resp.drain() + + if hasattr(os, "sendfile"): + sendfile = _sendfile_system + else: + sendfile = _sendfile_fallback + @asyncio.coroutine def handle(self, request): filename = request.match_info['filename'] @@ -200,20 +274,12 @@ def handle(self, request): resp.last_modified = st.st_mtime file_size = st.st_size - single_chunk = file_size < self._chunk_size - if single_chunk: - resp.content_length = file_size + resp.content_length = file_size resp.start(request) with open(filepath, 'rb') as f: - chunk = f.read(self._chunk_size) - if single_chunk: - resp.write(chunk) - else: - while chunk: - resp.write(chunk) - chunk = f.read(self._chunk_size) + yield from self.sendfile(request, resp, f, 0, file_size) return resp diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 8493bdeec2c..3340ffe6b48 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1,6 +1,7 @@ import asyncio import gc import json +import os import os.path import socket import unittest @@ -10,7 +11,7 @@ from aiohttp.streams import EOF_MARKER -class TestWebFunctional(unittest.TestCase): +class WebFunctionalSetupMixin(unittest.TestCase): def setUp(self): self.handler = None @@ -48,6 +49,9 @@ def create_server(self, method, path, handler=None): self.addCleanup(srv.close) return app, srv, url + +class TestWebFunctional(WebFunctionalSetupMixin): + def test_simple_get(self): @asyncio.coroutine @@ -312,209 +316,6 @@ def go(): self.loop.run_until_complete(go()) - def test_static_file(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - resp = yield from request('GET', url, loop=self.loop) - self.assertEqual(200, resp.status) - txt = yield from resp.text() - self.assertEqual('file content', txt.rstrip()) - ct = resp.headers['CONTENT-TYPE'] - self.assertEqual('application/octet-stream', ct) - self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None) - resp.close() - - resp = yield from request('GET', url + 'fake', loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - resp = yield from request('GET', url + '/../../', loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_with_content_type(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname, chunk_size=16) - - resp = yield from request('GET', url, loop=self.loop) - self.assertEqual(200, resp.status) - body = yield from resp.read() - with open(os.path.join(dirname, filename), 'rb') as f: - content = f.read() - self.assertEqual(content, body) - ct = resp.headers['CONTENT-TYPE'] - self.assertEqual('image/jpeg', ct) - self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None) - resp.close() - - resp = yield from request('GET', url + 'fake', loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - resp = yield from request('GET', url + '/../../', loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'software_development_in_picture.jpg' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_with_content_encoding(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - resp = yield from request('GET', url, loop=self.loop) - self.assertEqual(200, resp.status) - body = yield from resp.read() - self.assertEqual(b'hello aiohttp\n', body) - ct = resp.headers['CONTENT-TYPE'] - self.assertEqual('text/plain', ct) - encoding = resp.headers['CONTENT-ENCODING'] - self.assertEqual('gzip', encoding) - resp.close() - - here = os.path.dirname(__file__) - filename = 'hello.txt.gz' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_directory_traversal_attack(self): - - @asyncio.coroutine - def go(dirname, relpath): - self.assertTrue(os.path.isfile(os.path.join(dirname, relpath))) - - app, _, url = yield from self.create_server('GET', '/static/') - app.router.add_static('/static', dirname) - - url_relpath = url + relpath - resp = yield from request('GET', url_relpath, loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - url_relpath2 = url + 'dir/../' + filename - resp = yield from request('GET', url_relpath2, loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - url_abspath = \ - url + os.path.abspath(os.path.join(dirname, filename)) - resp = yield from request('GET', url_abspath, loop=self.loop) - self.assertEqual(404, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = '../README.rst' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_if_modified_since(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - resp = yield from request('GET', url, loop=self.loop) - self.assertEqual(200, resp.status) - lastmod = resp.headers.get('Last-Modified') - self.assertIsNotNone(lastmod) - resp.close() - - resp = yield from request('GET', url, loop=self.loop, - headers={'If-Modified-Since': lastmod}) - self.assertEqual(304, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_if_modified_since_past_date(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - lastmod = 'Mon, 1 Jan 1990 01:01:01 GMT' - resp = yield from request('GET', url, loop=self.loop, - headers={'If-Modified-Since': lastmod}) - self.assertEqual(200, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_if_modified_since_future_date(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - lastmod = 'Fri, 31 Dec 9999 23:59:59 GMT' - resp = yield from request('GET', url, loop=self.loop, - headers={'If-Modified-Since': lastmod}) - self.assertEqual(304, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - self.loop.run_until_complete(go(here, filename)) - - def test_static_file_if_modified_since_invalid_date(self): - - @asyncio.coroutine - def go(dirname, filename): - app, _, url = yield from self.create_server( - 'GET', '/static/' + filename - ) - app.router.add_static('/static', dirname) - - lastmod = 'not a valid HTTP-date' - resp = yield from request('GET', url, loop=self.loop, - headers={'If-Modified-Since': lastmod}) - self.assertEqual(200, resp.status) - resp.close() - - here = os.path.dirname(__file__) - filename = 'data.unknown_mime_type' - self.loop.run_until_complete(go(here, filename)) - - def test_static_route_path_existence_check(self): - directory = os.path.dirname(__file__) - web.StaticRoute(None, "/", directory) - - nodirectory = os.path.join(directory, "nonexistent-uPNiOEAg5d") - with self.assertRaises(ValueError): - web.StaticRoute(None, "/", nodirectory) - def test_post_form_with_duplicate_keys(self): @asyncio.coroutine @@ -926,3 +727,235 @@ def go(): client.close() self.loop.run_until_complete(go()) + + +class TestStaticFileSendfileFallback(WebFunctionalSetupMixin): + + def patch_sendfile(self, add_static): + def f(*args, **kwargs): + route = add_static(*args, **kwargs) + route.sendfile = route._sendfile_fallback + return route + return f + + @asyncio.coroutine + def create_server(self, method, path): + app, srv, url = yield from super().create_server(method, path) + app.router.add_static = self.patch_sendfile(app.router.add_static) + + return app, srv, url + + def test_static_file(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + resp = yield from request('GET', url, loop=self.loop) + self.assertEqual(200, resp.status) + txt = yield from resp.text() + self.assertEqual('file content', txt.rstrip()) + ct = resp.headers['CONTENT-TYPE'] + self.assertEqual('application/octet-stream', ct) + self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None) + resp.close() + + resp = yield from request('GET', url + 'fake', loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + resp = yield from request('GET', url + '/../../', loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_with_content_type(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname, chunk_size=16) + + resp = yield from request('GET', url, loop=self.loop) + self.assertEqual(200, resp.status) + body = yield from resp.read() + with open(os.path.join(dirname, filename), 'rb') as f: + content = f.read() + self.assertEqual(content, body) + ct = resp.headers['CONTENT-TYPE'] + self.assertEqual('image/jpeg', ct) + self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None) + resp.close() + + resp = yield from request('GET', url + 'fake', loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + resp = yield from request('GET', url + '/../../', loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'software_development_in_picture.jpg' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_with_content_encoding(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + resp = yield from request('GET', url, loop=self.loop) + self.assertEqual(200, resp.status) + body = yield from resp.read() + self.assertEqual(b'hello aiohttp\n', body) + ct = resp.headers['CONTENT-TYPE'] + self.assertEqual('text/plain', ct) + encoding = resp.headers['CONTENT-ENCODING'] + self.assertEqual('gzip', encoding) + resp.close() + + here = os.path.dirname(__file__) + filename = 'hello.txt.gz' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_directory_traversal_attack(self): + + @asyncio.coroutine + def go(dirname, relpath): + self.assertTrue(os.path.isfile(os.path.join(dirname, relpath))) + + app, _, url = yield from self.create_server('GET', '/static/') + app.router.add_static('/static', dirname) + + url_relpath = url + relpath + resp = yield from request('GET', url_relpath, loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + url_relpath2 = url + 'dir/../' + filename + resp = yield from request('GET', url_relpath2, loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + url_abspath = \ + url + os.path.abspath(os.path.join(dirname, filename)) + resp = yield from request('GET', url_abspath, loop=self.loop) + self.assertEqual(404, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = '../README.rst' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_if_modified_since(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + resp = yield from request('GET', url, loop=self.loop) + self.assertEqual(200, resp.status) + lastmod = resp.headers.get('Last-Modified') + self.assertIsNotNone(lastmod) + resp.close() + + resp = yield from request('GET', url, loop=self.loop, + headers={'If-Modified-Since': lastmod}) + self.assertEqual(304, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_if_modified_since_past_date(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + lastmod = 'Mon, 1 Jan 1990 01:01:01 GMT' + resp = yield from request('GET', url, loop=self.loop, + headers={'If-Modified-Since': lastmod}) + self.assertEqual(200, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_if_modified_since_future_date(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + lastmod = 'Fri, 31 Dec 9999 23:59:59 GMT' + resp = yield from request('GET', url, loop=self.loop, + headers={'If-Modified-Since': lastmod}) + self.assertEqual(304, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + + def test_static_file_if_modified_since_invalid_date(self): + + @asyncio.coroutine + def go(dirname, filename): + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename + ) + app.router.add_static('/static', dirname) + + lastmod = 'not a valid HTTP-date' + resp = yield from request('GET', url, loop=self.loop, + headers={'If-Modified-Since': lastmod}) + self.assertEqual(200, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + + def test_static_route_path_existence_check(self): + directory = os.path.dirname(__file__) + web.StaticRoute(None, "/", directory) + + nodirectory = os.path.join(directory, "nonexistent-uPNiOEAg5d") + with self.assertRaises(ValueError): + web.StaticRoute(None, "/", nodirectory) + + +@unittest.skipUnless(hasattr(os, "sendfile"), + "sendfile system call not supported") +class TestStaticFileSendfile(TestStaticFileSendfileFallback): + + def patch_sendfile(self, add_static): + def f(*args, **kwargs): + route = add_static(*args, **kwargs) + route.sendfile = route._sendfile_system + return route + return f