Skip to content

Commit

Permalink
clean up code, keep the breaking change but add parameter to revert it
Browse files Browse the repository at this point in the history
  • Loading branch information
Wade Roberts committed Jan 20, 2024
1 parent e75d572 commit 364f28e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
16 changes: 10 additions & 6 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ def __init__(
root_path: str = "",
*,
app_state: dict[str, typing.Any],
client: typing.Optional[typing.List[typing.Union[str, int]]],
scope_client: typing.Optional[typing.List[typing.Union[str, int]]],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
self.scope_client = scope_client

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -270,7 +270,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"client": self.scope_client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
Expand All @@ -288,7 +288,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"client": self.scope_client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
Expand Down Expand Up @@ -402,7 +402,7 @@ def __init__(
cookies: httpx._types.CookieTypes | None = None,
headers: typing.Dict[str, str] | None = None,
follow_redirects: bool = True,
client: typing.Optional[typing.List[typing.Union[str, int]]] = ["testclient", 50000],
scope_client: typing.Optional[typing.Tuple[str, int]] = None,
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
Expand All @@ -420,7 +420,11 @@ def __init__(
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
scope_client=(
[scope_client[0], scope_client[1]]
if scope_client is not None
else scope_client
),
)
if headers is None:
headers = {}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,34 @@ async def asgi(receive: Receive, send: Send):
assert data == {"message": "test"}


@pytest.mark.parametrize("scope_client", (None, ["testclient", 50000]))
def test_scope_client(scope_client, anyio_backend_name, anyio_backend_options):
async def app(scope, receive, send):
client = scope.get("client")
host = None
port = None
if client is not None:
host, port = client
response = JSONResponse({"host": host, "port": port})
await response(scope, receive, send)

test_client = TestClient(
app,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
scope_client=scope_client,
)
test_response = test_client.get("/")

if scope_client is None:
assert test_response.json() == {"host": None, "port": None}
else:
assert test_response.json() == {
"host": scope_client[0],
"port": scope_client[1],
}


def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
def app(scope: Scope):
async def asgi(receive: Receive, send: Send):
Expand Down

0 comments on commit 364f28e

Please sign in to comment.