diff --git a/httpx/_models.py b/httpx/_models.py index 713281e662..dc3a8941dd 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -47,6 +47,8 @@ URLTypes, ) from ._utils import ( + async_drain_by_chunks, + drain_by_chunks, flatten_queryparams, guess_json_utf, is_known_encoding, @@ -907,11 +909,14 @@ def read(self) -> bytes: self._content = b"".join(self.iter_bytes()) return self._content - def iter_bytes(self) -> typing.Iterator[bytes]: + def iter_bytes(self, chunk_size: int = 512) -> typing.Iterator[bytes]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, and brotli encoded responses. """ + yield from drain_by_chunks(self.__iter_bytes(), chunk_size) + + def __iter_bytes(self) -> typing.Iterator[bytes]: if hasattr(self, "_content"): yield self._content else: @@ -988,11 +993,15 @@ async def aread(self) -> bytes: self._content = b"".join([part async for part in self.aiter_bytes()]) return self._content - async def aiter_bytes(self) -> typing.AsyncIterator[bytes]: + async def aiter_bytes(self, chunk_size: int = 512) -> typing.AsyncIterator[bytes]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, and brotli encoded responses. """ + async for chunk in async_drain_by_chunks(self.__aiter_bytes(), chunk_size): + yield chunk + + async def __aiter_bytes(self) -> typing.AsyncIterator[bytes]: if hasattr(self, "_content"): yield self._content else: diff --git a/httpx/_utils.py b/httpx/_utils.py index aa670724cb..2b8aab4485 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -536,3 +536,75 @@ def __eq__(self, other: typing.Any) -> bool: def warn_deprecated(message: str) -> None: # pragma: nocover warnings.warn(message, DeprecationWarning, stacklevel=2) + + +def drain_by_chunks( + stream: typing.Iterator[bytes], chunk_size: int = 512 +) -> typing.Iterator[bytes]: + buffer, buffer_size = [], 0 + + try: + chunk = next(stream) + + while True: + last_chunk_size = len(chunk) + + if buffer_size + last_chunk_size < chunk_size: + buffer.append(chunk) + buffer_size += last_chunk_size + elif buffer_size + last_chunk_size == chunk_size: + buffer.append(chunk) + yield b"".join(buffer) + buffer, buffer_size = [], 0 + else: + head, tail = ( + chunk[: (chunk_size - buffer_size)], + chunk[(chunk_size - buffer_size) :], + ) + + buffer.append(head) + yield b"".join(buffer) + buffer, buffer_size = [], 0 + chunk = tail + continue + + chunk = next(stream) + except StopIteration: + if buffer: + yield b"".join(buffer) + + +async def async_drain_by_chunks( + stream: typing.AsyncIterator[bytes], chunk_size: int = 512 +) -> typing.AsyncIterator[bytes]: + buffer, buffer_size = [], 0 + + try: + chunk = await stream.__anext__() + + while True: + last_chunk_size = len(chunk) + + if buffer_size + last_chunk_size < chunk_size: + buffer.append(chunk) + buffer_size += last_chunk_size + elif buffer_size + last_chunk_size == chunk_size: + buffer.append(chunk) + yield b"".join(buffer) + buffer, buffer_size = [], 0 + else: + head, tail = ( + chunk[: (chunk_size - buffer_size)], + chunk[(chunk_size - buffer_size) :], + ) + + buffer.append(head) + yield b"".join(buffer) + buffer, buffer_size = [], 0 + chunk = tail + continue + + chunk = await stream.__anext__() + except StopAsyncIteration: + if buffer: + yield b"".join(buffer) diff --git a/tests/test_utils.py b/tests/test_utils.py index ae4b3aa96c..552f4faa59 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import os import random +import typing import pytest @@ -7,6 +8,8 @@ from httpx._utils import ( NetRCInfo, URLPattern, + async_drain_by_chunks, + drain_by_chunks, get_ca_bundle_from_env, get_environment_proxies, guess_json_utf, @@ -257,3 +260,61 @@ def test_pattern_priority(): URLPattern("http://"), URLPattern("all://"), ] + + +@pytest.mark.parametrize( + "data", + [ + [b"1", b"2", b"3"], + [b"1", b"abcdefghijklmnop", b"2"], + [b"123456", b"3"], + [b"1", b"23", b"456", b"7890", b"abcde", b"fghijk"], + [b""], + ], +) +@pytest.mark.parametrize( + "chunk_size", + [1, 2, 3, 5, 10, 11], +) +def test_drain_by_chunks(data, chunk_size): + iterator = iter(data) + chunk_sizes = [] + for chunk in drain_by_chunks(iterator, chunk_size): + chunk_sizes.append(len(chunk)) + + *head, tail = chunk_sizes + + assert tail <= chunk_size + assert [chunk_size] * len(head) == head + + +async def _async_iter(data: typing.List) -> typing.AsyncIterator: + for chunk in data: + yield chunk + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "data", + [ + [b"1", b"2", b"3"], + [b"1", b"abcdefghijklmnop", b"2"], + [b"123456", b"3"], + [b"1", b"23", b"456", b"7890", b"abcde", b"fghijk"], + [b""], + ], +) +@pytest.mark.parametrize( + "chunk_size", + [1, 2, 3, 5, 10, 11], +) +async def test_async_drain_by_chunks(data, chunk_size): + async_iterator = _async_iter(data) + chunk_sizes = [] + async for chunk in async_drain_by_chunks(async_iterator, chunk_size): + chunk_sizes.append(len(chunk)) + + *head, tail = chunk_sizes + + assert tail <= chunk_size + assert [chunk_size] * len(head) == head