diff --git a/HISTORY.rst b/HISTORY.rst index 09ec4253..2236a97d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,11 @@ History ======= +* Switch from ``urlparse()`` to ``urlsplit()`` for URL parsing, reducing the middleware runtime up to 5%. + This changes the type passed to ``origin_found_in_white_lists()``, so if you have subclassed the middleware to override this method, you should check it is compatible (it most likely is). + + Thanks to Thibaut Decombe in `PR #793 `__. + 3.13.0 (2022-06-05) ------------------- diff --git a/src/corsheaders/checks.py b/src/corsheaders/checks.py index 45d4461b..e56e1086 100644 --- a/src/corsheaders/checks.py +++ b/src/corsheaders/checks.py @@ -3,7 +3,7 @@ import re from collections.abc import Sequence from typing import Any -from urllib.parse import urlparse +from urllib.parse import urlsplit from django.apps import AppConfig from django.conf import settings @@ -87,7 +87,7 @@ def check_settings(app_configs: list[AppConfig], **kwargs: Any) -> list[CheckMes for origin in conf.CORS_ALLOWED_ORIGINS: if origin in special_origin_values: continue - parsed = urlparse(origin) + parsed = urlsplit(origin) if parsed.scheme == "" or parsed.netloc == "": errors.append( Error( @@ -104,7 +104,7 @@ def check_settings(app_configs: list[AppConfig], **kwargs: Any) -> list[CheckMes else: # Only do this check in this case because if the scheme is not # provided, netloc ends up in path - for part in ("path", "params", "query", "fragment"): + for part in ("path", "query", "fragment"): if getattr(parsed, part) != "": errors.append( Error( diff --git a/src/corsheaders/middleware.py b/src/corsheaders/middleware.py index 381f8875..226a4ba9 100644 --- a/src/corsheaders/middleware.py +++ b/src/corsheaders/middleware.py @@ -2,7 +2,7 @@ import re from typing import Any -from urllib.parse import ParseResult, urlparse +from urllib.parse import SplitResult, urlsplit from django.http import HttpRequest, HttpResponse from django.utils.cache import patch_vary_headers @@ -61,7 +61,7 @@ def _https_referer_replace(self, request: HttpRequest) -> None: and "ORIGINAL_HTTP_REFERER" not in request.META ): - url = urlparse(origin) + url = urlsplit(origin) if ( not conf.CORS_ALLOW_ALL_ORIGINS and not self.origin_found_in_white_lists(origin, url) @@ -137,7 +137,7 @@ def process_response( return response try: - url = urlparse(origin) + url = urlsplit(origin) except ValueError: return response @@ -169,7 +169,7 @@ def process_response( return response - def origin_found_in_white_lists(self, origin: str, url: ParseResult) -> bool: + def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool: return ( (origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS) or self._url_in_whitelist(url) @@ -191,8 +191,8 @@ def check_signal(self, request: HttpRequest) -> bool: signal_responses = check_request_enabled.send(sender=None, request=request) return any(return_value for function, return_value in signal_responses) - def _url_in_whitelist(self, url: ParseResult) -> bool: - origins = [urlparse(o) for o in conf.CORS_ALLOWED_ORIGINS] + def _url_in_whitelist(self, url: SplitResult) -> bool: + origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS] return any( origin.scheme == url.scheme and origin.netloc == url.netloc for origin in origins