Skip to content

Commit

Permalink
Refine generator runners implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jan 21, 2020
1 parent 5bd5e0f commit 1a5fe2b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 83 deletions.
14 changes: 4 additions & 10 deletions httpx/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def __call__(
return response

if allow_redirects:
yield SyncOrAsync(
for_sync=lambda: response.read(), for_async=lambda: response.aread()
)
yield SyncOrAsync(response.read, response.aread)

request = self.build_redirect_request(request, response, context)
context["history"] = history + [response]

Expand Down Expand Up @@ -268,9 +267,7 @@ def __call__(
history: typing.List[Response] = context.get("history", [])

if auth.requires_request_body:
yield SyncOrAsync(
for_sync=lambda: request.read(), for_async=lambda: request.aread(),
)
yield SyncOrAsync(request.read, request.aread)

auth_flow = auth.auth_flow(request)
request = next(auth_flow)
Expand All @@ -281,10 +278,7 @@ def __call__(
except StopIteration:
return response
except BaseException as exc:
yield SyncOrAsync(
for_sync=lambda: response.close(),
for_async=lambda: response.aclose(),
)
yield SyncOrAsync(response.close, response.aclose)
raise exc from None
else:
response.history = list(history)
Expand Down
104 changes: 31 additions & 73 deletions httpx/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import codecs
import collections
import contextlib
import inspect
import logging
import netrc
import os
Expand All @@ -21,9 +20,6 @@
from .models import URL

T = typing.TypeVar("T")
Y = typing.TypeVar("Y")
S = typing.TypeVar("S")


_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
Expand Down Expand Up @@ -375,19 +371,7 @@ def as_network_error(*exception_classes: type) -> typing.Iterator[None]:


class SyncOrAsync:
"""
Wrapper for values that aren't necessarily syntactically equivalent in
the sync and async cases.
For example:
```python
func = SyncOrAsync(
for_sync=lambda: obj.func(...),
for_async=lambda: obj.afunc(...),
)
```
"""
"""Container for behavior that is available for use in sync or async contexts."""

def __init__(
self,
Expand All @@ -397,46 +381,18 @@ def __init__(
self.for_sync = for_sync
self.for_async = for_async


def consume_generator(gen: typing.Generator[typing.Any, typing.Any, T]) -> T:
"""
Run a generator of regular ("sync") values to completion.
Supports yielding `SyncOrAsync` instances.
"""

value: typing.Any = next(gen)
assert not inspect.isawaitable(value)

while True:
try:
value = gen.send(value)
except StopIteration as exc:
result = exc.value
assert not inspect.isawaitable(result)
return result
except BaseException as exc:
value = gen.throw(type(exc), exc, exc.__traceback__)
else:
if isinstance(value, SyncOrAsync):
value = value.for_sync()
assert not inspect.isawaitable(value)
# Make this object awaitable.
def __await__(self) -> typing.Iterator:
yield from self.for_async().__await__()


async def consume_generator_of_awaitables(
gen: typing.Generator[typing.Any, typing.Any, T]
gen: typing.Generator[typing.Awaitable, typing.Any, T]
) -> T:
"""
Run a generator to completion, awaiting all yielded values.
Any coroutine function can be converted to a generator to be passed as
an argument to this helper function by replacing `await` with `yield`.
Provided any dependencies used to compute the yielded values are switched to a sync
equivalent in the sync case, this means we can use this generator in a
sync context too (running it with `consume_generator()` instead).
So, instead of writing:
Allows transforming async/await code:
```python
async def fetch():
Expand All @@ -446,7 +402,7 @@ async def fetch():
data = await fetch()
```
we can write:
Into generator-based code, that can then be made compatible for sync usage:
```python
def fetch():
Expand All @@ -456,32 +412,34 @@ def fetch():
data = await consume_generator_of_awaitables(fetch())
```
"""
value: typing.Any = None

while True:
try:
value = await gen.send(value)
except StopIteration as exc:
return exc.value
except BaseException as exc:
value = await gen.throw(type(exc), exc, exc.__traceback__)

def unwrap(value: typing.Any) -> typing.Awaitable:
if isinstance(value, SyncOrAsync):
return value.for_async()
return value

coro = unwrap(next(gen))
assert inspect.isawaitable(coro), coro
value: typing.Any = await coro
def consume_generator(gen: typing.Generator[typing.Any, typing.Any, T]) -> T:
"""
Run a generator to completion and return the result, assuming that yielded
values are synchronous (i.e. they're not coroutines).
"""
value: typing.Any = None

while True:
assert not inspect.isawaitable(value)
if isinstance(value, SyncOrAsync):
# Unfortunately there's no equivalent to the '.__await__()' hook
# in the sync case that would allow relying on the language syntax to
# evaluate 'SyncOrAsync' values, so we can't avoid an instance check.
value = value.for_sync()

try:
coro = gen.send(value)
value = gen.send(value)
except StopIteration as exc:
result = exc.value
assert not inspect.isawaitable(result), result
return result
else:
coro = unwrap(coro)
assert inspect.isawaitable(coro), coro

try:
value = await coro
except BaseException as exc:
coro = gen.throw(type(exc), exc, exc.__traceback__)
assert inspect.isawaitable(coro), coro
value = await coro
return exc.value
except BaseException as exc:
value = gen.throw(type(exc), exc, exc.__traceback__)

0 comments on commit 1a5fe2b

Please sign in to comment.