From dfc902726a9438231facb05d597151575e34695e Mon Sep 17 00:00:00 2001 From: Joakim Saario Date: Fri, 25 Sep 2020 01:57:45 +0200 Subject: [PATCH] Add support for `params` parameter --- respx/api.py | 24 +++++++++++++++++++++++- respx/models.py | 19 +++++++++++++++++-- respx/transports.py | 17 +++++++++++++++++ tests/test_api.py | 18 ++++++++++++++++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/respx/api.py b/respx/api.py index 86a9c60..325a3f7 100644 --- a/respx/api.py +++ b/respx/api.py @@ -1,7 +1,13 @@ from typing import Callable, Optional, Pattern, Union, overload from .mocks import MockTransport -from .models import ContentDataTypes, DefaultType, HeaderTypes, RequestPattern +from .models import ( + ContentDataTypes, + DefaultType, + HeaderTypes, + QueryParamTypes, + RequestPattern, +) mock = MockTransport(assert_all_called=False) @@ -49,6 +55,7 @@ def add( method: Union[str, Callable], url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -60,6 +67,7 @@ def add( return mock.add( method, url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -72,6 +80,7 @@ def add( def get( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -82,6 +91,7 @@ def get( global mock return mock.get( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -94,6 +104,7 @@ def get( def post( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -104,6 +115,7 @@ def post( global mock return mock.post( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -116,6 +128,7 @@ def post( def put( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -126,6 +139,7 @@ def put( global mock return mock.put( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -138,6 +152,7 @@ def put( def patch( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -148,6 +163,7 @@ def patch( global mock return mock.patch( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -160,6 +176,7 @@ def patch( def delete( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -170,6 +187,7 @@ def delete( global mock return mock.delete( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -182,6 +200,7 @@ def delete( def head( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -192,6 +211,7 @@ def head( global mock return mock.head( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -204,6 +224,7 @@ def head( def options( url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -214,6 +235,7 @@ def options( global mock return mock.options( url=url, + params=params, status_code=status_code, content=content, content_type=content_type, diff --git a/respx/models.py b/respx/models.py index 4e3301a..488c590 100644 --- a/respx/models.py +++ b/respx/models.py @@ -63,6 +63,7 @@ URLPatternTypes = Union[str, Pattern[str], URL] JSONTypes = Union[str, List, Dict] ContentDataTypes = Union[bytes, str, JSONTypes, Callable, Exception] +QueryParamTypes = Union[bytes, str, List[Tuple[str, Any]], Dict[str, Any]] istype = lambda t, o: isinstance(o, t) isregex = partial(istype, Regex) @@ -262,6 +263,7 @@ def __init__( self, method: Union[str, Callable], url: Optional[URLPatternTypes], + params: Optional[QueryParamTypes] = None, response: Optional[ResponseTemplate] = None, pass_through: bool = False, alias: Optional[str] = None, @@ -276,7 +278,7 @@ def __init__( self._match_func = method else: self.method = method.upper() - self.set_url(url, base=base_url) + self.set_url(url, base=base_url, params=params) self.pass_through = pass_through self.response = response or ResponseTemplate() @@ -301,26 +303,39 @@ def get_url(self) -> Optional[URLPatternTypes]: return self._url def set_url( - self, url: Optional[URLPatternTypes], base: Optional[str] = None + self, + url: Optional[URLPatternTypes], + base: Optional[str] = None, + params: Optional[QueryParamTypes] = None, ) -> None: + params = str(httpx.QueryParams(params)) url = url or None if url is None: url = base + if base is not None and params: + url = url + f"?{params}" elif isinstance(url, str): url = url if base is None else urljoin(base, url) parsed_url = urlparse(url) if not parsed_url.path: url = parsed_url._replace(path="/").geturl() + if url and params: + url = url + f"?{params}" elif isinstance(url, tuple): url = self.build_url(url) + if url and params: + url = url + f"?{params}" elif isregex(url): url = url if base is None else re.compile(urljoin(base, url.pattern)) + if params: + url = re.compile(url.pattern + re.escape(f"?{params}")) else: raise ValueError( "Request url pattern must be str or compiled regex, got {}.".format( type(url).__name__ ) ) + self._url = url url = property(get_url, set_url) diff --git a/respx/transports.py b/respx/transports.py index 6e55660..27e6d20 100644 --- a/respx/transports.py +++ b/respx/transports.py @@ -15,6 +15,7 @@ DefaultType, Headers, HeaderTypes, + QueryParamTypes, Request, RequestPattern, Response, @@ -91,6 +92,7 @@ def add( method: Union[str, Callable, RequestPattern], url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -111,6 +113,7 @@ def add( pattern = RequestPattern( method, url, + params=params, response, pass_through=pass_through, alias=alias, @@ -127,6 +130,7 @@ def get( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -137,6 +141,7 @@ def get( return self.add( "GET", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -149,6 +154,7 @@ def post( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -159,6 +165,7 @@ def post( return self.add( "POST", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -171,6 +178,7 @@ def put( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -181,6 +189,7 @@ def put( return self.add( "PUT", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -193,6 +202,7 @@ def patch( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -203,6 +213,7 @@ def patch( return self.add( "PATCH", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -215,6 +226,7 @@ def delete( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -225,6 +237,7 @@ def delete( return self.add( "DELETE", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -237,6 +250,7 @@ def head( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -247,6 +261,7 @@ def head( return self.add( "HEAD", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, @@ -259,6 +274,7 @@ def options( self, url: Optional[Union[str, Pattern]] = None, *, + params: Optional[QueryParamTypes] = None, status_code: Optional[int] = None, content: Optional[ContentDataTypes] = None, content_type: Optional[str] = None, @@ -269,6 +285,7 @@ def options( return self.add( "OPTIONS", url=url, + params=params, status_code=status_code, content=content, content_type=content_type, diff --git a/tests/test_api.py b/tests/test_api.py index 03d25bf..3cb21a1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -446,3 +446,21 @@ def test_pop(): respx.get("https://foo.bar/", alias="foobar") request_pattern = respx.pop("foobar") assert request_pattern.url == "https://foo.bar/" + + +@respx.mock +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url,params,call_url", + [ + ("https://foo/", "foo=bar&foo=foo", "https://foo/"), + ("https://foo/", b"foo=bar&foo=foo", "https://foo/"), + ("https://foo/", [("foo", "bar"), ("foo", "foo")], "https://foo/"), + ("https://foo/", {"foo": "bar"}, "https://foo/"), + (re.compile(r"https://foo/(?P\w+)/"), {"foo": "bar"}, "https://foo/baz/"), + ], +) +async def test_params(client, url, params, call_url): + respx.get(url, params=params, content="spam spam") + response = await client.get(call_url, params=params) + assert response.text == "spam spam"