diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..96d79dad6 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +tests/res/chunked.txt binary + diff --git a/tests/res/chunked.txt b/tests/res/chunked.txt new file mode 100644 index 000000000..44aa71f62 --- /dev/null +++ b/tests/res/chunked.txt @@ -0,0 +1,25 @@ +94 +----------------------------898239224156930639461866 +Content-Disposition: form-data; name="file"; filename="test.txt" +Content-Type: text/plain + + +f +This is a test + +2 + + +65 +----------------------------898239224156930639461866 +Content-Disposition: form-data; name="type" + + +a +text/plain +3a + +----------------------------898239224156930639461866-- + +0 + diff --git a/tests/test_serving.py b/tests/test_serving.py index 9ce1ee6ea..fb46a0beb 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -10,6 +10,7 @@ """ import os import ssl +import sys import subprocess import textwrap @@ -294,3 +295,89 @@ def app(environ, start_response): serving.run_simple(hostname='localhost', port='5001', application=app, use_reloader=False) assert 'port must be an integer' in str(excinfo.value) + + +def test_chunked_encoding(dev_server): + server = dev_server(r''' + from werkzeug.wrappers import Request + def app(environ, start_response): + assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' + assert environ.get('wsgi.input_terminated', False) + request = Request(environ) + assert request.mimetype == 'multipart/form-data' + assert request.files['file'].read() == b'This is a test\n' + assert request.form['type'] == 'text/plain' + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'YES'] + ''') + + testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt') + + if sys.version_info[0] == 2: + from httplib import HTTPConnection + else: + from http.client import HTTPConnection + + conn = HTTPConnection('127.0.0.1', server.port) + conn.connect() + conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1) + conn.putheader('Accept', 'text/plain') + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader( + 'Content-Type', + 'multipart/form-data; boundary=' + '--------------------------898239224156930639461866') + conn.endheaders() + + with open(testfile, 'rb') as f: + conn.send(f.read()) + + res = conn.getresponse() + assert res.status == 200 + assert res.read() == b'YES' + + conn.close() + + +def test_chunked_encoding_with_content_length(dev_server): + server = dev_server(r''' + from werkzeug.wrappers import Request + def app(environ, start_response): + assert environ['HTTP_TRANSFER_ENCODING'] == 'chunked' + assert environ.get('wsgi.input_terminated', False) + request = Request(environ) + assert request.mimetype == 'multipart/form-data' + assert request.files['file'].read() == b'This is a test\n' + assert request.form['type'] == 'text/plain' + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'YES'] + ''') + + testfile = os.path.join(os.path.dirname(__file__), 'res', 'chunked.txt') + + if sys.version_info[0] == 2: + from httplib import HTTPConnection + else: + from http.client import HTTPConnection + + conn = HTTPConnection('127.0.0.1', server.port) + conn.connect() + conn.putrequest('POST', '/', skip_host=1, skip_accept_encoding=1) + conn.putheader('Accept', 'text/plain') + conn.putheader('Transfer-Encoding', 'chunked') + # Content-Length is invalid for chunked, but some libraries might send it + conn.putheader('Content-Length', '372') + conn.putheader( + 'Content-Type', + 'multipart/form-data; boundary=' + '--------------------------898239224156930639461866') + conn.endheaders() + + with open(testfile, 'rb') as f: + conn.send(f.read()) + + res = conn.getresponse() + assert res.status == 200 + assert res.read() == b'YES' + + conn.close() diff --git a/werkzeug/formparser.py b/werkzeug/formparser.py index a0118054b..42d796911 100644 --- a/werkzeug/formparser.py +++ b/werkzeug/formparser.py @@ -11,8 +11,7 @@ """ import re import codecs -from io import BytesIO -from tempfile import TemporaryFile +from tempfile import SpooledTemporaryFile from itertools import chain, repeat, tee from functools import update_wrapper @@ -38,9 +37,7 @@ def default_stream_factory(total_content_length, filename, content_type, content_length=None): """The stream factory that is used per default.""" - if total_content_length > 1024 * 500: - return TemporaryFile('wb+') - return BytesIO() + return SpooledTemporaryFile(max_size=1024 * 500, mode='wb+') def parse_form_data(environ, stream_factory=None, charset='utf-8', diff --git a/werkzeug/serving.py b/werkzeug/serving.py index 223f5d4f9..88a0cdb6c 100644 --- a/werkzeug/serving.py +++ b/werkzeug/serving.py @@ -37,6 +37,7 @@ """ from __future__ import with_statement +import io import os import socket import sys @@ -97,6 +98,59 @@ class ForkingMixIn(object): can_open_by_fd = not WIN and hasattr(socket, 'fromfd') +class DechunkedInput(io.RawIOBase): + """An input stream that handles Transfer-Encoding 'chunked'""" + + def __init__(self, rfile): + self._rfile = rfile + self._done = False + self._len = 0 + + def readable(self): + return True + + def read_chunk_len(self): + try: + line = self._rfile.readline().decode('latin1') + _len = int(line.strip(), 16) + except ValueError: + raise IOError('Invalid chunk header') + if _len < 0: + raise IOError('Negative chunk length not allowed') + return _len + + def readinto(self, buf): + read = 0 + while not self._done and read < len(buf): + if self._len == 0: + # This is the first chunk or we fully consumed the previous + # one. Read the next length of the next chunk + self._len = self.read_chunk_len() + + if self._len == 0: + # Found the final chunk of size 0. The stream is now exhausted, + # but there is still a final newline that should be consumed + self._done = True + + if self._len > 0: + # There is data (left) in this chunk, so append it to the + # buffer. If this operation fully consumes the chunk, this will + # reset self._len to 0. + n = min(len(buf), self._len) + buf[read:read + n] = self._rfile.read(n) + self._len -= n + read += n + + if self._len == 0: + # Skip the terminating newline of a chunk that has been fully + # consumed. This also applies to the 0-sized final chunk + terminator = self._rfile.readline() + if terminator not in (b'\n', b'\r\n', b'\r'): + raise IOError('Missing chunk terminating newline') + + return read + + class WSGIRequestHandler(BaseHTTPRequestHandler, object): """A request handler that implements WSGI dispatching.""" @@ -141,6 +195,10 @@ def shutdown_server(): key = 'HTTP_' + key environ[key] = value + if environ.get('HTTP_TRANSFER_ENCODING', '').strip().lower() == 'chunked': + environ['wsgi.input_terminated'] = True + environ['wsgi.input'] = DechunkedInput(environ['wsgi.input']) + if request_url.scheme and request_url.netloc: environ['HTTP_HOST'] = request_url.netloc diff --git a/werkzeug/wsgi.py b/werkzeug/wsgi.py index 4df4e45b8..c5f09d392 100644 --- a/werkzeug/wsgi.py +++ b/werkzeug/wsgi.py @@ -166,12 +166,16 @@ def get_host(environ, trusted_hosts=None): def get_content_length(environ): """Returns the content length from the WSGI environment as - integer. If it's not available ``None`` is returned. + integer. If it's not available or chunked transfer encoding is used, + ``None`` is returned. .. versionadded:: 0.9 :param environ: the WSGI environ to fetch the content length from. """ + if environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked': + return None + content_length = environ.get('CONTENT_LENGTH') if content_length is not None: try: