Skip to content

Commit

Permalink
Swap auth/redirects ordering (#1267)
Browse files Browse the repository at this point in the history
* Internal refactoring to swap auth/redirects ordering

* Test for auth with cross domain redirect
  • Loading branch information
tomchristie authored Sep 10, 2020
1 parent 016e4ee commit 4d950e5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
64 changes: 34 additions & 30 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,12 @@ def send(

auth = self._build_request_auth(request, auth)

response = self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
response = self._send_handling_auth(
request,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
history=[],
)

if not stream:
Expand All @@ -740,23 +744,17 @@ def send(
def _send_handling_redirects(
self,
request: Request,
auth: Auth,
timeout: Timeout,
allow_redirects: bool = True,
history: typing.List[Response] = None,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
if history is None:
history = []

while True:
if len(history) > self.max_redirects:
raise TooManyRedirects(
"Exceeded maximum allowed redirects.", request=request
)

response = self._send_handling_auth(
request, auth=auth, timeout=timeout, history=history
)
response = self._send_single_request(request, timeout)
response.history = list(history)

if not response.is_redirect:
Expand All @@ -771,7 +769,6 @@ def _send_handling_redirects(
response.call_next = functools.partial(
self._send_handling_redirects,
request=request,
auth=auth,
timeout=timeout,
allow_redirects=False,
history=history,
Expand All @@ -781,16 +778,21 @@ def _send_handling_redirects(
def _send_handling_auth(
self,
request: Request,
history: typing.List[Response],
auth: Auth,
timeout: Timeout,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
auth_flow = auth.sync_auth_flow(request)
request = next(auth_flow)

while True:
response = self._send_single_request(request, timeout)

response = self._send_handling_redirects(
request,
timeout=timeout,
allow_redirects=allow_redirects,
history=history,
)
try:
next_request = auth_flow.send(response)
except StopIteration:
Expand Down Expand Up @@ -1346,8 +1348,12 @@ async def send(

auth = self._build_request_auth(request, auth)

response = await self._send_handling_redirects(
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects
response = await self._send_handling_auth(
request,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
history=[],
)

if not stream:
Expand All @@ -1361,23 +1367,17 @@ async def send(
async def _send_handling_redirects(
self,
request: Request,
auth: Auth,
timeout: Timeout,
allow_redirects: bool = True,
history: typing.List[Response] = None,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
if history is None:
history = []

while True:
if len(history) > self.max_redirects:
raise TooManyRedirects(
"Exceeded maximum allowed redirects.", request=request
)

response = await self._send_handling_auth(
request, auth=auth, timeout=timeout, history=history
)
response = await self._send_single_request(request, timeout)
response.history = list(history)

if not response.is_redirect:
Expand All @@ -1392,7 +1392,6 @@ async def _send_handling_redirects(
response.call_next = functools.partial(
self._send_handling_redirects,
request=request,
auth=auth,
timeout=timeout,
allow_redirects=False,
history=history,
Expand All @@ -1402,16 +1401,21 @@ async def _send_handling_redirects(
async def _send_handling_auth(
self,
request: Request,
history: typing.List[Response],
auth: Auth,
timeout: Timeout,
allow_redirects: bool,
history: typing.List[Response],
) -> Response:
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.__anext__()

while True:
response = await self._send_single_request(request, timeout)

response = await self._send_handling_redirects(
request,
timeout=timeout,
allow_redirects=allow_redirects,
history=history,
)
try:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:
Expand Down
10 changes: 9 additions & 1 deletion tests/client/test_redirects.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def test_redirect_loop():
client.get("https://example.org/redirect_loop")


def test_cross_domain_redirect():
def test_cross_domain_redirect_with_auth_header():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.com/cross_domain"
headers = {"Authorization": "abc"}
Expand All @@ -332,6 +332,14 @@ def test_cross_domain_redirect():
assert "authorization" not in response.json()["headers"]


def test_cross_domain_redirect_with_auth():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.com/cross_domain"
response = client.get(url, auth=("user", "pass"))
assert response.url == "https://example.org/cross_domain_target"
assert "authorization" not in response.json()["headers"]


def test_same_domain_redirect():
client = httpx.Client(transport=SyncMockTransport())
url = "https://example.org/cross_domain"
Expand Down

0 comments on commit 4d950e5

Please sign in to comment.