Skip to content

Commit

Permalink
Add Client.auth setter (#1185)
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca authored Aug 17, 2020
1 parent 34ba0e1 commit cb620e6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 20 deletions.
4 changes: 2 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@

::: httpx.Client
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send close
:members: headers cookies params auth request get head options post put patch delete build_request send close

## `AsyncClient`

::: httpx.AsyncClient
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send aclose
:members: headers cookies params auth request get head options post put patch delete build_request send aclose


## `Response`
Expand Down
45 changes: 33 additions & 12 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
):
self._base_url = self._enforce_trailing_slash(URL(base_url))

self.auth = auth
self._auth = self._build_auth(auth)
self._params = QueryParams(params)
self._headers = Headers(headers)
self._cookies = Cookies(cookies)
Expand Down Expand Up @@ -117,6 +117,21 @@ def timeout(self) -> Timeout:
def timeout(self, timeout: TimeoutTypes) -> None:
self._timeout = Timeout(timeout)

@property
def auth(self) -> typing.Optional[Auth]:
"""
Authentication class used when none is passed at the request-level.
See also [Authentication][0].
[0]: /quickstart/#authentication
"""
return self._auth

@auth.setter
def auth(self, auth: AuthTypes) -> None:
self._auth = self._build_auth(auth)

@property
def base_url(self) -> URL:
"""
Expand Down Expand Up @@ -284,19 +299,25 @@ def _merge_queryparams(
return merged_queryparams
return params

def _build_auth(
def _build_auth(self, auth: AuthTypes) -> typing.Optional[Auth]:
if auth is None:
return None
elif isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
else:
raise TypeError('Invalid "auth" argument.')

def _build_request_auth(
self, request: Request, auth: typing.Union[AuthTypes, UnsetType] = UNSET
) -> Auth:
auth = self.auth if isinstance(auth, UnsetType) else auth
auth = self._auth if isinstance(auth, UnsetType) else self._build_auth(auth)

if auth is not None:
if isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
raise TypeError('Invalid "auth" argument.')
return auth

username, password = request.url.username, request.url.password
if username or password:
Expand Down Expand Up @@ -667,7 +688,7 @@ def send(
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)

auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)

response = self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
Expand Down Expand Up @@ -1269,7 +1290,7 @@ async def send(
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)

auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)

response = await self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
Expand Down
33 changes: 27 additions & 6 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
URL,
AsyncClient,
Auth,
BasicAuth,
Client,
DigestAuth,
ProtocolError,
Expand Down Expand Up @@ -310,14 +311,34 @@ async def test_auth_hidden_header() -> None:


@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
async def test_auth_property() -> None:
client = AsyncClient(transport=AsyncMockTransport())
assert client.auth is None

client.auth = ("tomchristie", "password123") # type: ignore
assert isinstance(client.auth, BasicAuth)

url = "https://example.org/"
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}


@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
with pytest.raises(TypeError):
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)

client = AsyncClient(transport=AsyncMockTransport())

with pytest.raises(TypeError):
await client.get(auth="not a tuple, not a callable") # type: ignore

with pytest.raises(TypeError):
await client.get(url)
client.auth = "not a tuple, not a callable" # type: ignore


@pytest.mark.asyncio
Expand Down

0 comments on commit cb620e6

Please sign in to comment.