diff --git a/docs/advanced.md b/docs/advanced.md index b2a07df371..0f0b2ddf72 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth): ... ``` +If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`. + +```python +import asyncio +import threading +import httpx + + +class MyCustomAuth(httpx.Auth): + def __init__(self): + self._sync_lock = threading.RLock() + self._async_lock = asyncio.Lock() + + def sync_get_token(self): + with self._sync_lock: + ... + + def sync_auth_flow(self, request): + token = self.sync_get_token() + request.headers["Authorization"] = f"Token {token}" + yield request + + async def async_get_token(self): + async with self._async_lock: + ... + + async def async_auth_flow(self, request): + token = await self.async_get_token() + request.headers["Authorization"] = f"Token {token}" + yield request +``` + +If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`. + +```python +import httpx +import sync_only_library + + +class MyCustomAuth(httpx.Auth): + def sync_auth_flow(self, request): + token = sync_only_library.get_token(...) + request.headers["Authorization"] = f"Token {token}" + yield request + + async def async_auth_flow(self, request): + raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient") +``` + ## SSL certificates When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA). diff --git a/httpx/_auth.py b/httpx/_auth.py index eb110dea3a..439f337fbf 100644 --- a/httpx/_auth.py +++ b/httpx/_auth.py @@ -17,6 +17,11 @@ class Auth: To implement a custom authentication scheme, subclass `Auth` and override the `.auth_flow()` method. + + If the authentication scheme does I/O such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.sync_auth_flow()` + and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized + implementations that will be used by `Client` and `AsyncClient` respectively. """ requires_request_body = False @@ -46,6 +51,56 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non """ yield request + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow synchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + request.read() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + response.read() + + try: + request = flow.send(response) + except StopIteration: + break + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + await request.aread() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break + class FunctionAuth(Auth): """ diff --git a/httpx/_client.py b/httpx/_client.py index 0b67a78ddd..61a862bde3 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -785,15 +785,12 @@ def _send_handling_auth( auth: Auth, timeout: Timeout, ) -> Response: - if auth.requires_request_body: - request.read() - - auth_flow = auth.auth_flow(request) + auth_flow = auth.sync_auth_flow(request) request = next(auth_flow) + while True: response = self._send_single_request(request, timeout) - if auth.requires_response_body: - response.read() + try: next_request = auth_flow.send(response) except StopIteration: @@ -1409,18 +1406,15 @@ async def _send_handling_auth( auth: Auth, timeout: Timeout, ) -> Response: - if auth.requires_request_body: - await request.aread() + auth_flow = auth.async_auth_flow(request) + request = await auth_flow.__anext__() - auth_flow = auth.auth_flow(request) - request = next(auth_flow) while True: response = await self._send_single_request(request, timeout) - if auth.requires_response_body: - await response.aread() + try: - next_request = auth_flow.send(response) - except StopIteration: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: return response except BaseException as exc: await response.aclose() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index a08c3292fd..c6c6d979ac 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,5 +1,12 @@ +""" +Integration tests for authentication. + +Unit tests for auth classes also exist in tests/test_auth.py +""" +import asyncio import hashlib import os +import threading import typing import httpcore @@ -183,6 +190,31 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non yield request +class SyncOrAsyncAuth(Auth): + """ + A mock authentication scheme that uses a different implementation for the + sync and async cases. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + with self._lock: + request.headers["Authorization"] = "sync-auth" + yield request + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + async with self._async_lock: + request.headers["Authorization"] = "async-auth" + yield request + + @pytest.mark.asyncio async def test_basic_auth() -> None: url = "https://example.org/" @@ -664,3 +696,34 @@ def test_sync_auth_reads_response_body() -> None: assert response.status_code == 200 assert response.json() == {"auth": '{"auth": "xyz"}'} + + +@pytest.mark.asyncio +async def test_async_auth() -> None: + """ + Test that we can use an auth implementation specific to the async case, to + support cases that require performing I/O or using concurrency primitives (such + as checking a disk-based cache or fetching a token from a remote auth server). + """ + url = "https://example.org/" + auth = SyncOrAsyncAuth() + + async with httpx.AsyncClient(transport=AsyncMockTransport()) as client: + response = await client.get(url, auth=auth) + + assert response.status_code == 200 + assert response.json() == {"auth": "async-auth"} + + +def test_sync_auth() -> None: + """ + Test that we can use an auth implementation specific to the sync case. + """ + url = "https://example.org/" + auth = SyncOrAsyncAuth() + + with httpx.Client(transport=SyncMockTransport()) as client: + response = client.get(url, auth=auth) + + assert response.status_code == 200 + assert response.json() == {"auth": "sync-auth"} diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000000..20c666a88c --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,63 @@ +""" +Unit tests for auth classes. + +Integration tests also exist in tests/client/test_auth.py +""" +import pytest + +import httpx + + +def test_basic_auth(): + auth = httpx.BasicAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should include a basic auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert request.headers["Authorization"].startswith("Basic") + + # No other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response) + + +def test_digest_auth_with_200(): + auth = httpx.DigestAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should not include an auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert "Authorization" not in request.headers + + # If a 200 response is returned, then no other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response) + + +def test_digest_auth_with_401(): + auth = httpx.DigestAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should not include an auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert "Authorization" not in request.headers + + # If a 401 response is returned, then a digest auth request is made. + headers = { + "WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."' + } + response = httpx.Response( + content=b"Auth required", status_code=401, headers=headers + ) + request = flow.send(response) + assert request.headers["Authorization"].startswith("Digest") + + # No other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response)