Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option to disable automatic client response body decompression #2110

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ clean:
@rm -f .develop
@rm -f .flake
@rm -f .install-deps
@rm -rf aiohttp.egg-info

doc:
@make -C docs html SPHINXOPTS="-W -E"
Expand Down
13 changes: 8 additions & 5 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ cdef class HttpParser:
object _payload
bint _payload_error
object _payload_exception
object _last_error
object _last_error
bint _auto_decompress

Py_buffer py_buf

Expand All @@ -80,7 +81,7 @@ cdef class HttpParser:
object protocol, object loop, object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
response_with_body=True):
response_with_body=True, auto_decompress=True):
cparser.http_parser_init(self._cparser, mode)
self._cparser.data = <void*>self
self._cparser.content_length = 0
Expand All @@ -106,6 +107,7 @@ cdef class HttpParser:
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._upgraded = False
self._auto_decompress = auto_decompress

self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
Expand Down Expand Up @@ -194,7 +196,7 @@ cdef class HttpParser:
payload = EMPTY_PAYLOAD

self._payload = payload
if encoding is not None:
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding)

if not self._response_with_body:
Expand Down Expand Up @@ -301,10 +303,11 @@ cdef class HttpResponseParserC(HttpParser):
def __init__(self, protocol, loop, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
response_with_body=True, read_until_eof=False):
response_with_body=True, read_until_eof=False,
auto_decompress=True):
self._init(cparser.HTTP_RESPONSE, protocol, loop, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body)
payload_exception, response_with_body, auto_decompress)


cdef int cb_on_message_begin(cparser.http_parser* parser) except -1:
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
ws_response_class=ClientWebSocketResponse,
version=http.HttpVersion11,
cookie_jar=None, connector_owner=True, raise_for_status=False,
read_timeout=sentinel, conn_timeout=None):
read_timeout=sentinel, conn_timeout=None,
auto_decompress=True):

implicit_loop = False
if loop is None:
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
else DEFAULT_TIMEOUT)
self._conn_timeout = conn_timeout
self._raise_for_status = raise_for_status
self._auto_decompress = auto_decompress

# Convert to list of tuples
if headers:
Expand Down Expand Up @@ -223,7 +225,7 @@ def _request(self, method, url, *,
expect100=expect100, loop=self._loop,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self)
session=self, auto_decompress=self._auto_decompress)

# connection timeout
try:
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ def set_parser(self, parser, payload):
def set_response_params(self, *, timer=None,
skip_payload=False,
skip_status_codes=(),
read_until_eof=False):
read_until_eof=False,
auto_decompress=True):
self._skip_payload = skip_payload
self._skip_status_codes = skip_status_codes
self._read_until_eof = read_until_eof
self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
read_until_eof=read_until_eof)
read_until_eof=read_until_eof,
auto_decompress=auto_decompress)

if self._tail:
data, self._tail = self._tail, b''
Expand Down
12 changes: 8 additions & 4 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, method, url, *,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None, proxy_from_env=False,
timer=None, session=None):
timer=None, session=None, auto_decompress=True):

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -88,6 +88,7 @@ def __init__(self, method, url, *,
self.length = None
self.response_class = response_class or ClientResponse
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand Down Expand Up @@ -406,7 +407,8 @@ def send(self, conn):
self.response = self.response_class(
self.method, self.original_url,
writer=self._writer, continue100=self._continue, timer=self._timer,
request_info=self.request_info
request_info=self.request_info,
auto_decompress=self._auto_decompress
)

self.response._post_init(self.loop, self._session)
Expand Down Expand Up @@ -450,7 +452,7 @@ class ClientResponse(HeadersMixin):

def __init__(self, method, url, *,
writer=None, continue100=None, timer=None,
request_info=None):
request_info=None, auto_decompress=True):
assert isinstance(url, URL)

self.method = method
Expand All @@ -465,6 +467,7 @@ def __init__(self, method, url, *,
self._history = ()
self._request_info = request_info
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress

@property
def url(self):
Expand Down Expand Up @@ -550,7 +553,8 @@ def start(self, connection, read_until_eof=False):
timer=self._timer,
skip_payload=self.method.lower() == 'head',
skip_status_codes=(204, 304),
read_until_eof=read_until_eof)
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress)

with self._timer:
while True:
Expand Down
18 changes: 12 additions & 6 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(self, protocol=None, loop=None,
max_line_size=8190, max_headers=32768, max_field_size=8190,
timer=None, code=None, method=None, readall=False,
payload_exception=None,
response_with_body=True, read_until_eof=False):
response_with_body=True, read_until_eof=False,
auto_decompress=True):
self.protocol = protocol
self.loop = loop
self.max_line_size = max_line_size
Expand All @@ -78,6 +79,7 @@ def __init__(self, protocol=None, loop=None,
self._upgraded = False
self._payload = None
self._payload_parser = None
self._auto_decompress = auto_decompress

def feed_eof(self):
if self._payload_parser is not None:
Expand Down Expand Up @@ -162,7 +164,8 @@ def feed_data(self, data,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=self.readall,
response_with_body=self.response_with_body)
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
elif method == METH_CONNECT:
Expand All @@ -171,7 +174,8 @@ def feed_data(self, data,
self._upgraded = True
self._payload_parser = HttpPayloadParser(
payload, method=msg.method,
compression=msg.compression, readall=True)
compression=msg.compression, readall=True,
auto_decompress=self._auto_decompress)
else:
if (getattr(msg, 'code', 100) >= 199 and
length is None and self.read_until_eof):
Expand All @@ -182,7 +186,8 @@ def feed_data(self, data,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=True,
response_with_body=self.response_with_body)
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
Expand Down Expand Up @@ -432,18 +437,19 @@ class HttpPayloadParser:
def __init__(self, payload,
length=None, chunked=False, compression=None,
code=None, method=None,
readall=False, response_with_body=True):
readall=False, response_with_body=True, auto_decompress=True):
self.payload = payload

self._length = 0
self._type = ParseState.PARSE_NONE
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b''
self._auto_decompress = auto_decompress
self.done = False

# payload decompression wrapper
if (response_with_body and compression):
if response_with_body and compression and self._auto_decompress:
payload = DeflateBuffer(payload, compression)

# payload parser
Expand Down
1 change: 1 addition & 0 deletions changes/2110.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add ability to disable automatic response decompression
7 changes: 6 additions & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ The client session supports the context manager protocol for self closing.
cookie_jar=None, read_timeout=None, \
conn_timeout=None, \
raise_for_status=False, \
connector_owner=True)
connector_owner=True, \
auto_decompress=True)

The class for creating client sessions and making requests.

Expand Down Expand Up @@ -138,6 +139,10 @@ The client session supports the context manager protocol for self closing.

.. versionadded:: 2.1

:param bool auto_decompress: Automatically decompress response body

.. versionadded:: 2.3

.. attribute:: closed

``True`` if the session has been closed, ``False`` otherwise.
Expand Down
62 changes: 53 additions & 9 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import gzip
from unittest import mock

import pytest
Expand All @@ -14,11 +15,20 @@
teardown_test_loop, unittest_run_loop)


def _create_example_app():
_hello_world_str = "Hello, world"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this rate you can even generate a random string. Like:

uuid.uuid4().hex

I would actually even put the str, bytes and gzip inside the app instance:

app = web.Application()
app["body_str"] = uuid.uuid4().hex
app["body_bytes"] = app["body_str"].encode('utf-8')
app["body_gz"] = gzip.compress(app["body_bytes"])

_hello_world_bytes = _hello_world_str.encode('utf-8')
_hello_world_gz = gzip.compress(_hello_world_bytes)


def _create_example_app():
@asyncio.coroutine
def hello(request):
return web.Response(body=b"Hello, world")
return web.Response(body=_hello_world_bytes)

