From caa6bdb2484bf821b52a65322efa98a889b593de Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 20 Feb 2017 21:51:01 -0800 Subject: [PATCH] refactor MultipartWriter to use Payload --- aiohttp/__init__.py | 6 +- aiohttp/client_reqrep.py | 85 +++--- aiohttp/formdata.py | 122 ++++++++ aiohttp/helpers.py | 155 +++------- aiohttp/multipart.py | 492 +++++++++++------------------ aiohttp/payload.py | 48 ++- tests/test_client_functional.py | 54 +++- tests/test_client_request.py | 4 +- tests/test_formdata.py | 77 +++++ tests/test_helpers.py | 79 ++--- tests/test_multipart.py | 526 +++++++++++++------------------- 11 files changed, 786 insertions(+), 862 deletions(-) create mode 100644 aiohttp/formdata.py create mode 100644 tests/test_formdata.py diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 542bbae28d7..c0ee9ee514a 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -4,6 +4,7 @@ from . import hdrs # noqa from .client import * # noqa +from .formdata import * # noqa from .helpers import * # noqa from .http_message import HttpVersion, HttpVersion10, HttpVersion11 # noqa from .http_websocket import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa @@ -25,11 +26,12 @@ __all__ = (client.__all__ + # noqa + formdata.__all__ + # noqa helpers.__all__ + # noqa - streams.__all__ + # noqa + multipart.__all__ + # noqa payload.__all__ + # noqa payload_streamer.__all__ + # noqa - multipart.__all__ + # noqa + streams.__all__ + # noqa ('hdrs', 'FileSender', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', 'WSMsgType', 'MsgType', 'WSCloseCode', diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 18c4fce5ca3..557f4fc1c83 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -12,10 +12,10 @@ import aiohttp from . import hdrs, helpers, http, payload +from .formdata import FormData from .helpers import PY_35, HeadersMixin, SimpleCookie, _TimeServiceTimeoutNoop from .http import HttpMessage from .log import client_logger -from .multipart import MultipartWriter from .streams import FlowControlStreamReader try: @@ -217,71 +217,54 @@ def update_auth(self, auth): self.headers[hdrs.AUTHORIZATION] = auth.encode() - def update_body_from_data(self, data, skip_auto_headers): - if not data: - return - - try: - self.body = payload.PAYLOAD_REGISTRY.get(data) - - # enable chunked encoding if needed - if not self.chunked: - if hdrs.CONTENT_LENGTH not in self.headers: - size = self.body.size - if size is None: - self.chunked = True - else: - if hdrs.CONTENT_LENGTH not in self.headers: - self.headers[hdrs.CONTENT_LENGTH] = str(size) - - # set content-type - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = self.body.content_type - - # copy payload headers - if self.body.headers: - for (key, value) in self.body.headers.items(): - if key not in self.headers: - self.headers[key] = value - - except payload.LookupError: - pass - else: + def update_body_from_data(self, body, skip_auto_headers): + if not body: return - if asyncio.iscoroutine(data): + if asyncio.iscoroutine(body): warnings.warn( 'coroutine as data object is deprecated, ' 'use aiohttp.streamer #1664', DeprecationWarning, stacklevel=2) - self.body = data + self.body = body if (hdrs.CONTENT_LENGTH not in self.headers and self.chunked is None): self.chunked = True - elif isinstance(data, MultipartWriter): - self.body = data.serialize() - self.headers.update(data.headers) - self.chunked = True + return - else: - if not isinstance(data, helpers.FormData): - data = helpers.FormData(data) + # FormData + if isinstance(body, FormData): + body = body(self.encoding) - self.body = data(self.encoding) + try: + body = payload.PAYLOAD_REGISTRY.get(body) + except payload.LookupError: + body = FormData(body)(self.encoding) - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = data.content_type + self.body = body - if data.is_multipart: - self.chunked = True - else: - if (hdrs.CONTENT_LENGTH not in self.headers and - not self.chunked): - self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) + # enable chunked encoding if needed + if not self.chunked: + if hdrs.CONTENT_LENGTH not in self.headers: + size = body.size + if size is None: + self.chunked = True + else: + if hdrs.CONTENT_LENGTH not in self.headers: + self.headers[hdrs.CONTENT_LENGTH] = str(size) + + # set content-type + if (hdrs.CONTENT_TYPE not in self.headers and + hdrs.CONTENT_TYPE not in skip_auto_headers): + self.headers[hdrs.CONTENT_TYPE] = body.content_type + + # copy payload headers + if body.headers: + for (key, value) in body.headers.items(): + if key not in self.headers: + self.headers[key] = value def update_transfer_encoding(self): """Analyze transfer-encoding header.""" diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py new file mode 100644 index 00000000000..86550952366 --- /dev/null +++ b/aiohttp/formdata.py @@ -0,0 +1,122 @@ +import io +from urllib.parse import urlencode + +from multidict import MultiDict, MultiDictProxy + +from . import hdrs, multipart, payload +from .helpers import guess_filename + +__all__ = ('FormData',) + + +class FormData: + """Helper class for multipart/form-data and + application/x-www-form-urlencoded body generation.""" + + def __init__(self, fields=(), quote_fields=True): + self._writer = multipart.MultipartWriter('form-data') + self._fields = [] + self._is_multipart = False + self._quote_fields = quote_fields + + if isinstance(fields, dict): + fields = list(fields.items()) + elif not isinstance(fields, (list, tuple)): + fields = (fields,) + self.add_fields(*fields) + + def add_field(self, name, value, *, content_type=None, filename=None, + content_transfer_encoding=None): + + if isinstance(value, io.IOBase): + self._is_multipart = True + elif isinstance(value, (bytes, bytearray, memoryview)): + if filename is None and content_transfer_encoding is None: + filename = name + + type_options = MultiDict({'name': name}) + if filename is not None and not isinstance(filename, str): + raise TypeError('filename must be an instance of str. ' + 'Got: %s' % filename) + if filename is None and isinstance(value, io.IOBase): + filename = guess_filename(value, name) + if filename is not None: + type_options['filename'] = filename + self._is_multipart = True + + headers = {} + if content_type is not None: + if not isinstance(content_type, str): + raise TypeError('content_type must be an instance of str. ' + 'Got: %s' % content_type) + headers[hdrs.CONTENT_TYPE] = content_type + self._is_multipart = True + if content_transfer_encoding is not None: + if not isinstance(content_transfer_encoding, str): + raise TypeError('content_transfer_encoding must be an instance' + ' of str. Got: %s' % content_transfer_encoding) + headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding + self._is_multipart = True + + self._fields.append((type_options, headers, value)) + + def add_fields(self, *fields): + to_add = list(fields) + + while to_add: + rec = to_add.pop(0) + + if isinstance(rec, io.IOBase): + k = guess_filename(rec, 'unknown') + self.add_field(k, rec) + + elif isinstance(rec, (MultiDictProxy, MultiDict)): + to_add.extend(rec.items()) + + elif isinstance(rec, (list, tuple)) and len(rec) == 2: + k, fp = rec + self.add_field(k, fp) + + else: + raise TypeError('Only io.IOBase, multidict and (name, file) ' + 'pairs allowed, use .add_field() for passing ' + 'more complex parameters, got {!r}' + .format(rec)) + + def _gen_form_urlencoded(self, encoding): + # form data (x-www-form-urlencoded) + data = [] + for type_options, _, value in self._fields: + data.append((type_options['name'], value)) + + return payload.BytesPayload( + urlencode(data, doseq=True).encode(encoding), + content_type='application/x-www-form-urlencoded') + + def _gen_form_data(self, encoding): + """Encode a list of fields using the multipart/form-data MIME format""" + for dispparams, headers, value in self._fields: + if hdrs.CONTENT_TYPE in headers: + part = payload.get_payload( + value, content_type=headers[hdrs.CONTENT_TYPE], + headers=headers, encoding=encoding) + else: + part = payload.get_payload( + value, headers=headers, encoding=encoding) + if dispparams: + part.set_content_disposition( + 'form-data', quote_fields=self._quote_fields, **dispparams + ) + # FIXME cgi.FieldStorage doesn't likes body parts with + # Content-Length which were sent via chunked transfer encoding + part.headers.pop(hdrs.CONTENT_LENGTH, None) + + self._writer.append_payload(part) + + return self._writer + + def __call__(self, encoding): + if self._is_multipart: + return self._gen_form_data(encoding) + else: + return self._gen_form_urlencoded(encoding) diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 98a749409ab..6da1edb41ba 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -7,7 +7,6 @@ import datetime import functools import heapq -import io import os import re import sys @@ -16,10 +15,9 @@ from functools import total_ordering from pathlib import Path from time import gmtime -from urllib.parse import urlencode +from urllib.parse import quote from async_timeout import timeout -from multidict import MultiDict, MultiDictProxy from . import hdrs @@ -38,7 +36,7 @@ from .backport_cookies import SimpleCookie # noqa -__all__ = ('BasicAuth', 'create_future', 'FormData', 'parse_mimetype', +__all__ = ('BasicAuth', 'create_future', 'parse_mimetype', 'Timeout', 'ensure_future', 'noop') @@ -46,6 +44,13 @@ Timeout = timeout NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) +CHAR = set(chr(i) for i in range(0, 128)) +CTL = set(chr(i) for i in range(0, 32)) | {chr(127), } +SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', + '?', '=', '{', '}', ' ', chr(9)} +TOKEN = CHAR ^ CTL ^ SEPARATORS + + if sys.version_info < (3, 5): noop = tuple else: @@ -112,121 +117,6 @@ def create_future(loop): return asyncio.Future(loop=loop) -class FormData: - """Helper class for multipart/form-data and - application/x-www-form-urlencoded body generation.""" - - def __init__(self, fields=(), quote_fields=True): - from . import multipart - self._writer = multipart.MultipartWriter('form-data') - self._fields = [] - self._is_multipart = False - self._quote_fields = quote_fields - - if isinstance(fields, dict): - fields = list(fields.items()) - elif not isinstance(fields, (list, tuple)): - fields = (fields,) - self.add_fields(*fields) - - @property - def is_multipart(self): - return self._is_multipart - - @property - def content_type(self): - if self._is_multipart: - return self._writer.headers[hdrs.CONTENT_TYPE] - else: - return 'application/x-www-form-urlencoded' - - def add_field(self, name, value, *, content_type=None, filename=None, - content_transfer_encoding=None): - - if isinstance(value, io.IOBase): - self._is_multipart = True - elif isinstance(value, (bytes, bytearray, memoryview)): - if filename is None and content_transfer_encoding is None: - filename = name - - type_options = MultiDict({'name': name}) - if filename is not None and not isinstance(filename, str): - raise TypeError('filename must be an instance of str. ' - 'Got: %s' % filename) - if filename is None and isinstance(value, io.IOBase): - filename = guess_filename(value, name) - if filename is not None: - type_options['filename'] = filename - self._is_multipart = True - - headers = {} - if content_type is not None: - if not isinstance(content_type, str): - raise TypeError('content_type must be an instance of str. ' - 'Got: %s' % content_type) - headers[hdrs.CONTENT_TYPE] = content_type - self._is_multipart = True - if content_transfer_encoding is not None: - if not isinstance(content_transfer_encoding, str): - raise TypeError('content_transfer_encoding must be an instance' - ' of str. Got: %s' % content_transfer_encoding) - headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding - self._is_multipart = True - - self._fields.append((type_options, headers, value)) - - def add_fields(self, *fields): - to_add = list(fields) - - while to_add: - rec = to_add.pop(0) - - if isinstance(rec, io.IOBase): - k = guess_filename(rec, 'unknown') - self.add_field(k, rec) - - elif isinstance(rec, (MultiDictProxy, MultiDict)): - to_add.extend(rec.items()) - - elif isinstance(rec, (list, tuple)) and len(rec) == 2: - k, fp = rec - self.add_field(k, fp) - - else: - raise TypeError('Only io.IOBase, multidict and (name, file) ' - 'pairs allowed, use .add_field() for passing ' - 'more complex parameters, got {!r}' - .format(rec)) - - def _gen_form_urlencoded(self, encoding): - # form data (x-www-form-urlencoded) - data = [] - for type_options, _, value in self._fields: - data.append((type_options['name'], value)) - - data = urlencode(data, doseq=True) - return data.encode(encoding) - - def _gen_form_data(self, *args, **kwargs): - """Encode a list of fields using the multipart/form-data MIME format""" - for dispparams, headers, value in self._fields: - part = self._writer.append(value, headers) - if dispparams: - part.set_content_disposition( - 'form-data', quote_fields=self._quote_fields, **dispparams - ) - # FIXME cgi.FieldStorage doesn't likes body parts with - # Content-Length which were sent via chunked transfer encoding - part.headers.pop(hdrs.CONTENT_LENGTH, None) - yield from self._writer.serialize() - - def __call__(self, encoding): - if self._is_multipart: - return self._gen_form_data(encoding) - else: - return self._gen_form_urlencoded(encoding) - - def parse_mimetype(mimetype): """Parses a MIME type into its components. @@ -271,6 +161,33 @@ def guess_filename(obj, default=None): return default +def content_disposition_header(disptype, quote_fields=True, **params): + """Sets ``Content-Disposition`` header. + + :param str disptype: Disposition type: inline, attachment, form-data. + Should be valid extension token (see RFC 2183) + :param dict params: Disposition params + """ + if not disptype or not (TOKEN > set(disptype)): + raise ValueError('bad content disposition type {!r}' + ''.format(disptype)) + + value = disptype + if params: + lparams = [] + for key, val in params.items(): + if not key or not (TOKEN > set(key)): + raise ValueError('bad content disposition parameter' + ' {!r}={!r}'.format(key, val)) + qval = quote(val, '') if quote_fields else val + lparams.append((key, '"%s"' % qval)) + if key == 'filename': + lparams.append(('filename*', "utf-8''" + qval)) + sparams = '; '.join('='.join(pair) for pair in lparams) + value = '; '.join((value, sparams)) + return value + + class AccessLogger: """Helper object to log access. diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 933f47e6686..7952905c681 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,38 +1,28 @@ import asyncio import base64 import binascii -import io import json -import mimetypes -import os import re import uuid import warnings import zlib from collections import Mapping, Sequence, deque -from pathlib import Path -from urllib.parse import parse_qsl, quote, unquote, urlencode +from urllib.parse import parse_qsl, unquote, urlencode from multidict import CIMultiDict from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) -from .helpers import PY_35, PY_352, parse_mimetype +from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype from .http import HttpParser +from .payload import (BytesPayload, LookupError, Payload, StringPayload, + get_payload) -__all__ = ('MultipartReader', 'MultipartWriter', - 'BodyPartReader', 'BodyPartWriter', +__all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader', 'BadContentDispositionHeader', 'BadContentDispositionParam', 'parse_content_disposition', 'content_disposition_filename') -CHAR = set(chr(i) for i in range(0, 128)) -CTL = set(chr(i) for i in range(0, 32)) | {chr(127), } -SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', - '?', '=', '{', '}', ' ', chr(9)} -TOKEN = CHAR ^ CTL ^ SEPARATORS - - class BadContentDispositionHeader(RuntimeWarning): pass @@ -153,33 +143,6 @@ def content_disposition_filename(params): return value -def content_disposition_header(disptype, quote_fields=True, **params): - """Sets ``Content-Disposition`` header. - - :param str disptype: Disposition type: inline, attachment, form-data. - Should be valid extension token (see RFC 2183) - :param dict params: Disposition params - """ - if not disptype or not (TOKEN > set(disptype)): - raise ValueError('bad content disposition type {!r}' - ''.format(disptype)) - - value = disptype - if params: - lparams = [] - for key, val in params.items(): - if not key or not (TOKEN > set(key)): - raise ValueError('bad content disposition parameter' - ' {!r}={!r}'.format(key, val)) - qval = quote(val, '') if quote_fields else val - lparams.append((key, '"%s"' % qval)) - if key == 'filename': - lparams.append(('filename*', "utf-8''" + qval)) - sparams = '; '.join('='.join(pair) for pair in lparams) - value = '; '.join((value, sparams)) - return value - - class MultipartResponseWrapper(object): """Wrapper around the :class:`MultipartBodyReader` to take care about underlying connection and close it when it needs in.""" @@ -695,253 +658,22 @@ def _maybe_release_last_part(self): self._last_part = None -class BodyPartWriter(object): - """Multipart writer for single body part.""" - - def __init__(self, obj, headers=None, *, chunk_size=8192): - if isinstance(obj, MultipartWriter): - if headers is not None: - obj.headers.update(headers) - headers = obj.headers - elif headers is None: - headers = CIMultiDict() - elif not isinstance(headers, CIMultiDict): - headers = CIMultiDict(headers) - - self.obj = obj - self.headers = headers - self._chunk_size = chunk_size - self._fill_headers_with_defaults() - - self._serialize_map = { - bytes: self._serialize_bytes, - str: self._serialize_str, - io.IOBase: self._serialize_io, - MultipartWriter: self._serialize_multipart, - ('application', 'json'): self._serialize_json, - ('application', 'x-www-form-urlencoded'): self._serialize_form - } - self._validate_obj(obj, headers) - - def _validate_obj(self, obj, headers): - mtype, stype, *_ = parse_mimetype(headers.get(CONTENT_TYPE)) - if (mtype, stype) in self._serialize_map: - return - for key in self._serialize_map: - if isinstance(key, tuple): - continue - if isinstance(obj, key): - return - else: - raise TypeError('unexpected body part value type %r' % type(obj)) - - def _fill_headers_with_defaults(self): - if CONTENT_TYPE not in self.headers: - content_type = self._guess_content_type(self.obj) - if content_type is not None: - self.headers[CONTENT_TYPE] = content_type - - if CONTENT_LENGTH not in self.headers: - content_length = self._guess_content_length(self.obj) - if content_length is not None: - self.headers[CONTENT_LENGTH] = str(content_length) - - if CONTENT_DISPOSITION not in self.headers: - filename = self._guess_filename(self.obj) - if filename is not None: - self.set_content_disposition('attachment', filename=filename) - - def _guess_content_length(self, obj): - if isinstance(obj, bytes): - return len(obj) - elif isinstance(obj, str): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - charset = params.get('charset', 'us-ascii') - return len(obj.encode(charset)) - elif isinstance(obj, io.StringIO): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - charset = params.get('charset', 'us-ascii') - return len(obj.getvalue().encode(charset)) - obj.tell() - elif isinstance(obj, io.BytesIO): - return len(obj.getvalue()) - obj.tell() - elif isinstance(obj, io.IOBase): - try: - return os.fstat(obj.fileno()).st_size - obj.tell() - except (AttributeError, OSError): - return None - else: - return None - - def _guess_content_type(self, obj, default='application/octet-stream'): - if hasattr(obj, 'name'): - name = getattr(obj, 'name') - return mimetypes.guess_type(name)[0] - elif isinstance(obj, (str, io.StringIO)): - return 'text/plain; charset=utf-8' - else: - return default - - def _guess_filename(self, obj): - if isinstance(obj, io.IOBase): - name = getattr(obj, 'name', None) - if name is not None: - return Path(name).name - - def serialize(self): - """Yields byte chunks for body part.""" - - has_encoding = ( - CONTENT_ENCODING in self.headers and - self.headers[CONTENT_ENCODING] != 'identity' or - CONTENT_TRANSFER_ENCODING in self.headers - ) - if has_encoding: - # since we're following streaming approach which doesn't assumes - # any intermediate buffers, we cannot calculate real content length - # with the specified content encoding scheme. So, instead of lying - # about content length and cause reading issues, we have to strip - # this information. - self.headers.pop(CONTENT_LENGTH, None) - - if self.headers: - yield b'\r\n'.join( - b': '.join(map(lambda i: i.encode('latin1'), item)) - for item in self.headers.items() - ) - yield b'\r\n\r\n' - yield from self._maybe_encode_stream(self._serialize_obj()) - yield b'\r\n' - - def _serialize_obj(self): - obj = self.obj - mtype, stype, *_ = parse_mimetype(self.headers.get(CONTENT_TYPE)) - serializer = self._serialize_map.get((mtype, stype)) - if serializer is not None: - return serializer(obj) - - for key in self._serialize_map: - if not isinstance(key, tuple) and isinstance(obj, key): - return self._serialize_map[key](obj) - return self._serialize_default(obj) - - def _serialize_bytes(self, obj): - yield obj - - def _serialize_str(self, obj): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - yield obj.encode(params.get('charset', 'us-ascii')) - - def _serialize_io(self, obj): - while True: - chunk = obj.read(self._chunk_size) - if not chunk: - break - if isinstance(chunk, str): - yield from self._serialize_str(chunk) - else: - yield from self._serialize_bytes(chunk) - - def _serialize_multipart(self, obj): - yield from obj.serialize() - - def _serialize_json(self, obj): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - yield json.dumps(obj).encode(params.get('charset', 'utf-8')) - - def _serialize_form(self, obj): - if isinstance(obj, Mapping): - obj = list(obj.items()) - return self._serialize_str(urlencode(obj, doseq=True)) - - def _serialize_default(self, obj): - raise TypeError('unknown body part type %r' % type(obj)) - - def _maybe_encode_stream(self, stream): - if CONTENT_ENCODING in self.headers: - stream = self._apply_content_encoding(stream) - if CONTENT_TRANSFER_ENCODING in self.headers: - stream = self._apply_content_transfer_encoding(stream) - yield from stream - - def _apply_content_encoding(self, stream): - encoding = self.headers[CONTENT_ENCODING].lower() - if encoding == 'identity': - yield from stream - elif encoding in ('deflate', 'gzip'): - if encoding == 'gzip': - zlib_mode = 16 + zlib.MAX_WBITS - else: - zlib_mode = -zlib.MAX_WBITS - zcomp = zlib.compressobj(wbits=zlib_mode) - for chunk in stream: - yield zcomp.compress(chunk) - else: - yield zcomp.flush() - else: - raise RuntimeError('unknown content encoding: {}' - ''.format(encoding)) - - def _apply_content_transfer_encoding(self, stream): - encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower() - if encoding == 'base64': - buffer = bytearray() - while True: - if buffer: - div, mod = divmod(len(buffer), 3) - chunk, buffer = buffer[:div * 3], buffer[div * 3:] - if chunk: - yield base64.b64encode(chunk) - chunk = next(stream, None) - if not chunk: - if buffer: - yield base64.b64encode(buffer[:]) - return - buffer.extend(chunk) - elif encoding == 'quoted-printable': - for chunk in stream: - yield binascii.b2a_qp(chunk) - elif encoding == 'binary': - yield from stream - else: - raise RuntimeError('unknown content transfer encoding: {}' - ''.format(encoding)) - - def set_content_disposition(self, disptype, quote_fields=True, **params): - """Sets ``Content-Disposition`` header. - - :param str disptype: Disposition type: inline, attachment, form-data. - Should be valid extension token (see RFC 2183) - :param dict params: Disposition params - """ - self.headers[CONTENT_DISPOSITION] = content_disposition_header( - disptype, quote_fields=quote_fields, **params) - - @property - def filename(self): - """Returns filename specified in Content-Disposition header or ``None`` - if missed.""" - _, params = parse_content_disposition( - self.headers.get(CONTENT_DISPOSITION)) - return content_disposition_filename(params) - - -class MultipartWriter(object): +class MultipartWriter(Payload): """Multipart body writer.""" - #: Body part reader class for non multipart/* content types. - part_writer_cls = BodyPartWriter - def __init__(self, subtype='mixed', boundary=None): boundary = boundary if boundary is not None else uuid.uuid4().hex try: - boundary.encode('us-ascii') + self._boundary = boundary.encode('us-ascii') except UnicodeEncodeError: raise ValueError('boundary should contains ASCII only chars') - self.headers = CIMultiDict() - self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format( - subtype, boundary - ) - self.parts = [] + ctype = 'multipart/{}; boundary="{}"'.format(subtype, boundary) + + super().__init__(None, content_type=ctype) + + self._parts = [] + self._headers = CIMultiDict() + self._headers[CONTENT_TYPE] = self.content_type def __enter__(self): return self @@ -950,53 +682,191 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def __iter__(self): - return iter(self.parts) + return iter(self._parts) def __len__(self): - return len(self.parts) + return len(self._parts) @property def boundary(self): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - return params['boundary'].encode('us-ascii') + return self._boundary def append(self, obj, headers=None): - """Adds a new body part to multipart writer.""" - if isinstance(obj, self.part_writer_cls): - if headers: + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Payload): + if obj.headers is not None: obj.headers.update(headers) - self.parts.append(obj) + else: + obj._headers = headers + self.append_payload(obj) else: - if not headers: - headers = CIMultiDict() - self.parts.append(self.part_writer_cls(obj, headers)) - return self.parts[-1] + try: + self.append_payload(get_payload(obj, headers=headers)) + except LookupError: + raise TypeError + + def append_payload(self, payload): + """Adds a new body part to multipart writer.""" + # content-type + if CONTENT_TYPE not in payload.headers: + payload.headers[CONTENT_TYPE] = payload.content_type + + # compression + encoding = payload.headers.get(CONTENT_ENCODING, '').lower() + if encoding and encoding not in ('deflate', 'gzip', 'identity'): + raise RuntimeError('unknown content encoding: {}'.format(encoding)) + if encoding == 'identity': + encoding = None + + # te encoding + te_encoding = payload.headers.get( + CONTENT_TRANSFER_ENCODING, '').lower() + if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'): + raise RuntimeError('unknown content transfer encoding: {}' + ''.format(te_encoding)) + if te_encoding == 'binary': + te_encoding = None + + # size + size = payload.size + if size is not None and not (encoding or te_encoding): + payload.headers[CONTENT_LENGTH] = str(size) + + # render headers + headers = ''.join( + [k + ': ' + v + '\r\n' for k, v in payload.headers.items()] + ).encode('utf-8') + b'\r\n' + + self._parts.append((payload, headers, encoding, te_encoding)) def append_json(self, obj, headers=None): """Helper to append JSON part.""" - if not headers: + if headers is None: headers = CIMultiDict() - headers[CONTENT_TYPE] = 'application/json' - return self.append(obj, headers) + + *_, params = parse_mimetype(headers.get(CONTENT_TYPE)) + charset = params.get('charset', 'utf-8') + + data = json.dumps(obj).encode(charset) + self.append_payload( + BytesPayload( + data, headers=headers, content_type='application/json')) def append_form(self, obj, headers=None): """Helper to append form urlencoded part.""" - if not headers: - headers = CIMultiDict() - headers[CONTENT_TYPE] = 'application/x-www-form-urlencoded' assert isinstance(obj, (Sequence, Mapping)) - return self.append(obj, headers) - def serialize(self): - """Yields multipart byte chunks.""" - if not self.parts: - yield b'' + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Mapping): + obj = list(obj.items()) + data = urlencode(obj, doseq=True) + + return self.append_payload( + StringPayload(data, headers=headers, + content_type='application/x-www-form-urlencoded')) + + @property + def size(self): + """Size of the payload.""" + if not self._parts: + return 0 + + total = 0 + for part, headers, encoding, te_encoding in self._parts: + if encoding or te_encoding or part.size is None: + return None + + total += ( + 2 + len(self._boundary) + 2 + # b'--'+self._boundary+b'\r\n' + part.size + len(headers) + + 2 # b'\r\n' + ) + + total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' + return total + + @asyncio.coroutine + def write(self, writer): + """Write body.""" + if not self._parts: return - for part in self.parts: - yield b'--' + self.boundary + b'\r\n' - yield from part.serialize() - else: - yield b'--' + self.boundary + b'--\r\n' + for part, headers, encoding, te_encoding in self._parts: + yield from writer.write(b'--' + self._boundary + b'\r\n') + yield from writer.write(headers) + + if encoding or te_encoding: + w = MultipartPayloadWriter(writer) + if encoding: + w.enable_compression(encoding) + if te_encoding: + w.enable_encoding(te_encoding) + yield from part.write(w) + yield from w.write_eof() + else: + yield from part.write(writer) + + yield from writer.write(b'\r\n') + + yield from writer.write(b'--' + self._boundary + b'--\r\n') - yield b'' + +class MultipartPayloadWriter: + + def __init__(self, writer): + self._writer = writer + self._encoding = None + self._compress = None + + def enable_encoding(self, encoding): + if encoding == 'base64': + self._encoding = encoding + self._encoding_buffer = bytearray() + elif encoding == 'quoted-printable': + self._encoding = 'quoted-printable' + + def enable_compression(self, encoding='deflate'): + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + self._compress = zlib.compressobj(wbits=zlib_mode) + + @asyncio.coroutine + def write_eof(self): + if self._compress is not None: + chunk = self._compress.flush() + if chunk: + self._compress = None + yield from self.write(chunk) + + if self._encoding == 'base64': + if self._encoding_buffer: + yield from self._writer.write(base64.b64encode( + self._encoding_buffer)) + + @asyncio.coroutine + def write(self, chunk): + if self._compress is not None: + if chunk: + chunk = self._compress.compress(chunk) + if not chunk: + return + + if self._encoding == 'base64': + self._encoding_buffer.extend(chunk) + + if self._encoding_buffer: + buffer = self._encoding_buffer + div, mod = divmod(len(buffer), 3) + enc_chunk, self._encoding_buffer = ( + buffer[:div * 3], buffer[div * 3:]) + if enc_chunk: + enc_chunk = base64.b64encode(enc_chunk) + yield from self._writer.write(enc_chunk) + elif self._encoding == 'quoted-printable': + yield from self._writer.write(binascii.b2a_qp(chunk)) + else: + yield from self._writer.write(chunk) diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 2247b327ca3..9924448ef45 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -7,8 +7,8 @@ from multidict import CIMultiDict from . import hdrs -from .helpers import guess_filename -from .multipart import content_disposition_header +from .helpers import (content_disposition_header, guess_filename, + parse_mimetype, sentinel) from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader __all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload', @@ -54,13 +54,19 @@ class Payload(ABC): _content_type = 'application/octet-stream' def __init__(self, value, *, headers=None, - content_type=None, filename=None, encoding='utf-8'): + content_type=sentinel, filename=None, encoding=None): self._value = value self._encoding = encoding - self._content_type = content_type self._filename = filename if headers is not None: self._headers = CIMultiDict(headers) + if content_type is sentinel and hdrs.CONTENT_TYPE in headers: + content_type = headers[hdrs.CONTENT_TYPE] + + if content_type is sentinel: + content_type = None + + self._content_type = content_type @property def size(self): @@ -116,10 +122,12 @@ def __init__(self, value, *args, **kwargs): assert isinstance(value, (bytes, bytearray, memoryview)), \ "value argument must be byte-ish (%r)" % type(value) + if 'content_type' not in kwargs: + kwargs['content_type'] = 'application/octet-stream' + super().__init__(value, *args, **kwargs) self._size = len(value) - self._content_type = 'application/octet-stream' @asyncio.coroutine def write(self, writer): @@ -128,8 +136,15 @@ def write(self, writer): class StringPayload(BytesPayload): - def __init__(self, value, *args, encoding='utf-8', **kwargs): - super().__init__(value.encode(encoding), *args, **kwargs) + def __init__(self, value, *args, + content_type='text/plain; charset=utf-8', **kwargs): + + *_, params = parse_mimetype(content_type) + charset = params.get('charset', 'utf-8') + kwargs['encoding'] = charset + + super().__init__( + value.encode(charset), content_type=content_type, *args, **kwargs) class IOBasePayload(Payload): @@ -155,6 +170,16 @@ def write(self, writer): class StringIOPayload(IOBasePayload): + def __init__(self, value, *args, + content_type='text/plain; charset=utf-8', **kwargs): + *_, params = parse_mimetype(content_type) + charset = params.get('charset', 'utf-8') + + super().__init__( + value, + content_type=content_type, + encoding=charset, *args, **kwargs) + @asyncio.coroutine def write(self, writer): chunk = self._value.read(DEFAULT_LIMIT) @@ -165,7 +190,14 @@ def write(self, writer): self._value.close() -class TextIOPayload(Payload): +class TextIOPayload(IOBasePayload): + + @property + def size(self): + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except OSError: + return None @asyncio.coroutine def write(self, writer): diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index a8d854de96f..42767ff4dee 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1014,9 +1014,32 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_charset(loop, test_client): + @asyncio.coroutine + def handler(request): + mp = yield from request.multipart() + part = yield from mp.next() + text = yield from part.text() + return web.Response(text=text) + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + form = aiohttp.FormData() + form.add_field('name', 'текст', content_type='text/plain; charset=koi8-r') + + resp = yield from client.post('/', data=form) + assert 200 == resp.status + content = yield from resp.text() + assert content == 'текст' + resp.close() + + +@pytest.mark.xfail +@asyncio.coroutine +def test_POST_DATA_with_charset_post(loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() @@ -1036,14 +1059,13 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_context_transfer_encoding(loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() assert data['name'] == b'text' # should it be str? - return web.Response() + return web.Response(body=data['name']) app = web.Application(loop=loop) app.router.add_post('/', handler) @@ -1059,6 +1081,32 @@ def handler(request): resp.close() +@pytest.mark.xfail +@asyncio.coroutine +def test_POST_DATA_with_content_type_context_transfer_encoding( + loop, test_client): + @asyncio.coroutine + def handler(request): + data = yield from request.post() + assert data['name'] == 'text' # should it be str? + return web.Response(body=data['name']) + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + form = aiohttp.FormData() + form.add_field('name', 'text', + content_type='text/plain', + content_transfer_encoding='base64') + + resp = yield from client.post('/', data=form) + assert 200 == resp.status + content = yield from resp.text() + assert content == 'text' + resp.close() + + @asyncio.coroutine def test_POST_MultiDict(loop, test_client): @asyncio.coroutine diff --git a/tests/test_client_request.py b/tests/test_client_request.py index afc1d3a6a4f..7dd10fe1554 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -540,7 +540,7 @@ def test_post_data(loop): data={'life': '42'}, loop=loop) resp = req.send(mock.Mock(acquire=acquire)) assert '/' == req.url.path - assert b'life=42' == req.body + assert b'life=42' == req.body._value assert 'application/x-www-form-urlencoded' ==\ req.headers['CONTENT-TYPE'] yield from req.close() @@ -580,7 +580,7 @@ def test_get_with_data(loop): meth, URL('http://python.org/'), data={'life': '42'}, loop=loop) assert '/' == req.url.path - assert b'life=42' == req.body + assert b'life=42' == req.body._value yield from req.close() diff --git a/tests/test_formdata.py b/tests/test_formdata.py new file mode 100644 index 00000000000..c2e8e667f68 --- /dev/null +++ b/tests/test_formdata.py @@ -0,0 +1,77 @@ +import asyncio +from unittest import mock + +import pytest + +from aiohttp.formdata import FormData + + +@pytest.fixture +def buf(): + return bytearray() + + +@pytest.fixture +def writer(buf): + writer = mock.Mock() + + def write(chunk): + buf.extend(chunk) + return () + + writer.write.side_effect = write + return writer + + +def test_invalid_formdata_params(): + with pytest.raises(TypeError): + FormData('asdasf') + + +def test_invalid_formdata_params2(): + with pytest.raises(TypeError): + FormData('as') # 2-char str is not allowed + + +def test_invalid_formdata_content_type(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', 'bar', content_type=invalid_val) + + +def test_invalid_formdata_filename(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', 'bar', filename=invalid_val) + + +def test_invalid_formdata_content_transfer_encoding(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', + 'bar', + content_transfer_encoding=invalid_val) + + +@asyncio.coroutine +def test_formdata_field_name_is_quoted(buf, writer): + form = FormData() + form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") + payload = form("ascii") + yield from payload.write(writer) + assert b'name="emails%5B%5D"' in buf + + +@asyncio.coroutine +def test_formdata_field_name_is_not_quoted(buf, writer): + form = FormData(quote_fields=False) + form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") + payload = form("ascii") + yield from payload.write(writer) + assert b'name="emails[]"' in buf diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 529ed4a6ade..a396283bf5c 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -101,58 +101,9 @@ def test_basic_auth_decode_bad_base64(): helpers.BasicAuth.decode('Basic bmtpbTpwd2Q') -def test_invalid_formdata_params(): - with pytest.raises(TypeError): - helpers.FormData('asdasf') - - -def test_invalid_formdata_params2(): - with pytest.raises(TypeError): - helpers.FormData('as') # 2-char str is not allowed - - -def test_invalid_formdata_content_type(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', 'bar', content_type=invalid_val) - - -def test_invalid_formdata_filename(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', 'bar', filename=invalid_val) - - -def test_invalid_formdata_content_transfer_encoding(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', - 'bar', - content_transfer_encoding=invalid_val) - # ------------- access logger ------------------------- -def test_formdata_field_name_is_quoted(): - form = helpers.FormData() - form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") - res = b"".join(form("ascii")) - assert b'name="emails%5B%5D"' in res - - -def test_formdata_field_name_is_not_quoted(): - form = helpers.FormData(quote_fields=False) - form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") - res = b"".join(form("ascii")) - assert b'name="emails[]"' in res - - def test_access_logger_format(): log_format = '%T {%{SPAM}e} "%{ETag}o" %X {X} %%P %{FOO_TEST}e %{FOO1}e' mock_logger = mock.Mock() @@ -547,3 +498,33 @@ def test_eq(self): def test_le(self): l = helpers.FrozenList([1]) assert l < [2] + + +# -------------------------------- ContentDisposition ------------------- + +def test_content_disposition(): + assert (helpers.content_disposition_header('attachment', foo='bar') == + 'attachment; foo="bar"') + + +def test_content_disposition_bad_type(): + with pytest.raises(ValueError): + helpers.content_disposition_header('foo bar') + with pytest.raises(ValueError): + helpers.content_disposition_header('—Ç–µ—Å—Ç') + with pytest.raises(ValueError): + helpers.content_disposition_header('foo\x00bar') + with pytest.raises(ValueError): + helpers.content_disposition_header('') + + +def test_set_content_disposition_bad_param(): + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'foo bar': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'—Ç–µ—Å—Ç': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', + **{'foo\x00bar': 'baz'}) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 35f2b757bf5..52f2bd2e4b8 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,13 +1,14 @@ import asyncio import functools import io -import os import unittest import zlib from unittest import mock +import pytest + import aiohttp.multipart -from aiohttp import helpers +from aiohttp import helpers, payload from aiohttp.hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) from aiohttp.helpers import parse_mimetype @@ -17,6 +18,28 @@ from aiohttp.streams import StreamReader +@pytest.fixture +def buf(): + return bytearray() + + +@pytest.fixture +def stream(buf): + writer = mock.Mock() + + def write(chunk): + buf.extend(chunk) + return () + + writer.write.side_effect = write + return writer + + +@pytest.fixture +def writer(): + return aiohttp.multipart.MultipartWriter(boundary=':') + + def run_in_loop(f): @functools.wraps(f) def wrapper(testcase, *args, **kwargs): @@ -723,318 +746,189 @@ def test_reading_skips_prelude(self): self.assertFalse(second.at_eof()) -class BodyPartWriterTestCase(unittest.TestCase): +@asyncio.coroutine +def test_writer(writer): + assert writer.size == 0 + assert writer.boundary == b':' + + +@asyncio.coroutine +def test_writer_serialize_io_chunk(buf, stream, writer): + flo = io.BytesIO(b'foobarbaz') + writer.append(flo) + yield from writer.write(stream) + print(buf) + assert (buf == b'--:\r\nContent-Type: application/octet-stream' + b'\r\nContent-Length: 9\r\n\r\nfoobarbaz\r\n--:--\r\n') + + +@asyncio.coroutine +def test_writer_serialize_json(buf, stream, writer): + writer.append_json({'привет': 'мир'}) + yield from writer.write(stream) + assert (b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' + b' "\\u043c\\u0438\\u0440"}' in buf) + + +@asyncio.coroutine +def test_writer_serialize_form(buf, stream, writer): + data = [('foo', 'bar'), ('foo', 'baz'), ('boo', 'zoo')] + writer.append_form(data) + yield from writer.write(stream) + + assert (b'foo=bar&foo=baz&boo=zoo' in buf) + - def setUp(self): - self.part = aiohttp.multipart.BodyPartWriter(b'') - - def test_guess_content_length(self): - self.part.headers[CONTENT_TYPE] = 'text/plain; charset=utf-8' - self.assertIsNone(self.part._guess_content_length({})) - self.assertIsNone(self.part._guess_content_length(object())) - self.assertEqual(3, - self.part._guess_content_length(io.BytesIO(b'foo'))) - self.assertEqual(3, - self.part._guess_content_length(io.StringIO('foo'))) - self.assertEqual(6, - self.part._guess_content_length(io.StringIO('мяу'))) - self.assertEqual(3, self.part._guess_content_length(b'bar')) - self.assertEqual(12, self.part._guess_content_length('пассед')) - with open(__file__, 'rb') as f: - self.assertEqual(os.fstat(f.fileno()).st_size, - self.part._guess_content_length(f)) - - def test_guess_content_type(self): - default = 'application/octet-stream' - self.assertEqual(default, self.part._guess_content_type(b'foo')) - self.assertEqual('text/plain; charset=utf-8', - self.part._guess_content_type('foo')) - - here = os.path.dirname(__file__) - filename = os.path.join(here, 'aiohttp.png') - - with open(filename, 'rb') as f: - self.assertEqual('image/png', - self.part._guess_content_type(f)) - - def test_guess_filename(self): - class Named: - name = 'foo' - self.assertIsNone(self.part._guess_filename({})) - self.assertIsNone(self.part._guess_filename(object())) - self.assertIsNone(self.part._guess_filename(io.BytesIO(b'foo'))) - self.assertIsNone(self.part._guess_filename(Named())) - with open(__file__, 'rb') as f: - self.assertEqual(os.path.basename(f.name), - self.part._guess_filename(f)) - - def test_autoset_content_disposition(self): - self.part.obj = open(__file__, 'rb') - self.addCleanup(self.part.obj.close) - self.part._fill_headers_with_defaults() - self.assertIn(CONTENT_DISPOSITION, self.part.headers) - fname = os.path.basename(self.part.obj.name) - self.assertEqual( - 'attachment; filename="{0}"; filename*=utf-8\'\'{0}'.format(fname), - self.part.headers[CONTENT_DISPOSITION]) +@asyncio.coroutine +def test_writer_serialize_form_dict(buf, stream, writer): + data = {'hello': 'мир'} + writer.append_form(data) + yield from writer.write(stream) - def test_set_content_disposition(self): - self.part.set_content_disposition('attachment', foo='bar') - self.assertEqual( - 'attachment; foo="bar"', - self.part.headers[CONTENT_DISPOSITION]) + assert (b'hello=%D0%BC%D0%B8%D1%80' in buf) - def test_set_content_disposition_bad_type(self): - with self.assertRaises(ValueError): - self.part.set_content_disposition('foo bar') - with self.assertRaises(ValueError): - self.part.set_content_disposition('тест') - with self.assertRaises(ValueError): - self.part.set_content_disposition('foo\x00bar') - with self.assertRaises(ValueError): - self.part.set_content_disposition('') - def test_set_content_disposition_bad_param(self): - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'foo bar': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'тест': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', - **{'foo\x00bar': 'baz'}) - - def test_serialize_bytes(self): - self.assertEqual(b'foo', next(self.part._serialize_bytes(b'foo'))) - - def test_serialize_str(self): - self.assertEqual(b'foo', next(self.part._serialize_str('foo'))) - - def test_serialize_str_custom_encoding(self): - self.part.headers[CONTENT_TYPE] = \ - 'text/plain;charset=cp1251' - self.assertEqual('привет'.encode('cp1251'), - next(self.part._serialize_str('привет'))) - - def test_serialize_io(self): - self.assertEqual(b'foo', - next(self.part._serialize_io(io.BytesIO(b'foo')))) - self.assertEqual(b'foo', - next(self.part._serialize_io(io.StringIO('foo')))) - - def test_serialize_io_chunk(self): - flo = io.BytesIO(b'foobarbaz') - self.part._chunk_size = 3 - self.assertEqual([b'foo', b'bar', b'baz'], - list(self.part._serialize_io(flo))) - - def test_serialize_json(self): - self.assertEqual(b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' - b' "\\u043c\\u0438\\u0440"}', - next(self.part._serialize_json({'привет': 'мир'}))) - - def test_serialize_form(self): - data = [('foo', 'bar'), ('foo', 'baz'), ('boo', 'zoo')] - self.assertEqual(b'foo=bar&foo=baz&boo=zoo', - next(self.part._serialize_form(data))) - - def test_serialize_form_dict(self): - data = {'hello': 'мир'} - self.assertEqual(b'hello=%D0%BC%D0%B8%D1%80', - next(self.part._serialize_form(data))) - - def test_serialize_multipart(self): - multipart = aiohttp.multipart.MultipartWriter(boundary=':') - multipart.append('foo-bar-baz') - multipart.append_json({'test': 'passed'}) - multipart.append_form({'test': 'passed'}) - multipart.append_form([('one', 1), ('two', 2)]) - sub_multipart = aiohttp.multipart.MultipartWriter(boundary='::') - sub_multipart.append('nested content') - sub_multipart.headers['X-CUSTOM'] = 'test' - multipart.append(sub_multipart) - self.assertEqual( - [b'--:\r\n', - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 11', - b'\r\n\r\n', - b'foo-bar-baz', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/json', - b'\r\n\r\n', - b'{"test": "passed"}', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/x-www-form-urlencoded', - b'\r\n\r\n', - b'test=passed', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/x-www-form-urlencoded', - b'\r\n\r\n', - b'one=1&two=2', - b'\r\n', - - b'--:\r\n', - b'Content-Type: multipart/mixed; boundary="::"\r\nX-Custom: test', - b'\r\n\r\n', - b'--::\r\n', - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 14', - b'\r\n\r\n', - b'nested content', - b'\r\n', - b'--::--\r\n', - b'', - b'\r\n', - b'--:--\r\n', - b''], - list(self.part._serialize_multipart(multipart)) - ) - - def test_serialize_default(self): - with self.assertRaises(TypeError): - self.part.obj = object() - list(self.part.serialize()) - with self.assertRaises(TypeError): - next(self.part._serialize_default(object())) - - def test_serialize_with_content_encoding_gzip(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'gzip'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: gzip\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - result = b''.join(stream) - - decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) - data = decompressor.decompress(result) - self.assertEqual(b'Time to Relax!', data) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_deflate(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'deflate'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: deflate\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n' - self.assertEqual(thing, b''.join(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_identity(self): - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00' - part = aiohttp.multipart.BodyPartWriter( - thing, {CONTENT_ENCODING: 'identity'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: identity\r\n' - b'Content-Type: application/octet-stream\r\n' - b'Content-Length: 16', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(thing, next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_unknown(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'snappy'}) - with self.assertRaises(RuntimeError): - list(part.serialize()) - - def test_serialize_with_content_transfer_encoding_base64(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'base64'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: base64\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'VGltZSB0byBSZWxh', next(stream)) - self.assertEqual(b'eCE=', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_io_with_content_transfer_encoding_base64(self): - part = aiohttp.multipart.BodyPartWriter( - io.BytesIO(b'Time to Relax!'), - {CONTENT_TRANSFER_ENCODING: 'base64'}) - part._chunk_size = 6 - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: base64\r\n' - b'Content-Type: application/octet-stream', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'VGltZSB0', next(stream)) - self.assertEqual(b'byBSZWxh', next(stream)) - self.assertEqual(b'eCE=', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_transfer_encoding_quote_printable(self): - part = aiohttp.multipart.BodyPartWriter( - 'Привет, мир!', {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: quoted-printable\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' - b' =D0=BC=D0=B8=D1=80!', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_transfer_encoding_binary(self): - part = aiohttp.multipart.BodyPartWriter( - 'Привет, мир!'.encode('utf-8'), - {CONTENT_TRANSFER_ENCODING: 'binary'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: binary\r\n' - b'Content-Type: application/octet-stream', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) +@asyncio.coroutine +def test_writer_write(buf, stream, writer): + writer.append('foo-bar-baz') + writer.append_json({'test': 'passed'}) + writer.append_form({'test': 'passed'}) + writer.append_form([('one', 1), ('two', 2)]) - self.assertEqual(b'\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82,' - b' \xd0\xbc\xd0\xb8\xd1\x80!', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) + sub_multipart = aiohttp.multipart.MultipartWriter(boundary='::') + sub_multipart.append('nested content') + sub_multipart.headers['X-CUSTOM'] = 'test' + writer.append(sub_multipart) + yield from writer.write(stream) - def test_serialize_with_content_transfer_encoding_unknown(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) - with self.assertRaises(RuntimeError): - list(part.serialize()) + assert ( + (b'--:\r\n' + b'Content-Type: text/plain; charset=utf-8\r\n' + b'Content-Length: 11\r\n\r\n' + b'foo-bar-baz' + b'\r\n' - def test_filename(self): - self.part.set_content_disposition('related', filename='foo.html') - self.assertEqual('foo.html', self.part.filename) + b'--:\r\n' + b'Content-Type: application/json\r\n' + b'Content-Length: 18\r\n\r\n' + b'{"test": "passed"}' + b'\r\n' + + b'--:\r\n' + b'Content-Type: application/x-www-form-urlencoded\r\n' + b'Content-Length: 11\r\n\r\n' + b'test=passed' + b'\r\n' + + b'--:\r\n' + b'Content-Type: application/x-www-form-urlencoded\r\n' + b'Content-Length: 11\r\n\r\n' + b'one=1&two=2' + b'\r\n' + + b'--:\r\n' + b'Content-Type: multipart/mixed; boundary="::"\r\n' + b'X-Custom: test\r\nContent-Length: 93\r\n\r\n' + b'--::\r\n' + b'Content-Type: text/plain; charset=utf-8\r\n' + b'Content-Length: 14\r\n\r\n' + b'nested content\r\n' + b'--::--\r\n' + b'\r\n' + b'--:--\r\n') == bytes(buf)) - def test_wrap_multipart(self): - writer = aiohttp.multipart.MultipartWriter(boundary=':') - part = aiohttp.multipart.BodyPartWriter(writer) - self.assertEqual(part.headers, writer.headers) - part.headers['X-Custom'] = 'test' - self.assertEqual(part.headers, writer.headers) + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_gzip(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'gzip'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: gzip\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) + data = decompressor.decompress(message.split(b'\r\n')[0]) + data += decompressor.flush() + assert b'Time to Relax!' == data + + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_deflate(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'deflate'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: deflate\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n' + assert thing == message + + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_identity(buf, stream, writer): + thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00' + writer.append(thing, {CONTENT_ENCODING: 'identity'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: identity\r\n' + b'Content-Type: application/octet-stream\r\n' + b'Content-Length: 16' == headers) + + assert thing == message.split(b'\r\n')[0] + + +def test_writer_serialize_with_content_encoding_unknown(buf, stream, writer): + with pytest.raises(RuntimeError): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'snappy'}) + + +@asyncio.coroutine +def test_writer_with_content_transfer_encoding_base64(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'base64'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n' + b'Content-Type: text/plain; charset=utf-8' == + headers) + + assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0] + + +@asyncio.coroutine +def test_writer_content_transfer_encoding_quote_printable(buf, stream, writer): + writer.append('Привет, мир!', + {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' + b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0]) + + +def test_writer_content_transfer_encoding_unknown(buf, stream, writer): + with pytest.raises(RuntimeError): + writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) class MultipartWriterTestCase(unittest.TestCase): def setUp(self): + self.buf = bytearray() + self.stream = mock.Mock() + + def write(chunk): + self.buf.extend(chunk) + return () + + self.stream.write.side_effect = write + self.writer = aiohttp.multipart.MultipartWriter(boundary=':') def test_default_subtype(self): @@ -1061,52 +955,50 @@ def test_append(self): self.assertEqual(0, len(self.writer)) self.writer.append('hello, world!') self.assertEqual(1, len(self.writer)) - self.assertIsInstance(self.writer.parts[0], - self.writer.part_writer_cls) + self.assertIsInstance(self.writer._parts[0][0], payload.Payload) def test_append_with_headers(self): self.writer.append('hello, world!', {'x-foo': 'bar'}) self.assertEqual(1, len(self.writer)) - self.assertIn('x-foo', self.writer.parts[0].headers) - self.assertEqual(self.writer.parts[0].headers['x-foo'], 'bar') + self.assertIn('x-foo', self.writer._parts[0][0].headers) + self.assertEqual(self.writer._parts[0][0].headers['x-foo'], 'bar') def test_append_json(self): self.writer.append_json({'foo': 'bar'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'application/json') def test_append_part(self): - part = aiohttp.multipart.BodyPartWriter('test', - {CONTENT_TYPE: 'text/plain'}) + part = payload.get_payload( + 'test', headers={CONTENT_TYPE: 'text/plain'}) self.writer.append(part, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_json_overrides_content_type(self): self.writer.append_json({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] - self.assertEqual(part.headers[CONTENT_TYPE], 'application/json') + part = self.writer._parts[0][0] + self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_form(self): self.writer.append_form({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] - self.assertEqual(part.headers[CONTENT_TYPE], - 'application/x-www-form-urlencoded') + part = self.writer._parts[0][0] + self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_multipart(self): subwriter = aiohttp.multipart.MultipartWriter(boundary=':') subwriter.append_json({'foo': 'bar'}) self.writer.append(subwriter, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - def test_serialize(self): - self.assertEqual([b''], list(self.writer.serialize())) + def test_write(self): + self.assertEqual([], list(self.writer.write(self.stream))) def test_with(self): with aiohttp.multipart.MultipartWriter(boundary=':') as writer: