Skip to content

Commit

Permalink
Added chunk_size to Response.iter_bytes() and Response.aiter_bytes() (e…
Browse files Browse the repository at this point in the history
  • Loading branch information
cdeler committed Sep 9, 2020
1 parent 016e4ee commit 62f62bc
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
13 changes: 11 additions & 2 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
URLTypes,
)
from ._utils import (
async_drain_by_chunks,
drain_by_chunks,
flatten_queryparams,
guess_json_utf,
is_known_encoding,
Expand Down Expand Up @@ -907,11 +909,14 @@ def read(self) -> bytes:
self._content = b"".join(self.iter_bytes())
return self._content

def iter_bytes(self) -> typing.Iterator[bytes]:
def iter_bytes(self, chunk_size: int = 512) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
yield from drain_by_chunks(self.__iter_bytes(), chunk_size)

def __iter_bytes(self) -> typing.Iterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
Expand Down Expand Up @@ -988,11 +993,15 @@ async def aread(self) -> bytes:
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content

async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int = 512) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
async for chunk in async_drain_by_chunks(self.__aiter_bytes(), chunk_size):
yield chunk

async def __aiter_bytes(self) -> typing.AsyncIterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
Expand Down
72 changes: 72 additions & 0 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,75 @@ def __eq__(self, other: typing.Any) -> bool:

def warn_deprecated(message: str) -> None: # pragma: nocover
warnings.warn(message, DeprecationWarning, stacklevel=2)


def drain_by_chunks(
stream: typing.Iterator[bytes], chunk_size: int = 512
) -> typing.Iterator[bytes]:
buffer, buffer_size = [], 0

try:
chunk = next(stream)

while True:
last_chunk_size = len(chunk)

if buffer_size + last_chunk_size < chunk_size:
buffer.append(chunk)
buffer_size += last_chunk_size
elif buffer_size + last_chunk_size == chunk_size:
buffer.append(chunk)
yield b"".join(buffer)
buffer, buffer_size = [], 0
else:
head, tail = (
chunk[: (chunk_size - buffer_size)],
chunk[(chunk_size - buffer_size) :],
)

buffer.append(head)
yield b"".join(buffer)
buffer, buffer_size = [], 0
chunk = tail
continue

chunk = next(stream)
except StopIteration:
if buffer:
yield b"".join(buffer)


async def async_drain_by_chunks(
stream: typing.AsyncIterator[bytes], chunk_size: int = 512
) -> typing.AsyncIterator[bytes]:
buffer, buffer_size = [], 0

try:
chunk = await stream.__anext__()

while True:
last_chunk_size = len(chunk)

if buffer_size + last_chunk_size < chunk_size:
buffer.append(chunk)
buffer_size += last_chunk_size
elif buffer_size + last_chunk_size == chunk_size:
buffer.append(chunk)
yield b"".join(buffer)
buffer, buffer_size = [], 0
else:
head, tail = (
chunk[: (chunk_size - buffer_size)],
chunk[(chunk_size - buffer_size) :],
)

buffer.append(head)
yield b"".join(buffer)
buffer, buffer_size = [], 0
chunk = tail
continue

chunk = await stream.__anext__()
except StopAsyncIteration:
if buffer:
yield b"".join(buffer)
61 changes: 61 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import random
import typing

import pytest

import httpx
from httpx._utils import (
NetRCInfo,
URLPattern,
async_drain_by_chunks,
drain_by_chunks,
get_ca_bundle_from_env,
get_environment_proxies,
guess_json_utf,
Expand Down Expand Up @@ -257,3 +260,61 @@ def test_pattern_priority():
URLPattern("http://"),
URLPattern("all://"),
]


@pytest.mark.parametrize(
"data",
[
[b"1", b"2", b"3"],
[b"1", b"abcdefghijklmnop", b"2"],
[b"123456", b"3"],
[b"1", b"23", b"456", b"7890", b"abcde", b"fghijk"],
[b""],
],
)
@pytest.mark.parametrize(
"chunk_size",
[1, 2, 3, 5, 10, 11],
)
def test_drain_by_chunks(data, chunk_size):
iterator = iter(data)
chunk_sizes = []
for chunk in drain_by_chunks(iterator, chunk_size):
chunk_sizes.append(len(chunk))

*head, tail = chunk_sizes

assert tail <= chunk_size
assert [chunk_size] * len(head) == head


async def _async_iter(data: typing.List) -> typing.AsyncIterator:
for chunk in data:
yield chunk


@pytest.mark.asyncio
@pytest.mark.parametrize(
"data",
[
[b"1", b"2", b"3"],
[b"1", b"abcdefghijklmnop", b"2"],
[b"123456", b"3"],
[b"1", b"23", b"456", b"7890", b"abcde", b"fghijk"],
[b""],
],
)
@pytest.mark.parametrize(
"chunk_size",
[1, 2, 3, 5, 10, 11],
)
async def test_async_drain_by_chunks(data, chunk_size):
async_iterator = _async_iter(data)
chunk_sizes = []
async for chunk in async_drain_by_chunks(async_iterator, chunk_size):
chunk_sizes.append(len(chunk))

*head, tail = chunk_sizes

assert tail <= chunk_size
assert [chunk_size] * len(head) == head

0 comments on commit 62f62bc

Please sign in to comment.