@asyncio.coroutine
def gzip_hello(request):
return web.Response(body=_hello_world_gz,
headers={'Content-Encoding': 'gzip'})

@asyncio.coroutine
def websocket_handler(request):
Expand All @@ -36,12 +46,13 @@ def websocket_handler(request):

@asyncio.coroutine
def cookie_handler(request):
resp = web.Response(body=b"Hello, world")
resp = web.Response(body=_hello_world_bytes)
resp.set_cookie('cookie', 'val')
return resp

app = web.Application()
app.router.add_route('*', '/', hello)
app.router.add_route('*', '/gzip_hello', gzip_hello)
app.router.add_route('*', '/websocket', websocket_handler)
app.router.add_route('*', '/cookie', cookie_handler)
return app
Expand All @@ -58,7 +69,40 @@ def test_get_route():
resp = yield from client.request("GET", "/")
assert resp.status == 200
text = yield from resp.text()
assert "Hello, world" in text
assert _hello_world_str == text

loop.run_until_complete(test_get_route())


def test_auto_gzip_decompress():
with loop_context() as loop:
app = _create_example_app()
with _TestClient(_TestServer(app, loop=loop), loop=loop) as client:

@asyncio.coroutine
def test_get_route():
nonlocal client
resp = yield from client.request("GET", "/gzip_hello")
assert resp.status == 200
data = yield from resp.read()
assert data == _hello_world_bytes

loop.run_until_complete(test_get_route())


def test_noauto_gzip_decompress():
with loop_context() as loop:
app = _create_example_app()
with _TestClient(_TestServer(app, loop=loop), loop=loop,
auto_decompress=False) as client:

@asyncio.coroutine
def test_get_route():
nonlocal client
resp = yield from client.request("GET", "/gzip_hello")
assert resp.status == 200
data = yield from resp.read()
assert data == _hello_world_gz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test coroutine is duplicated at a lot of different places. Maybe you can make a helper outside the test and provide the client instance in argument. I think it would be more readable. What do you think?

async def _assertBody(client, url, body):
    resp = await ...
    assert resp.status == 200
    data = await ...
    assert data == body

Copy link
Contributor Author

@thehesiod thehesiod Jul 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been avoiding too much refactoring because @asvetlov has frowned upon that in the past in PRs that implement new features. He's wanted these types of changes in multiple PRs. If he gives the go-ahead I can do these suggested changes.


loop.run_until_complete(test_get_route())

Expand All @@ -73,7 +117,7 @@ def test_get_route():
resp = yield from client.request("GET", "/")
assert resp.status == 200
text = yield from resp.text()
assert "Hello, world" in text
assert _hello_world_str == text

loop.run_until_complete(test_get_route())

Expand Down Expand Up @@ -102,15 +146,15 @@ def test_example_with_loop(self):
request = yield from self.client.request("GET", "/")
assert request.status == 200
text = yield from request.text()
assert "Hello, world" in text
assert _hello_world_str == text

def test_example(self):
@asyncio.coroutine
def test_get_route():
resp = yield from self.client.request("GET", "/")
assert resp.status == 200
text = yield from resp.text()
assert "Hello, world" in text
assert _hello_world_str == text

self.loop.run_until_complete(test_get_route())

Expand Down Expand Up @@ -141,7 +185,7 @@ def test_get_route():
resp = yield from test_client.request("GET", "/")
assert resp.status == 200
text = yield from resp.text()
assert "Hello, world" in text
assert _hello_world_str == text

loop.run_until_complete(test_get_route())

Expand Down Expand Up @@ -176,7 +220,7 @@ def test_test_client_methods(method, loop, test_client):
resp = yield from getattr(test_client, method)("/")
assert resp.status == 200
text = yield from resp.text()
assert "Hello, world" in text
assert _hello_world_str == text


@asyncio.coroutine
Expand Down