Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Client.auth setter #1185

Merged
merged 2 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@

::: httpx.Client
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send close
:members: headers cookies params auth request get head options post put patch delete build_request send close

## `AsyncClient`

::: httpx.AsyncClient
:docstring:
:members: headers cookies params request get head options post put patch delete build_request send aclose
:members: headers cookies params auth request get head options post put patch delete build_request send aclose


## `Response`
Expand Down
45 changes: 33 additions & 12 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
):
self._base_url = self._enforce_trailing_slash(URL(base_url))

self.auth = auth
self._auth = self._build_auth(auth)
self._params = QueryParams(params)
self._headers = Headers(headers)
self._cookies = Cookies(cookies)
Expand Down Expand Up @@ -117,6 +117,21 @@ def timeout(self) -> Timeout:
def timeout(self, timeout: TimeoutTypes) -> None:
self._timeout = Timeout(timeout)

@property
def auth(self) -> typing.Optional[Auth]:
"""
Authentication class used when none is passed at the request-level.

See also [Authentication][0].

[0]: /quickstart/#authentication
"""
return self._auth

@auth.setter
def auth(self, auth: AuthTypes) -> None:
self._auth = self._build_auth(auth)

@property
def base_url(self) -> URL:
"""
Expand Down Expand Up @@ -284,19 +299,25 @@ def _merge_queryparams(
return merged_queryparams
return params

def _build_auth(
def _build_auth(self, auth: AuthTypes) -> typing.Optional[Auth]:
if auth is None:
return None
elif isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
else:
raise TypeError('Invalid "auth" argument.')

def _build_request_auth(
self, request: Request, auth: typing.Union[AuthTypes, UnsetType] = UNSET
) -> Auth:
auth = self.auth if isinstance(auth, UnsetType) else auth
auth = self._auth if isinstance(auth, UnsetType) else self._build_auth(auth)

if auth is not None:
if isinstance(auth, tuple):
return BasicAuth(username=auth[0], password=auth[1])
elif isinstance(auth, Auth):
return auth
elif callable(auth):
return FunctionAuth(func=auth)
raise TypeError('Invalid "auth" argument.')
return auth

username, password = request.url.username, request.url.password
if username or password:
Expand Down Expand Up @@ -667,7 +688,7 @@ def send(
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)

auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)

response = self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
Expand Down Expand Up @@ -1269,7 +1290,7 @@ async def send(
"""
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)

auth = self._build_auth(request, auth)
auth = self._build_request_auth(request, auth)

response = await self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
Expand Down
33 changes: 27 additions & 6 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
URL,
AsyncClient,
Auth,
BasicAuth,
Client,
DigestAuth,
ProtocolError,
Expand Down Expand Up @@ -310,14 +311,34 @@ async def test_auth_hidden_header() -> None:


@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
async def test_auth_property() -> None:
client = AsyncClient(transport=AsyncMockTransport())
assert client.auth is None

client.auth = ("tomchristie", "password123") # type: ignore
assert isinstance(client.auth, BasicAuth)

url = "https://example.org/"
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}


@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
with pytest.raises(TypeError):
client = AsyncClient(
transport=AsyncMockTransport(),
auth="not a tuple, not a callable", # type: ignore
)

client = AsyncClient(transport=AsyncMockTransport())

with pytest.raises(TypeError):
await client.get(auth="not a tuple, not a callable") # type: ignore

with pytest.raises(TypeError):
await client.get(url)
client.auth = "not a tuple, not a callable" # type: ignore


@pytest.mark.asyncio
Expand Down