From 24088f6c92b83b8ca98169a3c36c166fc23cf680 Mon Sep 17 00:00:00 2001 From: Joakim Saario Date: Sat, 10 Oct 2020 01:56:21 +0200 Subject: [PATCH] Support HTTPX params argument --- respx/api.py | 25 ++++++++++- respx/models.py | 89 ++++++++++++++++++---------------------- respx/transports.py | 19 ++++++++- tests/test_api.py | 56 ++++++++++++++++++++++++- tests/test_transports.py | 33 +++++++-------- 5 files changed, 153 insertions(+), 69 deletions(-) diff --git a/respx/api.py b/respx/api.py index 2dbb32a..1aa5849 100644 --- a/respx/api.py +++ b/respx/api.py @@ -1,7 +1,14 @@ from typing import Callable, Optional, Pattern, Union, overload from .mocks import MockTransport -from .models import CallList, ContentDataTypes, DefaultType, HeaderTypes, RequestPattern +from .models import ( + CallList, + ContentDataTypes, + DefaultType, + HeaderTypes, + QueryParamTypes, + RequestPattern, +) mock = MockTransport(assert_all_called=False) @@ -49,6 +56,7 @@ def add( method: Union[str, Callable], url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -60,6 +68,7 @@ def add( return mock.add( method, url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -72,6 +81,7 @@ def add( def get( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -82,6 +92,7 @@ def get( global mock return mock.get( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -94,6 +105,7 @@ def get( def post( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -104,6 +116,7 @@ def post( global mock return mock.post( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -116,6 +129,7 @@ def post( def put( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -126,6 +140,7 @@ def put( global mock return mock.put( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -138,6 +153,7 @@ def put( def patch( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -148,6 +164,7 @@ def patch( global mock return mock.patch( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -160,6 +177,7 @@ def patch( def delete( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -170,6 +188,7 @@ def delete( global mock return mock.delete( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -182,6 +201,7 @@ def delete( def head( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -192,6 +212,7 @@ def head( global mock return mock.head( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -204,6 +225,7 @@ def head( def options( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -214,6 +236,7 @@ def options( global mock return mock.options( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, diff --git a/respx/models.py b/respx/models.py index 881ea3e..e52763d 100644 --- a/respx/models.py +++ b/respx/models.py @@ -1,6 +1,5 @@ import inspect import re -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -19,7 +18,7 @@ Union, ) from unittest import mock -from urllib.parse import urljoin, urlparse +from urllib.parse import urljoin import httpx from httpcore import AsyncByteStream, SyncByteStream @@ -65,14 +64,42 @@ DefaultType = TypeVar("DefaultType", bound=Any) -Regex = type(re.compile("")) Kwargs = Dict[str, Any] -URLPatternTypes = Union[str, Pattern[str], URL] +URLPatternTypes = Union[str, Pattern[str], URL, httpx.URL] JSONTypes = Union[str, List, Dict] ContentDataTypes = Union[bytes, str, JSONTypes, Callable, Exception] +QueryParamTypes = Union[bytes, str, List[Tuple[str, Any]], Dict[str, Any]] -istype = lambda t, o: isinstance(o, t) -isregex = partial(istype, Regex) + +def build_url( + url: URLPatternTypes, *, base: str = "", params: Optional[QueryParamTypes] = None +) -> Union[httpx.URL, Pattern[str]]: + if not url: + if params is not None: + raise ValueError("Params cannot be used with empty url.") + return None + + if isinstance(url, Pattern): + if params is not None: + if r"\?" in url.pattern and params is not None: + raise ValueError( + "Request url pattern contains a query string, which is not " + "supported in conjuction with params argument." + ) + query_params = str(httpx.QueryParams(params)) + url = re.compile(url.pattern + re.escape(fr"?{query_params}")) + return re.compile(urljoin(base, url.pattern)) + if isinstance(url, str): + url = urljoin(base, url) + + try: + return httpx.URL(url, params=params) + except TypeError: + raise ValueError( + "Request url pattern must be str or compiled regex, got {}.".format( + type(url).__name__ + ) + ) def decode_request(request: Request) -> httpx.Request: @@ -287,6 +314,7 @@ def __init__( self, method: Union[str, Callable], url: Optional[URLPatternTypes], + params: Optional[QueryParamTypes] = None, response: Optional[ResponseTemplate] = None, pass_through: bool = False, alias: Optional[str] = None, @@ -301,7 +329,7 @@ def __init__( self._match_func = method else: self.method = method.upper() - self.set_url(url, base=base_url) + self.url = build_url(url or "", base=base_url, params=params) self.pass_through = pass_through self.response = response or ResponseTemplate() @@ -320,43 +348,6 @@ def call_count(self) -> int: def calls(self) -> CallList: return CallList.from_unittest_call_list(self.stats.call_args_list) - def get_url(self) -> Optional[URLPatternTypes]: - return self._url - - def set_url( - self, url: Optional[URLPatternTypes], base: Optional[str] = None - ) -> None: - url = url or None - if url is None: - url = base - elif isinstance(url, str): - url = url if base is None else urljoin(base, url) - parsed_url = urlparse(url) - if not parsed_url.path: - url = parsed_url._replace(path="/").geturl() - elif isinstance(url, tuple): - url = self.build_url(url) - elif isregex(url): - url = url if base is None else re.compile(urljoin(base, url.pattern)) - else: - raise ValueError( - "Request url pattern must be str or compiled regex, got {}.".format( - type(url).__name__ - ) - ) - self._url = url - - url = property(get_url, set_url) - - def build_url(self, parts: URL) -> str: - scheme, host, port, full_path = parts - port_str = ( - "" - if not port or port == {b"https": 443, b"http": 80}[scheme] - else f":{port}" - ) - return f"{scheme.decode()}://{host.decode()}{port_str}{full_path.decode()}" - def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]: """ Matches request with configured pattern; @@ -383,13 +374,13 @@ def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]: if self.method != request_method.decode(): return None - request_url = self.build_url(_request_url) - if not self._url: + if not self.url: matches = True - elif isinstance(self._url, str): - matches = self._url == request_url + elif isinstance(self.url, httpx.URL): + matches = self.url.raw == _request_url else: - match = self._url.match(request_url) + request_url = build_url(_request_url) + match = self.url.match(str(request_url)) if match: matches = True url_params = match.groupdict() diff --git a/respx/transports.py b/respx/transports.py index 5e85fca..41e202f 100644 --- a/respx/transports.py +++ b/respx/transports.py @@ -16,6 +16,7 @@ DefaultType, Headers, HeaderTypes, + QueryParamTypes, Request, RequestPattern, ResponseTemplate, @@ -91,6 +92,7 @@ def add( method: Union[str, Callable, RequestPattern], url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -111,7 +113,8 @@ def add( pattern = RequestPattern( method, url, - response, + params=params, + response=response, pass_through=pass_through, alias=alias, base_url=self._base_url, @@ -127,6 +130,7 @@ def get( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -137,6 +141,7 @@ def get( return self.add( "GET", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -149,6 +154,7 @@ def post( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -159,6 +165,7 @@ def post( return self.add( "POST", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -171,6 +178,7 @@ def put( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -181,6 +189,7 @@ def put( return self.add( "PUT", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -193,6 +202,7 @@ def patch( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -203,6 +213,7 @@ def patch( return self.add( "PATCH", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -215,6 +226,7 @@ def delete( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -225,6 +237,7 @@ def delete( return self.add( "DELETE", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -237,6 +250,7 @@ def head( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -247,6 +261,7 @@ def head( return self.add( "HEAD", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -259,6 +274,7 @@ def options( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -269,6 +285,7 @@ def options( return self.add( "OPTIONS", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, diff --git a/tests/test_api.py b/tests/test_api.py index 03d25bf..f4df2c9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -68,12 +68,13 @@ async def test_http_methods(client): @pytest.mark.parametrize( "url,pattern", [ - ("https://foo.bar", "https://foo.bar"), + ("https://foo.bar", "https://foo.bar/"), ("https://foo.bar/baz/", None), ("https://foo.bar/baz/", ""), ("https://foo.bar/baz/", "https://foo.bar/baz/"), ("https://foo.bar/baz/", re.compile(r"^https://foo.bar/\w+/$")), - ("https://foo.bar/baz/", (b"https", b"foo.bar", 443, b"/baz/")), + ("https://foo.bar/baz/", (b"https", b"foo.bar", None, b"/baz/")), + ("https://foo.bar:443/baz/", (b"https", b"foo.bar", 443, b"/baz/")), ], ) async def test_url_match(client, url, pattern): @@ -446,3 +447,54 @@ def test_pop(): respx.get("https://foo.bar/", alias="foobar") request_pattern = respx.pop("foobar") assert request_pattern.url == "https://foo.bar/" + + +@respx.mock +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url,params,call_url,call_params", + [ + ( + "https://foo/", + "foo=bar&foo=foo", + "https://foo/", + [("foo", "bar"), ("foo", "foo")], + ), + ( + "https://foo/", + b"foo=bar&foo=foo", + "https://foo/", + [("foo", "bar"), ("foo", "foo")], + ), + ( + "https://foo/", + [("foo", "bar"), ("foo", "foo")], + "https://foo/?foo=bar&foo=foo", + None, + ), + ("https://foo/", {"foo": "bar"}, "https://foo/?foo=bar", None), + ("https://foo/", {"foo": "bar"}, "https://foo/", {"foo": "bar"}), + ( + "https://foo/?baz=buz", + {"foo": "bar"}, + "https://foo/?baz=buz", + {"foo": "bar"}, + ), + ( + re.compile(r"https://foo/(?P\w+)/"), + {"foo": "bar"}, + "https://foo/baz/?foo=bar", + None, + ), + ( + re.compile(r"https://foo/(?P\w+)/"), + {"foo": "bar"}, + "https://foo/baz/", + {"foo": "bars"}, + ), + ], +) +async def test_params(client, url, params, call_url, call_params): + respx.get(url, params=params, content="spam spam") + response = await client.get(call_url, params=call_params) + assert response.text == "spam spam" diff --git a/tests/test_transports.py b/tests/test_transports.py index c96669f..3b34ed8 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -61,22 +61,23 @@ async def test_transport_assertions(): @pytest.mark.asyncio async def test_httpcore_request(): async with MockTransport() as transport: - transport.add("GET", "https://foo.bar/", content="foobar") - with httpcore.SyncConnectionPool() as http: - (status_code, headers, stream, ext) = http.request( - method=b"GET", url=(b"https", b"foo.bar", 443, b"/") - ) - - body = b"".join([chunk for chunk in stream]) - assert body == b"foobar" - - async with httpcore.AsyncConnectionPool() as http: - (status_code, headers, stream, ext) = await http.arequest( - method=b"GET", url=(b"https", b"foo.bar", 443, b"/") - ) - - body = b"".join([chunk async for chunk in stream]) - assert body == b"foobar" + for url, port in [("https://foo.bar/", None), ("https://foo.bar:443/", 443)]: + transport.add("GET", url, content="foobar") + with httpcore.SyncConnectionPool() as http: + (status_code, headers, stream, ext) = http.request( + method=b"GET", url=(b"https", b"foo.bar", port, b"/") + ) + + body = b"".join([chunk for chunk in stream]) + assert body == b"foobar" + + async with httpcore.AsyncConnectionPool() as http: + (status_code, headers, stream, ext) = await http.arequest( + method=b"GET", url=(b"https", b"foo.bar", port, b"/") + ) + + body = b"".join([chunk async for chunk in stream]) + assert body == b"foobar" @pytest.mark.asyncio