Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for chunk_size #1277

Merged
merged 9 commits into from
Nov 25, 2020
79 changes: 79 additions & 0 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
"""
import codecs
import io
import typing
import zlib

Expand Down Expand Up @@ -157,6 +158,84 @@ def flush(self) -> bytes:
return data


class ByteChunker:
"""
Handles returning byte content in fixed-size chunks.
"""

def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.BytesIO()
self._chunk_size = chunk_size

def decode(self, content: bytes) -> typing.List[bytes]:
if self._chunk_size is None:
return [content]

self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []

def flush(self) -> typing.List[bytes]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []


class TextChunker:
"""
Handles returning text content in fixed-size chunks.
"""

def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.StringIO()
self._chunk_size = chunk_size

def decode(self, content: str) -> typing.List[str]:
if self._chunk_size is None:
return [content]

self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []

def flush(self) -> typing.List[str]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []


class TextDecoder:
"""
Handles incrementally decoding bytes into text
Expand Down
69 changes: 51 additions & 18 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._content_streams import ByteStream, ContentStream, encode
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
ContentDecoder,
IdentityDecoder,
LineDecoder,
Expand Down Expand Up @@ -912,19 +913,28 @@ def read(self) -> bytes:
self._content = b"".join(self.iter_bytes())
return self._content

def iter_bytes(self) -> typing.Iterator[bytes]:
def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about chunk_size with default None?

I do agree that it looks more suitable, but requests provides us with defaults chunk_size=1 or 512

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, when chunk_size=None we just return the input content unchanged, as one single big chunk.

Yes, this would deviate from what Requests seems to do, but:

  • Setting a non-None default would break backward compatibility on our side.
  • Defaulting to "transparently pass the chunk sent by the server" is probably the most reasonable approach anyway.

That said, we'd need to add this deviation from Requests to the compatibility guide. 👍

"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
for chunk in self.iter_raw():
yield decoder.decode(chunk)
yield decoder.flush()
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk

def iter_text(self) -> typing.Iterator[str]:
"""
Expand All @@ -947,7 +957,7 @@ def iter_lines(self) -> typing.Iterator[str]:
for line in decoder.flush():
yield line

def iter_raw(self) -> typing.Iterator[bytes]:
def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
Expand All @@ -958,10 +968,17 @@ def iter_raw(self) -> typing.Iterator[bytes]:

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self._raw_stream:
self._num_bytes_downloaded += len(part)
yield part
for raw_stream_bytes in self._raw_stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
yield chunk

for chunk in chunker.flush():
yield chunk

self.close()

def next(self) -> "Response":
Expand Down Expand Up @@ -996,19 +1013,28 @@ async def aread(self) -> bytes:
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content

async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
async for chunk in self.aiter_raw():
yield decoder.decode(chunk)
yield decoder.flush()
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk

async def aiter_text(self) -> typing.AsyncIterator[str]:
"""
Expand All @@ -1031,7 +1057,7 @@ async def aiter_lines(self) -> typing.AsyncIterator[str]:
for line in decoder.flush():
yield line

async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
Expand All @@ -1042,10 +1068,17 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]:

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self._raw_stream:
self._num_bytes_downloaded += len(part)
yield part
async for raw_stream_bytes in self._raw_stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk

for chunk in chunker.flush():
yield chunk

await self.aclose()

async def anext(self) -> "Response":
Expand Down
81 changes: 77 additions & 4 deletions tests/models/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,26 @@ def test_iter_raw():
assert raw == b"Hello, world!"


def test_iter_raw_with_chunksize():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part for part in response.iter_raw(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]

stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part for part in response.iter_raw(chunk_size=13)]
assert parts == [b"Hello, world!"]

stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part for part in response.iter_raw(chunk_size=20)]
assert parts == [b"Hello, world!"]


def test_iter_raw_increments_updates_counter():
stream = IteratorStream(iterator=streaming_body())

Expand Down Expand Up @@ -255,6 +275,27 @@ async def test_aiter_raw():
assert raw == b"Hello, world!"


@pytest.mark.asyncio
async def test_aiter_raw_with_chunksize():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part async for part in response.aiter_raw(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]

stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part async for part in response.aiter_raw(chunk_size=13)]
assert parts == [b"Hello, world!"]

stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)

parts = [part async for part in response.aiter_raw(chunk_size=20)]
assert parts == [b"Hello, world!"]


@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
Expand All @@ -271,17 +312,31 @@ async def test_aiter_raw_increments_updates_counter():


def test_iter_bytes():
response = httpx.Response(
200,
content=b"Hello, world!",
)
response = httpx.Response(200, content=b"Hello, world!")

content = b""
for part in response.iter_bytes():
content += part
assert content == b"Hello, world!"


def test_iter_bytes_with_chunk_size():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part for part in response.iter_bytes(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]

stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part for part in response.iter_bytes(chunk_size=13)]
assert parts == [b"Hello, world!"]

stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part for part in response.iter_bytes(chunk_size=20)]
assert parts == [b"Hello, world!"]


@pytest.mark.asyncio
async def test_aiter_bytes():
response = httpx.Response(
Expand All @@ -295,6 +350,24 @@ async def test_aiter_bytes():
assert content == b"Hello, world!"


@pytest.mark.asyncio
async def test_aiter_bytes_with_chunk_size():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part async for part in response.aiter_bytes(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]

stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part async for part in response.aiter_bytes(chunk_size=13)]
assert parts == [b"Hello, world!"]

stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
parts = [part async for part in response.aiter_bytes(chunk_size=20)]
assert parts == [b"Hello, world!"]


def test_iter_text():
response = httpx.Response(
200,
Expand Down
Loading