diff --git a/.gitignore b/.gitignore index 46955f612..36f36703b 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ test_uwsgi_failed venv/ .vscode .mypy_cache/ +.dmypy.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7297d8adc..457500c6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.1.0 + rev: v2.7.4 hooks: - id: pyupgrade args: ["--py36-plus"] - repo: https://github.com/asottile/reorder_python_imports - rev: v2.1.0 + rev: v2.3.6 hooks: - id: reorder-python-imports name: Reorder Python imports (src, tests) @@ -16,18 +16,18 @@ repos: files: "^examples/" args: ["--application-directories", "examples"] - repo: https://github.com/ambv/black - rev: 19.10b0 + rev: 20.8b1 hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.7.9 + rev: 3.8.4 hooks: - id: flake8 additional_dependencies: - flake8-bugbear - flake8-implicit-str-concat - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 + rev: v3.3.0 hooks: - id: check-byte-order-marker - id: trailing-whitespace diff --git a/MANIFEST.in b/MANIFEST.in index 4cd9e1967..86d1bee1f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,7 @@ include CHANGES.rst include tox.ini include requirements/*.txt include src/werkzeug/py.typed +include src/werkzeug/*.pyi graft artwork graft docs prune docs/_build diff --git a/requirements/dev.txt b/requirements/dev.txt index b46cb015d..851f160a0 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -4,63 +4,145 @@ # # pip-compile requirements/dev.in # -alabaster==0.7.12 # via sphinx -appdirs==1.4.4 # via virtualenv -attrs==19.3.0 # via pytest -babel==2.8.0 # via sphinx -certifi==2020.4.5.1 # via requests -cffi==1.14.0 # via cryptography -cfgv==3.1.0 # via pre-commit -chardet==3.0.4 # via requests -click==7.1.2 # via pip-tools -cryptography==3.3.1 # via -r requirements/tests.in -distlib==0.3.0 # via virtualenv -docutils==0.16 # via sphinx -filelock==3.0.12 # via tox, virtualenv -greenlet==0.4.17 # via -r requirements/tests.in -identify==1.4.15 # via pre-commit -idna==2.9 # via requests -imagesize==1.2.0 # via sphinx -iniconfig==1.0.0 # via pytest -jinja2==2.11.2 # via sphinx -markupsafe==1.1.1 # via jinja2 -mypy-extensions==0.4.3 # via mypy -mypy==0.782 # via -r requirements/typing.in -nodeenv==1.3.5 # via pre-commit -packaging==20.3 # via pallets-sphinx-themes, pytest, sphinx, tox -pallets-sphinx-themes==1.2.3 # via -r requirements/docs.in -pip-tools==5.5.0 # via -r requirements/dev.in -pluggy==0.13.1 # via pytest, tox -pre-commit==2.9.3 # via -r requirements/dev.in -psutil==5.7.0 # via pytest-xprocess -py==1.9.0 # via pytest, tox -pycparser==2.20 # via cffi -pygments==2.6.1 # via sphinx -pyparsing==2.4.7 # via packaging -pytest-timeout==1.4.2 # via -r requirements/tests.in -pytest-xprocess==0.17.0 # via -r requirements/tests.in -pytest==6.1.2 # via -r requirements/tests.in, pytest-timeout, pytest-xprocess -pytz==2020.1 # via babel -pyyaml==5.3.1 # via pre-commit -requests==2.24.0 # via sphinx -six==1.14.0 # via cryptography, packaging, tox, virtualenv -snowballstemmer==2.0.0 # via sphinx -sphinx-issues==1.2.0 # via -r requirements/docs.in -sphinx==3.4.1 # via -r requirements/docs.in, pallets-sphinx-themes, sphinx-issues, sphinxcontrib-log-cabinet -sphinxcontrib-applehelp==1.0.2 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==1.0.3 # via sphinx -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-log-cabinet==1.0.1 # via -r requirements/docs.in -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.4 # via sphinx -toml==0.10.1 # via pre-commit, pytest, tox -tox==3.20.1 # via -r requirements/dev.in -typed-ast==1.4.1 # via mypy -typing-extensions==3.7.4.3 # via mypy -urllib3==1.25.9 # via requests -virtualenv==20.0.20 # via pre-commit, tox -watchdog==1.0.2 # via -r requirements/tests.in +alabaster==0.7.12 + # via sphinx +appdirs==1.4.4 + # via virtualenv +attrs==20.3.0 + # via pytest +babel==2.9.0 + # via sphinx +certifi==2020.12.5 + # via requests +cffi==1.14.4 + # via cryptography +cfgv==3.2.0 + # via pre-commit +chardet==4.0.0 + # via requests +click==7.1.2 + # via pip-tools +cryptography==3.3.1 + # via -r requirements/tests.in +distlib==0.3.1 + # via virtualenv +docutils==0.16 + # via sphinx +filelock==3.0.12 + # via + # tox + # virtualenv +greenlet==0.4.17 + # via -r requirements/tests.in +identify==1.5.12 + # via pre-commit +idna==2.10 + # via requests +imagesize==1.2.0 + # via sphinx +iniconfig==1.1.1 + # via pytest +jinja2==2.11.2 + # via sphinx +markupsafe==1.1.1 + # via jinja2 +mypy-extensions==0.4.3 + # via mypy +mypy==0.790 + # via -r requirements/typing.in +nodeenv==1.5.0 + # via pre-commit +packaging==20.8 + # via + # pallets-sphinx-themes + # pytest + # sphinx + # tox +pallets-sphinx-themes==1.2.3 + # via -r requirements/docs.in +pip-tools==5.5.0 + # via -r requirements/dev.in +pluggy==0.13.1 + # via + # pytest + # tox +pre-commit==2.9.3 + # via -r requirements/dev.in +psutil==5.8.0 + # via pytest-xprocess +py==1.10.0 + # via + # pytest + # tox +pycparser==2.20 + # via cffi +pygments==2.7.3 + # via sphinx +pyparsing==2.4.7 + # via packaging +pytest-timeout==1.4.2 + # via -r requirements/tests.in +pytest-xprocess==0.17.0 + # via -r requirements/tests.in +pytest==6.1.2 + # via + # -r requirements/tests.in + # pytest-timeout + # pytest-xprocess +pytz==2020.5 + # via babel +pyyaml==5.3.1 + # via pre-commit +requests==2.25.1 + # via sphinx +six==1.15.0 + # via + # cryptography + # tox + # virtualenv +snowballstemmer==2.0.0 + # via sphinx +sphinx-issues==1.2.0 + # via -r requirements/docs.in +sphinx==3.4.3 + # via + # -r requirements/docs.in + # pallets-sphinx-themes + # sphinx-issues + # sphinxcontrib-log-cabinet +sphinxcontrib-applehelp==1.0.2 + # via sphinx +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==1.0.3 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-log-cabinet==1.0.1 + # via -r requirements/docs.in +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.4 + # via sphinx +toml==0.10.2 + # via + # pre-commit + # pytest + # tox +tox==3.21.0 + # via -r requirements/dev.in +typed-ast==1.4.2 + # via mypy +typing-extensions==3.7.4.3 + # via mypy +urllib3==1.26.2 + # via requests +virtualenv==20.3.0 + # via + # pre-commit + # tox +watchdog==1.0.2 + # via -r requirements/tests.in # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/requirements/docs.txt b/requirements/docs.txt index bd1d7d5a3..eb188a7ee 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -4,33 +4,64 @@ # # pip-compile requirements/docs.in # -alabaster==0.7.12 # via sphinx -babel==2.8.0 # via sphinx -certifi==2020.4.5.1 # via requests -chardet==3.0.4 # via requests -docutils==0.16 # via sphinx -idna==2.9 # via requests -imagesize==1.2.0 # via sphinx -jinja2==2.11.2 # via sphinx -markupsafe==1.1.1 # via jinja2 -packaging==20.3 # via pallets-sphinx-themes, sphinx -pallets-sphinx-themes==1.2.3 # via -r requirements/docs.in -pygments==2.6.1 # via sphinx -pyparsing==2.4.7 # via packaging -pytz==2020.1 # via babel -requests==2.24.0 # via sphinx -six==1.14.0 # via packaging -snowballstemmer==2.0.0 # via sphinx -sphinx-issues==1.2.0 # via -r requirements/docs.in -sphinx==3.4.1 # via -r requirements/docs.in, pallets-sphinx-themes, sphinx-issues, sphinxcontrib-log-cabinet -sphinxcontrib-applehelp==1.0.2 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==1.0.3 # via sphinx -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-log-cabinet==1.0.1 # via -r requirements/docs.in -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.4 # via sphinx -urllib3==1.25.9 # via requests +alabaster==0.7.12 + # via sphinx +babel==2.9.0 + # via sphinx +certifi==2020.12.5 + # via requests +chardet==4.0.0 + # via requests +docutils==0.16 + # via sphinx +idna==2.10 + # via requests +imagesize==1.2.0 + # via sphinx +jinja2==2.11.2 + # via sphinx +markupsafe==1.1.1 + # via jinja2 +packaging==20.8 + # via + # pallets-sphinx-themes + # sphinx +pallets-sphinx-themes==1.2.3 + # via -r requirements/docs.in +pygments==2.7.3 + # via sphinx +pyparsing==2.4.7 + # via packaging +pytz==2020.5 + # via babel +requests==2.25.1 + # via sphinx +snowballstemmer==2.0.0 + # via sphinx +sphinx-issues==1.2.0 + # via -r requirements/docs.in +sphinx==3.4.3 + # via + # -r requirements/docs.in + # pallets-sphinx-themes + # sphinx-issues + # sphinxcontrib-log-cabinet +sphinxcontrib-applehelp==1.0.2 + # via sphinx +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==1.0.3 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-log-cabinet==1.0.1 + # via -r requirements/docs.in +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.4 + # via sphinx +urllib3==1.26.2 + # via requests # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements/tests.txt b/requirements/tests.txt index a7d5a9af2..727e7b77d 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -4,20 +4,40 @@ # # pip-compile requirements/tests.in # -attrs==19.3.0 # via pytest -cffi==1.14.0 # via cryptography -cryptography==3.3.1 # via -r requirements/tests.in -greenlet==0.4.17 # via -r requirements/tests.in -iniconfig==1.0.0 # via pytest -packaging==20.3 # via pytest -pluggy==0.13.1 # via pytest -psutil==5.7.0 # via pytest-xprocess -py==1.9.0 # via pytest -pycparser==2.20 # via cffi -pyparsing==2.4.7 # via packaging -pytest-timeout==1.4.2 # via -r requirements/tests.in -pytest-xprocess==0.17.0 # via -r requirements/tests.in -pytest==6.1.2 # via -r requirements/tests.in, pytest-timeout, pytest-xprocess -six==1.14.0 # via cryptography, packaging -toml==0.10.1 # via pytest -watchdog==1.0.2 # via -r requirements/tests.in +attrs==20.3.0 + # via pytest +cffi==1.14.4 + # via cryptography +cryptography==3.3.1 + # via -r requirements/tests.in +greenlet==0.4.17 + # via -r requirements/tests.in +iniconfig==1.1.1 + # via pytest +packaging==20.8 + # via pytest +pluggy==0.13.1 + # via pytest +psutil==5.8.0 + # via pytest-xprocess +py==1.10.0 + # via pytest +pycparser==2.20 + # via cffi +pyparsing==2.4.7 + # via packaging +pytest-timeout==1.4.2 + # via -r requirements/tests.in +pytest-xprocess==0.17.0 + # via -r requirements/tests.in +pytest==6.1.2 + # via + # -r requirements/tests.in + # pytest-timeout + # pytest-xprocess +six==1.15.0 + # via cryptography +toml==0.10.2 + # via pytest +watchdog==1.0.2 + # via -r requirements/tests.in diff --git a/requirements/typing.in b/requirements/typing.in index f0aa93ac8..d76706eae 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1 +1,2 @@ mypy +pytest diff --git a/requirements/typing.txt b/requirements/typing.txt index a169bd4c3..2ce00e002 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -4,7 +4,27 @@ # # pip-compile requirements/typing.in # -mypy-extensions==0.4.3 # via mypy -mypy==0.782 # via -r requirements/typing.in -typed-ast==1.4.1 # via mypy -typing-extensions==3.7.4.3 # via mypy +attrs==20.3.0 + # via pytest +iniconfig==1.1.1 + # via pytest +mypy-extensions==0.4.3 + # via mypy +mypy==0.790 + # via -r requirements/typing.in +packaging==20.8 + # via pytest +pluggy==0.13.1 + # via pytest +py==1.10.0 + # via pytest +pyparsing==2.4.7 + # via packaging +pytest==6.2.1 + # via -r requirements/typing.in +toml==0.10.2 + # via pytest +typed-ast==1.4.2 + # via mypy +typing-extensions==3.7.4.3 + # via mypy diff --git a/setup.cfg b/setup.cfg index bf0f3f2ee..d770a98b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,39 +83,25 @@ per-file-ignores = [mypy] allow_redefinition = True disallow_subclassing_any = True +# warn_return_any = True # disallow_untyped_defs = True +# disallow_incomplete_defs = True +no_implicit_optional = True +local_partial_types = True +# no_implicit_reexport = True strict_equality = True -strict_optional = False warn_redundant_casts = True warn_unused_configs = True warn_unused_ignores = True -[mypy-_pytest.*] -ignore_missing_imports = True - -[mypy-bar.*] -ignore_missing_imports = True - [mypy-cryptography.*] ignore_missing_imports = True [mypy-greenlet.*] ignore_missing_imports = True -[mypy-pytest.*] -ignore_missing_imports = True - -[mypy-requests_unixsocket.*] -ignore_missing_imports = True - -[mypy-testsuite_app.*] -ignore_missing_imports = True - [mypy-watchdog.*] ignore_missing_imports = True -[mypy-winreg.*] -ignore_missing_imports = True - [mypy-xprocess.*] ignore_missing_imports = True diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 8fccfaa11..574a0fc92 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -4,39 +4,28 @@ import re import string import sys +import typing +import typing as t from datetime import date from datetime import datetime from itertools import chain -from typing import Any -from typing import AnyStr -from typing import Callable -from typing import Dict -from typing import Iterator -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union +from time import struct_time from weakref import WeakKeyDictionary -if TYPE_CHECKING: - from werkzeug.wrappers.base_request import BaseRequest # noqa: F401 - from werkzeug.wrappers.request import Request # noqa: F401 - from werkzeug.wrappers.response import Response # noqa: F401 +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment + from .wrappers.base_request import BaseRequest # noqa: F401 - -_logger = None +_logger: t.Optional[logging.Logger] = None _signature_cache = WeakKeyDictionary() # type: ignore _epoch_ord = date(1970, 1, 1).toordinal() -_legal_cookie_chars = f"{string.ascii_letters}{string.digits}/=!#$%&'*+-.^_`|~:".encode( - "ascii" +_legal_cookie_chars = frozenset( + c.encode("ascii") + for c in f"{string.ascii_letters}{string.digits}/=!#$%&'*+-.^_`|~:" ) -_cookie_quoting_map = { - b",": b"\\054", - b";": b"\\073", - b'"': b'\\"', - b"\\": b"\\\\", -} +_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"} for _i in chain(range(32), range(127, 256)): _cookie_quoting_map[_i.to_bytes(1, sys.byteorder)] = f"\\{_i:03o}".encode("latin1") @@ -69,7 +58,17 @@ def __reduce__(self): _missing = _Missing() -def _make_encode_wrapper(reference: Optional[AnyStr],) -> Callable[[str], AnyStr]: +@typing.overload +def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]: + ... + + +@typing.overload +def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: + ... + + +def _make_encode_wrapper(reference): """Create a function that will be called with a string argument. If the reference is bytes, values will be encoded to bytes. """ @@ -79,7 +78,7 @@ def _make_encode_wrapper(reference: Optional[AnyStr],) -> Callable[[str], AnyStr return operator.methodcaller("encode", "latin1") -def _check_str_tuple(value: Tuple[AnyStr, ...]) -> None: +def _check_str_tuple(value: t.Tuple[t.AnyStr, ...]) -> None: """Ensure tuple items are all strings or all bytes.""" if not value: return @@ -90,10 +89,11 @@ def _check_str_tuple(value: Tuple[AnyStr, ...]) -> None: raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})") +_default_encoding = sys.getdefaultencoding() + + def _to_bytes( - x: Union[str, bytes], - charset: str = sys.getdefaultencoding(), # noqa: B008 - errors: str = "strict", + x: t.Union[str, bytes], charset: str = _default_encoding, errors: str = "strict" ) -> bytes: if x is None or isinstance(x, bytes): return x @@ -107,20 +107,36 @@ def _to_bytes( raise TypeError("Expected bytes") +@typing.overload +def _to_str( # type: ignore + x: None, + charset: t.Optional[str] = ..., + errors: str = ..., + allow_none_charset: bool = ..., +) -> None: + ... + + +@typing.overload def _to_str( - x: Optional[Union[str, int, bytes]], - charset: Optional[str] = sys.getdefaultencoding(), # noqa: B008 - errors: str = "strict", - allow_none_charset: bool = False, -) -> Optional[str]: + x: t.Any, + charset: t.Optional[str] = ..., + errors: str = ..., + allow_none_charset: bool = ..., +) -> str: + ... + + +def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=False): if x is None or isinstance(x, str): return x if not isinstance(x, bytes): return str(x) - if charset is None and allow_none_charset: - return x # type: ignore + if charset is None: + if allow_none_charset: + return x return x.decode(charset, errors) @@ -132,14 +148,15 @@ def _wsgi_decoding_dance( def _wsgi_encoding_dance( - s: Union[str, bytes], charset: str = "utf-8", errors: str = "replace" + s: str, charset: str = "utf-8", errors: str = "replace" ) -> str: - if isinstance(s, str): - s = s.encode(charset) - return s.decode("latin1", errors) + if isinstance(s, bytes): + return s.decode("latin1", errors) + return s.encode(charset).decode("latin1", errors) -def _get_environ(obj: Any) -> Dict[str, Any]: + +def _get_environ(obj: t.Union["WSGIEnvironment", "BaseRequest"]) -> "WSGIEnvironment": env = getattr(obj, "environ", obj) assert isinstance( env, dict @@ -147,7 +164,7 @@ def _get_environ(obj: Any) -> Dict[str, Any]: return env -def _has_level_handler(logger): +def _has_level_handler(logger: logging.Logger) -> bool: """Check if there is a handler in the logging chain that will handle the given logger's effective level. """ @@ -161,12 +178,12 @@ def _has_level_handler(logger): if not current.propagate: break - current = current.parent + current = current.parent # type: ignore return False -def _log(type, message, *args, **kwargs): +def _log(type: str, message: str, *args, **kwargs) -> None: """Log a message to the 'werkzeug' logger. The logger is created the first time it is needed. If there is no @@ -264,7 +281,7 @@ def parse(args, kwargs): return parse -def _date_to_unix(arg: Union[datetime, tuple, int]) -> int: +def _date_to_unix(arg: t.Union[datetime, int, float, struct_time]) -> int: """Converts a timetuple, integer or datetime object into the seconds from epoch in utc. """ @@ -280,24 +297,22 @@ def _date_to_unix(arg: Union[datetime, tuple, int]) -> int: return seconds -class _DictAccessorProperty: +_TAccessorValue = t.TypeVar("_TAccessorValue") + + +class _DictAccessorProperty(t.Generic[_TAccessorValue]): """Baseclass for `environ_property` and `header_property`.""" - name: Any - default: Any - load_func: Any - dump_func: Any - __doc__: Any - read_only: Any = False + read_only = False def __init__( self, name: str, - default: Optional[Any] = None, - load_func: Optional[Any] = None, - dump_func: Optional[Any] = None, - read_only: Optional[Any] = None, - doc: Optional[Any] = None, + default: t.Optional[_TAccessorValue] = None, + load_func: t.Optional[t.Callable[[str], _TAccessorValue]] = None, + dump_func: t.Optional[t.Callable[[_TAccessorValue], str]] = None, + read_only: t.Optional[bool] = None, + doc: t.Optional[str] = None, ) -> None: self.name = name self.default = default @@ -307,37 +322,54 @@ def __init__( self.read_only = read_only self.__doc__ = doc + def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]: + raise NotImplementedError + + @typing.overload def __get__( - self, - obj: Union["Response", "Request", "BaseRequest"], - type: Optional[Any] = None, - ) -> Any: - if obj is None: + self, instance: None, owner: type + ) -> "_DictAccessorProperty[_TAccessorValue]": + ... + + @typing.overload + def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: + ... + + def __get__(self, instance, owner): + if instance is None: return self - storage = self.lookup(obj) # type: ignore + + storage = self.lookup(instance) + if self.name not in storage: return self.default - rv = storage[self.name] + + value = storage[self.name] + if self.load_func is not None: try: - rv = self.load_func(rv) + return self.load_func(value) except (ValueError, TypeError): - rv = self.default - return rv + return self.default + + return value - def __set__(self, obj: object, value: object) -> None: + def __set__(self, instance: t.Any, value: _TAccessorValue) -> None: if self.read_only: raise AttributeError("read only property") + if self.dump_func is not None: - value = self.dump_func(value) - self.lookup(obj)[self.name] = value # type: ignore + self.lookup(instance)[self.name] = self.dump_func(value) + else: + self.lookup(instance)[self.name] = value - def __delete__(self, obj): + def __delete__(self, instance: t.Any) -> None: if self.read_only: raise AttributeError("read only property") - self.lookup(obj).pop(self.name, None) - def __repr__(self): + self.lookup(instance).pop(self.name, None) + + def __repr__(self) -> str: return f"<{type(self).__name__} {self.name}>" @@ -395,7 +427,7 @@ def _cookie_unquote(b: bytes) -> bytes: return bytes(rv) -def _cookie_parse_impl(b: bytes) -> Iterator[Tuple[bytes, bytes]]: +def _cookie_parse_impl(b: bytes) -> t.Iterator[t.Tuple[bytes, bytes]]: """Lowlevel cookie parsing facility that operates on bytes.""" i = 0 n = len(b) @@ -414,7 +446,7 @@ def _cookie_parse_impl(b: bytes) -> Iterator[Tuple[bytes, bytes]]: def _encode_idna(domain: str) -> bytes: # If we're given bytes, make sure they fit into ASCII - if not isinstance(domain, str): + if isinstance(domain, bytes): domain.decode("ascii") return domain @@ -425,36 +457,42 @@ def _encode_idna(domain: str) -> bytes: pass # Otherwise encode each part separately - parts = domain.split(".") - for idx, part in enumerate(parts): - parts[idx] = part.encode("idna") # type: ignore - return b".".join(parts) # type: ignore + return b".".join(p.encode("idna") for p in domain.split(".")) -def _decode_idna(domain: Union[str, bytes]) -> Union[str, bytes]: - # If the input is a string try to encode it to ascii to - # do the idna decoding. if that fails because of an - # unicode error, then we already have a decoded idna domain +def _decode_idna(domain: t.Union[str, bytes]) -> str: + # If the input is a string try to encode it to ascii to do the idna + # decoding. If that fails because of a unicode error, then we + # already have a decoded idna domain. if isinstance(domain, str): try: domain = domain.encode("ascii") except UnicodeError: - return domain + return domain # type: ignore - # Decode each part separately. If a part fails, try to - # decode it with ascii and silently ignore errors. This makes - # most sense because the idna codec does not have error handling - parts = domain.split(b".") - for idx, part in enumerate(parts): + # Decode each part separately. If a part fails, try to decode it + # with ascii and silently ignore errors. This makes sense because + # the idna codec does not have error handling. + def decode_part(part: bytes) -> str: try: - parts[idx] = part.decode("idna") # type: ignore + return part.decode("idna") except UnicodeError: - parts[idx] = part.decode("ascii", "ignore") # type: ignore + return part.decode("ascii", "ignore") + + return ".".join(decode_part(p) for p in domain.split(b".")) + + +@typing.overload +def _make_cookie_domain(domain: None) -> None: + ... + - return ".".join(parts) # type: ignore +@typing.overload +def _make_cookie_domain(domain: str) -> bytes: + ... -def _make_cookie_domain(domain: Optional[str]) -> Optional[bytes]: +def _make_cookie_domain(domain): if domain is None: return None domain = _encode_idna(domain) @@ -471,7 +509,7 @@ def _make_cookie_domain(domain: Optional[str]) -> Optional[bytes]: ) -def _easteregg(app: Optional[Any] = None) -> Callable: +def _easteregg(app: t.Optional["WSGIApplication"] = None) -> "WSGIApplication": """Like the name says. But who knows how it works?""" def bzzzzzzz(gyver): diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 98a3fd673..ab0966b8a 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -23,7 +23,7 @@ del prefix -def _iter_module_paths(): +def _iter_module_paths() -> t.Iterator[str]: """Find the filesystem paths associated with imported modules.""" # List is in case the value is modified by the app while updating. for module in list(sys.modules.values()): @@ -48,7 +48,9 @@ def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None: paths.difference_update(fnmatch.filter(paths, pattern)) -def _find_stat_paths(extra_files, exclude_patterns): +def _find_stat_paths( + extra_files: t.Set[str], exclude_patterns: t.Set[str] +) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Returns imported module files, Python files under non-system paths. Extra files and Python files under extra directories can also be scanned. @@ -88,7 +90,9 @@ def _find_stat_paths(extra_files, exclude_patterns): return paths -def _find_watchdog_paths(extra_files, exclude_patterns): +def _find_watchdog_paths( + extra_files: t.Set[str], exclude_patterns: t.Set[str] +) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Looks at the same sources as the stat reloader, but watches everything under directories instead of individual files. @@ -110,8 +114,8 @@ def _find_watchdog_paths(extra_files, exclude_patterns): return _find_common_roots(dirs) -def _find_common_roots(paths): - root = {} +def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: + root: t.Dict[str, dict] = {} for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): node = root @@ -134,7 +138,7 @@ def _walk(node, path): return rv -def _get_args_for_reloading(): +def _get_args_for_reloading() -> t.List[str]: """Determine how the script was executed, and return the args needed to execute it again in a new process. """ @@ -179,7 +183,7 @@ def _get_args_for_reloading(): else: if os.path.isfile(py_script): # Rewritten by Python from "-m script" to "/path/to/script.py". - py_module = __main__.__package__ + py_module = t.cast(str, __main__.__package__) name = os.path.splitext(os.path.basename(py_script))[0] if name != "__main__": @@ -202,12 +206,12 @@ def __init__( extra_files: t.Optional[t.Iterable[str]] = None, exclude_patterns: t.Optional[t.Iterable[str]] = None, interval: t.Union[int, float] = 1, - ): - self.extra_files = {os.path.abspath(x) for x in extra_files or ()} - self.exclude_patterns = set(exclude_patterns or ()) + ) -> None: + self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()} + self.exclude_patterns: t.Set[str] = set(exclude_patterns or ()) self.interval = interval - def __enter__(self): + def __enter__(self) -> "ReloaderLoop": """Do any setup, then run one step of the watch to populate the initial filesystem state. """ @@ -218,7 +222,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): """Clean up any resources associated with the reloader.""" pass - def run(self): + def run(self) -> None: """Continually run the watch step, sleeping for the configured interval after each step. """ @@ -226,13 +230,13 @@ def run(self): self.run_step() time.sleep(self.interval) - def run_step(self): + def run_step(self) -> None: """Run one step for watching the filesystem. Called once to set up initial state, then repeatedly to update it. """ pass - def restart_with_reloader(self): + def restart_with_reloader(self) -> int: """Spawn a new Python interpreter with the same arguments as the current one, but running the reloader thread. """ @@ -246,11 +250,11 @@ def restart_with_reloader(self): if exit_code != 3: return exit_code - def trigger_reload(self, filename): + def trigger_reload(self, filename: str) -> None: self.log_reload(filename) sys.exit(3) - def log_reload(self, filename): + def log_reload(self, filename: str) -> None: filename = os.path.abspath(filename) _log("info", f" * Detected change in {filename!r}, reloading") @@ -258,11 +262,11 @@ def log_reload(self, filename): class StatReloaderLoop(ReloaderLoop): name = "stat" - def __enter__(self): - self.mtimes = {} + def __enter__(self) -> ReloaderLoop: + self.mtimes: t.Dict[str, float] = {} return super().__enter__() - def run_step(self): + def run_step(self) -> None: for name in chain(_find_stat_paths(self.extra_files, self.exclude_patterns)): try: mtime = os.stat(name).st_mtime @@ -280,14 +284,14 @@ def run_step(self): class WatchdogReloaderLoop(ReloaderLoop): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: from watchdog.observers import Observer from watchdog.events import PatternMatchingEventHandler super().__init__(*args, **kwargs) trigger_reload = self.trigger_reload - class EventHandler(PatternMatchingEventHandler): + class EventHandler(PatternMatchingEventHandler): # type: ignore def on_any_event(self, event): trigger_reload(event.src_path) @@ -315,15 +319,15 @@ def on_any_event(self, event): ) self.should_reload = False - def trigger_reload(self, filename): + def trigger_reload(self, filename: str) -> None: # This is called inside an event handler, which means throwing # SystemExit has no effect. # https://github.com/gorakhargosh/watchdog/issues/294 self.should_reload = True self.log_reload(filename) - def __enter__(self): - self.watches = {} + def __enter__(self) -> ReloaderLoop: + self.watches: t.Dict[str, t.Any] = {} self.observer.start() return super().__enter__() @@ -331,14 +335,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.observer.stop() self.observer.join() - def run(self): + def run(self) -> None: while not self.should_reload: self.run_step() time.sleep(self.interval) sys.exit(3) - def run_step(self): + def run_step(self) -> None: to_delete = set(self.watches) for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns): @@ -395,7 +399,7 @@ def ensure_echo_on(): def run_with_reloader( - main_func, + main_func: t.Callable[[], None], extra_files: t.Optional[t.Iterable[str]] = None, exclude_patterns: t.Optional[t.Iterable[str]] = None, interval: t.Union[int, float] = 1, diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py index 3e98623e9..8ff359744 100644 --- a/src/werkzeug/datastructures.py +++ b/src/werkzeug/datastructures.py @@ -2,51 +2,24 @@ import codecs import mimetypes import re -from collections.abc import Container -from collections.abc import Iterable +from collections.abc import Collection from collections.abc import MutableSet from copy import deepcopy -from io import BufferedReader -from io import BufferedWriter from io import BytesIO -from io import StringIO from itertools import repeat from os import fspath -from typing import Any -from typing import Callable -from typing import Dict -from typing import Hashable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Set -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union from . import exceptions from ._internal import _make_encode_wrapper from ._internal import _missing from .filesystem import get_filesystem_encoding -from .types import T -from .types import UnicodeEncodable -from .types import WSGIEnvironment - -if TYPE_CHECKING: - from datetime import datetime # noqa: F401 - from pathlib import PosixPath # noqa: F401 - from tempfile import SpooledTemporaryFile # noqa: F401 - from werkzeug._internal import _Missing # noqa: F401 def is_immutable(self): raise TypeError(f"{type(self).__name__!r} objects are immutable") -def iter_multi_items(mapping: Union[Mapping, Iterable]) -> Iterator[Any]: +def iter_multi_items(mapping): """Iterates over the items of a mapping yielding keys and values without dropping any from more complex structures. """ @@ -73,10 +46,10 @@ class ImmutableListMixin: _hash_cache = None - def __hash__(self) -> int: + def __hash__(self): if self._hash_cache is not None: return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) # type: ignore + rv = self._hash_cache = hash(tuple(self)) return rv def __reduce_ex__(self, protocol): @@ -88,7 +61,8 @@ def __delitem__(self, key): def __iadd__(self, other): is_immutable(self) - __imul__ = __iadd__ + def __imul__(self, other): + is_immutable(self) def __setitem__(self, key, value): is_immutable(self) @@ -96,7 +70,8 @@ def __setitem__(self, key, value): def append(self, item): is_immutable(self) - remove = append + def remove(self, item): + is_immutable(self) def extend(self, iterable): is_immutable(self) @@ -110,11 +85,11 @@ def pop(self, index=-1): def reverse(self): is_immutable(self) - def sort(self, cmp=None, key=None, reverse=None): + def sort(self, key=None, reverse=False): is_immutable(self) -class ImmutableList(ImmutableListMixin, list): # type: ignore +class ImmutableList(ImmutableListMixin, list): """An immutable :class:`list`. .. versionadded:: 0.5 @@ -148,7 +123,7 @@ def __reduce_ex__(self, protocol): def _iter_hashitems(self): return self.items() - def __hash__(self) -> int: + def __hash__(self): if self._hash_cache is not None: return self._hash_cache rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) @@ -184,16 +159,11 @@ class ImmutableMultiDictMixin(ImmutableDictMixin): :private: """ - def __reduce_ex__( - self, protocol: int - ) -> Union[ - Tuple[Type["ImmutableMultiDict"], Tuple[List[Any]]], - Tuple[Type["ImmutableMultiDict"], Tuple[List[Tuple[str, str]]]], - ]: - return type(self), (list(self.items(multi=True)),) # type: ignore + def __reduce_ex__(self, protocol): + return type(self), (list(self.items(multi=True)),) - def _iter_hashitems(self) -> Iterator[Any]: - return self.items(multi=True) # type: ignore + def _iter_hashitems(self): + return self.items(multi=True) def add(self, key, value): is_immutable(self) @@ -211,7 +181,20 @@ def setlistdefault(self, key, default_list=None): is_immutable(self) -class UpdateDictMixin: +def _calls_update(name): + def oncall(self, *args, **kw): + rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) + + if self.on_update is not None: + self.on_update(self) + + return rv + + oncall.__name__ = name + return oncall + + +class UpdateDictMixin(dict): """Makes dicts call `self.on_update` on modifications. .. versionadded:: 0.5 @@ -219,41 +202,30 @@ class UpdateDictMixin: :private: """ - on_update: Callable = None - - def calls_update(name: str): # type: ignore # noqa: B902 - def oncall(self, *args, **kw): - rv = getattr(super(), name)(*args, **kw) - if self.on_update is not None: - self.on_update(self) - return rv + on_update = None - oncall.__name__ = name - return oncall - - def setdefault(self, key: str, default: None = None) -> Optional[str]: - modified = key not in self # type: ignore - rv = super().setdefault(key, default) # type: ignore + def setdefault(self, key, default=None): + modified = key not in self + rv = super().setdefault(key, default) if modified and self.on_update is not None: self.on_update(self) return rv - def pop(self, key: str, default: Optional["_Missing"] = _missing) -> Optional[str]: - modified = key in self # type: ignore + def pop(self, key, default=_missing): + modified = key in self if default is _missing: - rv = super().pop(key) # type: ignore + rv = super().pop(key) else: - rv = super().pop(key, default) # type: ignore + rv = super().pop(key, default) if modified and self.on_update is not None: self.on_update(self) return rv - __setitem__ = calls_update("__setitem__") - __delitem__ = calls_update("__delitem__") - clear = calls_update("clear") - popitem = calls_update("popitem") - update = calls_update("update") - del calls_update + __setitem__ = _calls_update("__setitem__") + __delitem__ = _calls_update("__delitem__") + clear = _calls_update("clear") + popitem = _calls_update("popitem") + update = _calls_update("update") class TypeConversionDict(dict): @@ -264,12 +236,7 @@ class TypeConversionDict(dict): .. versionadded:: 0.5 """ - def get( - self, - key: str, - default: Optional[T] = None, - type: Optional[Callable[[Any], T]] = None, - ) -> Optional[Union[Any, T]]: + def get(self, key, default=None, type=None): """Return the default value if the requested data doesn't exist. If `type` is provided and is a callable it should convert the value, return it or raise a :exc:`ValueError` if that is not possible. In @@ -302,23 +269,21 @@ def get( return rv -class ImmutableTypeConversionDict( # type: ignore - ImmutableDictMixin, TypeConversionDict -): +class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): """Works like a :class:`TypeConversionDict` but does not support modifications. .. versionadded:: 0.5 """ - def copy(self) -> TypeConversionDict: + def copy(self): """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return TypeConversionDict(self) - def __copy__(self) -> "ImmutableTypeConversionDict": + def __copy__(self): return self @@ -363,7 +328,7 @@ class MultiDict(TypeConversionDict): or `None`. """ - def __init__(self, mapping: Optional[Any] = None) -> None: + def __init__(self, mapping=None): if isinstance(mapping, MultiDict): dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) elif isinstance(mapping, dict): @@ -378,19 +343,19 @@ def __init__(self, mapping: Optional[Any] = None) -> None: tmp[key] = value dict.__init__(self, tmp) else: - tmp = {} # type: ignore + tmp = {} for key, value in mapping or (): tmp.setdefault(key, []).append(value) dict.__init__(self, tmp) - def __getstate__(self) -> Dict[bytes, Union[List[int], List[bytes]]]: - return dict(self.lists()) # type: ignore + def __getstate__(self): + return dict(self.lists()) - def __setstate__(self, value: Dict[Any, Any]) -> None: + def __setstate__(self, value): dict.clear(self) dict.update(self, value) - def __getitem__(self, key: Hashable) -> Any: + def __getitem__(self, key): """Return the first data value for this key; raises KeyError if not found. @@ -404,7 +369,7 @@ def __getitem__(self, key: Hashable) -> Any: return lst[0] raise exceptions.BadRequestKeyError(key) - def __setitem__(self, key: Hashable, value: Any) -> None: + def __setitem__(self, key, value): """Like :meth:`add` but removes an existing key first. :param key: the key for the value. @@ -412,7 +377,7 @@ def __setitem__(self, key: Hashable, value: Any) -> None: """ dict.__setitem__(self, key, [value]) - def add(self, key: Hashable, value: Any) -> None: + def add(self, key, value): """Adds a new value for the key. .. versionadded:: 0.6 @@ -422,9 +387,7 @@ def add(self, key: Hashable, value: Any) -> None: """ dict.setdefault(self, key, []).append(value) - def getlist( - self, key: Hashable, type: Optional[Callable[[Any], T]] = None - ) -> List[Union[Any, T]]: + def getlist(self, key, type=None): """Return the list of items for a given key. If that key is not in the `MultiDict`, the return value will be an empty list. Just like `get`, `getlist` accepts a `type` parameter. All items will be converted @@ -450,7 +413,7 @@ def getlist( pass return result - def setlist(self, key: Hashable, new_list: List[Any]) -> None: + def setlist(self, key, new_list): """Remove the old values for a key and add new ones. Note that the list you pass the values in will be shallow-copied before it is inserted in the dictionary. @@ -468,7 +431,7 @@ def setlist(self, key: Hashable, new_list: List[Any]) -> None: """ dict.__setitem__(self, key, list(new_list)) - def setdefault(self, key: Hashable, default: Optional[T] = None) -> Union[Any, T]: + def setdefault(self, key, default=None): """Returns the value for the key if it is in the dict, otherwise it returns `default` and sets that value for `key`. @@ -482,9 +445,7 @@ def setdefault(self, key: Hashable, default: Optional[T] = None) -> Union[Any, T default = self[key] return default - def setlistdefault( - self, key: Hashable, default_list: Optional[List[T]] = None - ) -> List[T]: + def setlistdefault(self, key, default_list=None): """Like `setdefault` but sets multiple values. The list returned is not a copy, but the list that is actually used internally. This means that you can put new values into the dict by appending items @@ -508,17 +469,7 @@ def setlistdefault( default_list = dict.__getitem__(self, key) return default_list - def items( # type: ignore - self, multi: bool = False - ) -> Iterator[ - Union[ - Tuple[str, str], - Tuple[str, int], - Tuple[bytes, int], - Tuple[bytes, bytes], - Tuple[str, "FileStorage"], - ] - ]: + def items(self, multi=False): """Return an iterator of ``(key, value)`` pairs. :param multi: If set to `True` the iterator returned will have a pair @@ -532,13 +483,13 @@ def items( # type: ignore else: yield key, values[0] - def lists(self,) -> Iterator[Tuple[Hashable, List[Any]]]: + def lists(self): """Return a iterator of ``(key, values)`` pairs, where values is the list of all values associated with the key.""" for key, values in dict.items(self): yield key, list(values) - def values(self) -> Iterator[Any]: # type: ignore + def values(self): """Returns an iterator of the first value on every key's value list.""" for values in dict.values(self): yield values[0] @@ -553,15 +504,15 @@ def listvalues(self): """ return dict.values(self) - def copy(self) -> Union["MultiDict", "OrderedMultiDict"]: + def copy(self): """Return a shallow copy of this object.""" return self.__class__(self) - def deepcopy(self, memo: None = None) -> Union["MultiDict", "OrderedMultiDict"]: + def deepcopy(self, memo=None): """Return a deep copy of this object.""" return self.__class__(deepcopy(self.to_dict(flat=False), memo)) - def to_dict(self, flat: bool = True) -> Dict[Hashable, Any]: + def to_dict(self, flat=True): """Return the contents as regular dict. If `flat` is `True` the returned dict will only have the first item present, if `flat` is `False` all values will be returned as lists. @@ -575,7 +526,7 @@ def to_dict(self, flat: bool = True) -> Dict[Hashable, Any]: return dict(self.items()) return dict(self.lists()) - def update(self, other_dict: Mapping) -> None: # type: ignore + def update(self, mapping): """update() extends rather than replaces existing key lists: >>> a = MultiDict({'x': 1}) @@ -593,12 +544,10 @@ def update(self, other_dict: Mapping) -> None: # type: ignore >>> y MultiDict([]) """ - for key, value in iter_multi_items(other_dict): + for key, value in iter_multi_items(mapping): MultiDict.add(self, key, value) - def pop( # type: ignore - self, key: str, default: Union["_Missing", int] = _missing - ) -> int: + def pop(self, key, default=_missing): """Pop the first item for a list on the dict. Afterwards the key is removed from the dict, so additional values are discarded: @@ -621,10 +570,10 @@ def pop( # type: ignore return lst[0] except KeyError: if default is not _missing: - return default # type: ignore + return default raise exceptions.BadRequestKeyError(key) - def popitem(self) -> Tuple[Any, Any]: + def popitem(self): """Pop an item from the dict.""" try: item = dict.popitem(self) @@ -636,7 +585,7 @@ def popitem(self) -> Tuple[Any, Any]: except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) - def poplist(self, key: Hashable) -> List[Any]: + def poplist(self, key): """Pop the list for a key from the dict. If the key is not in the dict an empty list is returned. @@ -646,22 +595,20 @@ def poplist(self, key: Hashable) -> List[Any]: """ return dict.pop(self, key, []) - def popitemlist(self) -> Tuple[Hashable, List[Any]]: + def popitemlist(self): """Pop a ``(key, list)`` tuple from the dict.""" try: return dict.popitem(self) except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) - def __copy__(self) -> Union["MultiDict", "OrderedMultiDict"]: + def __copy__(self): return self.copy() - def __deepcopy__( - self, memo: Dict[Any, Any] - ) -> Union["MultiDict", "OrderedMultiDict"]: - return self.deepcopy(memo=memo) # type: ignore + def __deepcopy__(self, memo): + return self.deepcopy(memo=memo) - def __repr__(self) -> str: + def __repr__(self): return f"{type(self).__name__}({list(self.items(multi=True))!r})" @@ -674,32 +621,27 @@ class _omd_bucket: __slots__ = ("prev", "key", "value", "next") - def __init__( - self, - omd: Union["ImmutableOrderedMultiDict", "OrderedMultiDict"], - key: Hashable, - value: Any, - ) -> None: + def __init__(self, omd, key, value): self.prev = omd._last_bucket self.key = key self.value = value self.next = None if omd._first_bucket is None: - omd._first_bucket = self # type: ignore + omd._first_bucket = self if omd._last_bucket is not None: omd._last_bucket.next = self - omd._last_bucket = self # type: ignore + omd._last_bucket = self - def unlink(self, omd: "OrderedMultiDict") -> None: + def unlink(self, omd): if self.prev: self.prev.next = self.next if self.next: self.next.prev = self.prev if omd._first_bucket is self: - omd._first_bucket = self.next # type: ignore + omd._first_bucket = self.next if omd._last_bucket is self: - omd._last_bucket = self.prev # type: ignore + omd._last_bucket = self.prev class OrderedMultiDict(MultiDict): @@ -718,13 +660,13 @@ class OrderedMultiDict(MultiDict): the internal bucket objects are exposed. """ - def __init__(self, mapping: Optional[Any] = None) -> None: + def __init__(self, mapping=None): dict.__init__(self) self._first_bucket = self._last_bucket = None if mapping is not None: OrderedMultiDict.update(self, mapping) - def __eq__(self, other: object) -> bool: + def __eq__(self, other): if not isinstance(other, MultiDict): return NotImplemented if isinstance(other, OrderedMultiDict): @@ -751,21 +693,7 @@ def __eq__(self, other: object) -> bool: __hash__ = None - def __reduce_ex__( - self, protocol: int - ) -> Tuple[ - Type["OrderedMultiDict"], - Tuple[ - List[ - Union[ - Tuple[str, str], - Tuple[str, int], - Tuple[bytes, int], - Tuple[bytes, bytes], - ] - ] - ], - ]: + def __reduce_ex__(self, protocol): return type(self), (list(self.items(multi=True)),) def __getstate__(self): @@ -776,46 +704,42 @@ def __setstate__(self, values): for key, value in values: self.add(key, value) - def __getitem__(self, key: object) -> object: + def __getitem__(self, key): if key in self: return dict.__getitem__(self, key)[0].value raise exceptions.BadRequestKeyError(key) - def __setitem__(self, key: Hashable, value: Any) -> None: + def __setitem__(self, key, value): self.poplist(key) self.add(key, value) - def __delitem__(self, key: str) -> None: + def __delitem__(self, key): self.pop(key) - def keys(self) -> Iterator[Any]: # type: ignore + def keys(self): return (key for key, value in self.items()) __iter__ = keys - def values(self) -> Iterator[Any]: # type: ignore + def values(self): return (value for key, value in self.items()) - def items( # type: ignore - self, multi: bool = False - ) -> Iterator[ - Union[Tuple[str, str], Tuple[str, int], Tuple[bytes, int], Tuple[bytes, bytes]] - ]: + def items(self, multi=False): ptr = self._first_bucket if multi: while ptr is not None: yield ptr.key, ptr.value ptr = ptr.next else: - returned_keys: Set[Any] = set() + returned_keys = set() while ptr is not None: if ptr.key not in returned_keys: returned_keys.add(ptr.key) yield ptr.key, ptr.value ptr = ptr.next - def lists(self) -> Iterator[Tuple[Hashable, List[Any]]]: - returned_keys: Set[Any] = set() + def lists(self): + returned_keys = set() ptr = self._first_bucket while ptr is not None: if ptr.key not in returned_keys: @@ -827,10 +751,10 @@ def listvalues(self): for _key, values in self.lists(): yield values - def add(self, key: Hashable, value: Any) -> None: + def add(self, key, value): dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) - def getlist(self, key: Hashable, type: Optional[Callable] = None) -> List[Any]: + def getlist(self, key, type=None): try: rv = dict.__getitem__(self, key) except KeyError: @@ -845,7 +769,7 @@ def getlist(self, key: Hashable, type: Optional[Callable] = None) -> List[Any]: pass return result - def setlist(self, key: Hashable, new_list: List[Any]) -> None: + def setlist(self, key, new_list): self.poplist(key) for value in new_list: self.add(key, value) @@ -853,19 +777,17 @@ def setlist(self, key: Hashable, new_list: List[Any]) -> None: def setlistdefault(self, key, default_list=None): raise TypeError("setlistdefault is unsupported for ordered multi dicts") - def update(self, mapping: Mapping) -> None: # type: ignore + def update(self, mapping): for key, value in iter_multi_items(mapping): OrderedMultiDict.add(self, key, value) - def poplist(self, key: Hashable) -> List[int]: + def poplist(self, key): buckets = dict.pop(self, key, ()) for bucket in buckets: bucket.unlink(self) return [x.value for x in buckets] - def pop( - self, key: str, default: Optional[Union[Any, "_Missing"]] = _missing - ) -> Optional[Any]: + def pop(self, key, default=_missing): try: buckets = dict.pop(self, key) except KeyError: @@ -876,7 +798,7 @@ def pop( bucket.unlink(self) return buckets[0].value - def popitem(self) -> Tuple[Hashable, Any]: + def popitem(self): try: key, buckets = dict.popitem(self) except KeyError as e: @@ -885,7 +807,7 @@ def popitem(self) -> Tuple[Hashable, Any]: bucket.unlink(self) return key, buckets[0].value - def popitemlist(self) -> Tuple[str, List[int]]: + def popitemlist(self): try: key, buckets = dict.popitem(self) except KeyError as e: @@ -895,11 +817,13 @@ def popitemlist(self) -> Tuple[str, List[int]]: return key, [x.value for x in buckets] -def _options_header_vkw(value: str, kw: Dict[str, str]) -> str: - return dump_options_header(value, {k.replace("_", "-"): v for k, v in kw.items()}) +def _options_header_vkw(value, kw): + return http.dump_options_header( + value, {k.replace("_", "-"): v for k, v in kw.items()} + ) -def _unicodify_header_value(value: UnicodeEncodable) -> str: +def _unicodify_header_value(value): if isinstance(value, bytes): value = value.decode("latin-1") if not isinstance(value, str): @@ -941,9 +865,7 @@ class Headers: was an API that does not support the changes to the encoding model. """ - _list: List[Any] - - def __init__(self, defaults: Optional[Any] = None) -> None: + def __init__(self, defaults=None): self._list = [] if defaults is not None: if isinstance(defaults, (list, Headers)): @@ -951,7 +873,7 @@ def __init__(self, defaults: Optional[Any] = None) -> None: else: self.extend(defaults) - def __getitem__(self, key: Union[str, int, slice], _get_mode: bool = False) -> Any: + def __getitem__(self, key, _get_mode=False): if not _get_mode: if isinstance(key, int): return self._list[key] @@ -970,23 +892,17 @@ def __getitem__(self, key: Union[str, int, slice], _get_mode: bool = False) -> A raise KeyError() raise exceptions.BadRequestKeyError(key) - def __eq__(self, other: object) -> bool: + def __eq__(self, other): def lowered(item): return (item[0].lower(),) + item[1:] return other.__class__ is self.__class__ and set( - map(lowered, other._list) # type: ignore + map(lowered, other._list) ) == set(map(lowered, self._list)) __hash__ = None - def get( - self, - key: Hashable, - default: Optional[Any] = None, - type: Optional[Callable[[Any], Any]] = None, - as_bytes: bool = False, - ) -> Optional[Any]: + def get(self, key, default=None, type=None, as_bytes=False): """Return the default value if the requested data doesn't exist. If `type` is provided and is a callable it should convert the value, return it or raise a :exc:`ValueError` if that is not possible. In @@ -1010,7 +926,7 @@ def get( :param as_bytes: return bytes instead of strings. """ try: - rv = self.__getitem__(key, _get_mode=True) # type: ignore + rv = self.__getitem__(key, _get_mode=True) except KeyError: return default if as_bytes: @@ -1022,12 +938,7 @@ def get( except ValueError: return default - def getlist( - self, - key: UnicodeEncodable, - type: Optional[Type[int]] = None, - as_bytes: bool = False, - ) -> Union[List[int], List[str]]: + def getlist(self, key, type=None, as_bytes=False): """Return the list of items for a given key. If that key is not in the :class:`Headers`, the return value will be an empty list. Just like :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will @@ -1043,7 +954,7 @@ def getlist( :return: a :class:`list` of all the values for the key. :param as_bytes: return bytes instead of strings. """ - ikey = key.lower() # type: ignore + ikey = key.lower() result = [] for k, v in self: if k.lower() == ikey: @@ -1057,7 +968,7 @@ def getlist( result.append(v) return result - def get_all(self, name: str) -> List[Any]: + def get_all(self, name): """Return a list of all the values for the named field. This method is compatible with the :mod:`wsgiref` @@ -1065,13 +976,13 @@ def get_all(self, name: str) -> List[Any]: """ return self.getlist(name) - def items(self, lower: bool = False) -> Iterator[Tuple[str, str]]: + def items(self, lower=False): for key, value in self: if lower: key = key.lower() yield key, value - def keys(self, lower: bool = False) -> Iterator[str]: + def keys(self, lower=False): for key, _ in self.items(lower): yield key @@ -1079,7 +990,7 @@ def values(self): for _, value in self.items(): yield value - def extend(self, *args, **kwargs) -> None: + def extend(self, *args, **kwargs): """Extend headers in this object with items from another object containing header items as well as keyword arguments. @@ -1103,31 +1014,25 @@ def extend(self, *args, **kwargs) -> None: for key, value in iter_multi_items(kwargs): self.add(key, value) - def __delitem__( - self, key: Union[str, int, slice], _index_operation: bool = True - ) -> None: + def __delitem__(self, key, _index_operation=True): if _index_operation and isinstance(key, (int, slice)): del self._list[key] return - key = key.lower() # type: ignore + key = key.lower() new = [] for k, v in self._list: if k.lower() != key: new.append((k, v)) self._list[:] = new - def remove(self, key: UnicodeEncodable) -> None: + def remove(self, key): """Remove a key. :param key: The key to be removed. """ - return self.__delitem__(key, _index_operation=False) # type: ignore + return self.__delitem__(key, _index_operation=False) - def pop( - self, - key: Optional[Union[str, int]] = None, - default: Optional[Union[int, "_Missing"]] = _missing, - ) -> Optional[Any]: + def pop(self, key=None, default=_missing): """Removes and returns a key or index. :param key: The key to be popped. If this is an integer the item at @@ -1153,10 +1058,10 @@ def popitem(self): """Removes a key or index and returns a (key, value) item.""" return self.pop() - def __contains__(self, key: Union[str, int]) -> bool: + def __contains__(self, key): """Check if a key is present.""" try: - self.__getitem__(key, _get_mode=True) # type: ignore + self.__getitem__(key, _get_mode=True) except KeyError: return False return True @@ -1167,12 +1072,10 @@ def __iter__(self): """Yield ``(key, value)`` tuples.""" return iter(self._list) - def __len__(self) -> int: + def __len__(self): return len(self._list) - def add( - self, _key: UnicodeEncodable, _value: Union[UnicodeEncodable], **kw - ) -> None: + def add(self, _key, _value, **kw): """Add a new header tuple to the list. Keyword arguments can specify additional parameters for the header @@ -1189,13 +1092,13 @@ def add( keyword arguments were added for :mod:`wsgiref` compatibility. """ if kw: - _value = _options_header_vkw(_value, kw) # type: ignore + _value = _options_header_vkw(_value, kw) _key = _unicodify_header_value(_key) _value = _unicodify_header_value(_value) self._validate_value(_value) self._list.append((_key, _value)) - def _validate_value(self, value: Any) -> None: + def _validate_value(self, value): if not isinstance(value, str): raise TypeError("Value should be a string.") if "\n" in value or "\r" in value: @@ -1216,9 +1119,7 @@ def clear(self): """Clears all headers.""" del self._list[:] - def set( - self, _key: UnicodeEncodable, _value: Union[UnicodeEncodable, str], **kw - ) -> None: + def set(self, _key, _value, **kw): """Remove all header tuples for `key` and add a new one. The newly added key either appears at the end of the list if there was no entry or replaces the first one. @@ -1234,7 +1135,7 @@ def set( :param value: The value to be inserted. """ if kw: - _value = _options_header_vkw(_value, kw) # type: ignore + _value = _options_header_vkw(_value, kw) _key = _unicodify_header_value(_key) _value = _unicodify_header_value(_value) self._validate_value(_value) @@ -1253,9 +1154,7 @@ def set( return self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] - def setlist( - self, key: UnicodeEncodable, values: List[Union[UnicodeEncodable, str]] - ) -> None: + def setlist(self, key, values): """Remove any existing values for a header and add new ones. :param key: The header key to set. @@ -1272,7 +1171,7 @@ def setlist( else: self.remove(key) - def setdefault(self, key: UnicodeEncodable, default: Any) -> Any: + def setdefault(self, key, default): """Return the first value for the key if it is in the headers, otherwise set the header to the value given by ``default`` and return that. @@ -1282,12 +1181,12 @@ def setdefault(self, key: UnicodeEncodable, default: Any) -> Any: headers. """ if key in self: - return self[key] # type: ignore + return self[key] self.set(key, default) return default - def setlistdefault(self, key: UnicodeEncodable, default: List[Any]) -> List[Any]: + def setlistdefault(self, key, default): """Return the list of values for the key if it is in the headers, otherwise set the header to the list of values given by ``default`` and return that. @@ -1306,7 +1205,7 @@ def setlistdefault(self, key: UnicodeEncodable, default: List[Any]) -> List[Any] return self.getlist(key) - def __setitem__(self, key: Union[UnicodeEncodable, slice], value: Any) -> None: + def __setitem__(self, key, value): """Like :meth:`set` but also supports index/slice based setting.""" if isinstance(key, (slice, int)): if isinstance(key, int): @@ -1324,7 +1223,7 @@ def __setitem__(self, key: Union[UnicodeEncodable, slice], value: Any) -> None: else: self.set(key, value) - def update(self, *args, **kwargs) -> None: + def update(self, *args, **kwargs): """Replace headers in this object with items from another headers object and keyword arguments. @@ -1345,11 +1244,11 @@ def update(self, *args, **kwargs) -> None: if isinstance(mapping, (Headers, MultiDict)): for key in mapping.keys(): - self.setlist(key, mapping.getlist(key)) # type: ignore + self.setlist(key, mapping.getlist(key)) elif isinstance(mapping, dict): for key, value in mapping.items(): if isinstance(value, (list, tuple)): - self.setlist(key, value) # type: ignore + self.setlist(key, value) else: self.set(key, value) else: @@ -1358,24 +1257,24 @@ def update(self, *args, **kwargs) -> None: for key, value in kwargs.items(): if isinstance(value, (list, tuple)): - self.setlist(key, value) # type: ignore + self.setlist(key, value) else: self.set(key, value) - def to_wsgi_list(self) -> List[Tuple[str, str]]: + def to_wsgi_list(self): """Convert the headers into a list suitable for WSGI. :return: list """ return list(self) - def copy(self) -> "Headers": + def copy(self): return self.__class__(self._list) def __copy__(self): return self.copy() - def __str__(self) -> str: + def __str__(self): """Returns formatted headers suitable for HTTP transmission.""" strs = [] for key, value in self.to_wsgi_list(): @@ -1387,9 +1286,6 @@ def __repr__(self): return f"{type(self).__name__}({list(self)!r})" -AnyHeaders = TypeVar("AnyHeaders", List[Tuple[str, str]], Headers) - - class ImmutableHeadersMixin: """Makes a :class:`Headers` immutable. We do not mark them as hashable though since the only usecase for this datastructure @@ -1406,19 +1302,19 @@ def __delitem__(self, key, **kwargs): def __setitem__(self, key, value): is_immutable(self) - def set(self, key, value): + def set(self, _key, _value, **kw): is_immutable(self) - def setlist(self, key, value): + def setlist(self, key, values): is_immutable(self) - def add(self, item): + def add(self, _key, _value, **kw): is_immutable(self) - def add_header(self, item): + def add_header(self, _key, _value, **_kw): is_immutable(self) - def remove(self, item): + def remove(self, key): is_immutable(self) def extend(self, *args, **kwargs): @@ -1430,7 +1326,7 @@ def update(self, *args, **kwargs): def insert(self, pos, value): is_immutable(self) - def pop(self, index=-1): + def pop(self, key=None, default=_missing): is_immutable(self) def popitem(self): @@ -1443,7 +1339,7 @@ def setlistdefault(self, key, default): is_immutable(self) -class EnvironHeaders(ImmutableHeadersMixin, Headers): # type: ignore +class EnvironHeaders(ImmutableHeadersMixin, Headers): """Read only version of the headers from a WSGI environment. This provides the same interface as `Headers` and is constructed from a WSGI environment. @@ -1454,7 +1350,7 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): # type: ignore HTTP exceptions. """ - def __init__(self, environ: WSGIEnvironment) -> None: + def __init__(self, environ): self.environ = environ def __eq__(self, other): @@ -1462,12 +1358,7 @@ def __eq__(self, other): __hash__ = None - def __getitem__( # type: ignore - # note: the type annotation for "key" might need to be adjusted - self, - key: Union[str, int], - _get_mode: bool = False, - ) -> str: + def __getitem__(self, key, _get_mode=False): # _get_mode is a no-op for this class as there is no index but # used because get() calls it. if not isinstance(key, str): @@ -1477,12 +1368,12 @@ def __getitem__( # type: ignore return _unicodify_header_value(self.environ[key]) return _unicodify_header_value(self.environ[f"HTTP_{key}"]) - def __len__(self) -> int: + def __len__(self): # the iter is necessary because otherwise list calls our # len which would call list again and so forth. return len(list(iter(self))) - def __iter__(self) -> Iterator[Tuple[str, str]]: + def __iter__(self): for key, value in self.environ.items(): if key.startswith("HTTP_") and key not in ( "HTTP_CONTENT_TYPE", @@ -1493,16 +1384,13 @@ def __iter__(self) -> Iterator[Tuple[str, str]]: _unicodify_header_value(value), ) elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: - yield ( - key.replace("_", "-").title(), - _unicodify_header_value(value), - ) + yield (key.replace("_", "-").title(), _unicodify_header_value(value)) def copy(self): raise TypeError(f"cannot create {type(self).__name__!r} copies") -class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): # type: ignore +class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` instances as sequence and it will combine the return values of all wrapped dicts: @@ -1525,30 +1413,23 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): # type: ignore exceptions. """ - dicts: List[Any] - def __reduce_ex__(self, protocol): return type(self), (self.dicts,) - def __init__(self, dicts: Optional[Iterable] = None,) -> None: + def __init__(self, dicts=None): self.dicts = list(dicts) or [] @classmethod - def fromkeys(cls): + def fromkeys(cls, keys, value=None): raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") - def __getitem__(self, key: Hashable) -> Any: + def __getitem__(self, key): for d in self.dicts: if key in d: return d[key] raise exceptions.BadRequestKeyError(key) - def get( - self, - key: Hashable, - default: T = None, - type: Optional[Callable[[Any], T]] = None, - ) -> Union[Any, T]: + def get(self, key, default=None, type=None): for d in self.dicts: if key in d: if type is not None: @@ -1559,30 +1440,26 @@ def get( return d[key] return default - def getlist( - self, key: Hashable, type: Optional[Callable[[Any], T]] = None - ) -> Union[List[Any], List[T]]: + def getlist(self, key, type=None): rv = [] for d in self.dicts: rv.extend(d.getlist(key, type)) return rv - def _keys_impl(self) -> Set[Any]: + def _keys_impl(self): """This function exists so __len__ can be implemented more efficiently, saving one list creation from an iterator. """ - rv: Set[Any] = set() + rv = set() rv.update(*self.dicts) return rv - def keys(self) -> Set[Any]: # type: ignore + def keys(self): return self._keys_impl() - __iter__ = keys # type: ignore + __iter__ = keys - def items( # type: ignore - self, multi: bool = False - ) -> Iterator[Tuple[Hashable, Any]]: + def items(self, multi=False): found = set() for d in self.dicts: for key, value in d.items(multi): @@ -1592,21 +1469,21 @@ def items( # type: ignore found.add(key) yield key, value - def values(self) -> Any: + def values(self): for _key, value in self.items(): yield value - def lists(self) -> List[Any]: # type: ignore - rv: Any = {} + def lists(self): + rv = {} for d in self.dicts: for key, values in d.lists(): rv.setdefault(key, []).extend(values) return list(rv.items()) - def listvalues(self) -> Iterator[Any]: + def listvalues(self): return (x[1] for x in self.lists()) - def copy(self) -> MultiDict: + def copy(self): """Return a shallow mutable copy of this object. This returns a :class:`MultiDict` representing the data at the @@ -1633,10 +1510,10 @@ def to_dict(self, flat=True): rv.update(d.to_dict(flat)) return rv - def __len__(self) -> int: + def __len__(self): return len(self._keys_impl()) - def __contains__(self, key: object) -> bool: + def __contains__(self, key): for d in self.dicts: if key in d: return True @@ -1656,13 +1533,7 @@ class FileMultiDict(MultiDict): .. versionadded:: 0.5 """ - def add_file( - self, - name: str, - file: BytesIO, - filename: Optional[str] = None, - content_type: None = None, - ) -> None: + def add_file(self, name, file, filename=None, content_type=None): """Adds a new file to the dict. `file` can be a file name or a :class:`file`-like or a :class:`FileStorage` object. @@ -1680,15 +1551,14 @@ def add_file( file = open(file, "rb") if filename and content_type is None: content_type = ( - mimetypes.guess_type(filename)[0] # type: ignore - or "application/octet-stream" + mimetypes.guess_type(filename)[0] or "application/octet-stream" ) value = FileStorage(file, filename, name, content_type) self.add(name, value) -class ImmutableDict(ImmutableDictMixin, dict): # type: ignore +class ImmutableDict(ImmutableDictMixin, dict): """An immutable :class:`dict`. .. versionadded:: 0.5 @@ -1697,53 +1567,51 @@ class ImmutableDict(ImmutableDictMixin, dict): # type: ignore def __repr__(self): return f"{type(self).__name__}({dict.__repr__(self)})" - def copy(self) -> Dict[str, Any]: + def copy(self): """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return dict(self) - def __copy__(self) -> "ImmutableDict": + def __copy__(self): return self -class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): # type: ignore +class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): """An immutable :class:`MultiDict`. .. versionadded:: 0.5 """ - def copy(self) -> MultiDict: + def copy(self): """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return MultiDict(self) - def __copy__(self) -> "ImmutableMultiDict": + def __copy__(self): return self -class ImmutableOrderedMultiDict( # type: ignore - ImmutableMultiDictMixin, OrderedMultiDict -): +class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): """An immutable :class:`OrderedMultiDict`. .. versionadded:: 0.6 """ - def _iter_hashitems(self) -> enumerate: + def _iter_hashitems(self): return enumerate(self.items(multi=True)) - def copy(self) -> OrderedMultiDict: + def copy(self): """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return OrderedMultiDict(self) - def __copy__(self) -> "ImmutableOrderedMultiDict": + def __copy__(self): return self @@ -1783,16 +1651,7 @@ class Accept(ImmutableList): """ - def __init__( - self, - values: Optional[ - Union[ - List[Tuple[str, int]], - Tuple, - List[Union[Tuple[str, int], Tuple[str, float]]], - ] - ] = (), - ) -> None: + def __init__(self, values=()): if values is None: list.__init__(self) self.provided = False @@ -1802,21 +1661,19 @@ def __init__( else: self.provided = True values = sorted( - values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True, + values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True ) list.__init__(self, values) - def _specificity(self, value: str) -> Tuple[bool]: + def _specificity(self, value): """Returns a tuple describing the value's specificity.""" return (value != "*",) - def _value_matches(self, value: str, item: str) -> bool: + def _value_matches(self, value, item): """Check if a value matches a given accept item.""" return item == "*" or item.lower() == value.lower() - def __getitem__( # type: ignore - self, key: Union[str, int] - ) -> Union[float, Tuple[str, int], int, Tuple[str, float]]: + def __getitem__(self, key): """Besides index lookup (getting item n) you can also pass it a string to get the quality for the item. If the item is not in the list, the returned quality is ``0``. @@ -1825,7 +1682,7 @@ def __getitem__( # type: ignore return self.quality(key) return list.__getitem__(self, key) - def quality(self, key: str) -> Union[float, int]: + def quality(self, key): """Returns the quality of the key. .. versionadded:: 0.6 @@ -1837,9 +1694,9 @@ def quality(self, key: str) -> Union[float, int]: return quality return 0 - def __contains__(self, value: object) -> bool: + def __contains__(self, value): for item, _quality in self: - if self._value_matches(value, item): # type: ignore + if self._value_matches(value, item): return True return False @@ -1847,7 +1704,7 @@ def __repr__(self): pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) return f"{type(self).__name__}([{pairs_str}])" - def index(self, key: Any) -> int: # type: ignore + def index(self, key): """Get the position of an entry or raise :exc:`ValueError`. :param key: The key to be looked up. @@ -1863,7 +1720,7 @@ def index(self, key: Any) -> int: # type: ignore raise ValueError(key) return list.index(self, key) - def find(self, key: str) -> int: + def find(self, key): """Get the position of an entry or return -1. :param key: The key to be looked up. @@ -1873,12 +1730,12 @@ def find(self, key: str) -> int: except ValueError: return -1 - def values(self) -> Iterator[str]: + def values(self): """Iterate over all values.""" for item in self: yield item[0] - def to_header(self) -> str: + def to_header(self): """Convert the header set into an HTTP header string.""" result = [] for value, quality in self: @@ -1890,18 +1747,14 @@ def to_header(self) -> str: def __str__(self): return self.to_header() - def _best_single_match( - self, match: str - ) -> Optional[Union[Tuple[str, int], Tuple[str, float]]]: + def _best_single_match(self, match): for client_item, quality in self: if self._value_matches(match, client_item): # self is sorted by specificity descending, we can exit return client_item, quality return None - def best_match( - self, matches: List[str], default: Optional[str] = None - ) -> Optional[str]: + def best_match(self, matches, default=None): """Returns the best match from a list of possible matches based on the specificity and quality of the client. If two items have the same quality and specificity, the one is returned that comes first. @@ -1923,21 +1776,21 @@ def best_match( # better quality or same quality but more specific => better match if quality > best_quality or specificity > best_specificity: result = server_item - best_quality = quality # type: ignore + best_quality = quality best_specificity = specificity return result @property - def best(self) -> str: # type: ignore + def best(self): """The best match as value.""" if self: - return self[0][0] # type: ignore + return self[0][0] _mime_split_re = re.compile(r"/|(?:\s*;\s*)") -def _normalize_mime(value: str) -> List[str]: +def _normalize_mime(value): return _mime_split_re.split(value.lower()) @@ -1946,12 +1799,10 @@ class MIMEAccept(Accept): mimetypes. """ - def _specificity( # type: ignore - self, value: str - ) -> Tuple[bool, ...]: + def _specificity(self, value): return tuple(x != "*" for x in _mime_split_re.split(value)) - def _value_matches(self, value: str, item: str) -> bool: + def _value_matches(self, value, item): # item comes from the client, can't match if it's invalid. if "/" not in item: return False @@ -2012,7 +1863,7 @@ def accept_json(self): _locale_delim_re = re.compile(r"[_-]") -def _normalize_lang(value: str) -> List[str]: +def _normalize_lang(value): """Process a language tag for matching.""" return _locale_delim_re.split(value.lower()) @@ -2020,12 +1871,10 @@ def _normalize_lang(value: str) -> List[str]: class LanguageAccept(Accept): """Like :class:`Accept` but with normalization for language tags.""" - def _value_matches(self, value: str, item: str) -> bool: + def _value_matches(self, value, item): return item == "*" or _normalize_lang(value) == _normalize_lang(item) - def best_match( - self, matches: List[str], default: Optional[str] = None - ) -> Optional[str]: + def best_match(self, matches, default=None): """Given a list of supported values, finds the best match from the list of accepted values. @@ -2076,7 +1925,7 @@ def best_match( class CharsetAccept(Accept): """Like :class:`Accept` but with normalization for charsets.""" - def _value_matches(self, value: str, item: str) -> bool: + def _value_matches(self, value, item): def _normalize(name): try: return codecs.lookup(name).name @@ -2097,7 +1946,7 @@ def cache_property(key, empty, type): ) -class _CacheControl(UpdateDictMixin, dict): # type: ignore +class _CacheControl(UpdateDictMixin, dict): """Subclass of a dict that stores values for a Cache-Control header. It has accessors for all the cache-control directives specified in RFC 2616. The class does not differentiate between request and response directives. @@ -2134,25 +1983,12 @@ class _CacheControl(UpdateDictMixin, dict): # type: ignore max_age = cache_property("max-age", -1, int) no_transform = cache_property("no-transform", None, None) - def __init__( - self, - values: Optional[ - Union[ - Tuple, List[Tuple[str, str]], Dict[str, Optional[str]], Dict[str, None], - ] - ] = (), - on_update: Optional[Callable] = None, - ) -> None: + def __init__(self, values=(), on_update=None): dict.__init__(self, values or ()) - self.on_update = on_update # type: ignore + self.on_update = on_update self.provided = values is not None - def _get_cache_value( - self, - key: str, - empty: Optional[Union[str, int]], - type: Optional[Union[Type[int], Type[bool]]], - ) -> Optional[Union[bool, str, int]]: + def _get_cache_value(self, key, empty, type): """Used internally by the accessor properties.""" if type is bool: return key in self @@ -2168,12 +2004,7 @@ def _get_cache_value( return value return None - def _set_cache_value( - self, - key: str, - value: Optional[Union[bool, int]], - type: Optional[Union[Type[int], Type[bool]]], - ) -> None: + def _set_cache_value(self, key, value, type): """Used internally by the accessor properties.""" if type is bool: if value: @@ -2188,24 +2019,26 @@ def _set_cache_value( else: self[key] = value - def _del_cache_value(self, key: str) -> None: + def _del_cache_value(self, key): """Used internally by the accessor properties.""" if key in self: del self[key] - def to_header(self) -> str: + def to_header(self): """Convert the stored values into a cache control header.""" - return dump_header(self) + return http.dump_header(self) def __str__(self): return self.to_header() - def __repr__(self) -> str: + def __repr__(self): kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) return f"<{type(self).__name__} {kv_str}>" + cache_property = staticmethod(cache_property) -class RequestCacheControl(ImmutableDictMixin, _CacheControl): # type: ignore + +class RequestCacheControl(ImmutableDictMixin, _CacheControl): """A cache control for requests. This is immutable and gives access to all the request-relevant cache control headers. @@ -2247,11 +2080,6 @@ class ResponseCacheControl(_CacheControl): immutable = cache_property("immutable", None, bool) -# attach cache_property to the _CacheControl as staticmethod -# so that others can reuse it. -_CacheControl.cache_property = staticmethod(cache_property) # type: ignore - - def csp_property(key): """Return a new property object for a content security policy header. Useful if you want to add support for a csp extension in a @@ -2265,7 +2093,7 @@ def csp_property(key): ) -class ContentSecurityPolicy(UpdateDictMixin, dict): # type: ignore +class ContentSecurityPolicy(UpdateDictMixin, dict): """Subclass of a dict that stores values for a Content Security Policy header. It has accessors for all the level 3 policies. @@ -2308,20 +2136,16 @@ class ContentSecurityPolicy(UpdateDictMixin, dict): # type: ignore style_src_elem = csp_property("style-src-elem") worker_src = csp_property("worker-src") - def __init__( - self, - values: Union[Tuple, List[Tuple[str, str]]] = (), - on_update: Optional[Callable] = None, - ) -> None: + def __init__(self, values=(), on_update=None): dict.__init__(self, values or ()) - self.on_update = on_update # type: ignore + self.on_update = on_update self.provided = values is not None - def _get_value(self, key: str) -> Optional[str]: + def _get_value(self, key): """Used internally by the accessor properties.""" return self.get(key) - def _set_value(self, key: str, value: str) -> None: + def _set_value(self, key, value): """Used internally by the accessor properties.""" if value is None: self.pop(key, None) @@ -2333,9 +2157,9 @@ def _del_value(self, key): if key in self: del self[key] - def to_header(self) -> str: + def to_header(self): """Convert the stored values into a cache control header.""" - return dump_csp_header(self) + return http.dump_csp_header(self) def __str__(self): return self.to_header() @@ -2345,18 +2169,14 @@ def __repr__(self): return f"<{type(self).__name__} {kv_str}>" -class CallbackDict(UpdateDictMixin, dict): # type: ignore +class CallbackDict(UpdateDictMixin, dict): """A dict that calls a function passed every time something is changed. The function is passed the dict instance. """ - def __init__( - self, - initial: Optional[Dict[str, str]] = None, - on_update: Optional[Callable] = None, - ) -> None: + def __init__(self, initial=None, on_update=None): dict.__init__(self, initial or ()) - self.on_update = on_update # type: ignore + self.on_update = on_update def __repr__(self): return f"<{type(self).__name__} {dict.__repr__(self)}>" @@ -2375,18 +2195,16 @@ class HeaderSet(MutableSet): HeaderSet(['foo', 'bar', 'baz']) """ - def __init__( - self, headers: Optional[List[str]] = None, on_update: Optional[Callable] = None, - ) -> None: + def __init__(self, headers=None, on_update=None): self._headers = list(headers or ()) self._set = {x.lower() for x in self._headers} self.on_update = on_update - def add(self, header: str) -> None: + def add(self, header): """Add a new header to the set.""" self.update((header,)) - def remove(self, header: str) -> None: + def remove(self, header): """Remove a header from the set. This raises an :exc:`KeyError` if the header is not in the set. @@ -2407,7 +2225,7 @@ def remove(self, header: str) -> None: if self.on_update is not None: self.on_update(self) - def update(self, iterable: Union[List[str], Tuple[str]]) -> None: + def update(self, iterable): """Add all the headers from the iterable to the set. :param iterable: updates the set with the items from the iterable. @@ -2422,17 +2240,17 @@ def update(self, iterable: Union[List[str], Tuple[str]]) -> None: if inserted_any and self.on_update is not None: self.on_update(self) - def discard(self, header: str) -> None: + def discard(self, header): """Like :meth:`remove` but ignores errors. :param header: the header to be discarded. """ try: - return self.remove(header) + self.remove(header) except KeyError: pass - def find(self, header: str) -> int: + def find(self, header): """Return the index of the header in the set or return -1 if not found. :param header: the header to be looked up. @@ -2443,7 +2261,7 @@ def find(self, header: str) -> int: return idx return -1 - def index(self, header: str) -> int: + def index(self, header): """Return the index of the header in the set or raise an :exc:`IndexError`. @@ -2454,14 +2272,14 @@ def index(self, header: str) -> int: raise IndexError(header) return rv - def clear(self) -> None: + def clear(self): """Clear the set.""" self._set.clear() del self._headers[:] if self.on_update is not None: self.on_update(self) - def as_set(self, preserve_casing: bool = False) -> Set[str]: + def as_set(self, preserve_casing=False): """Return the set as real python set type. When calling this, all the items are converted to lowercase and the ordering is lost. @@ -2474,9 +2292,9 @@ def as_set(self, preserve_casing: bool = False) -> Set[str]: return set(self._headers) return set(self._set) - def to_header(self) -> str: + def to_header(self): """Convert the header set into an HTTP header string.""" - return ", ".join(map(quote_header_value, self._headers)) + return ", ".join(map(http.quote_header_value, self._headers)) def __getitem__(self, idx): return self._headers[idx] @@ -2495,16 +2313,16 @@ def __setitem__(self, idx, value): if self.on_update is not None: self.on_update(self) - def __contains__(self, header: object) -> bool: - return header.lower() in self._set # type: ignore + def __contains__(self, header): + return header.lower() in self._set - def __len__(self) -> int: + def __len__(self): return len(self._set) def __iter__(self): return iter(self._headers) - def __nonzero__(self): + def __bool__(self): return bool(self._set) def __str__(self): @@ -2514,17 +2332,12 @@ def __repr__(self): return f"{type(self).__name__}({self._headers!r})" -class ETags(Container, Iterable): +class ETags(Collection): """A set that can be used to check if one etag is present in a collection of etags. """ - def __init__( - self, - strong_etags: Optional[List[str]] = None, - weak_etags: Optional[List[str]] = None, - star_tag: bool = False, - ) -> None: + def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): if not star_tag and strong_etags: self._strong = frozenset(strong_etags) else: @@ -2541,19 +2354,19 @@ def as_set(self, include_weak=False): rv.update(self._weak) return rv - def is_weak(self, etag: str) -> bool: + def is_weak(self, etag): """Check if an etag is weak.""" return etag in self._weak - def is_strong(self, etag: str) -> bool: + def is_strong(self, etag): """Check if an etag is strong.""" return etag in self._strong - def contains_weak(self, etag: str) -> bool: + def contains_weak(self, etag): """Check if an etag is part of the set including weak and strong tags.""" return self.is_weak(etag) or self.contains(etag) - def contains(self, etag: str) -> bool: + def contains(self, etag): """Check if an etag is part of the set ignoring weak tags. It is also possible to use the ``in`` operator. """ @@ -2561,16 +2374,16 @@ def contains(self, etag: str) -> bool: return True return self.is_strong(etag) - def contains_raw(self, etag: str) -> bool: + def contains_raw(self, etag): """When passed a quoted tag it will check if this tag is part of the set. If the tag is weak it is checked against weak and strong tags, otherwise strong only.""" - etag, weak = unquote_etag(etag) + etag, weak = http.unquote_etag(etag) if weak: return self.contains_weak(etag) return self.contains(etag) - def to_header(self) -> str: + def to_header(self): """Convert the etags set into a HTTP header string.""" if self.star_tag: return "*" @@ -2578,31 +2391,30 @@ def to_header(self) -> str: [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] ) - def __call__( - self, etag: Optional[str] = None, data: None = None, include_weak: bool = False, - ) -> bool: + def __call__(self, etag=None, data=None, include_weak=False): if [etag, data].count(None) != 1: raise TypeError("either tag or data required, but at least one") if etag is None: - etag = generate_etag(data) + etag = http.generate_etag(data) if include_weak: if etag in self._weak: return True return etag in self._strong - def __bool__(self) -> bool: + def __bool__(self): return bool(self.star_tag or self._strong or self._weak) - __nonzero__ = __bool__ - def __str__(self): return self.to_header() + def __len__(self): + return len(self._strong) + def __iter__(self): return iter(self._strong) - def __contains__(self, etag: object) -> bool: - return self.contains(etag) # type: ignore + def __contains__(self, etag): + return self.contains(etag) def __repr__(self): return f"<{type(self).__name__} {str(self)!r}>" @@ -2616,21 +2428,19 @@ class IfRange: .. versionadded:: 0.7 """ - def __init__( - self, etag: Optional[str] = None, date: Optional["datetime"] = None - ) -> None: + def __init__(self, etag=None, date=None): #: The etag parsed and unquoted. Ranges always operate on strong #: etags so the weakness information is not necessary. self.etag = etag #: The date in parsed format or `None`. self.date = date - def to_header(self) -> str: + def to_header(self): """Converts the object back into an HTTP header.""" if self.date is not None: - return http_date(self.date) + return http.http_date(self.date) if self.etag is not None: - return quote_etag(self.etag) + return http.quote_etag(self.etag) return "" def __str__(self): @@ -2653,16 +2463,7 @@ class Range: .. versionadded:: 0.7 """ - def __init__( - self, - units: str, - ranges: Union[ - List[Union[Tuple[int, int], Tuple[int, None]]], - List[Tuple[None, int]], - List[Tuple[int, int]], - List[Tuple[int, None]], - ], - ) -> None: + def __init__(self, units, ranges): #: The units of this range. Usually "bytes". self.units = units #: A list of ``(begin, end)`` tuples for the range header provided. @@ -2673,7 +2474,7 @@ def __init__( if start is None or (end is not None and (start < 0 or start >= end)): raise ValueError(f"{(start, end)} is not a valid range.") - def range_for_length(self, length: int) -> Optional[Tuple[int, int]]: + def range_for_length(self, length): """If the range is for bytes, the length is not None and there is exactly one range and it is satisfiable it returns a ``(start, stop)`` tuple, otherwise `None`. @@ -2685,11 +2486,11 @@ def range_for_length(self, length: int) -> Optional[Tuple[int, int]]: end = length if start < 0: start += length - if is_byte_range_valid(start, end, length): + if http.is_byte_range_valid(start, end, length): return start, min(end, length) return None - def make_content_range(self, length: int) -> Optional["ContentRange"]: + def make_content_range(self, length): """Creates a :class:`~werkzeug.datastructures.ContentRange` object from the current range and given content length. """ @@ -2698,7 +2499,7 @@ def make_content_range(self, length: int) -> Optional["ContentRange"]: return ContentRange(self.units, rng[0], rng[1], length) return None - def to_header(self) -> str: + def to_header(self): """Converts the object back into an HTTP header.""" ranges = [] for begin, end in self.ranges: @@ -2708,7 +2509,7 @@ def to_header(self) -> str: ranges.append(f"{begin}-{end - 1}") return f"{self.units}={','.join(ranges)}" - def to_content_range_header(self, length: int) -> str: + def to_content_range_header(self, length): """Converts the object into `Content-Range` HTTP header, based on given length """ @@ -2724,35 +2525,29 @@ def __repr__(self): return f"<{type(self).__name__} {str(self)!r}>" +def _callback_property(name): + def fget(self): + return getattr(self, name) + + def fset(self, value): + setattr(self, name, value) + if self.on_update is not None: + self.on_update(self) + + return property(fget, fset) + + class ContentRange: """Represents the content range header. .. versionadded:: 0.7 """ - def __init__( - self, - units: str, - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - on_update: Optional[Callable] = None, - ) -> None: - assert is_byte_range_valid(start, stop, length), "Bad range provided" + def __init__(self, units, start, stop, length=None, on_update=None): + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" self.on_update = on_update self.set(start, stop, length, units) - def _callback_property(name: str): # type: ignore # noqa: B902 - def fget(self): - return getattr(self, name) - - def fset(self, value): - setattr(self, name, value) - if self.on_update is not None: - self.on_update(self) - - return property(fget, fset) - #: The units to use, usually "bytes" units = _callback_property("_units") #: The start point of the range or `None`. @@ -2762,17 +2557,10 @@ def fset(self, value): stop = _callback_property("_stop") #: The length of the range or `None`. length = _callback_property("_length") - del _callback_property - def set( - self, - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - units: Optional[str] = "bytes", - ) -> None: + def set(self, start, stop, length=None, units="bytes"): """Simple method to update the ranges.""" - assert is_byte_range_valid(start, stop, length), "Bad range provided" + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" self._units = units self._start = start self._stop = stop @@ -2780,13 +2568,13 @@ def set( if self.on_update is not None: self.on_update(self) - def unset(self) -> None: + def unset(self): """Sets the units to `None` which indicates that the header should no longer be used. """ self.set(None, None, units=None) - def to_header(self) -> str: + def to_header(self): if self.units is None: return "" if self.length is None: @@ -2797,11 +2585,9 @@ def to_header(self) -> str: return f"{self.units} */{length}" return f"{self.units} {self.start}-{self.stop - 1}/{length}" - def __nonzero__(self) -> bool: + def __bool__(self): return self.units is not None - __bool__ = __nonzero__ - def __str__(self): return self.to_header() @@ -2809,7 +2595,7 @@ def __repr__(self): return f"<{type(self).__name__} {str(self)!r}>" -class Authorization(ImmutableDictMixin, dict): # type: ignore +class Authorization(ImmutableDictMixin, dict): """Represents an ``Authorization`` header sent by the client. This is returned by @@ -2821,38 +2607,38 @@ class Authorization(ImmutableDictMixin, dict): # type: ignore This object became immutable. """ - def __init__(self, auth_type: str, data: Optional[Dict[str, str]] = None) -> None: + def __init__(self, auth_type, data=None): dict.__init__(self, data or {}) self.type = auth_type @property - def username(self) -> str: + def username(self): """The username transmitted. This is set for both basic and digest auth all the time. """ return self.get("username") @property - def password(self) -> str: + def password(self): """When the authentication type is basic this is the password transmitted by the client, else `None`. """ return self.get("password") @property - def realm(self) -> str: + def realm(self): """This is the server realm sent back for HTTP digest auth.""" return self.get("realm") @property - def nonce(self) -> str: + def nonce(self): """The nonce the server sent for digest auth, sent back by the client. A nonce should be unique for every 401 response for HTTP digest auth. """ return self.get("nonce") @property - def uri(self) -> str: + def uri(self): """The URI from Request-URI of the Request-Line; duplicated because proxies are allowed to change the Request-Line in transit. HTTP digest auth only. @@ -2860,14 +2646,14 @@ def uri(self) -> str: return self.get("uri") @property - def nc(self) -> str: + def nc(self): """The nonce count value transmitted by clients if a qop-header is also transmitted. HTTP digest auth only. """ return self.get("nc") @property - def cnonce(self) -> str: + def cnonce(self): """If the server sent a qop-header in the ``WWW-Authenticate`` header, the client has to provide this value for HTTP digest auth. See the RFC for more details. @@ -2875,14 +2661,14 @@ def cnonce(self) -> str: return self.get("cnonce") @property - def response(self) -> str: + def response(self): """A string of 32 hex digits computed as defined in RFC 2617, which proves that the user knows a password. Digest auth only. """ return self.get("response") @property - def opaque(self) -> str: + def opaque(self): """The opaque header from the server returned unchanged by the client. It is recommended that this string be base64 or hexadecimal data. Digest auth only. @@ -2890,14 +2676,14 @@ def opaque(self) -> str: return self.get("opaque") @property - def qop(self) -> str: + def qop(self): """Indicates what "quality of protection" the client has applied to the message for HTTP digest auth. Note that this is a single token, not a quoted list of alternatives as in WWW-Authenticate. """ return self.get("qop") - def to_header(self) -> str: + def to_header(self): """Convert to a string value for an ``Authorization`` header. .. versionadded:: 2.0 @@ -2910,29 +2696,57 @@ def to_header(self) -> str: return f"Basic {value}" if self.type == "digest": - return f"Digest {dump_header(self)}" + return f"Digest {http.dump_header(self)}" raise ValueError(f"Unsupported type {self.type!r}.") -class WWWAuthenticate(UpdateDictMixin, dict): # type: ignore +def auth_property(name, doc=None): + """A static helper function for Authentication subclasses to add + extra authentication system properties onto a class:: + + class FooAuthenticate(WWWAuthenticate): + special_realm = auth_property('special_realm') + + For more information have a look at the sourcecode to see how the + regular properties (:attr:`realm` etc.) are implemented. + """ + + def _set_value(self, value): + if value is None: + self.pop(name, None) + else: + self[name] = str(value) + + return property(lambda x: x.get(name), _set_value, doc=doc) + + +def _set_property(name, doc=None): + def fget(self): + def on_update(header_set): + if not header_set and name in self: + del self[name] + elif header_set: + self[name] = header_set.to_header() + + return http.parse_set_header(self.get(name), on_update) + + return property(fget, doc=doc) + + +class WWWAuthenticate(UpdateDictMixin, dict): """Provides simple access to `WWW-Authenticate` headers.""" #: list of keys that require quoting in the generated header _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"]) - def __init__( - self, - auth_type: Optional[str] = None, - values: Optional[Dict[str, str]] = None, - on_update: Optional[Callable] = None, - ) -> None: + def __init__(self, auth_type=None, values=None, on_update=None): dict.__init__(self, values or ()) if auth_type: self["__auth_type__"] = auth_type - self.on_update = on_update # type: ignore + self.on_update = on_update - def set_basic(self, realm: str = "authentication required") -> None: + def set_basic(self, realm="authentication required"): """Clear the auth info and enable basic auth.""" dict.clear(self) dict.update(self, {"__auth_type__": "basic", "realm": realm}) @@ -2940,20 +2754,14 @@ def set_basic(self, realm: str = "authentication required") -> None: self.on_update(self) def set_digest( - self, - realm: str, - nonce: str, - qop: Union[Tuple[str, str], Tuple[str]] = ("auth",), - opaque: None = None, - algorithm: None = None, - stale: bool = False, - ) -> None: + self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False + ): """Clear the auth info and enable digest auth.""" d = { "__auth_type__": "digest", "realm": realm, "nonce": nonce, - "qop": dump_header(qop), + "qop": http.dump_header(qop), } if stale: d["stale"] = "TRUE" @@ -2966,58 +2774,23 @@ def set_digest( if self.on_update: self.on_update(self) - def to_header(self) -> str: + def to_header(self): """Convert the stored values into a WWW-Authenticate header.""" d = dict(self) auth_type = d.pop("__auth_type__", None) or "basic" kv_items = ( - (k, quote_header_value(v, allow_token=k not in self._require_quoting),) + (k, http.quote_header_value(v, allow_token=k not in self._require_quoting)) for k, v in d.items() ) kv_string = ", ".join([f"{k}={v}" for k, v in kv_items]) return f"{auth_type.title()} {kv_string}" - def __str__(self) -> str: + def __str__(self): return self.to_header() def __repr__(self): return f"<{type(self).__name__} {self.to_header()!r}>" - def auth_property( # type: ignore - name: str, doc: Optional[str] = None # noqa: B902 - ): - """A static helper function for subclasses to add extra authentication - system properties onto a class:: - - class FooAuthenticate(WWWAuthenticate): - special_realm = auth_property('special_realm') - - For more information have a look at the sourcecode to see how the - regular properties (:attr:`realm` etc.) are implemented. - """ - - def _set_value(self, value): - if value is None: - self.pop(name, None) - else: - self[name] = str(value) - - return property(lambda x: x.get(name), _set_value, doc=doc) - - def _set_property( # type: ignore - name: str, doc: Optional[str] = None # noqa: B902 - ): - def fget(self): - def on_update(header_set): - if not header_set and name in self: - del self[name] - elif header_set: - self[name] = header_set.to_header() - - return parse_set_header(self.get(name), on_update) - - return property(fget, doc=doc) - type = auth_property( "__auth_type__", doc="""The type of the auth mechanism. HTTP currently specifies @@ -3082,8 +2855,7 @@ def stale(self, value): else: self["stale"] = "TRUE" if value else "FALSE" - auth_property = staticmethod(auth_property) # type: ignore - del _set_property + auth_property = staticmethod(auth_property) class FileStorage: @@ -3096,15 +2868,13 @@ class FileStorage: def __init__( self, - stream: Optional[ - Union[StringIO, BytesIO, BufferedReader, "SpooledTemporaryFile"] - ] = None, - filename: Optional[str] = None, - name: Optional[str] = None, - content_type: Optional[str] = None, - content_length: None = None, - headers: Optional[Headers] = None, - ) -> None: + stream=None, + filename=None, + name=None, + content_type=None, + content_length=None, + headers=None, + ): self.name = name self.stream = stream or BytesIO() @@ -3121,9 +2891,7 @@ def __init__( # Make sure the filename is not bytes. This might happen if # the file was opened from the bytes API. if isinstance(filename, bytes): - filename = filename.decode( # type: ignore - get_filesystem_encoding(), "replace" - ) + filename = filename.decode(get_filesystem_encoding(), "replace") self.filename = filename if headers is None: @@ -3134,12 +2902,12 @@ def __init__( if content_length is not None: headers["Content-Length"] = str(content_length) - def _parse_content_type(self) -> None: + def _parse_content_type(self): if not hasattr(self, "_parsed_content_type"): - self._parsed_content_type = parse_options_header(self.content_type) + self._parsed_content_type = http.parse_options_header(self.content_type) @property - def content_type(self) -> Optional[str]: + def content_type(self): """The content-type sent in the header. Usually not available""" return self.headers.get("content-type") @@ -3149,7 +2917,7 @@ def content_length(self): return int(self.headers.get("content-length") or 0) @property - def mimetype(self) -> str: + def mimetype(self): """Like :attr:`content_type`, but without parameters (eg, without charset, type etc.) and always lowercase. For example if the content type is ``text/HTML; charset=utf-8`` the mimetype would be @@ -3161,7 +2929,7 @@ def mimetype(self) -> str: return self._parsed_content_type[0].lower() @property - def mimetype_params(self) -> Dict[str, str]: + def mimetype_params(self): """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -3171,11 +2939,7 @@ def mimetype_params(self) -> Dict[str, str]: self._parse_content_type() return self._parsed_content_type[1] - def save( - self, - dst: Union[BytesIO, "PosixPath", BufferedWriter], - buffer_size: int = 16384, - ) -> None: + def save(self, dst, buffer_size=16384): """Save the file to a destination path or file object. If the destination is a file object you have to close it yourself after the call. The buffer size is the number of bytes held in memory during @@ -3196,31 +2960,29 @@ def save( close_dst = False if hasattr(dst, "__fspath__"): - dst = fspath(dst) # type: ignore + dst = fspath(dst) if isinstance(dst, str): dst = open(dst, "wb") close_dst = True try: - copyfileobj(self.stream, dst, buffer_size) # type: ignore + copyfileobj(self.stream, dst, buffer_size) finally: if close_dst: - dst.close() # type: ignore + dst.close() - def close(self) -> None: + def close(self): """Close the underlying file if possible.""" try: self.stream.close() except Exception: pass - def __nonzero__(self): + def __bool__(self): return bool(self.filename) - __bool__ = __nonzero__ - - def __getattr__(self, name: str) -> Union[bool, Callable]: + def __getattr__(self, name): try: return getattr(self.stream, name) except AttributeError: @@ -3228,10 +2990,10 @@ def __getattr__(self, name: str) -> Union[bool, Callable]: # attribute from its backing file instead. # https://github.com/python/cpython/pull/3249 if hasattr(self.stream, "_file"): - return getattr(self.stream._file, name) # type: ignore + return getattr(self.stream._file, name) raise - def __iter__(self) -> Iterator[Any]: + def __iter__(self): return iter(self.stream) def __repr__(self): @@ -3239,14 +3001,4 @@ def __repr__(self): # circular dependencies -from .http import dump_csp_header -from .http import dump_header -from .http import dump_options_header -from .http import generate_etag -from .http import http_date -from .http import is_byte_range_valid -from .http import parse_options_header -from .http import parse_set_header -from .http import quote_etag -from .http import quote_header_value -from .http import unquote_etag +from . import http diff --git a/src/werkzeug/datastructures.pyi b/src/werkzeug/datastructures.pyi new file mode 100644 index 000000000..cc9be0900 --- /dev/null +++ b/src/werkzeug/datastructures.pyi @@ -0,0 +1,893 @@ +from datetime import datetime +from os import PathLike +from typing import Any +from typing import BinaryIO +from typing import Callable +from typing import Collection +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Hashable +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Literal +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from wsgiref.types import WSGIEnvironment + +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") + +def is_immutable(self: object) -> NoReturn: ... +def iter_multi_items( + mapping: Union[Mapping[K, Union[V, Iterable[V]]], Iterable[Tuple[K, V]]] +) -> Iterator[Tuple[K, V]]: ... + +class ImmutableListMixin(List[V]): + _hash_cache: Optional[int] + def __hash__(self) -> int: ... # type: ignore + def __delitem__(self, key) -> NoReturn: ... + def __iadd__(self, other) -> NoReturn: ... # type: ignore + def __imul__(self, other) -> NoReturn: ... # type: ignore + def __setitem__(self, key, value) -> NoReturn: ... + def append(self, value) -> NoReturn: ... + def remove(self, value) -> NoReturn: ... + def extend(self, values) -> NoReturn: ... + def insert(self, pos, value) -> NoReturn: ... + def pop(self, index=-1) -> NoReturn: ... + def reverse(self) -> NoReturn: ... + def sort(self, key=None, reverse=False) -> NoReturn: ... + +class ImmutableList(ImmutableListMixin[V]): ... + +class ImmutableDictMixin(Dict[K, V]): + _hash_cache: Optional[int] + @classmethod + def fromkeys( # type: ignore + cls, keys: Iterable[K], value: Optional[V] = None + ) -> ImmutableDictMixin[K, V]: ... + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def __hash__(self) -> int: ... # type: ignore + def setdefault(self, key, default=None) -> NoReturn: ... + def update(self, *args, **kwargs) -> NoReturn: ... + def pop(self, key, default=None) -> NoReturn: ... + def popitem(self) -> NoReturn: ... + def __setitem__(self, key, value) -> NoReturn: ... + def __delitem__(self, key) -> NoReturn: ... + def clear(self) -> NoReturn: ... + +class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def add(self, key, value) -> NoReturn: ... + def popitemlist(self) -> NoReturn: ... + def poplist(self, key) -> NoReturn: ... + def setlist(self, key, new_list) -> NoReturn: ... + def setlistdefault(self, key, default_list=None) -> NoReturn: ... + +def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... + +class UpdateDictMixin(Dict[K, V]): + on_update: Optional[Callable[[UpdateDictMixin[K, V]], None]] + def setdefault(self, key: K, default: Optional[V] = None) -> V: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def clear(self) -> None: ... + def popitem(self) -> Tuple[K, V]: ... + def update( + self, *args: Union[Mapping[K, V], Iterable[Tuple[K, V]]], **kwargs: V + ) -> None: ... + +class TypeConversionDict(Dict[K, V]): + @overload # type: ignore + def get(self, key: K) -> Optional[V]: ... + @overload + def get(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... + @overload + def get( + self, key: K, default: Optional[T] = None, type: Callable[[V], T] = ... + ) -> Optional[T]: ... + +class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): + def copy(self) -> TypeConversionDict[K, V]: ... + def __copy__(self) -> ImmutableTypeConversionDict: ... + +class MultiDict(TypeConversionDict[K, V]): + def __init__( + self, + mapping: Optional[ + Union[Mapping[K, Union[V, Iterable[V]]], Iterable[Tuple[K, V]]] + ] = None, + ) -> None: ... + def __getitem__(self, item: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> List[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setdefault(self, key: K, default: Optional[V] = None) -> V: ... + def setlistdefault( + self, key: K, default_list: Optional[Iterable[V]] = None + ) -> List[V]: ... + def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[Tuple[K, List[V]]]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def listvalues(self) -> Iterator[List[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + def deepcopy(self, memo=None) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> Dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... + def update( # type: ignore + self, mapping: Union[Mapping[K, V], Iterable[Tuple[K, V]]] + ) -> None: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... + def popitem(self) -> Tuple[K, V]: ... + def poplist(self, key: K) -> List[V]: ... + def popitemlist(self) -> Tuple[K, List[V]]: ... + def __copy__(self) -> MultiDict[K, V]: ... + def __deepcopy__(self, memo) -> MultiDict[K, V]: ... + +class _omd_bucket(Generic[K, V]): + prev: Optional[_omd_bucket] + next: Optional[_omd_bucket] + key: K + value: V + def __init__(self, omd: OrderedMultiDict, key: K, value: V) -> None: ... + def unlink(self, omd: OrderedMultiDict) -> None: ... + +class OrderedMultiDict(MultiDict[K, V]): + _first_bucket: Optional[_omd_bucket] + _last_bucket: Optional[_omd_bucket] + def __init__(self, mapping: Optional[Mapping[K, V]] = None) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, key: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def keys(self) -> Iterator[K]: ... # type: ignore + def __iter__(self) -> Iterator[K]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[Tuple[K, List[V]]]: ... + def listvalues(self) -> Iterator[List[V]]: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> List[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setlistdefault( + self, key: K, default_list: Optional[Iterable[V]] = None + ) -> List[V]: ... + def update( # type: ignore + self, mapping: Union[Mapping[K, V], Iterable[Tuple[K, V]]] + ) -> None: ... + def poplist(self, key: K) -> List[V]: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... + def popitem(self) -> Tuple[K, V]: ... + def popitemlist(self) -> Tuple[K, List[V]]: ... + +def _options_header_vkw( + value: str, kw: Mapping[str, Optional[Union[str, int]]] +) -> str: ... +def _unicodify_header_value(value: Union[str, int]) -> str: ... + +HV = Union[str, int] + +class Headers(Dict[str, str]): + _list: List[Tuple[str, str]] + def __init__( + self, + defaults: Optional[ + Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]] + ] = None, + ) -> None: ... + @overload + def __getitem__(self, key: str) -> str: ... + @overload + def __getitem__(self, key: int) -> Tuple[str, str]: ... + @overload + def __getitem__(self, key: slice) -> Headers: ... + @overload + def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... + def __eq__(self, other: object) -> bool: ... + @overload # type: ignore + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: ... + @overload + def get( + self, key: str, default: Optional[T] = None, type: Callable[[str], T] = ... + ) -> Optional[T]: ... + @overload + def getlist(self, key: str) -> List[str]: ... + @overload + def getlist(self, key: str, type: Callable[[str], T]) -> List[T]: ... + def get_all(self, name: str) -> List[str]: ... + def items( # type: ignore + self, lower: bool = False + ) -> Iterator[Tuple[str, str]]: ... + def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore + def values(self) -> Iterator[str]: ... # type: ignore + def extend( + self, + *args: Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]], + **kwargs: Union[HV, Iterable[HV]], + ) -> None: ... + @overload + def __delitem__(self, key: Union[str, int, slice]) -> None: ... + @overload + def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... + def remove(self, key: str) -> None: ... + @overload # type: ignore + def pop(self, key: str, default: Optional[str] = None) -> str: ... + @overload + def pop( + self, key: Optional[int] = None, default: Optional[Tuple[str, str]] = None + ) -> Tuple[str, str]: ... + def popitem(self) -> Tuple[str, str]: ... + def __contains__(self, key: str) -> bool: ... # type: ignore + def has_key(self, key: str) -> bool: ... + def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore + def add(self, _key: str, _value: HV, **kw: HV) -> None: ... + def _validate_value(self, value: str) -> None: ... + def add_header(self, _key: str, _value: HV, **_kw: HV) -> None: ... + def clear(self) -> None: ... + def set(self, _key: str, _value: HV, **kw: HV) -> None: ... + def setlist(self, key: str, values: Iterable[HV]) -> None: ... + def setdefault(self, key: str, default: HV) -> str: ... # type: ignore + def setlistdefault(self, key: str, default: Iterable[HV]) -> None: ... + @overload + def __setitem__(self, key: str, value: HV) -> None: ... + @overload + def __setitem__(self, key: int, value: Tuple[str, HV]) -> None: ... + @overload + def __setitem__(self, key: slice, value: Iterable[Tuple[str, HV]]) -> None: ... + def update( + self, + *args: Union[Mapping[str, HV], Iterable[Tuple[str, HV]]], + **kwargs: Union[HV, Iterable[HV]], + ) -> None: ... + def to_wsgi_list(self) -> List[Tuple[str, str]]: ... + def copy(self) -> Headers: ... + def __copy__(self) -> Headers: ... + +class ImmutableHeadersMixin(Headers): + def __delitem__(self, key, _index_operation: bool = True) -> NoReturn: ... + def __setitem__(self, key, value) -> NoReturn: ... + def set(self, _key, _value, **kw) -> NoReturn: ... + def setlist(self, key, values) -> NoReturn: ... + def add(self, _key, _value, **kw) -> NoReturn: ... + def add_header(self, _key, _value, **_kw) -> NoReturn: ... + def remove(self, key) -> NoReturn: ... + def extend(self, *args, **kwargs) -> NoReturn: ... + def update(self, *args, **kwargs) -> NoReturn: ... + def insert(self, pos, value) -> NoReturn: ... + def pop(self, key=None, default=...) -> NoReturn: ... + def popitem(self) -> NoReturn: ... + def setdefault(self, key, default) -> NoReturn: ... # type: ignore + def setlistdefault(self, key, default) -> NoReturn: ... + +class EnvironHeaders(ImmutableHeadersMixin, Headers): + environ: WSGIEnvironment + def __init__(self, environ: WSGIEnvironment) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__( # type: ignore + self, key: str, _get_mode: Literal[False] = False + ) -> str: ... + def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore + def copy(self) -> NoReturn: ... + +class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): + dicts: List[MultiDict[K, V]] + def __init__(self, dicts: Optional[Iterable[MultiDict[K, V]]]) -> None: ... + @classmethod + def fromkeys(cls, keys, value=None) -> NoReturn: ... + def __getitem__(self, key: K) -> V: ... + @overload # type: ignore + def get(self, key: K) -> Optional[V]: ... + @overload + def get(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... + @overload + def get( + self, key: K, default: Optional[T] = None, type: Callable[[V], T] = ... + ) -> Optional[T]: ... + @overload + def getlist(self, key: K) -> List[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... + def _keys_impl(self) -> Set[K]: ... + def keys(self) -> Set[K]: ... # type: ignore + def __iter__(self) -> Set[K]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore + def values(self) -> Iterator[V]: ... # type: ignore + def lists(self) -> Iterator[Tuple[K, List[V]]]: ... + def listvalues(self) -> Iterator[List[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> Dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... + def __contains__(self, key: K) -> bool: ... # type: ignore + def has_key(self, key: K) -> bool: ... + +class FileMultiDict(MultiDict[str, "FileStorage"]): + def add_file( + self, + name: str, + file: Union[FileStorage, str, BinaryIO], + filename: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: ... + +class ImmutableDict(ImmutableDictMixin[K, V], Dict[K, V]): + def copy(self) -> Dict[K, V]: ... + def __copy__(self) -> ImmutableDict[K, V]: ... + +class ImmutableMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): + def copy(self) -> MultiDict[K, V]: ... + def __copy__(self) -> ImmutableMultiDict[K, V]: ... + +class ImmutableOrderedMultiDict(ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V]): + def _iter_hashitems(self) -> Iterator[Tuple[int, Tuple[K, V]]]: ... + def copy(self) -> OrderedMultiDict[K, V]: ... + def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... + +class Accept(ImmutableList[Tuple[str, int]]): + provided: bool + def __init__( + self, values: Optional[Union[Accept, Iterable[Tuple[str, int]]]] = None + ) -> None: ... + def _specificity(self, value) -> Tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @overload # type: ignore + def __getitem__(self, key: str) -> int: ... + @overload + def __getitem__(self, key: int) -> Tuple[str, int]: ... + @overload + def __getitem__(self, key: slice) -> Iterable[Tuple[str, int]]: ... + def quality(self, key: str) -> int: ... + def __contains__(self, value: str) -> bool: ... # type: ignore + def index(self, key: str) -> int: ... # type: ignore + def find(self, key: str) -> int: ... + def values(self) -> Iterator[str]: ... + def to_header(self) -> str: ... + def _best_single_match(self, match: str) -> Optional[Tuple[str, int]]: ... + def best_match( + self, matches: Iterable[str], default: Optional[str] = None + ) -> Optional[str]: ... + @property + def best(self) -> str: ... + +def _normalize_mime(value: str) -> List[str]: ... + +class MIMEAccept(Accept): + def _specificity(self, value) -> Tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @property + def accept_html(self) -> bool: ... + @property + def accept_xhtml(self) -> bool: ... + @property + def accept_json(self) -> bool: ... + +def _normalize_lang(value: str) -> List[str]: ... + +class LanguageAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... + def best_match( + self, matches: Iterable[str], default: Optional[str] = None + ) -> Optional[str]: ... + +class CharsetAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... + +_CPT = TypeVar("_CPT", str, int, bool) +_OptCPT = Optional[_CPT] + +def cache_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... + +class _CacheControl(UpdateDictMixin[str, _OptCPT], Dict[str, _OptCPT]): + provided: bool + def __init__( + self, + values: Union[Mapping[str, _OptCPT], Iterable[Tuple[str, _OptCPT]]] = (), + on_update: Optional[Callable[[_CacheControl], None]] = None, + ) -> None: ... + @property + def no_cache(self) -> Optional[bool]: ... + @no_cache.setter + def no_cache(self, value: Optional[bool]) -> None: ... + @no_cache.deleter + def no_cache(self) -> None: ... + @property + def no_store(self) -> Optional[bool]: ... + @no_store.setter + def no_store(self, value: Optional[bool]) -> None: ... + @no_store.deleter + def no_store(self) -> None: ... + @property + def max_age(self) -> Optional[int]: ... + @max_age.setter + def max_age(self, value: Optional[int]) -> None: ... + @max_age.deleter + def max_age(self) -> None: ... + @property + def no_transform(self) -> Optional[bool]: ... + @no_transform.setter + def no_transform(self, value: Optional[bool]) -> None: ... + @no_transform.deleter + def no_transform(self) -> None: ... + def _get_cache_value(self, key: str, empty: Optional[T], type: Type[T]) -> T: ... + def _set_cache_value(self, key: str, value: Optional[T], type: Type[T]) -> None: ... + def _del_cache_value(self, key: str) -> None: ... + def to_header(self) -> str: ... + @staticmethod + def cache_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... + +class RequestCacheControl(ImmutableDictMixin[str, _OptCPT], _CacheControl): + @property + def max_stale(self) -> Optional[int]: ... + @max_stale.setter + def max_stale(self, value: Optional[int]) -> None: ... + @max_stale.deleter + def max_stale(self) -> None: ... + @property + def min_fresh(self) -> Optional[int]: ... + @min_fresh.setter + def min_fresh(self, value: Optional[int]) -> None: ... + @min_fresh.deleter + def min_fresh(self) -> None: ... + @property + def only_if_cached(self) -> Optional[bool]: ... + @only_if_cached.setter + def only_if_cached(self, value: Optional[bool]) -> None: ... + @only_if_cached.deleter + def only_if_cached(self) -> None: ... + +class ResponseCacheControl(_CacheControl): + @property + def public(self) -> Optional[bool]: ... + @public.setter + def public(self, value: Optional[bool]) -> None: ... + @public.deleter + def public(self) -> None: ... + @property + def private(self) -> Optional[bool]: ... + @private.setter + def private(self, value: Optional[bool]) -> None: ... + @private.deleter + def private(self) -> None: ... + @property + def must_revalidate(self) -> Optional[bool]: ... + @must_revalidate.setter + def must_revalidate(self, value: Optional[bool]) -> None: ... + @must_revalidate.deleter + def must_revalidate(self) -> None: ... + @property + def proxy_revalidate(self) -> Optional[bool]: ... + @proxy_revalidate.setter + def proxy_revalidate(self, value: Optional[bool]) -> None: ... + @proxy_revalidate.deleter + def proxy_revalidate(self) -> None: ... + @property + def s_maxage(self) -> Optional[int]: ... + @s_maxage.setter + def s_maxage(self, value: Optional[int]) -> None: ... + @s_maxage.deleter + def s_maxage(self) -> None: ... + @property + def immutable(self) -> Optional[bool]: ... + @immutable.setter + def immutable(self, value: Optional[bool]) -> None: ... + @immutable.deleter + def immutable(self) -> None: ... + +def csp_property(key: str) -> property: ... + +class ContentSecurityPolicy(UpdateDictMixin[str, str], Dict[str, str]): + @property + def base_uri(self) -> Optional[str]: ... + @base_uri.setter + def base_uri(self, value: Optional[str]) -> None: ... + @base_uri.deleter + def base_uri(self) -> None: ... + @property + def child_src(self) -> Optional[str]: ... + @child_src.setter + def child_src(self, value: Optional[str]) -> None: ... + @child_src.deleter + def child_src(self) -> None: ... + @property + def connect_src(self) -> Optional[str]: ... + @connect_src.setter + def connect_src(self, value: Optional[str]) -> None: ... + @connect_src.deleter + def connect_src(self) -> None: ... + @property + def default_src(self) -> Optional[str]: ... + @default_src.setter + def default_src(self, value: Optional[str]) -> None: ... + @default_src.deleter + def default_src(self) -> None: ... + @property + def font_src(self) -> Optional[str]: ... + @font_src.setter + def font_src(self, value: Optional[str]) -> None: ... + @font_src.deleter + def font_src(self) -> None: ... + @property + def form_action(self) -> Optional[str]: ... + @form_action.setter + def form_action(self, value: Optional[str]) -> None: ... + @form_action.deleter + def form_action(self) -> None: ... + @property + def frame_ancestors(self) -> Optional[str]: ... + @frame_ancestors.setter + def frame_ancestors(self, value: Optional[str]) -> None: ... + @frame_ancestors.deleter + def frame_ancestors(self) -> None: ... + @property + def frame_src(self) -> Optional[str]: ... + @frame_src.setter + def frame_src(self, value: Optional[str]) -> None: ... + @frame_src.deleter + def frame_src(self) -> None: ... + @property + def img_src(self) -> Optional[str]: ... + @img_src.setter + def img_src(self, value: Optional[str]) -> None: ... + @img_src.deleter + def img_src(self) -> None: ... + @property + def manifest_src(self) -> Optional[str]: ... + @manifest_src.setter + def manifest_src(self, value: Optional[str]) -> None: ... + @manifest_src.deleter + def manifest_src(self) -> None: ... + @property + def media_src(self) -> Optional[str]: ... + @media_src.setter + def media_src(self, value: Optional[str]) -> None: ... + @media_src.deleter + def media_src(self) -> None: ... + @property + def navigate_to(self) -> Optional[str]: ... + @navigate_to.setter + def navigate_to(self, value: Optional[str]) -> None: ... + @navigate_to.deleter + def navigate_to(self) -> None: ... + @property + def object_src(self) -> Optional[str]: ... + @object_src.setter + def object_src(self, value: Optional[str]) -> None: ... + @object_src.deleter + def object_src(self) -> None: ... + @property + def prefetch_src(self) -> Optional[str]: ... + @prefetch_src.setter + def prefetch_src(self, value: Optional[str]) -> None: ... + @prefetch_src.deleter + def prefetch_src(self) -> None: ... + @property + def plugin_types(self) -> Optional[str]: ... + @plugin_types.setter + def plugin_types(self, value: Optional[str]) -> None: ... + @plugin_types.deleter + def plugin_types(self) -> None: ... + @property + def report_to(self) -> Optional[str]: ... + @report_to.setter + def report_to(self, value: Optional[str]) -> None: ... + @report_to.deleter + def report_to(self) -> None: ... + @property + def report_uri(self) -> Optional[str]: ... + @report_uri.setter + def report_uri(self, value: Optional[str]) -> None: ... + @report_uri.deleter + def report_uri(self) -> None: ... + @property + def sandbox(self) -> Optional[str]: ... + @sandbox.setter + def sandbox(self, value: Optional[str]) -> None: ... + @sandbox.deleter + def sandbox(self) -> None: ... + @property + def script_src(self) -> Optional[str]: ... + @script_src.setter + def script_src(self, value: Optional[str]) -> None: ... + @script_src.deleter + def script_src(self) -> None: ... + @property + def script_src_attr(self) -> Optional[str]: ... + @script_src_attr.setter + def script_src_attr(self, value: Optional[str]) -> None: ... + @script_src_attr.deleter + def script_src_attr(self) -> None: ... + @property + def script_src_elem(self) -> Optional[str]: ... + @script_src_elem.setter + def script_src_elem(self, value: Optional[str]) -> None: ... + @script_src_elem.deleter + def script_src_elem(self) -> None: ... + @property + def style_src(self) -> Optional[str]: ... + @style_src.setter + def style_src(self, value: Optional[str]) -> None: ... + @style_src.deleter + def style_src(self) -> None: ... + @property + def style_src_attr(self) -> Optional[str]: ... + @style_src_attr.setter + def style_src_attr(self, value: Optional[str]) -> None: ... + @style_src_attr.deleter + def style_src_attr(self) -> None: ... + @property + def style_src_elem(self) -> Optional[str]: ... + @style_src_elem.setter + def style_src_elem(self, value: Optional[str]) -> None: ... + @style_src_elem.deleter + def style_src_elem(self) -> None: ... + @property + def worker_src(self) -> Optional[str]: ... + @worker_src.setter + def worker_src(self, value: Optional[str]) -> None: ... + @worker_src.deleter + def worker_src(self) -> None: ... + provided: bool + def __init__( + self, + values: Union[Mapping[str, str], Iterable[Tuple[str, str]]] = (), + on_update: Optional[Callable[[ContentSecurityPolicy], None]] = None, + ) -> None: ... + def _get_value(self, key: str) -> Optional[str]: ... + def _set_value(self, key: str, value: str) -> None: ... + def _del_value(self, key: str) -> None: ... + def to_header(self) -> str: ... + +class CallbackDict(UpdateDictMixin[K, V], Dict[K, V]): + def __init__( + self, + initial: Optional[Union[Mapping[K, V], Iterable[Tuple[K, V]]]] = None, + on_update: Optional[Callable[[CallbackDict], None]] = None, + ) -> None: ... + +class HeaderSet(Set[str]): + _headers: List[str] + _set: Set[str] + on_update: Optional[Callable[[HeaderSet], None]] + def __init__( + self, + headers: Optional[Iterable[str]] = None, + on_update: Optional[Callable[[HeaderSet], None]] = None, + ) -> None: ... + def add(self, header: str) -> None: ... + def remove(self, header: str) -> None: ... + def update(self, iterable: Iterable[str]) -> None: ... # type: ignore + def discard(self, header: str) -> None: ... + def find(self, header: str) -> int: ... + def index(self, header: str) -> int: ... + def clear(self) -> None: ... + def as_set(self, preserve_casing: bool = False) -> Set[str]: ... + def to_header(self) -> str: ... + def __getitem__(self, idx: int) -> str: ... + def __delitem__(self, idx: int) -> None: ... + def __setitem__(self, idx: int, value: str) -> None: ... + def __contains__(self, header: str) -> bool: ... # type: ignore + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... + +class ETags(Collection[str]): + _strong: FrozenSet[str] + _weak: FrozenSet[str] + star_tag: bool + def __init__( + self, + strong_etags: Optional[Iterable[str]] = None, + weak_etags: Optional[Iterable[str]] = None, + star_tag: bool = False, + ) -> None: ... + def as_set(self, include_weak: bool = False) -> Set[str]: ... + def is_weak(self, etag: str) -> bool: ... + def is_strong(self, etag: str) -> bool: ... + def contains_weak(self, etag: str) -> bool: ... + def contains(self, etag: str) -> bool: ... + def contains_raw(self, etag: str) -> bool: ... + def to_header(self) -> str: ... + def __call__( + self, + etag: Optional[str] = None, + data: Optional[bytes] = None, + include_weak: bool = False, + ) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... + def __contains__(self, item: str) -> bool: ... # type: ignore + +class IfRange: + etag: Optional[str] + date: Optional[datetime] + def __init__( + self, etag: Optional[str] = None, date: Optional[datetime] = None + ) -> None: ... + def to_header(self) -> str: ... + +class Range: + units: str + ranges: List[Tuple[int, Optional[int]]] + def __init__(self, units: str, ranges: List[Tuple[int, Optional[int]]]) -> None: ... + def range_for_length(self, length: Optional[int]) -> Optional[Tuple[int, int]]: ... + def make_content_range(self, length: Optional[int]) -> Optional[ContentRange]: ... + def to_header(self) -> str: ... + def to_content_range_header(self, length: Optional[int]) -> Optional[str]: ... + +def _callback_property(name: str) -> property: ... + +class ContentRange: + on_update: Optional[Callable[[ContentRange], None]] + def __init__( + self, + units: Optional[str], + start: Optional[int], + stop: Optional[int], + length: Optional[int] = None, + on_update: Optional[Callable[[ContentRange], None]] = None, + ) -> None: ... + @property + def units(self) -> Optional[str]: ... + @units.setter + def units(self, value: Optional[str]) -> None: ... + @property + def start(self) -> Optional[int]: ... + @start.setter + def start(self, value: Optional[int]) -> None: ... + @property + def stop(self) -> Optional[int]: ... + @stop.setter + def stop(self, value: Optional[int]) -> None: ... + @property + def length(self) -> Optional[int]: ... + @length.setter + def length(self, value: Optional[int]) -> None: ... + def set( + self, + start: Optional[int], + stop: Optional[int], + length: Optional[int] = None, + units: Optional[str] = "bytes", + ) -> None: ... + def unset(self) -> None: ... + def to_header(self) -> str: ... + +class Authorization(ImmutableDictMixin[str, str], Dict[str, str]): + type: str + def __init__( + self, + auth_type: str, + data: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, + ) -> None: ... + @property + def username(self) -> Optional[str]: ... + @property + def password(self) -> Optional[str]: ... + @property + def realm(self) -> Optional[str]: ... + @property + def nonce(self) -> Optional[str]: ... + @property + def uri(self) -> Optional[str]: ... + @property + def nc(self) -> Optional[str]: ... + @property + def cnonce(self) -> Optional[str]: ... + @property + def response(self) -> Optional[str]: ... + @property + def opaque(self) -> Optional[str]: ... + @property + def qop(self) -> Optional[str]: ... + def to_header(self) -> str: ... + +def auth_property(name: str, doc: Optional[str] = None) -> property: ... +def _set_property(name: str, doc: Optional[str] = None) -> property: ... + +class WWWAuthenticate(UpdateDictMixin[str, str], Dict[str, str]): + _require_quoting: FrozenSet[str] + def __init__( + self, + auth_type: Optional[str] = None, + values: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, + on_update: Optional[Callable[[WWWAuthenticate], None]] = None, + ) -> None: ... + def set_basic(self, realm: str = ...) -> None: ... + def set_digest( + self, + realm: str, + nonce: str, + qop: Iterable[str] = ("auth",), + opaque: Optional[str] = None, + algorithm: Optional[str] = None, + stale: bool = False, + ) -> None: ... + def to_header(self) -> str: ... + @property + def type(self) -> Optional[str]: ... + @type.setter + def type(self, value: Optional[str]) -> None: ... + @property + def realm(self) -> Optional[str]: ... + @realm.setter + def realm(self, value: Optional[str]) -> None: ... + @property + def domain(self) -> HeaderSet: ... + @property + def nonce(self) -> Optional[str]: ... + @nonce.setter + def nonce(self, value: Optional[str]) -> None: ... + @property + def opaque(self) -> Optional[str]: ... + @opaque.setter + def opaque(self, value: Optional[str]) -> None: ... + @property + def algorithm(self) -> Optional[str]: ... + @algorithm.setter + def algorithm(self, value: Optional[str]) -> None: ... + @property + def qop(self) -> HeaderSet: ... + @property + def stale(self) -> Optional[bool]: ... + @stale.setter + def stale(self, value: Optional[bool]) -> None: ... + @staticmethod + def auth_property(name: str, doc: Optional[str] = None) -> property: ... + +class FileStorage: + name: Optional[str] + stream: BinaryIO + filename: Optional[str] + headers: Headers + _parsed_content_type: Tuple[str, Dict[str, str]] + def __init__( + self, + stream: Optional[BinaryIO] = None, + filename: Optional[str] = None, + name: Optional[str] = None, + content_type: Optional[str] = None, + content_length: Optional[int] = None, + headers: Optional[Headers] = None, + ) -> None: ... + def _parse_content_type(self) -> None: ... + @property + def content_type(self) -> str: ... + @property + def content_length(self) -> int: ... + @property + def mimetype(self) -> str: ... + @property + def mimetype_params(self) -> Dict[str, str]: ... + def save( + self, dst: Union[str, PathLike, BinaryIO], buffer_size: int = ... + ) -> None: ... + def close(self) -> None: ... + def __iter__(self) -> Iterator[bytes]: ... diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index 0e10bb0f2..df0903764 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -7,18 +7,11 @@ import re import sys import time +import typing as t import uuid -from io import BytesIO from itertools import chain from os.path import basename from os.path import join -from typing import Any -from typing import Callable -from typing import Dict -from typing import Hashable -from typing import Iterator -from typing import Tuple -from typing import Union from .._internal import _log from ..http import parse_cookie @@ -28,22 +21,24 @@ from .console import Console from .tbtools import get_current_traceback from .tbtools import render_console_html -from werkzeug.types import WSGIEnvironment + +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment # A week PIN_TIME = 60 * 60 * 24 * 7 -def hash_pin(pin): - if isinstance(pin, str): - pin = pin.encode("utf-8", "replace") - return hashlib.md5(pin + b"shittysalt").hexdigest()[:12] +def hash_pin(pin: str) -> str: + return hashlib.md5(f"{pin} added salt".encode("utf-8", "replace")).hexdigest()[:12] -_machine_id = None +_machine_id: t.Optional[str] = None -def get_machine_id() -> Any: +def get_machine_id() -> str: global _machine_id if _machine_id is not None: @@ -83,7 +78,7 @@ def _generate(): from subprocess import Popen, PIPE dump = Popen( - ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE, + ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE ).communicate()[0] match = re.search(b'"serial-number" = <([^>]+)', dump) @@ -123,12 +118,14 @@ class _ConsoleFrame: standalone console. """ - def __init__(self, namespace): + def __init__(self, namespace: t.Dict[str, t.Any]): self.console = Console(namespace) self.id = 0 -def get_pin_and_cookie_name(app: Callable) -> Tuple[str, str]: +def get_pin_and_cookie_name( + app: "WSGIApplication", +) -> t.Union[t.Tuple[str, str], t.Tuple[None, None]]: """Given an application object this returns a semi-stable 9 digit pin code and a random key. The hope is that this is stable between restarts to not make debugging particularly frustrating. If the pin @@ -152,7 +149,8 @@ def get_pin_and_cookie_name(app: Callable) -> Tuple[str, str]: else: num = pin - modname = getattr(app, "__module__", app.__class__.__module__) + modname = getattr(app, "__module__", t.cast(object, app).__class__.__module__) + username: t.Optional[str] try: # getuser imports the pwd module, which does not exist in Google @@ -238,13 +236,16 @@ class DebuggedApplication: :param pin_logging: enables the logging of the pin system. """ + _pin: str + _pin_cookie: str + def __init__( self, - app: Callable, + app: "WSGIApplication", evalex: bool = False, request_key: str = "werkzeug.request", console_path: str = "/console", - console_init_func: None = None, + console_init_func: t.Optional[t.Callable[[], t.Dict[str, t.Any]]] = None, show_hidden_frames: bool = False, pin_security: bool = True, pin_logging: bool = True, @@ -253,8 +254,8 @@ def __init__( console_init_func = None self.app = app self.evalex = evalex - self.frames: Dict[Hashable, Any] = {} - self.tracebacks: Dict[Hashable, Any] = {} + self.frames: t.Dict[t.Hashable, t.Any] = {} + self.tracebacks: t.Dict[t.Hashable, t.Any] = {} self.request_key = request_key self.console_path = console_path self.console_init_func = console_init_func @@ -268,44 +269,44 @@ def __init__( if os.environ.get("WERKZEUG_RUN_MAIN") == "true" and pin_logging: _log("warning", " * Debugger is active!") if self.pin is None: - _log( - "warning", " * Debugger PIN disabled. DEBUGGER UNSECURED!", - ) + _log("warning", " * Debugger PIN disabled. DEBUGGER UNSECURED!") else: _log("info", " * Debugger PIN: %s", self.pin) else: self.pin = None @property - def pin(self): + def pin(self) -> t.Optional[str]: if not hasattr(self, "_pin"): - self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) + pin_cookie = get_pin_and_cookie_name(self.app) + self._pin, self._pin_cookie = pin_cookie # type: ignore return self._pin @pin.setter - def pin(self, value): + def pin(self, value: str) -> None: self._pin = value @property def pin_cookie_name(self) -> str: """The name of the pin cookie.""" if not hasattr(self, "_pin_cookie"): - self._pin, self._pin_cookie = get_pin_and_cookie_name(self.app) + pin_cookie = get_pin_and_cookie_name(self.app) + self._pin, self._pin_cookie = pin_cookie # type: ignore return self._pin_cookie def debug_application( - self, environ: WSGIEnvironment, start_response: Callable - ) -> Iterator[bytes]: + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterator[bytes]: """Run the application and conserve the traceback frames.""" app_iter = None try: app_iter = self.app(environ, start_response) yield from app_iter if hasattr(app_iter, "close"): - app_iter.close() + app_iter.close() # type: ignore except Exception: if hasattr(app_iter, "close"): - app_iter.close() + app_iter.close() # type: ignore traceback = get_current_traceback( skip=1, show_hidden_frames=self.show_hidden_frames, @@ -338,7 +339,7 @@ def debug_application( else: is_trusted = bool(self.check_pin_trust(environ)) yield traceback.render_full( - evalex=self.evalex, evalex_trusted=is_trusted, secret=self.secret, + evalex=self.evalex, evalex_trusted=is_trusted, secret=self.secret ).encode("utf-8", "replace") traceback.log(environ["wsgi.errors"]) @@ -347,7 +348,7 @@ def execute_command(self, request, command, frame): """Execute a command in a console.""" return Response(frame.console.eval(command), mimetype="text/html") - def display_console(self, request): + def display_console(self, request: Request) -> Response: """Display a standalone shell.""" if 0 not in self.frames: if self.console_init_func is None: @@ -362,7 +363,7 @@ def display_console(self, request): mimetype="text/html", ) - def get_resource(self, request, filename): + def get_resource(self, request: Request, filename: str) -> Response: """Return a static resource from the shared folder.""" filename = join("shared", basename(filename)) try: @@ -374,7 +375,7 @@ def get_resource(self, request, filename): return Response(data, mimetype=mimetype) return Response("Not Found", status=404) - def check_pin_trust(self, environ: WSGIEnvironment) -> bool: + def check_pin_trust(self, environ: "WSGIEnvironment") -> t.Optional[bool]: """Checks if the request passed the pin test. This returns `True` if the request is trusted on a pin/cookie basis and returns `False` if not. Additionally if the cookie's stored pin hash is wrong it will return @@ -382,7 +383,7 @@ def check_pin_trust(self, environ: WSGIEnvironment) -> bool: """ if self.pin is None: return True - val = parse_cookie(environ).get(self.pin_cookie_name) # type: ignore + val = parse_cookie(environ).get(self.pin_cookie_name) if not val or "|" not in val: return False ts, pin_hash = val.split("|", 1) @@ -392,15 +393,16 @@ def check_pin_trust(self, environ: WSGIEnvironment) -> bool: return None return (time.time() - PIN_TIME) < int(ts) - def _fail_pin_auth(self): + def _fail_pin_auth(self) -> None: time.sleep(5.0 if self._failed_pin_auth > 5 else 0.5) self._failed_pin_auth += 1 - def pin_auth(self, request): + def pin_auth(self, request: Request) -> Response: """Authenticates with the pin.""" exhausted = False auth = False trust = self.check_pin_trust(request.environ) + pin = t.cast(str, self.pin) # If the trust return value is `None` it means that the cookie is # set but the stored pin hash value is bad. This means that the @@ -422,8 +424,9 @@ def pin_auth(self, request): # Otherwise go through pin based authentication else: - entered_pin = request.args.get("pin") - if entered_pin.strip().replace("-", "") == self.pin.replace("-", ""): + entered_pin = request.args["pin"] + + if entered_pin.strip().replace("-", "") == pin.replace("-", ""): self._failed_pin_auth = 0 auth = True else: @@ -436,7 +439,7 @@ def pin_auth(self, request): if auth: rv.set_cookie( self.pin_cookie_name, - f"{int(time.time())}|{hash_pin(self.pin)}", + f"{int(time.time())}|{hash_pin(pin)}", httponly=True, samesite="None", ) @@ -444,18 +447,18 @@ def pin_auth(self, request): rv.delete_cookie(self.pin_cookie_name) return rv - def log_pin_request(self): + def log_pin_request(self) -> Response: """Log the pin if needed.""" if self.pin_logging and self.pin is not None: _log( - "info", " * To enable the debugger you need to enter the security pin:", + "info", " * To enable the debugger you need to enter the security pin:" ) _log("info", " * Debugger pin code: %s", self.pin) return Response("") def __call__( - self, environ: WSGIEnvironment, start_response: Callable, - ) -> Iterator[Any]: + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: """Dispatch the requests.""" # important: don't ever access a function here that reads the incoming # form data! Otherwise the application won't have access to that data @@ -468,11 +471,11 @@ def __call__( secret = request.args.get("s") frame = self.frames.get(request.args.get("frm", type=int)) if cmd == "resource" and arg: - response = self.get_resource(request, arg) + response = self.get_resource(request, arg) # type: ignore elif cmd == "pinauth" and secret == self.secret: - response = self.pin_auth(request) + response = self.pin_auth(request) # type: ignore elif cmd == "printpin" and secret == self.secret: - response = self.log_pin_request() + response = self.log_pin_request() # type: ignore elif ( self.evalex and cmd is not None @@ -486,5 +489,5 @@ def __call__( and self.console_path is not None and request.path == self.console_path ): - response = self.display_console(request) + response = self.display_console(request) # type: ignore return response(environ, start_response) diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 31391cf27..c2382c43a 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -1,44 +1,39 @@ import code import sys +import typing as t from html import escape from types import CodeType -from typing import Any -from typing import Callable -from typing import Dict -from typing import Hashable -from typing import List -from typing import Optional -from typing import Union from ..local import Local from .repr import debug_repr from .repr import dump from .repr import helper +if t.TYPE_CHECKING: + import codeop # noqa: F401 + _local = Local() class HTMLStringO: """A StringO version that HTML escapes on write.""" - _buffer: List[str] - def __init__(self) -> None: - self._buffer = [] + self._buffer: t.List[str] = [] - def isatty(self): + def isatty(self) -> bool: return False - def close(self): + def close(self) -> None: pass - def flush(self): + def flush(self) -> None: pass - def seek(self, n, mode: int = 0): + def seek(self, n: int, mode: int = 0) -> None: pass - def readline(self): + def readline(self) -> str: if len(self._buffer) == 0: return "" ret = self._buffer[0] @@ -68,7 +63,7 @@ class ThreadedStream: @staticmethod def push() -> None: if not isinstance(sys.stdout, ThreadedStream): - sys.stdout = ThreadedStream() # type: ignore + sys.stdout = t.cast(t.TextIO, ThreadedStream()) _local.stream = HTMLStringO() @staticmethod @@ -80,7 +75,7 @@ def fetch() -> str: return stream.reset() @staticmethod - def displayhook(obj): + def displayhook(obj: object) -> None: try: stream = _local.stream except AttributeError: @@ -91,22 +86,20 @@ def displayhook(obj): _local._current_ipy.locals["_"] = obj stream._write(debug_repr(obj)) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: t.Any) -> None: raise AttributeError(f"read only attribute {name}") - def __dir__(self): + def __dir__(self) -> t.List[str]: return dir(sys.__stdout__) - def __getattribute__(self, name: str) -> Union[Callable, List[str]]: - if name == "__members__": - return dir(sys.__stdout__) + def __getattribute__(self, name: str) -> t.Any: try: stream = _local.stream except AttributeError: stream = sys.__stdout__ return getattr(stream, name) - def __repr__(self): + def __repr__(self) -> str: return repr(sys.__stdout__) @@ -117,59 +110,53 @@ def __repr__(self): class _ConsoleLoader: def __init__(self) -> None: - self._storage: Dict[Hashable, Any] = {} + self._storage: t.Dict[int, str] = {} - def register(self, code, source): + def register(self, code: CodeType, source: str) -> None: self._storage[id(code)] = source # register code objects of wrapped functions too. for var in code.co_consts: if isinstance(var, CodeType): self._storage[id(var)] = source - def get_source_by_code(self, code): + def get_source_by_code(self, code: CodeType) -> t.Optional[str]: try: return self._storage[id(code)] except KeyError: - pass - - -def _wrap_compiler(console: "_InteractiveConsole") -> None: - compile = console.compile # type: ignore - - def func(source, filename, symbol): - code = compile(source, filename, symbol) - console.loader.register(code, source) - return code - - console.compile = func # type: ignore + return None class _InteractiveConsole(code.InteractiveInterpreter): - globals: Any - more: Any - buffer: Any - - def __init__(self, globals: Dict[Any, Any], locals: Dict[Any, Any]) -> None: - _locals = dict(globals) - _locals.update(locals) - locals = _locals - locals["dump"] = dump - locals["help"] = helper - locals["__loader__"] = self.loader = _ConsoleLoader() - code.InteractiveInterpreter.__init__(self, locals) + locals: t.Dict[str, t.Any] + + def __init__(self, globals: t.Dict[str, t.Any], locals: t.Dict[str, t.Any]) -> None: + self.loader = _ConsoleLoader() + locals = { + **globals, + **locals, + "dump": dump, + "help": helper, + "__loader__": self.loader, + } + super().__init__(locals) + original_compile = self.compile + + def compile(source: str, filename: str, symbol: str) -> CodeType: + code = original_compile(source, filename, symbol) + self.loader.register(code, source) + return code + + self.compile = compile self.more = False - self.buffer = [] - _wrap_compiler(self) + self.buffer: t.List[str] = [] - def runsource(self, source: str, **kwargs: Any) -> str: # type: ignore + def runsource(self, source: str, **kwargs: t.Any) -> str: # type: ignore source = f"{source.rstrip()}\n" ThreadedStream.push() prompt = "... " if self.more else ">>> " try: source_to_eval = "".join(self.buffer + [source]) - if code.InteractiveInterpreter.runsource( - self, source_to_eval, "", "single" - ): + if super().runsource(source_to_eval, "", "single"): self.more = True self.buffer.append(source) else: @@ -179,25 +166,25 @@ def runsource(self, source: str, **kwargs: Any) -> str: # type: ignore output = ThreadedStream.fetch() return prompt + escape(source) + output - def runcode(self, code): + def runcode(self, code: CodeType) -> None: try: exec(code, self.locals) except Exception: self.showtraceback() - def showtraceback(self): + def showtraceback(self) -> None: from .tbtools import get_current_traceback tb = get_current_traceback(skip=1) - sys.stdout._write(tb.render_summary()) + sys.stdout._write(tb.render_summary()) # type: ignore - def showsyntaxerror(self, filename: Optional[Any] = None): + def showsyntaxerror(self, filename: t.Optional[str] = None) -> None: from .tbtools import get_current_traceback tb = get_current_traceback(skip=4) sys.stdout._write(tb.render_summary()) # type: ignore - def write(self, data): + def write(self, data: str) -> None: sys.stdout.write(data) @@ -205,7 +192,9 @@ class Console: """An interactive console.""" def __init__( - self, globals: Optional[Any] = None, locals: Optional[Any] = None + self, + globals: t.Optional[t.Dict[str, t.Any]] = None, + locals: t.Optional[t.Dict[str, t.Any]] = None, ) -> None: if locals is None: locals = {} @@ -213,7 +202,7 @@ def __init__( globals = {} self._ipy = _InteractiveConsole(globals, locals) - def eval(self, code): + def eval(self, code: str) -> str: _local._current_ipy = self._ipy old_sys_stdout = sys.stdout try: diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index 4c74110dd..7d847b039 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -7,31 +7,22 @@ import codecs import re import sys +import typing as t from collections import deque from html import escape from traceback import format_exception_only -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Pattern -from typing import Tuple -from typing import Type -from typing import Union - -missing: Any = object() -_paragraph_re = re.compile(r"(?:\r\n|\r|\n){2,}") -RegexType: Any = type(_paragraph_re) +missing = object() +_paragraph_re = re.compile(r"(?:\r\n|\r|\n){2,}") +RegexType = type(_paragraph_re) -HELP_HTML: Any = """\ +HELP_HTML = """\

%(title)s

%(text)s
\ """ -OBJECT_DUMP_HTML: Any = """\ +OBJECT_DUMP_HTML = """\

%(title)s

%(repr)s @@ -40,12 +31,12 @@ """ -def debug_repr(obj: Any) -> str: +def debug_repr(obj: object) -> str: """Creates a debug repr of an object as HTML string.""" return DebugReprGenerator().repr(obj) -def dump(obj: Union[List[int], object] = missing) -> None: +def dump(obj: object = missing) -> None: """Print the object details to stdout._write (for the interactive console of the web debugger. """ @@ -62,10 +53,10 @@ class _Helper: debugger only because it requires a patched sys.stdout. """ - def __repr__(self): + def __repr__(self) -> str: return "Type help(object) for help about object." - def __call__(self, topic: Optional[Any] = None) -> None: + def __call__(self, topic: t.Optional[t.Any] = None) -> None: if topic is None: sys.stdout._write(f"{self!r}") # type: ignore return @@ -85,10 +76,12 @@ def __call__(self, topic: Optional[Any] = None) -> None: sys.stdout._write(HELP_HTML % {"title": title, "text": text}) # type: ignore -helper: Any = _Helper() +helper = _Helper() -def _add_subclass_info(inner: str, obj: Any, base: Any) -> str: +def _add_subclass_info( + inner: str, obj: object, base: t.Union[t.Type, t.Tuple[t.Type, ...]] +) -> str: if isinstance(base, tuple): for base in base: if type(obj) is base: @@ -101,48 +94,48 @@ def _add_subclass_info(inner: str, obj: Any, base: Any) -> str: return f"{module}{type(obj).__name__}({inner})" +def _sequence_repr_maker( + left: str, right: str, base: t.Type, limit: int = 8 +) -> t.Callable[["DebugReprGenerator", t.Iterable, bool], str]: + def proxy(self: "DebugReprGenerator", obj: t.Iterable, recursive: bool) -> str: + if recursive: + return _add_subclass_info(f"{left}...{right}", obj, base) + buf = [left] + have_extended_section = False + for idx, item in enumerate(obj): + if idx: + buf.append(", ") + if idx == limit: + buf.append('') + have_extended_section = True + buf.append(self.repr(item)) + if have_extended_section: + buf.append("") + buf.append(right) + return _add_subclass_info("".join(buf), obj, base) + + return proxy + + class DebugReprGenerator: def __init__(self) -> None: - self._stack: List[Any] = [] - - def _sequence_repr_maker( # type: ignore - left: str, right: str, base=object(), limit=8 # noqa: B008, B902 - ): - def proxy(self, obj, recursive): - if recursive: - return _add_subclass_info(f"{left}...{right}", obj, base) - buf = [left] - have_extended_section = False - for idx, item in enumerate(obj): - if idx: - buf.append(", ") - if idx == limit: - buf.append('') - have_extended_section = True - buf.append(self.repr(item)) - if have_extended_section: - buf.append("") - buf.append(right) - return _add_subclass_info("".join(buf), obj, base) - - return proxy - - list_repr: Any = _sequence_repr_maker("[", "]", list) - tuple_repr: Any = _sequence_repr_maker("(", ")", tuple) - set_repr: Any = _sequence_repr_maker("set([", "])", set) - frozenset_repr: Any = _sequence_repr_maker("frozenset([", "])", frozenset) - deque_repr: Any = _sequence_repr_maker( + self._stack: t.List[t.Any] = [] + + list_repr = _sequence_repr_maker("[", "]", list) + tuple_repr = _sequence_repr_maker("(", ")", tuple) + set_repr = _sequence_repr_maker("set([", "])", set) + frozenset_repr = _sequence_repr_maker("frozenset([", "])", frozenset) + deque_repr = _sequence_repr_maker( 'collections.deque([', "])", deque ) - del _sequence_repr_maker - def regex_repr(self, obj: Pattern) -> str: + def regex_repr(self, obj: t.Pattern) -> str: pattern = repr(obj.pattern) pattern = codecs.decode(pattern, "unicode-escape", "ignore") # type: ignore pattern = f"r{pattern}" return f're.compile({pattern})' - def string_repr(self, obj: Union[str, bytes], limit: int = 70) -> str: + def string_repr(self, obj: t.Union[str, bytes], limit: int = 70) -> str: buf = [''] r = repr(obj) @@ -171,7 +164,7 @@ def string_repr(self, obj: Union[str, bytes], limit: int = 70) -> str: def dict_repr( self, - d: Union[Dict[int, None], Dict[str, int], Dict[Union[str, int], int]], + d: t.Union[t.Dict[int, None], t.Dict[str, int], t.Dict[t.Union[str, int], int]], recursive: bool, limit: int = 5, ) -> str: @@ -195,12 +188,12 @@ def dict_repr( return _add_subclass_info("".join(buf), d, dict) def object_repr( - self, obj: Optional[Union[Type[dict], Callable, Type[list]]] + self, obj: t.Optional[t.Union[t.Type[dict], t.Callable, t.Type[list]]] ) -> str: r = repr(obj) return f'{escape(r)}' - def dispatch_repr(self, obj: Any, recursive: bool) -> str: + def dispatch_repr(self, obj: t.Any, recursive: bool) -> str: if obj is helper: return f'{helper!r}' if isinstance(obj, (int, float, complex)): @@ -233,7 +226,7 @@ def fallback_repr(self) -> str: f"<broken repr ({escape(info.strip())})>" ) - def repr(self, obj) -> str: + def repr(self, obj: object) -> str: recursive = False for item in self._stack: if item is obj: @@ -250,7 +243,8 @@ def repr(self, obj) -> str: def dump_object(self, obj: object) -> str: repr = None - items: Optional[List[Tuple[str, str]]] = None + items: t.Optional[t.List[t.Tuple[str, str]]] = None + if isinstance(obj, dict): title = "Contents of" items = [] @@ -271,12 +265,12 @@ def dump_object(self, obj: object) -> str: title += f" {object.__repr__(obj)[1:-1]}" return self.render_object_dump(items, title, repr) - def dump_locals(self, d: Dict[str, int]) -> str: + def dump_locals(self, d: t.Dict[str, t.Any]) -> str: items = [(key, self.repr(value)) for key, value in d.items()] return self.render_object_dump(items, "Local variables in frame") def render_object_dump( - self, items: List[Tuple[str, str]], title: str, repr: Optional[Any] = None, + self, items: t.List[t.Tuple[str, str]], title: str, repr: t.Optional[str] = None ) -> str: html_items = [] for key, value in items: diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index 6342ef674..ec9e7c6b1 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -5,14 +5,11 @@ import sys import sysconfig import traceback +import typing as t from html import escape from tokenize import TokenError -from typing import Any -from typing import List -from typing import Optional -from typing import TextIO -from typing import Tuple -from typing import Union +from types import CodeType +from types import TracebackType from .._internal import _to_str from ..filesystem import get_filesystem_encoding @@ -22,13 +19,6 @@ _coding_re = re.compile(br"coding[:=]\s*([-\w.]+)") _line_re = re.compile(br"^(.*?)$", re.MULTILINE) _funcdef_re = re.compile(r"^(\s*def\s)|(.*(? str: return CONSOLE_HTML % { "evalex": "true", "evalex_trusted": "true" if evalex_trusted else "false", @@ -172,8 +162,16 @@ def get_current_traceback( system exit or others. This behavior can be disabled by passing `False` to the function as first parameter. """ - exc_type, exc_value, tb = sys.exc_info() - if ignore_system_exceptions and exc_type in system_exceptions: + info = t.cast( + t.Tuple[t.Type[BaseException], BaseException, TracebackType], sys.exc_info() + ) + exc_type, exc_value, tb = info + + if ignore_system_exceptions and exc_type in { + SystemExit, + KeyboardInterrupt, + GeneratorExit, + }: raise for _ in range(skip): if tb.tb_next is None: @@ -197,7 +195,7 @@ def __init__(self, lineno: int, code: str) -> None: self.current = False @property - def classes(self): + def classes(self) -> t.List[str]: rv = ["line"] if self.in_frame: rv.append("in-frame") @@ -205,7 +203,7 @@ def classes(self): rv.append("current") return rv - def render(self): + def render(self) -> str: return SOURCE_LINE_HTML % { "classes": " ".join(self.classes), "lineno": self.lineno, @@ -216,7 +214,12 @@ def render(self): class Traceback: """Wraps a traceback.""" - def __init__(self, exc_type, exc_value, tb): + def __init__( + self, + exc_type: t.Type[BaseException], + exc_value: BaseException, + tb: TracebackType, + ) -> None: self.exc_type = exc_type self.exc_value = exc_value self.tb = tb @@ -231,11 +234,11 @@ def __init__(self, exc_type, exc_value, tb): while True: self.groups.append(Group(exc_type, exc_value, tb)) memo.add(id(exc_value)) - exc_value = exc_value.__cause__ or exc_value.__context__ + exc_value = exc_value.__cause__ or exc_value.__context__ # type: ignore if exc_value is None or id(exc_value) in memo: break exc_type = type(exc_value) - tb = exc_value.__traceback__ + tb = exc_value.__traceback__ # type: ignore self.groups.reverse() self.frames = [frame for group in self.groups for frame in group.frames] @@ -256,7 +259,7 @@ def exception(self) -> str: """String representation of the final exception.""" return self.groups[-1].exception - def log(self, logfile: Optional[Union[TextIO]] = None) -> None: + def log(self, logfile: t.Optional[t.TextIO] = None) -> None: """Log the ASCII traceback into a file object.""" if logfile is None: logfile = sys.stderr @@ -296,7 +299,7 @@ def render_summary(self, include_title: bool = True) -> str: def render_full( self, evalex: bool = False, - secret: Optional[str] = None, + secret: t.Optional[str] = None, evalex_trusted: bool = True, ) -> str: """Render the Full HTML page with the traceback info.""" @@ -316,7 +319,7 @@ def render_full( } @cached_property - def plaintext(self): + def plaintext(self) -> str: return "\n".join([group.render_text() for group in self.groups]) @property @@ -330,7 +333,12 @@ class Group: exception groups. """ - def __init__(self, exc_type, exc_value, tb): + def __init__( + self, + exc_type: t.Type[BaseException], + exc_value: BaseException, + tb: TracebackType, + ) -> None: self.exc_type = exc_type self.exc_value = exc_value self.info = None @@ -346,10 +354,10 @@ def __init__(self, exc_type, exc_value, tb): self.frames = [] while tb is not None: self.frames.append(Frame(exc_type, exc_value, tb)) - tb = tb.tb_next + tb = tb.tb_next # type: ignore def filter_hidden_frames(self) -> None: - new_frames: List[Frame] = [] + new_frames: t.List[Frame] = [] hidden = False for frame in self.frames: @@ -410,7 +418,12 @@ def render_text(self) -> str: class Frame: """A single frame in a traceback.""" - def __init__(self, exc_type, exc_value, tb): + def __init__( + self, + exc_type: t.Type[BaseException], + exc_value: BaseException, + tb: TracebackType, + ) -> None: self.lineno = tb.tb_lineno self.function_name = tb.tb_frame.f_code.co_name self.locals = tb.tb_frame.f_locals @@ -446,7 +459,7 @@ def render(self, mark_lib: bool = True) -> str: } @cached_property - def is_library(self): + def is_library(self) -> bool: return any( self.filename.startswith(path) for path in sysconfig.get_paths().values() ) @@ -461,7 +474,7 @@ def render_line_context(self) -> str: before, current, after = self.get_context_lines() rv = [] - def render_line(line, cls): + def render_line(line: str, cls: str) -> None: line = line.expandtabs().rstrip() stripped_line = line.strip() prefix = len(line) - len(stripped_line) @@ -478,7 +491,7 @@ def render_line(line, cls): return "\n".join(rv) - def get_annotated_lines(self) -> List[Line]: + def get_annotated_lines(self) -> t.List[Line]: """Helper function that returns lines with extra information.""" lines = [Line(idx + 1, x) for idx, x in enumerate(self.sourcelines)] @@ -504,14 +517,14 @@ def get_annotated_lines(self) -> List[Line]: return lines - def eval(self, code, mode="single"): + def eval(self, code: t.Union[str, CodeType], mode: str = "single") -> t.Any: """Evaluate code in the context of the frame.""" if isinstance(code, str): code = compile(code, "", mode) return eval(code, self.globals, self.locals) @cached_property - def sourcelines(self): + def sourcelines(self) -> t.List[str]: """The sourcecode of the file as list of strings.""" # get sourcecode from loader or file source = None @@ -538,13 +551,13 @@ def sourcelines(self): return source.splitlines() charset = "utf-8" - if source.startswith(UTF8_COOKIE): + if source.startswith(codecs.BOM_UTF8): source = source[3:] else: for idx, match in enumerate(_line_re.finditer(source)): - match = _coding_re.search(match.group()) - if match is not None: - charset = match.group(1) + coding_match = _coding_re.search(match.group()) + if coding_match is not None: + charset = coding_match.group(1).decode("utf-8") break if idx > 1: break @@ -560,7 +573,7 @@ def sourcelines(self): def get_context_lines( self, context: int = 5 - ) -> Union[Tuple[List[str], str, List[Any]], Tuple[List[str], str, List[str]]]: + ) -> t.Tuple[t.List[str], str, t.List[str]]: before = self.sourcelines[self.lineno - context - 1 : self.lineno - 1] past = self.sourcelines[self.lineno : self.lineno + context] return (before, self.current_line, past) @@ -573,7 +586,7 @@ def current_line(self) -> str: return "" @cached_property - def console(self): + def console(self) -> Console: return Console(self.globals, self.locals) @property diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index dc4f4e60a..43d3097bb 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -46,26 +46,17 @@ def application(environ, start_response): return e """ import sys +import typing as t from datetime import datetime from html import escape -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import Union from ._internal import _get_environ -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from io import BytesIO # noqa: F401 - from werkzeug.datastructures import WWWAuthenticate # noqa: F401 - from werkzeug.wrappers.response import Response - from werkzeug.wsgi import ClosingIterator +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIEnvironment + from .datastructures import WWWAuthenticate + from .wrappers.response import Response class HTTPException(Exception): @@ -74,11 +65,13 @@ class HTTPException(Exception): of it independently and render nicer error messages. """ - code: Optional[int] = None - description: Optional[str] = None + code: t.Optional[int] = None + description: t.Optional[str] = None def __init__( - self, description: Optional[str] = None, response: Optional["Response"] = None, + self, + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, ) -> None: super().__init__() if description is not None: @@ -86,7 +79,9 @@ def __init__( self.response = response @classmethod - def wrap(cls, exception, name=None): + def wrap( + cls, exception: t.Type[BaseException], name: t.Optional[str] = None + ) -> t.Type["HTTPException"]: """Create an exception that is a subclass of the calling HTTP exception and the ``exception`` argument. @@ -104,7 +99,7 @@ def wrap(cls, exception, name=None): The description includes the wrapped exception message. """ - class newcls(cls, exception): + class newcls(cls, exception): # type: ignore _description = cls.description show_exception = False @@ -130,7 +125,7 @@ def description(self): def description(self, value): self._description = value - newcls.__module__ = sys._getframe(1).f_globals.get("__name__") + newcls.__module__ = sys._getframe(1).f_globals["__name__"] name = name or cls.__name__ + exception.__name__ newcls.__name__ = newcls.__qualname__ = name return newcls @@ -140,16 +135,14 @@ def name(self) -> str: """The status name.""" from .http import HTTP_STATUS_CODES - return HTTP_STATUS_CODES.get(self.code, "Unknown Error") + return HTTP_STATUS_CODES.get(self.code, "Unknown Error") # type: ignore - def get_description( - self, environ: Any = None, # this type might need to be changed - ) -> str: + def get_description(self, environ: t.Optional["WSGIEnvironment"] = None) -> str: """Get the description.""" - description = escape(self.description).replace("\n", "
") + description = escape(self.description).replace("\n", "
") # type: ignore return f"

{description}

" - def get_body(self, environ: Optional[Dict[str, Any]] = None) -> str: + def get_body(self, environ: t.Optional["WSGIEnvironment"] = None) -> str: """Get the HTML body.""" return ( '\n' @@ -159,12 +152,12 @@ def get_body(self, environ: Optional[Dict[str, Any]] = None) -> str: ) def get_headers( - self, environ: Optional[WSGIEnvironment] = None - ) -> List[Tuple[str, str]]: + self, environ: t.Optional["WSGIEnvironment"] = None + ) -> t.List[t.Tuple[str, str]]: """Get a list of headers.""" return [("Content-Type", "text/html; charset=utf-8")] - def get_response(self, environ: Optional[Dict[str, Any]] = None) -> "Response": + def get_response(self, environ: t.Optional["WSGIEnvironment"] = None) -> "Response": """Get a response object. If one was passed to the exception it's returned directly. @@ -183,8 +176,8 @@ def get_response(self, environ: Optional[Dict[str, Any]] = None) -> "Response": return Response(self.get_body(environ), self.code, headers) def __call__( - self, environ: WSGIEnvironment, start_response: Callable, - ) -> "ClosingIterator": + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: """Call the exception as WSGI application. :param environ: the WSGI environment. @@ -295,24 +288,25 @@ class Unauthorized(HTTPException): def __init__( self, - description: Optional[str] = None, - response: None = None, - www_authenticate: Optional[ - Union[List["WWWAuthenticate"], Tuple["WWWAuthenticate"], "WWWAuthenticate"] + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, + www_authenticate: t.Optional[ + t.Union["WWWAuthenticate", t.Iterable["WWWAuthenticate"]] ] = None, ) -> None: - HTTPException.__init__(self, description, response) + super().__init__(description, response) - if www_authenticate is not None: - if not isinstance(www_authenticate, (tuple, list)): - www_authenticate = (www_authenticate,) + from .datastructures import WWWAuthenticate + + if isinstance(www_authenticate, WWWAuthenticate): + www_authenticate = (www_authenticate,) self.www_authenticate = www_authenticate def get_headers( - self, environ: Optional[Dict[Any, Any]] = None - ) -> List[Tuple[str, str]]: - headers = HTTPException.get_headers(self, environ) + self, environ: t.Optional["WSGIEnvironment"] = None + ) -> t.List[t.Tuple[str, str]]: + headers = super().get_headers(environ) if self.www_authenticate: headers.extend(("WWW-Authenticate", str(x)) for x in self.www_authenticate) return headers @@ -362,19 +356,19 @@ class MethodNotAllowed(HTTPException): def __init__( self, - valid_methods: Optional[List[str]] = None, - description: None = None, - response: None = None, + valid_methods: t.Optional[t.Iterable[str]] = None, + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, ) -> None: """Takes an optional list of valid http methods starting with werkzeug 0.3 the list will be mandatory.""" - HTTPException.__init__(self, description=description, response=response) + super().__init__(description=description, response=response) self.valid_methods = valid_methods def get_headers( - self, environ: Optional[Dict[Any, Any]] = None - ) -> List[Tuple[str, str]]: - headers = HTTPException.get_headers(self, environ) + self, environ: t.Optional["WSGIEnvironment"] = None + ) -> t.List[t.Tuple[str, str]]: + headers = super().get_headers(environ) if self.valid_methods: headers.append(("Allow", ", ".join(self.valid_methods))) return headers @@ -388,7 +382,6 @@ class NotAcceptable(HTTPException): """ code = 406 - description = ( "The resource identified by the request is only capable of" " generating response entities which have content" @@ -518,20 +511,22 @@ class RequestedRangeNotSatisfiable(HTTPException): def __init__( self, - length: Optional[int] = None, + length: t.Optional[int] = None, units: str = "bytes", - description: None = None, - response: None = None, + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, ) -> None: """Takes an optional `Content-Range` header value based on ``length`` parameter. """ - HTTPException.__init__(self, description=description, response=response) + super().__init__(description=description, response=response) self.length = length self.units = units - def get_headers(self, environ=None): - headers = HTTPException.get_headers(self, environ) + def get_headers( + self, environ: t.Optional["WSGIEnvironment"] = None + ) -> t.List[t.Tuple[str, str]]: + headers = super().get_headers(environ) if self.length is not None: headers.append(("Content-Range", f"{self.units} */{self.length}")) return headers @@ -628,16 +623,16 @@ class _RetryAfter(HTTPException): def __init__( self, - description: None = None, - response: None = None, - retry_after: Optional[Union[datetime, int]] = None, + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, + retry_after: t.Optional[t.Union[datetime, int]] = None, ) -> None: super().__init__(description, response) self.retry_after = retry_after def get_headers( - self, environ: Optional[Dict[Any, Any]] = None - ) -> List[Tuple[str, str]]: + self, environ: t.Optional["WSGIEnvironment"] = None + ) -> t.List[t.Tuple[str, str]]: headers = super().get_headers(environ) if self.retry_after: @@ -716,9 +711,9 @@ class InternalServerError(HTTPException): def __init__( self, - description: None = None, - response: None = None, - original_exception: None = None, + description: t.Optional[str] = None, + response: t.Optional["Response"] = None, + original_exception: t.Optional[BaseException] = None, ) -> None: #: The original exception that caused this 500 error. Can be #: used by frameworks to provide context when handling @@ -797,11 +792,10 @@ class HTTPVersionNotSupported(HTTPException): ) -default_exceptions: Dict[int, Type[HTTPException]] = {} -__all__ = ["HTTPException"] +default_exceptions: t.Dict[int, t.Type[HTTPException]] = {} -def _find_exceptions(): +def _find_exceptions() -> None: for obj in globals().values(): try: is_http_exception = issubclass(obj, HTTPException) @@ -809,7 +803,6 @@ def _find_exceptions(): is_http_exception = False if not is_http_exception or obj.code is None: continue - __all__.append(obj.__name__) old_obj = default_exceptions.get(obj.code, None) if old_obj is not None and issubclass(obj, old_obj): continue @@ -831,24 +824,28 @@ class Aborter: def __init__( self, - mapping: Optional[Dict[int, Type[NotFound]]] = None, - extra: Optional[Dict[int, Type[NotFound]]] = None, + mapping: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, + extra: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, ) -> None: if mapping is None: - mapping = default_exceptions # type: ignore + mapping = default_exceptions self.mapping = dict(mapping) if extra is not None: self.mapping.update(extra) - def __call__(self, code, *args, **kwargs): - if not args and not kwargs and not isinstance(code, int): + def __call__(self, code: t.Union[int, "Response"], *args, **kwargs) -> None: + from .wrappers.response import Response + + if isinstance(code, Response): raise HTTPException(response=code) + if code not in self.mapping: raise LookupError(f"no exception for {code!r}") + raise self.mapping[code](*args, **kwargs) -def abort(status, *args, **kwargs): +def abort(status: t.Union[int, "Response"], *args, **kwargs) -> None: """Raises an :py:exc:`HTTPException` for the given status code or WSGI application. @@ -860,7 +857,7 @@ def abort(status, *args, **kwargs): abort(Response('Hello World')) """ - return _aborter(status, *args, **kwargs) + _aborter(status, *args, **kwargs) _aborter = Aborter() diff --git a/src/werkzeug/filesystem.py b/src/werkzeug/filesystem.py index d72fac80b..bdb8c83ea 100644 --- a/src/werkzeug/filesystem.py +++ b/src/werkzeug/filesystem.py @@ -1,11 +1,9 @@ import codecs import sys import warnings -from typing import Any # We do not trust traditional unixes. - -has_likely_buggy_unicode_filesystem: Any = ( +has_likely_buggy_unicode_filesystem = ( sys.platform.startswith("linux") or "bsd" in sys.platform ) diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 831bdd11b..707f3834c 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -1,23 +1,10 @@ import codecs -import re +import typing as t from functools import update_wrapper from io import BytesIO from itertools import chain from itertools import repeat from itertools import tee -from typing import Any -from typing import AnyStr -from typing import BinaryIO -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import Union from . import exceptions from ._internal import _to_str @@ -29,11 +16,6 @@ from .wsgi import get_content_length from .wsgi import get_input_stream from .wsgi import make_line_iter -from werkzeug.types import WSGIEnvironment - -if TYPE_CHECKING: - from werkzeug.datastructures import ImmutableMultiDict # noqa: F401 - from werkzeug.wsgi import LimitedStream # noqa: F401 # there are some platforms where SpooledTemporaryFile is not available. # In that case we need to provide a fallback. @@ -44,44 +26,48 @@ SpooledTemporaryFile = None # type: ignore +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment -#: an iterator that yields empty strings -_empty_string_iter = repeat("") - -#: a regular expression for multipart boundaries -_multipart_boundary_re = re.compile("^[ -~]{0,200}[!-~]$") + t_parse_result = t.Tuple[t.BinaryIO, MultiDict, MultiDict] -#: supported http encodings that are also available in python we support -#: for multipart messages. -_supported_multipart_encodings = frozenset(["base64", "quoted-printable"]) + class TStreamFactory(t.Protocol): + def __call__( + self, + total_content_length: int, + content_type: t.Optional[str], + filename: t.Optional[str], + content_length: t.Optional[int] = None, + ) -> t.BinaryIO: + ... def default_stream_factory( total_content_length: int, - content_type: Optional[str], - filename: str, - content_length: Optional[int] = None, -) -> BinaryIO: - """The stream factory that is used per default.""" + content_type: t.Optional[str], + filename: t.Optional[str], + content_length: t.Optional[int] = None, +) -> t.BinaryIO: max_size = 1024 * 500 - # because these are opened in binary mode, `BytesIO` is an appropriate return type + if SpooledTemporaryFile is not None: - return SpooledTemporaryFile(max_size=max_size, mode="wb+") # type: ignore - if total_content_length is None or total_content_length > max_size: - return TemporaryFile("wb+") # type: ignore + return t.cast(t.BinaryIO, SpooledTemporaryFile(max_size=max_size, mode="wb+")) + elif total_content_length is None or total_content_length > max_size: + return t.cast(t.BinaryIO, TemporaryFile("wb+")) + return BytesIO() def parse_form_data( - environ: WSGIEnvironment, - stream_factory: None = None, + environ: "WSGIEnvironment", + stream_factory: t.Optional["TStreamFactory"] = None, charset: str = "utf-8", errors: str = "replace", - max_form_memory_size: None = None, - max_content_length: None = None, - cls: None = None, + max_form_memory_size: t.Optional[int] = None, + max_content_length: t.Optional[int] = None, + cls: t.Optional[t.Type[MultiDict]] = None, silent: bool = True, -) -> Tuple[BinaryIO, Type[dict], Type[dict]]: +) -> "t_parse_result": """Parse the form data in the environ and return it as tuple in the form ``(stream, form, files)``. You should only call this method if the transport method is `POST`, `PUT`, or `PATCH`. @@ -141,11 +127,13 @@ def wrapper(self, stream, *args, **kwargs): return f(self, stream, *args, **kwargs) finally: exhaust = getattr(stream, "exhaust", None) + if exhaust is not None: exhaust() else: - while 1: + while True: chunk = stream.read(1024 * 64) + if not chunk: break @@ -182,34 +170,40 @@ class FormDataParser: def __init__( self, - stream_factory: Optional[Callable] = None, + stream_factory: t.Optional["TStreamFactory"] = None, charset: str = "utf-8", errors: str = "replace", - max_form_memory_size: Optional[int] = None, - max_content_length: Optional[int] = None, - cls: Optional[Type[dict]] = None, + max_form_memory_size: t.Optional[int] = None, + max_content_length: t.Optional[int] = None, + cls: t.Optional[t.Type[MultiDict]] = None, silent: bool = True, ) -> None: if stream_factory is None: stream_factory = default_stream_factory + self.stream_factory = stream_factory self.charset = charset self.errors = errors self.max_form_memory_size = max_form_memory_size self.max_content_length = max_content_length + if cls is None: cls = MultiDict + self.cls = cls self.silent = silent def get_parse_func( - self, mimetype: str, options: Dict[str, str] - ) -> Optional[Callable]: + self, mimetype: str, options: t.Dict[str, str] + ) -> t.Optional[ + t.Callable[ + ["FormDataParser", t.BinaryIO, str, t.Optional[int], t.Dict[str, str]], + "t_parse_result", + ] + ]: return self.parse_functions.get(mimetype) - def parse_from_environ( - self, environ: WSGIEnvironment - ) -> Tuple[BytesIO, Type[dict], Type[dict]]: + def parse_from_environ(self, environ: "WSGIEnvironment") -> "t_parse_result": """Parses the information from the environment as form data. :param environ: the WSGI environment to be used for parsing. @@ -222,11 +216,11 @@ def parse_from_environ( def parse( self, - stream: Union["BytesIO", str, "LimitedStream"], + stream: t.BinaryIO, mimetype: str, - content_length: Optional[int], - options: Optional[Dict[str, str]] = None, - ) -> Tuple["BytesIO", Type[dict], Type[dict]]: + content_length: t.Optional[int], + options: t.Optional[t.Dict[str, str]] = None, + ) -> "t_parse_result": """Parses the information from the given stream, mimetype, content length and mimetype parameters. @@ -243,10 +237,12 @@ def parse( and content_length > self.max_content_length ): raise exceptions.RequestEntityTooLarge() + if options is None: options = {} parse_func = self.get_parse_func(mimetype, options) + if parse_func is not None: try: return parse_func(self, stream, mimetype, content_length, options) @@ -254,16 +250,16 @@ def parse( if not self.silent: raise - return stream, self.cls(), self.cls() # type: ignore + return stream, self.cls(), self.cls() @exhaust_stream def _parse_multipart( self, - stream: BinaryIO, + stream: t.BinaryIO, mimetype: str, content_length: int, - options: Dict[str, str], - ) -> Tuple[BinaryIO, dict, dict]: + options: t.Dict[str, str], + ) -> "t_parse_result": parser = MultiPartParser( self.stream_factory, self.charset, @@ -271,56 +267,60 @@ def _parse_multipart( max_form_memory_size=self.max_form_memory_size, cls=self.cls, ) - boundary = options.get("boundary") - if boundary is None: + boundary = options.get("boundary", "").encode("ascii") + + if not boundary: raise ValueError("Missing boundary") - if isinstance(boundary, str): - boundary = boundary.encode("ascii") # type: ignore - form, files = parser.parse(stream, boundary, content_length) # type: ignore + + form, files = parser.parse(stream, boundary, content_length) return stream, form, files @exhaust_stream def _parse_urlencoded( self, - stream: BinaryIO, + stream: t.BinaryIO, mimetype: str, content_length: int, - options: Dict[Any, Any], - ) -> Union[BinaryIO, Type[dict], Type[dict]]: + options: t.Dict[str, str], + ) -> "t_parse_result": if ( self.max_form_memory_size is not None and content_length is not None and content_length > self.max_form_memory_size ): raise exceptions.RequestEntityTooLarge() + form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls) - return stream, form, self.cls() # type: ignore + return stream, form, self.cls() #: mapping of mimetypes to parsing functions - parse_functions = { + parse_functions: t.Dict[ + str, + t.Callable[ + ["FormDataParser", t.BinaryIO, str, t.Optional[int], t.Dict[str, str]], + "t_parse_result", + ], + ] = { "multipart/form-data": _parse_multipart, "application/x-www-form-urlencoded": _parse_urlencoded, "application/x-url-encoded": _parse_urlencoded, } -def is_valid_multipart_boundary(boundary): - """Checks if the string given is a valid multipart boundary.""" - return _multipart_boundary_re.match(boundary) is not None - - -def _line_parse(line: str) -> Tuple[str, bool]: +def _line_parse(line: str) -> t.Tuple[str, bool]: """Removes line ending characters and returns a tuple (`stripped_line`, `is_terminated`). """ - if line[-2:] in ["\r\n", b"\r\n"]: + if line[-2:] == "\r\n": return line[:-2], True - elif line[-1:] in ["\r", "\n", b"\r", b"\n"]: + + elif line[-1:] in {"\r", "\n"}: return line[:-1], True + return line, False -def parse_multipart_headers(iterable: Union[List[str], chain]) -> Headers: +def parse_multipart_headers(iterable: t.Iterable[bytes]) -> Headers: """Parses multipart headers from an iterable that yields lines (including the trailing newline symbol). The iterable has to be newline terminated. @@ -329,12 +329,15 @@ def parse_multipart_headers(iterable: Union[List[str], chain]) -> Headers: :param iterable: iterable of strings that are newline terminated """ - result: List[Any] = [] - for line in iterable: - line = _to_str(line) + result: t.List[t.Tuple[str, str]] = [] + + for b_line in iterable: + line = _to_str(b_line) line, line_terminated = _line_parse(line) + if not line_terminated: raise ValueError("unexpected end of line in multipart header") + if not line: break elif line[0] in " \t" and result: @@ -342,6 +345,7 @@ def parse_multipart_headers(iterable: Union[List[str], chain]) -> Headers: result[-1] = (key, f"{value}\n {line[1:]}") else: parts = line.split(":", 1) + if len(parts) == 2: result.append((parts[0].strip(), parts[1].strip())) @@ -359,22 +363,26 @@ def parse_multipart_headers(iterable: Union[List[str], chain]) -> Headers: class MultiPartParser: def __init__( self, - stream_factory: Optional[Union[Callable, int]] = None, + stream_factory: t.Optional["TStreamFactory"] = None, charset: str = "utf-8", errors: str = "replace", - max_form_memory_size: Optional[int] = None, - cls: Optional[ - Union[Type["ImmutableMultiDict"], Type[dict], Type["MultiDict"]] - ] = None, + max_form_memory_size: t.Optional[int] = None, + cls: t.Optional[t.Type[MultiDict]] = None, buffer_size: int = 64 * 1024, ) -> None: self.charset = charset self.errors = errors self.max_form_memory_size = max_form_memory_size - self.stream_factory = ( - default_stream_factory if stream_factory is None else stream_factory - ) - self.cls = MultiDict if cls is None else cls + + if stream_factory is None: + stream_factory = default_stream_factory + + self.stream_factory = stream_factory + + if cls is None: + cls = MultiDict + + self.cls = cls # make sure the buffer size is divisible by four so that we can base64 # decode chunk by chunk @@ -382,7 +390,6 @@ def __init__( # also the buffer size has to be at least 1024 bytes long or long headers # will freak out the system assert buffer_size >= 1024, "buffer size has to be at least 1KB" - self.buffer_size = buffer_size def _fix_ie_filename(self, filename: str) -> str: @@ -392,53 +399,61 @@ def _fix_ie_filename(self, filename: str) -> str: """ if filename[1:3] == ":\\" or filename[:2] == "\\\\": return filename.split("\\")[-1] + return filename - def _find_terminator(self, iterator: Union[Iterable[AnyStr]]) -> Union[bytes, str]: + def _find_terminator(self, iterator: t.Iterable[bytes]) -> bytes: """The terminator might have some additional newlines before it. - There is at least one application that xsends additional newlines + There is at least one application that sends additional newlines before headers (the python setuptools package). """ for line in iterator: if not line: break + line = line.strip() + if line: return line + return b"" - def fail(self, message): + def fail(self, message: str) -> "t.NoReturn": raise ValueError(message) - def get_part_encoding(self, headers: Headers) -> Optional[str]: - transfer_encoding = headers.get("content-transfer-encoding") - if ( - transfer_encoding is not None - and transfer_encoding in _supported_multipart_encodings - ): + def get_part_encoding(self, headers: Headers) -> t.Optional[str]: + transfer_encoding: t.Optional[str] = headers.get("content-transfer-encoding") + + if transfer_encoding in {"base64", "quoted-printable"}: return transfer_encoding + return None def get_part_charset(self, headers: Headers) -> str: # Figure out input charset for current part content_type = headers.get("content-type") + if content_type: mimetype, ct_params = parse_options_header(content_type) return ct_params.get("charset", self.charset) + return self.charset def start_file_streaming( self, filename: str, headers: Headers, total_content_length: int - ) -> Union[Tuple[str, BytesIO], Tuple[str, SpooledTemporaryFile]]: + ) -> t.Tuple[str, t.BinaryIO]: if isinstance(filename, bytes): filename = filename.decode(self.charset, self.errors) + filename = self._fix_ie_filename(filename) content_type = headers.get("content-type") + try: content_length = int(headers["content-length"]) except (KeyError, ValueError): content_length = 0 - container = self.stream_factory( # type: ignore + + container = self.stream_factory( total_content_length=total_content_length, filename=filename, content_type=content_type, @@ -446,33 +461,18 @@ def start_file_streaming( ) return filename, container - def in_memory_threshold_reached(self, bytes): + def in_memory_threshold_reached(self, size: int) -> None: raise exceptions.RequestEntityTooLarge() - def validate_boundary(self, boundary): - if not boundary: - self.fail("Missing boundary") - if not is_valid_multipart_boundary(boundary): - self.fail(f"Invalid boundary: {boundary}") - if len(boundary) > self.buffer_size: - # this should never happen because we check for a minimum size - # of 1024 and boundaries may not be longer than 200. The only - # situation when this happens is for non debug builds where - # the assert is skipped. - self.fail("Boundary longer than buffer size") - def parse_lines( self, - file: BinaryIO, + file: t.BinaryIO, boundary: bytes, content_length: int, cap_at_buffer: bool = True, - ) -> Iterator[ - Union[ - Tuple[str, Tuple[Headers, str]], - Tuple[str, Union[str, bytes]], - Tuple[str, Tuple[Headers, str, str]], - Tuple[str, None], + ) -> t.Iterator[ + t.Tuple[ + str, t.Union[t.Tuple[Headers, str], t.Tuple[Headers, str, str], bytes, None] ] ]: """Generate parts of @@ -487,7 +487,6 @@ def parse_lines( """ next_part = b"--" + boundary last_part = next_part + b"--" - iterator = chain( make_line_iter( file, @@ -495,9 +494,8 @@ def parse_lines( buffer_size=self.buffer_size, cap_at_buffer=cap_at_buffer, ), - _empty_string_iter, + repeat(b""), ) - terminator = self._find_terminator(iterator) if terminator == last_part: @@ -507,38 +505,41 @@ def parse_lines( while terminator != last_part: headers = parse_multipart_headers(iterator) - disposition = headers.get("content-disposition") + if disposition is None: self.fail("Missing Content-Disposition header") + disposition, extra = parse_options_header(disposition) transfer_encoding = self.get_part_encoding(headers) - name = extra.get("name") + name = t.cast(str, extra.get("name")) filename = extra.get("filename") # if no content type is given we stream into memory. A list is # used as a temporary container. if filename is None: yield _begin_form, (headers, name) - # otherwise we parse the rest of the headers and ask the stream # factory for something we can write in. else: yield _begin_file, (headers, name, filename) buf = b"" + for line in iterator: if not line: self.fail("unexpected end of stream") - if line[:2] == b"--": # type: ignore + if line[:2] == b"--": terminator = line.rstrip() - if terminator in (next_part, last_part): # type: ignore + + if terminator in {next_part, last_part}: break if transfer_encoding is not None: if transfer_encoding == "base64": transfer_encoding = "base64_codec" + try: line = codecs.decode(line, transfer_encoding) # type: ignore except Exception: @@ -548,7 +549,6 @@ def parse_lines( # this is usually a newline delimiter. if buf: yield _cont, buf - buf = b"" # If the line ends with windows CRLF we write everything except # the last two bytes. In all other cases however we write @@ -559,88 +559,90 @@ def parse_lines( # truncate the stream. However we do have to make sure that # if something else than a newline is in there we write it # out. - if line[-2:] == b"\r\n": # type: ignore + if line[-2:] == b"\r\n": buf = b"\r\n" cutoff = -2 else: - buf = line[-1:] # type: ignore + buf = line[-1:] cutoff = -1 - yield _cont, line[:cutoff] + yield _cont, line[:cutoff] else: raise ValueError("unexpected end of part") # if we have a leftover in the buffer that is not a newline # character we have to flush it, otherwise we will chop of # certain values. - if buf not in (b"", b"\r", b"\n", b"\r\n"): + if buf not in {b"", b"\r", b"\n", b"\r\n"}: yield _cont, buf yield _end, None def parse_parts( - self, file: BinaryIO, boundary: bytes, content_length: int - ) -> Iterator[Union[Tuple[str, Tuple[str, Union[str, FileStorage]]]]]: + self, file: t.BinaryIO, boundary: bytes, content_length: int + ) -> t.Iterator[t.Tuple[str, t.Tuple[str, t.Union[str, FileStorage]]]]: """Generate ``('file', (name, val))`` and ``('form', (name, val))`` parts. """ in_memory = 0 + guard_memory: bool + is_file: bool + container: t.Union[t.BinaryIO, t.List[bytes]] + _write: t.Callable[[bytes], t.Any] + headers: Headers + name: str + filename: str for ellt, ell in self.parse_lines(file, boundary, content_length): if ellt == _begin_file: - headers, name, filename = ell # type: ignore + headers, name, filename = t.cast(t.Tuple[Headers, str, str], ell) is_file = True guard_memory = False filename, container = self.start_file_streaming( - filename, headers, content_length # type: ignore + filename, headers, content_length ) _write = container.write elif ellt == _begin_form: - headers, name = ell # type: ignore + headers, name = t.cast(t.Tuple[Headers, str], ell) is_file = False - container = [] # type: ignore - _write = container.append # type: ignore + container = [] + _write = container.append guard_memory = self.max_form_memory_size is not None elif ellt == _cont: - _write(ell) # type: ignore + ell = t.cast(bytes, ell) + _write(ell) # if we write into memory and there is a memory size limit we # count the number of bytes in memory and raise an exception if # there is too much data in memory. if guard_memory: in_memory += len(ell) - if in_memory > self.max_form_memory_size: + + if in_memory > self.max_form_memory_size: # type: ignore self.in_memory_threshold_reached(in_memory) elif ellt == _end: if is_file: + container = t.cast(t.BinaryIO, container) container.seek(0) - yield ( # type: ignore + yield ( "file", - ( - name, - FileStorage( - container, - filename, - name, # type: ignore - headers=headers, # type: ignore - ), - ), + (name, FileStorage(container, filename, name, headers=headers)), ) else: - part_charset = self.get_part_charset(headers) # type: ignore - yield ( # type: ignore + part_charset = self.get_part_charset(headers) + yield ( "form", - (name, b"".join(container).decode(part_charset, self.errors),), + (name, b"".join(container).decode(part_charset, self.errors)), ) def parse( - self, file: BinaryIO, boundary: bytes, content_length: int - ) -> Tuple[dict, dict]: - formstream, filestream = tee( + self, file: t.BinaryIO, boundary: bytes, content_length: int + ) -> t.Tuple[MultiDict, MultiDict]: + form_stream, file_stream = tee( self.parse_parts(file, boundary, content_length), 2 ) - form = (p[1] for p in formstream if p[0] == "form") - files = (p[1] for p in filestream if p[0] == "file") - return self.cls(form), self.cls(files) # type: ignore + form = (v for t, v in form_stream if t == "form") + files = (v for t, v in file_stream if t == "file") + return self.cls(form), self.cls(files) diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 0d655a3df..9eda5ea65 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -1,5 +1,7 @@ import base64 import re +import typing +import typing as t import warnings from datetime import datetime from datetime import timedelta @@ -8,18 +10,6 @@ from time import gmtime from time import struct_time from time import time -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import overload -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union from urllib.parse import unquote_to_bytes as _unquote from urllib.request import parse_http_list as _parse_list_header @@ -28,20 +18,11 @@ from ._internal import _make_cookie_domain from ._internal import _to_bytes from ._internal import _to_str -from .types import T -from .types import WSGIEnvironment - -if TYPE_CHECKING: - from .datastructures import ( # noqa: F401 - CallbackDict, - CharsetAccept, - LanguageAccept, - MIMEAccept, - ) - from .datastructures import AnyHeaders +from ._internal import _wsgi_decoding_dance + +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment -_cookie_charset = "latin1" -_basic_auth_charset = "utf-8" # for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231 _accept_re = re.compile( r""" @@ -66,7 +47,6 @@ "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~" ) _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') -_unsafe_header_chars = set('()<>@,;:"/[]?={} \t') _option_header_piece_re = re.compile( r""" ;\s*,?\s* # newlines were replaced with commas @@ -98,7 +78,6 @@ flags=re.VERBOSE, ) _option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?") - _entity_headers = frozenset( [ "allow", @@ -125,8 +104,6 @@ "upgrade", ] ) - - HTTP_STATUS_CODES = { 100: "Continue", 101: "Switching Protocols", @@ -191,27 +168,12 @@ 507: "Insufficient Storage", 508: "Loop Detected", # see RFC 5842 510: "Not Extended", - 511: "Network Authentication Failed", # see RFC 6585 + 511: "Network Authentication Failed", } -def wsgi_to_bytes(data: Union[str, bytes]) -> bytes: - """If data is not bytes, encode it as latin1 for WSGI.""" - if isinstance(data, bytes): - return data - return data.encode("latin1") # XXX: utf8 fallback? - - -def bytes_to_wsgi(data: bytes) -> str: - assert isinstance(data, bytes), "data must be bytes" - if isinstance(data, str): - return data - else: - return data.decode("latin1") - - def quote_header_value( - value: Union[str, int], extra_chars: str = "", allow_token: bool = True + value: t.Union[str, int], extra_chars: str = "", allow_token: bool = True ) -> str: """Quote a header value if necessary. @@ -223,7 +185,7 @@ def quote_header_value( unchanged. """ if isinstance(value, bytes): - value = bytes_to_wsgi(value) + value = value.decode("latin1") value = str(value) if allow_token: token_chars = _token_chars | set(extra_chars) @@ -261,10 +223,7 @@ def unquote_header_value(value: str, is_filename: bool = False) -> str: def dump_options_header( - header: str, - options: Union[ - Dict[str, Optional[int]], "CallbackDict", Dict[str, int], Dict[str, str] - ], + header: str, options: t.Dict[str, t.Optional[t.Union[str, int]]] ) -> str: """The reverse function to :func:`parse_options_header`. @@ -282,7 +241,10 @@ def dump_options_header( return "; ".join(segments) -def dump_header(iterable: Iterable, allow_token: bool = True) -> str: +def dump_header( + iterable: t.Union[t.Dict[str, t.Union[str, int]], t.Iterable[str]], + allow_token: bool = True, +) -> str: """Dump an HTTP header again. This is the reversal of :func:`parse_list_header`, :func:`parse_set_header` and :func:`parse_dict_header`. This also quotes strings that include an @@ -311,7 +273,7 @@ def dump_header(iterable: Iterable, allow_token: bool = True) -> str: return ", ".join(items) -def dump_csp_header(header: "ContentSecurityPolicy") -> str: +def dump_csp_header(header: "ds.ContentSecurityPolicy") -> str: """Dump a Content Security Policy header. These are structured into policies such as "default-src 'self'; @@ -324,7 +286,7 @@ def dump_csp_header(header: "ContentSecurityPolicy") -> str: return "; ".join(f"{key} {value}" for key, value in header.items()) -def parse_list_header(value: str) -> List[str]: +def parse_list_header(value: str) -> t.List[str]: """Parse lists as described by RFC 2068 Section 2. In particular, parse comma-separated lists where the elements of @@ -354,9 +316,7 @@ def parse_list_header(value: str) -> List[str]: return result -def parse_dict_header( - value: Union[str, bytes], cls: Type[dict] = dict -) -> Dict[str, Optional[str]]: +def parse_dict_header(value: str, cls: t.Type[dict] = dict) -> t.Dict[str, str]: """Parse lists of key, value pairs as described by RFC 2068 Section 2 and convert them into a python dict (or any other mapping object created from the type with a dict like interface provided by the `cls` argument): @@ -383,9 +343,8 @@ def parse_dict_header( :return: an instance of `cls` """ result = cls() - if not isinstance(value, str): - # XXX: validate - value = bytes_to_wsgi(value) + if isinstance(value, bytes): + value = value.decode("latin1") for item in _parse_list_header(value): if "=" not in item: result[item] = None @@ -397,7 +356,21 @@ def parse_dict_header( return result -def parse_options_header(value: Optional[str], multiple: bool = False) -> Any: +@typing.overload +def parse_options_header( + value: t.Optional[str], multiple: "t.Literal[False]" = False +) -> t.Tuple[str, t.Dict[str, str]]: + ... + + +@typing.overload +def parse_options_header( + value: t.Optional[str], multiple: "t.Literal[True]" +) -> t.Tuple[t.Any, ...]: + ... + + +def parse_options_header(value, multiple=False): """Parse a ``Content-Type`` like header into a tuple with the content type and the options: @@ -429,7 +402,7 @@ def parse_options_header(value: Optional[str], multiple: bool = False) -> Any: if not match: break result.append(match.group(1)) # mimetype - options = {} # type: ignore + options = {} # Parse options rest = match.group(2) continued_encoding = None @@ -461,7 +434,7 @@ def parse_options_header(value: Optional[str], multiple: bool = False) -> Any: else: options[option] = option_value rest = rest[optmatch.end() :] - result.append(options) # type: ignore + result.append(options) if multiple is False: return tuple(result) value = rest @@ -469,18 +442,18 @@ def parse_options_header(value: Optional[str], multiple: bool = False) -> Any: return tuple(result) if result else ("", {}) -AcceptClass = TypeVar( - "AcceptClass", "Accept", "CharsetAccept", "LanguageAccept", "MIMEAccept" -) +_TAnyAccept = t.TypeVar("_TAnyAccept", bound="ds.Accept") -@overload -def parse_accept_header(value: str, cls: None,) -> "Accept": +@typing.overload +def parse_accept_header(value: t.Optional[str], cls: None = None) -> "ds.Accept": ... -@overload -def parse_accept_header(value: str, cls: Type[AcceptClass],) -> AcceptClass: +@typing.overload +def parse_accept_header( + value: t.Optional[str], cls: t.Type[_TAnyAccept] +) -> _TAnyAccept: ... @@ -501,7 +474,7 @@ def parse_accept_header(value, cls=None): :return: an instance of `cls`. """ if cls is None: - cls = Accept + cls = ds.Accept if not value: return cls(None) @@ -510,24 +483,28 @@ def parse_accept_header(value, cls=None): for match in _accept_re.finditer(value): quality_match = match.group(2) if not quality_match: - quality: Union[int, float] = 1 + quality: float = 1 else: quality = max(min(float(quality_match), 1), 0) result.append((match.group(1), quality)) return cls(result) -@overload +_TAnyCC = t.TypeVar("_TAnyCC", bound="ds._CacheControl") +_t_cc_update = t.Optional[t.Callable[[_TAnyCC], None]] + + +@typing.overload def parse_cache_control_header( - value: Optional[str], on_update: Optional[Callable], cls: None, -) -> "RequestCacheControl": + value: t.Optional[str], on_update: _t_cc_update, cls: None = None +) -> "ds.RequestCacheControl": ... -@overload +@typing.overload def parse_cache_control_header( - value: Optional[str], on_update: Optional[Callable], cls: Type[T], -) -> T: + value: t.Optional[str], on_update: _t_cc_update, cls: t.Type[_TAnyCC] +) -> _TAnyCC: ... @@ -549,21 +526,27 @@ def parse_cache_control_header(value, on_update=None, cls=None): :return: a `cls` object. """ if cls is None: - cls = RequestCacheControl + cls = ds.RequestCacheControl if not value: return cls(None, on_update) return cls(parse_dict_header(value), on_update) -@overload +_TAnyCSP = t.TypeVar("_TAnyCSP", bound="ds.ContentSecurityPolicy") +_t_csp_update = t.Optional[t.Callable[[_TAnyCSP], None]] + + +@typing.overload def parse_csp_header( - value: Optional[str], on_update: Callable, cls: None, -) -> "ContentSecurityPolicy": + value: t.Optional[str], on_update: _t_csp_update, cls: None = None +) -> "ds.ContentSecurityPolicy": ... -@overload -def parse_csp_header(value: Optional[str], on_update: Callable, cls: Type[T]) -> T: +@typing.overload +def parse_csp_header( + value: t.Optional[str], on_update: _t_csp_update, cls: t.Type[_TAnyCSP] +) -> _TAnyCSP: ... @@ -580,9 +563,8 @@ def parse_csp_header(value, on_update=None, cls=None): :class:`~werkzeug.datastructures.ContentSecurityPolicy` is used. :return: a `cls` object. """ - if cls is None: - cls = ContentSecurityPolicy + cls = ds.ContentSecurityPolicy if value is None: return cls(None, on_update) items = [] @@ -596,8 +578,9 @@ def parse_csp_header(value, on_update=None, cls=None): def parse_set_header( - value: Optional[str], on_update: Optional[Callable] = None -) -> "HeaderSet": + value: t.Optional[str], + on_update: t.Optional[t.Callable[["ds.HeaderSet"], None]] = None, +) -> "ds.HeaderSet": """Parse a set-like header and return a :class:`~werkzeug.datastructures.HeaderSet` object: @@ -623,11 +606,13 @@ def parse_set_header( :return: a :class:`~werkzeug.datastructures.HeaderSet` """ if not value: - return HeaderSet(None, on_update) - return HeaderSet(parse_list_header(value), on_update) + return ds.HeaderSet(None, on_update) + return ds.HeaderSet(parse_list_header(value), on_update) -def parse_authorization_header(value: Optional[str],) -> Optional["Authorization"]: +def parse_authorization_header( + value: t.Optional[str], +) -> t.Optional["ds.Authorization"]: """Parse an HTTP basic/digest authorization header transmitted by the web browser. The return value is either `None` if the header was invalid or not given, otherwise an :class:`~werkzeug.datastructures.Authorization` @@ -638,28 +623,28 @@ def parse_authorization_header(value: Optional[str],) -> Optional["Authorization """ if not value: return None - value = wsgi_to_bytes(value) + value = _wsgi_decoding_dance(value) try: auth_type, auth_info = value.split(None, 1) auth_type = auth_type.lower() except ValueError: return None - if auth_type == b"basic": + if auth_type == "basic": try: username, password = base64.b64decode(auth_info).split(b":", 1) except Exception: return None try: - return Authorization( + return ds.Authorization( "basic", { - "username": _to_str(username, _basic_auth_charset), - "password": _to_str(password, _basic_auth_charset), + "username": _to_str(username, "utf-8"), + "password": _to_str(password, "utf-8"), }, ) except UnicodeDecodeError: return None - elif auth_type == b"digest": + elif auth_type == "digest": auth_map = parse_dict_header(auth_info) for key in "username", "realm", "nonce", "uri", "response": if key not in auth_map: @@ -667,13 +652,14 @@ def parse_authorization_header(value: Optional[str],) -> Optional["Authorization if "qop" in auth_map: if not auth_map.get("nc") or not auth_map.get("cnonce"): return None - return Authorization("digest", auth_map) + return ds.Authorization("digest", auth_map) return None def parse_www_authenticate_header( - value: Optional[str], on_update: Optional[Callable] = None -) -> "WWWAuthenticate": + value: t.Optional[str], + on_update: t.Optional[t.Callable[["ds.WWWAuthenticate"], None]] = None, +) -> "ds.WWWAuthenticate": """Parse an HTTP WWW-Authenticate header into a :class:`~werkzeug.datastructures.WWWAuthenticate` object. @@ -684,33 +670,33 @@ def parse_www_authenticate_header( :return: a :class:`~werkzeug.datastructures.WWWAuthenticate` object. """ if not value: - return WWWAuthenticate(on_update=on_update) + return ds.WWWAuthenticate(on_update=on_update) try: auth_type, auth_info = value.split(None, 1) auth_type = auth_type.lower() except (ValueError, AttributeError): - return WWWAuthenticate(value.strip().lower(), on_update=on_update) - return WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) + return ds.WWWAuthenticate(value.strip().lower(), on_update=on_update) + return ds.WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) -def parse_if_range_header(value: Optional[str]) -> "IfRange": +def parse_if_range_header(value: t.Optional[str]) -> "ds.IfRange": """Parses an if-range header which can be an etag or a date. Returns a :class:`~werkzeug.datastructures.IfRange` object. .. versionadded:: 0.7 """ if not value: - return IfRange() + return ds.IfRange() date = parse_date(value) if date is not None: - return IfRange(date=date) + return ds.IfRange(date=date) # drop weakness information - return IfRange(unquote_etag(value)[0]) + return ds.IfRange(unquote_etag(value)[0]) def parse_range_header( - value: Optional[str], make_inclusive: bool = True -) -> Optional["Range"]: + value: t.Optional[str], make_inclusive: bool = True +) -> t.Optional["ds.Range"]: """Parses a range header into a :class:`~werkzeug.datastructures.Range` object. If the header is missing or malformed `None` is returned. `ranges` is a list of ``(start, stop)`` tuples where the ranges are @@ -740,31 +726,32 @@ def parse_range_header( end = None last_end = -1 elif "-" in item: - begin, end = item.split("-", 1) # type: ignore - begin = begin.strip() # type: ignore - end = end.strip() - if not begin.isdigit(): # type: ignore + begin_str, end_str = item.split("-", 1) + begin_str = begin_str.strip() + end_str = end_str.strip() + if not begin_str.isdigit(): return None - begin = int(begin) + begin = int(begin_str) if begin < last_end or last_end < 0: return None - if end: - if not end.isdigit(): + if end_str: + if not end_str.isdigit(): return None - end = int(end) + 1 # type: ignore - if begin >= end: # type: ignore + end = int(end_str) + 1 + if begin >= end: return None else: end = None - last_end = end # type: ignore + last_end = end if end is not None else -1 ranges.append((begin, end)) - return Range(units, ranges) # type: ignore + return ds.Range(units, ranges) def parse_content_range_header( - value: str, on_update: Optional[Callable] = None -) -> Optional["ContentRange"]: + value: t.Optional[str], + on_update: t.Optional[t.Callable[["ds.ContentRange"], None]] = None, +) -> t.Optional["ds.ContentRange"]: """Parses a range header into a :class:`~werkzeug.datastructures.ContentRange` object or `None` if parsing is not possible. @@ -785,32 +772,28 @@ def parse_content_range_header( if "/" not in rangedef: return None - rng, length = rangedef.split("/", 1) - if length == "*": + rng, length_str = rangedef.split("/", 1) + if length_str == "*": length = None - elif length.isdigit(): - length = int(length) # type: ignore + elif length_str.isdigit(): + length = int(length_str) else: return None if rng == "*": - return ContentRange( - units, None, None, length, on_update=on_update # type: ignore - ) + return ds.ContentRange(units, None, None, length, on_update=on_update) elif "-" not in rng: return None - start, stop = rng.split("-", 1) + start_str, stop_str = rng.split("-", 1) try: - start = int(start) # type: ignore - stop = int(stop) + 1 # type: ignore + start = int(start_str) + stop = int(stop_str) + 1 except ValueError: return None - if is_byte_range_valid(start, stop, length): # type: ignore - return ContentRange( - units, start, stop, length, on_update=on_update # type: ignore - ) + if is_byte_range_valid(start, stop, length): + return ds.ContentRange(units, start, stop, length, on_update=on_update) return None @@ -829,7 +812,9 @@ def quote_etag(etag: str, weak: bool = False) -> str: return etag -def unquote_etag(etag: Optional[str],) -> Union[Tuple[str, bool], Tuple[None, None]]: +def unquote_etag( + etag: t.Optional[str], +) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: """Unquote a single etag: >>> unquote_etag('W/"bar"') @@ -852,14 +837,14 @@ def unquote_etag(etag: Optional[str],) -> Union[Tuple[str, bool], Tuple[None, No return etag, weak -def parse_etags(value: Optional[str]) -> "ETags": +def parse_etags(value: t.Optional[str]) -> "ds.ETags": """Parse an etag header. :param value: the tag header to parse :return: an :class:`~werkzeug.datastructures.ETags` object. """ if not value: - return ETags() + return ds.ETags() strong = [] weak = [] end = len(value) @@ -870,7 +855,7 @@ def parse_etags(value: Optional[str]) -> "ETags": break is_weak, quoted, raw = match.groups() if raw == "*": - return ETags(star_tag=True) + return ds.ETags(star_tag=True) elif quoted: raw = quoted if is_weak: @@ -878,7 +863,7 @@ def parse_etags(value: Optional[str]) -> "ETags": else: strong.append(raw) pos = match.end() - return ETags(strong, weak) + return ds.ETags(strong, weak) def generate_etag(data: bytes) -> str: @@ -886,7 +871,7 @@ def generate_etag(data: bytes) -> str: return md5(data).hexdigest() -def parse_date(value: Optional[str]) -> Optional[datetime]: +def parse_date(value: t.Optional[str]) -> t.Optional[datetime]: """Parse one of the following date formats into a datetime object: .. sourcecode:: text @@ -919,9 +904,10 @@ def parse_date(value: Optional[str]) -> Optional[datetime]: return None -def _dump_date( - d: Optional[Union[float, datetime, int, struct_time]], delim: str -) -> str: +_t_date_input = t.Optional[t.Union[datetime, int, float, struct_time]] + + +def _dump_date(d: _t_date_input, delim: str) -> str: """Used for `http_date` and `cookie_date`.""" if d is None: d = gmtime() @@ -950,7 +936,7 @@ def _dump_date( ) -def cookie_date(expires: Optional[Union[datetime, int, float]] = None) -> str: +def cookie_date(expires: _t_date_input = None) -> str: """Formats the time to ensure compatibility with Netscape's cookie standard. @@ -965,7 +951,7 @@ def cookie_date(expires: Optional[Union[datetime, int, float]] = None) -> str: return _dump_date(expires, "-") -def http_date(timestamp: Optional[Union[float, datetime, int]] = None) -> str: +def http_date(timestamp: _t_date_input = None) -> str: """Formats the time to match the RFC1123 date format. Accepts a floating point number expressed in seconds since the epoch in, a @@ -979,7 +965,7 @@ def http_date(timestamp: Optional[Union[float, datetime, int]] = None) -> str: return _dump_date(timestamp, " ") -def parse_age(value: Optional[str] = None) -> timedelta: +def parse_age(value: t.Optional[str] = None) -> t.Optional[timedelta]: """Parses a base-10 integer count of seconds into a timedelta. If parsing fails, the return value is `None`. @@ -1001,7 +987,7 @@ def parse_age(value: Optional[str] = None) -> timedelta: return None -def dump_age(age: Optional[Union[timedelta, int]] = None) -> Optional[str]: +def dump_age(age: t.Optional[t.Union[timedelta, int]] = None) -> t.Optional[str]: """Formats the duration as a base-10 integer. :param age: should be an integer number of seconds, @@ -1011,8 +997,10 @@ def dump_age(age: Optional[Union[timedelta, int]] = None) -> Optional[str]: if age is None: return None if isinstance(age, timedelta): - age = age.total_seconds() # type: ignore - age = int(age) # type: ignore + age = int(age.total_seconds()) + else: + age = int(age) + if age < 0: raise ValueError("age cannot be negative") @@ -1020,10 +1008,10 @@ def dump_age(age: Optional[Union[timedelta, int]] = None) -> Optional[str]: def is_resource_modified( - environ: WSGIEnvironment, - etag: Optional[str] = None, - data: Optional[Union[bytes, str]] = None, - last_modified: Optional[datetime] = None, + environ: "WSGIEnvironment", + etag: t.Optional[str] = None, + data: t.Optional[t.Union[bytes]] = None, + last_modified: t.Optional[t.Union[datetime, str]] = None, ignore_if_range: bool = True, ) -> bool: """Convenience method for conditional requests. @@ -1041,7 +1029,7 @@ def is_resource_modified( The check is run for methods other than ``GET`` and ``HEAD``. """ if etag is None and data is not None: - etag = generate_etag(data) # type: ignore + etag = generate_etag(data) elif data is not None: raise TypeError("both data and etag given") @@ -1062,7 +1050,7 @@ def is_resource_modified( if_range = parse_if_range_header(environ.get("HTTP_IF_RANGE")) if if_range is not None and if_range.date is not None: - modified_since = if_range.date + modified_since: t.Optional[datetime] = if_range.date else: modified_since = parse_date(environ.get("HTTP_IF_MODIFIED_SINCE")) @@ -1071,6 +1059,8 @@ def is_resource_modified( if etag: etag, _ = unquote_etag(etag) + etag = t.cast(str, etag) + if if_range is not None and if_range.etag is not None: unmodified = parse_etags(if_range.etag).contains(etag) else: @@ -1092,7 +1082,8 @@ def is_resource_modified( def remove_entity_headers( - headers: "AnyHeaders", allowed: Tuple[str, str] = ("expires", "content-location"), + headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]], + allowed: t.Iterable[str] = ("expires", "content-location"), ) -> None: """Remove all entity headers from a list or :class:`Headers` object. This operation works in-place. `Expires` and `Content-Location` headers are @@ -1114,7 +1105,9 @@ def remove_entity_headers( ] -def remove_hop_by_hop_headers(headers: "AnyHeaders") -> None: +def remove_hop_by_hop_headers( + headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]] +) -> None: """Remove all HTTP/1.1 "Hop-by-Hop" headers from a list or :class:`Headers` object. This operation works in-place. @@ -1149,21 +1142,12 @@ def is_hop_by_hop_header(header: str) -> bool: return header.lower() in _hop_by_hop_headers -@overload -def parse_cookie( - header: Union[WSGIEnvironment, str], charset: str, errors: str, cls: None, -) -> "MultiDict": - ... - - -@overload def parse_cookie( - header: Union[WSGIEnvironment, str], charset: str, errors: str, cls: Type[dict], -) -> dict: - ... - - -def parse_cookie(header, charset="utf-8", errors="replace", cls=None): + header: t.Union["WSGIEnvironment", str, bytes], + charset: str = "utf-8", + errors: str = "replace", + cls: t.Optional[t.Type["ds.MultiDict"]] = None, +) -> "ds.MultiDict[str, str]": """Parse a cookie from a string or WSGI environ. The same key can be provided multiple times, the values are stored @@ -1197,7 +1181,7 @@ def parse_cookie(header, charset="utf-8", errors="replace", cls=None): header = header.encode("latin1", "replace") if cls is None: - cls = MultiDict + cls = ds.MultiDict def _parse_pairs(): for key, val in _cookie_parse_impl(header): @@ -1212,17 +1196,17 @@ def _parse_pairs(): def dump_cookie( key: str, - value: Union[str, bytes] = "", - max_age: Optional[Union[int, timedelta]] = None, - expires: Optional[Union[float, int, datetime]] = None, - path: str = "/", - domain: Optional[str] = None, + value: str = "", + max_age: t.Optional[t.Union[timedelta, int]] = None, + expires: t.Optional[t.Union[str, datetime, int, float]] = None, + path: t.Optional[str] = "/", + domain: t.Optional[str] = None, secure: bool = False, httponly: bool = False, charset: str = "utf-8", sync_expires: bool = True, max_size: int = 4093, - samesite: Optional[str] = None, + samesite: t.Optional[str] = None, ) -> str: """Create a Set-Cookie header without the ``Set-Cookie`` prefix. @@ -1277,9 +1261,9 @@ def dump_cookie( max_age = (max_age.days * 60 * 60 * 24) + max_age.seconds if expires is not None: if not isinstance(expires, str): - expires = cookie_date(expires) # type: ignore + expires = cookie_date(expires) elif max_age is not None and sync_expires: - expires = _to_bytes(cookie_date(time() + max_age)) # type: ignore + expires = cookie_date(time() + max_age) if samesite is not None: samesite = samesite.title() @@ -1330,7 +1314,7 @@ def dump_cookie( if max_size and cookie_size > max_size: value_size = len(value) warnings.warn( - f'The "{key}" cookie is too large: the value was' # type: ignore + f"The {key.decode(charset)!r} cookie is too large: the value was" f" {value_size} bytes but the" f" header required {cookie_size - value_size} extra bytes. The final size" f" was {cookie_size} bytes but the limit is {max_size} bytes. Browsers may" @@ -1342,7 +1326,7 @@ def dump_cookie( def is_byte_range_valid( - start: Optional[int], stop: Optional[int], length: Optional[int] + start: t.Optional[int], stop: t.Optional[int], length: t.Optional[int] ) -> bool: """Checks if a given byte content range is valid for the given length. @@ -1353,21 +1337,11 @@ def is_byte_range_valid( elif start is None: return length is None or length >= 0 elif length is None: - return 0 <= start < stop - elif start >= stop: + return 0 <= start < stop # type: ignore + elif start >= stop: # type: ignore return False return 0 <= start < length # circular dependencies -from .datastructures import Accept -from .datastructures import Authorization -from .datastructures import ContentRange -from .datastructures import ContentSecurityPolicy -from .datastructures import ETags -from .datastructures import HeaderSet -from .datastructures import IfRange -from .datastructures import MultiDict -from .datastructures import Range -from .datastructures import RequestCacheControl -from .datastructures import WWWAuthenticate +from . import datastructures as ds diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index 9e5e503a5..8f9352664 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -1,17 +1,8 @@ import copy +import typing as t from functools import update_wrapper -from typing import Any -from typing import Callable -from typing import List -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from .wsgi import ClosingIterator -from werkzeug.types import WSGIEnvironment - -if TYPE_CHECKING: - from werkzeug.debug.console import HTMLStringO, _InteractiveConsole # noqa: F401 # Each thread has its own greenlet, use that as the identifier for the # context. If greenlets are not available fall back to the current @@ -21,8 +12,11 @@ except ImportError: from threading import get_ident +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + -def release_local(local: Union["LocalStack", "Local"]) -> None: +def release_local(local: t.Union["Local", "LocalStack"]) -> None: """Releases the contents of the local for the current context. This makes it possible to use locals without a manager. @@ -52,7 +46,7 @@ def __init__(self) -> None: object.__setattr__(self, "__storage__", {}) object.__setattr__(self, "__ident_func__", get_ident) - def __iter__(self): + def __iter__(self) -> t.Iterator[t.Tuple[int, t.Any]]: return iter(self.__storage__.items()) def __call__(self, proxy: str) -> "LocalProxy": @@ -62,15 +56,13 @@ def __call__(self, proxy: str) -> "LocalProxy": def __release_local__(self) -> None: self.__storage__.pop(self.__ident_func__(), None) - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> t.Any: try: return self.__storage__[self.__ident_func__()][name] except KeyError: raise AttributeError(name) - def __setattr__( - self, name: str, value: Union["_InteractiveConsole", "HTMLStringO", int] - ) -> None: + def __setattr__(self, name: str, value: t.Any) -> None: ident = self.__ident_func__() storage = self.__storage__ try: @@ -119,11 +111,11 @@ def __release_local__(self) -> None: self._local.__release_local__() @property - def __ident_func__(self): + def __ident_func__(self) -> t.Callable[[], int]: return self._local.__ident_func__ @__ident_func__.setter - def __ident_func__(self, value): + def __ident_func__(self, value: t.Callable[[], int]) -> None: object.__setattr__(self._local, "__ident_func__", value) def __call__(self) -> "LocalProxy": @@ -135,15 +127,15 @@ def _lookup(): return LocalProxy(_lookup) - def push(self, obj: Any) -> Any: + def push(self, obj: t.Any) -> t.List[t.Any]: """Pushes a new item to the stack""" rv = getattr(self._local, "stack", None) if rv is None: - self._local.stack = rv = [] # type: ignore + self._local.stack = rv = [] rv.append(obj) return rv - def pop(self) -> Any: + def pop(self) -> t.Any: """Removes the topmost item from the stack, will return the old value or `None` if the stack was already empty. """ @@ -157,7 +149,7 @@ def pop(self) -> Any: return stack.pop() @property - def top(self) -> Any: + def top(self) -> t.Any: """The topmost item on the stack. If the stack is empty, `None` is returned. """ @@ -186,15 +178,15 @@ class LocalManager: def __init__( self, - locals: Optional[List[Union[Local, LocalStack]]] = None, - ident_func: Optional[Callable] = None, + locals: t.Optional[t.Iterable[t.Union[Local, LocalStack]]] = None, + ident_func: t.Optional[t.Callable[[], int]] = None, ) -> None: if locals is None: self.locals = [] elif isinstance(locals, Local): self.locals = [locals] else: - self.locals = list(locals) # type: ignore + self.locals = list(locals) if ident_func is not None: self.ident_func = ident_func for local in self.locals: @@ -202,7 +194,7 @@ def __init__( else: self.ident_func = get_ident - def get_ident(self) -> Any: + def get_ident(self) -> int: """Return the context identifier the local objects use internally for this context. You cannot override this method to change the behavior but use it to link other context local objects (such as SQLAlchemy's @@ -222,9 +214,7 @@ def cleanup(self): for local in self.locals: release_local(local) - def make_middleware( - self, app: Callable[[Any, Any], Any] - ) -> Callable[[WSGIEnvironment, Any], ClosingIterator]: + def make_middleware(self, app: "WSGIApplication") -> "WSGIApplication": """Wrap a WSGI application so that cleaning up happens after request end. """ @@ -234,7 +224,7 @@ def application(environ, start_response): return application - def middleware(self, func: Callable) -> Callable: + def middleware(self, func: "WSGIApplication") -> "WSGIApplication": """Like `make_middleware` but for decorating functions. Example usage:: @@ -249,7 +239,7 @@ def application(environ, start_response): """ return update_wrapper(self.make_middleware(func), func) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} storages: {len(self.locals)}>" @@ -292,7 +282,9 @@ class LocalProxy: __slots__ = ("__local", "__dict__", "__name__", "__wrapped__") def __init__( - self, local: Union[Any, "LocalProxy", "LocalStack"], name: Optional[str] = None, + self, + local: t.Union["Local", t.Callable[[], t.Any]], + name: t.Optional[str] = None, ) -> None: object.__setattr__(self, "_LocalProxy__local", local) object.__setattr__(self, "__name__", name) @@ -301,7 +293,7 @@ def __init__( # LocalManager: mark it as a wrapped function. object.__setattr__(self, "__wrapped__", local) - def _get_current_object(self,) -> object: + def _get_current_object(self) -> t.Any: """Return the current object. This is useful if you want the real object behind the proxy at a time for performance reasons or because you want to pass the object into a different context. @@ -333,21 +325,21 @@ def __bool__(self) -> bool: except RuntimeError: return False - def __dir__(self): + def __dir__(self) -> t.List[str]: try: return dir(self._get_current_object()) except RuntimeError: return [] - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> t.Any: if name == "__members__": return dir(self._get_current_object()) return getattr(self._get_current_object(), name) - def __setitem__(self, key: Any, value: Any) -> None: - self._get_current_object()[key] = value # type: ignore + def __setitem__(self, key: str, value: t.Any) -> None: + self._get_current_object()[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str): del self._get_current_object()[key] __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) # type: ignore diff --git a/src/werkzeug/middleware/dispatcher.py b/src/werkzeug/middleware/dispatcher.py index 7bc7d65d2..5f79db910 100644 --- a/src/werkzeug/middleware/dispatcher.py +++ b/src/werkzeug/middleware/dispatcher.py @@ -30,13 +30,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from typing import Iterable -from typing import Mapping -from typing import Optional -from typing import Text -from typing import TYPE_CHECKING +import typing as t -if TYPE_CHECKING: +if t.TYPE_CHECKING: from wsgiref.types import StartResponse from wsgiref.types import WSGIApplication from wsgiref.types import WSGIEnvironment @@ -55,14 +51,14 @@ class DispatcherMiddleware: def __init__( self, app: "WSGIApplication", - mounts: Optional[Mapping[Text, "WSGIApplication"]] = None, + mounts: t.Optional[t.Dict[str, "WSGIApplication"]] = None, ) -> None: self.app = app self.mounts = mounts or {} def __call__( self, environ: "WSGIEnvironment", start_response: "StartResponse" - ) -> Iterable[bytes]: + ) -> t.Iterable[bytes]: script = environ.get("PATH_INFO", "") path_info = "" diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 56d2f8eac..c09231c11 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,13 +7,8 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ +import typing as t from http import client -from typing import Any -from typing import Iterable -from typing import Mapping -from typing import MutableMapping -from typing import Text -from typing import TYPE_CHECKING from ..datastructures import EnvironHeaders from ..http import is_hop_by_hop_header @@ -21,14 +16,11 @@ from ..urls import url_quote from ..wsgi import get_input_stream -if TYPE_CHECKING: +if t.TYPE_CHECKING: from wsgiref.types import StartResponse from wsgiref.types import WSGIApplication from wsgiref.types import WSGIEnvironment -_MutableOpts = MutableMapping[Text, Any] -_Opts = Mapping[Text, Any] - class ProxyMiddleware: """Proxy requests under a path to an external server, routing other @@ -87,11 +79,11 @@ class ProxyMiddleware: def __init__( self, app: "WSGIApplication", - targets: Mapping[Text, _MutableOpts], + targets: t.Mapping[str, t.Dict[str, t.Any]], chunk_size: int = 2 << 13, timeout: int = 10, ) -> None: - def _set_defaults(opts): + def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: opts.setdefault("remove_prefix", False) opts.setdefault("host", "") opts.setdefault("headers", {}) @@ -105,10 +97,15 @@ def _set_defaults(opts): self.chunk_size = chunk_size self.timeout = timeout - def proxy_to(self, opts: _Opts, path: Text, prefix: Text) -> "WSGIApplication": + def proxy_to( + self, opts: t.Dict[str, t.Any], path: str, prefix: str + ) -> "WSGIApplication": target = url_parse(opts["target"]) + host = t.cast(str, target.ascii_host) - def application(environ, start_response): + def application( + environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: headers = list(EnvironHeaders(environ).items()) headers[:] = [ (k, v) @@ -119,7 +116,7 @@ def application(environ, start_response): headers.append(("Connection", "close")) if opts["host"] == "": - headers.append(("Host", target.ascii_host)) + headers.append(("Host", host)) elif opts["host"] is None: headers.append(("Host", environ["HTTP_HOST"])) else: @@ -136,7 +133,7 @@ def application(environ, start_response): chunked = False if content_length not in ("", None): - headers.append(("Content-Length", content_length)) + headers.append(("Content-Length", content_length)) # type: ignore elif content_length is not None: headers.append(("Transfer-Encoding", "chunked")) chunked = True @@ -144,11 +141,11 @@ def application(environ, start_response): try: if target.scheme == "http": con = client.HTTPConnection( - target.ascii_host, target.port or 80, timeout=self.timeout, + host, target.port or 80, timeout=self.timeout ) elif target.scheme == "https": con = client.HTTPSConnection( - target.ascii_host, + host, target.port or 443, timeout=self.timeout, context=opts["ssl_context"], @@ -177,7 +174,7 @@ def application(environ, start_response): con.endheaders() stream = get_input_stream(environ) - while 1: + while True: data = stream.read(self.chunk_size) if not data: @@ -203,8 +200,8 @@ def application(environ, start_response): ], ) - def read(): - while 1: + def read() -> t.Iterator[bytes]: + while True: try: data = resp.read(self.chunk_size) except OSError: @@ -221,7 +218,7 @@ def read(): def __call__( self, environ: "WSGIEnvironment", start_response: "StartResponse" - ) -> Iterable[bytes]: + ) -> t.Iterable[bytes]: path = environ["PATH_INFO"] app = self.app diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 298914836..0724f014f 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -12,21 +12,19 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from io import BytesIO -from typing import Any -from typing import Callable -from typing import IO -from typing import List -from typing import Set -from typing import Tuple -from typing import Union +import typing as t +from types import TracebackType from urllib.parse import urlparse from warnings import warn from ..datastructures import Headers from ..http import is_entity_header from ..wsgi import FileWrapper -from werkzeug.types import WSGIEnvironment + +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment class WSGIWarning(Warning): @@ -37,15 +35,17 @@ class HTTPWarning(Warning): """Warning class for HTTP warnings.""" -def check_string(context: str, obj: Any, stacklevel: int = 3) -> None: - if type(obj) is not str: +def check_type(context: str, obj: object, need: t.Type = str) -> None: + if type(obj) is not need: warn( - f"'{context}' requires strings, got '{type(obj).__name__}'", WSGIWarning, + f"{context!r} requires {need.__name__!r}, got {type(obj).__name__!r}.", + WSGIWarning, + stacklevel=3, ) class InputStream: - def __init__(self, stream: BytesIO) -> None: + def __init__(self, stream: t.BinaryIO) -> None: self._stream = stream def read(self, *args): @@ -92,18 +92,16 @@ def __iter__(self): return iter(()) def close(self): - warn( - "The application closed the input stream!", WSGIWarning, stacklevel=2, - ) + warn("The application closed the input stream!", WSGIWarning, stacklevel=2) self._stream.close() class ErrorStream: - def __init__(self, stream: IO) -> None: + def __init__(self, stream: t.TextIO) -> None: self._stream = stream def write(self, s): - check_string("wsgi.error.write()", s) + check_type("wsgi.error.write()", s, str) self._stream.write(s) def flush(self): @@ -114,29 +112,27 @@ def writelines(self, seq): self.write(line) def close(self): - warn( - "The application closed the error stream!", WSGIWarning, stacklevel=2, - ) + warn("The application closed the error stream!", WSGIWarning, stacklevel=2) self._stream.close() class GuardedWrite: - def __init__(self, write: Callable, chunks: List[Any]) -> None: + def __init__(self, write: t.Callable[[bytes], None], chunks: t.List[int]) -> None: self._write = write self._chunks = chunks - def __call__(self, s): - check_string("write()", s) - self._write.write(s) + def __call__(self, s: bytes) -> None: + check_type("write()", s, bytes) + self._write(s) self._chunks.append(len(s)) class GuardedIterator: def __init__( self, - iterator: List[str], - headers_set: Set[Union[int, Headers]], - chunks: List[Any], + iterator: t.Iterable[bytes], + headers_set: t.Tuple[int, Headers], + chunks: t.List[int], ) -> None: self._iterator = iterator self._next = iter(iterator).__next__ @@ -147,7 +143,7 @@ def __init__( def __iter__(self) -> "GuardedIterator": return self - def __next__(self) -> str: + def __next__(self) -> bytes: if self.closed: warn("Iterated over closed 'app_iter'.", WSGIWarning, stacklevel=2) @@ -160,7 +156,7 @@ def __next__(self) -> str: stacklevel=2, ) - check_string("application iterator items", rv) + check_type("application iterator items", rv, bytes) self.chunks.append(len(rv)) return rv @@ -173,34 +169,31 @@ def close(self) -> None: if self.headers_set: status_code, headers = self.headers_set bytes_sent = sum(self.chunks) - content_length = headers.get("content-length", type=int) # type: ignore + content_length = headers.get("content-length", type=int) if status_code == 304: - for key, _value in headers: # type: ignore + for key, _value in headers: key = key.lower() - if key not in ("expires", "content-location",) and is_entity_header( + if key not in ("expires", "content-location") and is_entity_header( key ): warn( - f"Entity header {key!r} found in 304 response.", - HTTPWarning, + f"Entity header {key!r} found in 304 response.", HTTPWarning ) if bytes_sent: warn("304 responses must not have a body.", HTTPWarning) - elif 100 <= status_code < 200 or status_code == 204: # type: ignore + elif 100 <= status_code < 200 or status_code == 204: if content_length != 0: warn( f"{status_code} responses must have an empty content length.", HTTPWarning, ) if bytes_sent: - warn( - f"{status_code} responses must not have a body.", HTTPWarning, - ) + warn(f"{status_code} responses must not have a body.", HTTPWarning) elif content_length is not None and content_length != bytes_sent: warn( - "Content-Length and the number of bytes sent to the client do not" - " match.", + "Content-Length and the number of bytes sent to the" + " client do not match.", WSGIWarning, ) @@ -208,7 +201,7 @@ def __del__(self) -> None: if not self.closed: try: warn( - "Iterator was garbage collected before it was closed.", WSGIWarning, + "Iterator was garbage collected before it was closed.", WSGIWarning ) except Exception: pass @@ -237,10 +230,10 @@ class LintMiddleware: app = LintMiddleware(app) """ - def __init__(self, app: Callable) -> None: + def __init__(self, app: "WSGIApplication") -> None: self.app = app - def check_environ(self, environ: WSGIEnvironment,) -> None: + def check_environ(self, environ: "WSGIEnvironment") -> None: if type(environ) is not dict: warn( "WSGI environment is not a standard Python dict.", @@ -287,52 +280,51 @@ def check_environ(self, environ: WSGIEnvironment,) -> None: def check_start_response( self, status: str, - headers: Union[List[Tuple[str, str]], Tuple[Tuple[str, str]], List[str]], - exc_info: None, - ) -> Tuple[int, Headers]: - check_string("status", status) - status_code = status.split(None, 1)[0] + headers: t.List[t.Tuple[str, str]], + exc_info: t.Optional[ + t.Tuple[t.Type[BaseException], BaseException, TracebackType] + ], + ) -> t.Tuple[int, Headers]: + check_type("status", status, str) + status_code_str = status.split(None, 1)[0] - if len(status_code) != 3 or not status_code.isdigit(): - warn(WSGIWarning("Status code must be three digits"), stacklevel=3) + if len(status_code_str) != 3 or not status_code_str.isdigit(): + warn("Status code must be three digits.", WSGIWarning, stacklevel=3) if len(status) < 4 or status[3] != " ": warn( - WSGIWarning( - f"Invalid value for status {status!r}. Valid " - "status strings are three digits, a space " - "and a status explanation" - ), + f"Invalid value for status {status!r}. Valid status strings are three" + " digits, a space and a status explanation.", + WSGIWarning, stacklevel=3, ) - status_code = int(status_code) + status_code = int(status_code_str) if status_code < 100: - warn(WSGIWarning("status code < 100 detected"), stacklevel=3) + warn("Status code < 100 detected.", WSGIWarning, stacklevel=3) if type(headers) is not list: - warn(WSGIWarning("header list is not a list"), stacklevel=3) + warn("Header list is not a list.", WSGIWarning, stacklevel=3) for item in headers: if type(item) is not tuple or len(item) != 2: + warn("Header items must be 2-item tuples.", WSGIWarning, stacklevel=3) + name, value = item + if type(name) is not str or type(value) is not str: warn( - WSGIWarning("Headers must tuple 2-item tuples"), stacklevel=3, + "Header keys and values must be strings.", WSGIWarning, stacklevel=3 ) - name, value = item # type: ignore - if type(name) is not str or type(value) is not str: - warn(WSGIWarning("header items must be strings"), stacklevel=3) if name.lower() == "status": warn( - WSGIWarning( - "The status header is not supported due to " - "conflicts with the CGI spec." - ), + "The status header is not supported due to" + " conflicts with the CGI spec.", + WSGIWarning, stacklevel=3, ) if exc_info is not None and not isinstance(exc_info, tuple): - warn(WSGIWarning("invalid value for exc_info"), stacklevel=3) + warn("Invalid value for exc_info.", WSGIWarning, stacklevel=3) headers = Headers(headers) self.check_headers(headers) @@ -346,46 +338,47 @@ def check_headers(self, headers: Headers) -> None: if etag.startswith(("W/", "w/")): if etag.startswith("w/"): warn( - HTTPWarning("weak etag indicator should be upcase."), + "Weak etag indicator should be upper case.", + HTTPWarning, stacklevel=4, ) etag = etag[2:] if not (etag[:1] == etag[-1:] == '"'): - warn(HTTPWarning("unquoted etag emitted."), stacklevel=4) + warn("Unquoted etag emitted.", HTTPWarning, stacklevel=4) location = headers.get("location") if location is not None: if not urlparse(location).netloc: warn( - HTTPWarning("absolute URLs required for location header"), + "Absolute URLs required for location header.", + HTTPWarning, stacklevel=4, ) - def check_iterator(self, app_iter: List[str]) -> None: - if isinstance(app_iter, str): + def check_iterator(self, app_iter: t.Iterable[bytes]) -> None: + if isinstance(app_iter, bytes): warn( - "The application returned a string. The response will send one" + "The application returned a bytestring. The response will send one" " character at a time to the client, which will kill performance." " Return a list or iterable instead.", WSGIWarning, stacklevel=3, ) - def __call__(self, *args, **kwargs) -> GuardedIterator: + def __call__(self, *args, **kwargs) -> t.Iterable[bytes]: if len(args) != 2: warn("A WSGI app takes two arguments.", WSGIWarning, stacklevel=2) if kwargs: warn( - "A WSGI app does not take keyword arguments.", - WSGIWarning, - stacklevel=2, + "A WSGI app does not take keyword arguments.", WSGIWarning, stacklevel=2 ) - environ, start_response = args + environ: "WSGIEnvironment" = args[0] + start_response: "StartResponse" = args[1] self.check_environ(environ) environ["wsgi.input"] = InputStream(environ["wsgi.input"]) @@ -395,11 +388,11 @@ def __call__(self, *args, **kwargs) -> GuardedIterator: # iterate to the end and we can check the content length. environ["wsgi.file_wrapper"] = FileWrapper - headers_set: List[Any] = [] - chunks: List[Any] = [] + headers_set: t.List[t.Any] = [] + chunks: t.List[int] = [] - def checking_start_response(*args, **kwargs): - if len(args) not in (2, 3): + def checking_start_response(*args, **kwargs) -> t.Callable[[bytes], None]: + if len(args) not in {2, 3}: warn( f"Invalid number of arguments: {len(args)}, expected 2 or 3.", WSGIWarning, @@ -407,20 +400,19 @@ def checking_start_response(*args, **kwargs): ) if kwargs: - warn( - "'start_response' does not take keyword arguments.", WSGIWarning, - ) + warn("'start_response' does not take keyword arguments.", WSGIWarning) - status, headers = args[:2] - - if len(args) == 3: - exc_info = args[2] - else: - exc_info = None + status: str = args[0] + headers: t.List[t.Tuple[str, str]] = args[1] + exc_info: t.Optional[ + t.Tuple[t.Type[BaseException], BaseException, TracebackType] + ] = (args[2] if len(args) == 3 else None) headers_set[:] = self.check_start_response(status, headers, exc_info) return GuardedWrite(start_response(status, headers, exc_info), chunks) - app_iter = self.app(environ, checking_start_response) + app_iter = self.app(environ, t.cast("StartResponse", checking_start_response)) self.check_iterator(app_iter) - return GuardedIterator(app_iter, headers_set, chunks) # type: ignore + return GuardedIterator( + app_iter, t.cast(t.Tuple[int, Headers], headers_set), chunks + ) diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index e46e6b112..6166e4563 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -14,26 +14,19 @@ import os.path import sys import time +import typing as t from pstats import Stats -from typing import IO -from typing import Iterable -from typing import List -from typing import Optional -from typing import Text -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union - -if TYPE_CHECKING: - from wsgiref.types import StartResponse - from wsgiref.types import WSGIApplication - from wsgiref.types import WSGIEnvironment try: from cProfile import Profile except ImportError: from profile import Profile # type: ignore +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment + class ProfilerMiddleware: """Wrap a WSGI application and profile the execution of each @@ -84,11 +77,11 @@ class ProfilerMiddleware: def __init__( self, app: "WSGIApplication", - stream: IO[str] = sys.stdout, - sort_by: Tuple[Text, Text] = ("time", "calls"), - restrictions: Iterable[Union[str, float]] = (), - profile_dir: Optional[Text] = None, - filename_format: Text = "{method}.{path}.{elapsed:.0f}ms.{time:.0f}.prof", + stream: t.TextIO = sys.stdout, + sort_by: t.Iterable[str] = ("time", "calls"), + restrictions: t.Iterable[t.Union[str, int, float]] = (), + profile_dir: t.Optional[str] = None, + filename_format: str = "{method}.{path}.{elapsed:.0f}ms.{time:.0f}.prof", ) -> None: self._app = app self._stream = stream @@ -99,15 +92,17 @@ def __init__( def __call__( self, environ: "WSGIEnvironment", start_response: "StartResponse" - ) -> List[bytes]: - response_body: List[bytes] = [] + ) -> t.Iterable[bytes]: + response_body: t.List[bytes] = [] def catching_start_response(status, headers, exc_info=None): start_response(status, headers, exc_info) return response_body.append def runapp(): - app_iter = self._app(environ, catching_start_response) + app_iter = self._app( + environ, t.cast("StartResponse", catching_start_response) + ) response_body.extend(app_iter) if hasattr(app_iter, "close"): @@ -125,9 +120,7 @@ def runapp(): else: filename = self._filename_format.format( method=environ["REQUEST_METHOD"], - path=( - environ.get("PATH_INFO").strip("/").replace("/", ".") or "root" - ), + path=environ["PATH_INFO"].strip("/").replace("/", ".") or "root", elapsed=elapsed * 1000.0, time=time.time(), ) diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index f141b9010..e1f8d3219 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -21,11 +21,14 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ -from typing import Callable -from typing import Optional +import typing as t -from werkzeug.http import parse_list_header -from werkzeug.types import WSGIEnvironment +from ..http import parse_list_header + +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment class ProxyFix: @@ -89,7 +92,7 @@ class ProxyFix: def __init__( self, - app: Callable, + app: "WSGIApplication", x_for: int = 1, x_proto: int = 1, x_host: int = 0, @@ -103,7 +106,7 @@ def __init__( self.x_port = x_port self.x_prefix = x_prefix - def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]: + def _get_real_value(self, trusted: int, value: t.Optional[str]) -> t.Optional[str]: """Get the real value from a list header based on the configured number of trusted proxies. @@ -124,7 +127,9 @@ def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]: return values[-trusted] return None - def __call__(self, environ: WSGIEnvironment, start_response: Callable,) -> Callable: + def __call__( + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: """Modify the WSGI environ based on the various ``Forwarded`` headers before calling the wrapped application. Store the original environ values in ``werkzeug.proxy_fix.orig_{key}``. diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 4c1338079..776498ee9 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -12,16 +12,11 @@ import os import pkgutil import posixpath +import typing as t from datetime import datetime from io import BytesIO from time import mktime from time import time -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union from zlib import adler32 from ..filesystem import get_filesystem_encoding @@ -31,8 +26,14 @@ from ..utils import get_content_type from ..wsgi import get_path_info from ..wsgi import wrap_file -from werkzeug.types import WSGIEnvironment -from werkzeug.wsgi import FileWrapper + +_TOpener = t.Callable[[], t.Tuple[t.BinaryIO, datetime, int]] +_TLoader = t.Callable[[t.Optional[str]], t.Tuple[str, _TOpener]] + +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment class SharedDataMiddleware: @@ -99,10 +100,10 @@ class SharedDataMiddleware: def __init__( self, - app: Optional[Callable], - exports: Union[ - List[Union[Tuple[str, str], Tuple[str, Tuple[str, str]]]], - Dict[str, Union[str, Tuple[str, str]]], + app: "WSGIApplication", + exports: t.Union[ + t.Dict[str, t.Union[str, t.Tuple[str, str]]], + t.Iterable[t.Tuple[str, t.Union[str, t.Tuple[str, str]]]], ], disallow: None = None, cache: bool = True, @@ -110,14 +111,14 @@ def __init__( fallback_mimetype: str = "application/octet-stream", ) -> None: self.app = app - self.exports = [] + self.exports: t.List[t.Tuple[str, _TLoader]] = [] self.cache = cache self.cache_timeout = cache_timeout - if hasattr(exports, "items"): - exports = exports.items() # type: ignore + if isinstance(exports, dict): + exports = exports.items() - for key, value in exports: # type: ignore + for key, value in exports: if isinstance(value, tuple): loader = self.get_package_loader(*value) elif isinstance(value, str): @@ -144,17 +145,17 @@ def is_allowed(self, filename: str) -> bool: """ return True - def _opener(self, filename: str) -> Callable: + def _opener(self, filename: str) -> _TOpener: return lambda: ( open(filename, "rb"), datetime.utcfromtimestamp(os.path.getmtime(filename)), int(os.path.getsize(filename)), ) - def get_file_loader(self, filename: str) -> Callable: + def get_file_loader(self, filename: str) -> _TLoader: return lambda x: (os.path.basename(filename), self._opener(filename)) - def get_package_loader(self, package: str, package_path: str) -> Callable: + def get_package_loader(self, package: str, package_path: str) -> _TLoader: load_time = datetime.utcnow() provider = pkgutil.get_loader(package) @@ -217,7 +218,7 @@ def loader(path): return loader - def get_directory_loader(self, directory: str) -> Callable: + def get_directory_loader(self, directory: str) -> _TLoader: def loader(path): if path is not None: path = safe_join(directory, path) @@ -242,10 +243,9 @@ def generate_etag(self, mtime: datetime, file_size: int, real_filename: str) -> return f"wzsdm-{timestamp}-{file_size}-{checksum}" def __call__( - self, environ: WSGIEnvironment, start_response: Callable, - ) -> Union[FileWrapper, list]: + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: path = get_path_info(environ) - file_loader = None for search_path, loader in self.exports: @@ -258,7 +258,7 @@ def __call__( if not search_path.endswith("/"): search_path += "/" - if path.startswith(search_path): # type: ignore + if path.startswith(search_path): real_filename, file_loader = loader(path[len(search_path) :]) if file_loader is not None: diff --git a/src/werkzeug/py.typed b/src/werkzeug/py.typed index f5642f79f..e69de29bb 100644 --- a/src/werkzeug/py.typed +++ b/src/werkzeug/py.typed @@ -1 +0,0 @@ -Marker diff --git a/src/werkzeug/routing.py b/src/werkzeug/routing.py index aba7c04ab..6efbec8ce 100644 --- a/src/werkzeug/routing.py +++ b/src/werkzeug/routing.py @@ -109,24 +109,14 @@ import difflib import posixpath import re +import typing +import typing as t import uuid import warnings from pprint import pformat from string import Template from threading import Lock -from typing import Any -from typing import Callable -from typing import Dict -from typing import Hashable -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Optional -from typing import Pattern -from typing import Set -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union +from types import CodeType from ._internal import _encode_idna from ._internal import _get_environ @@ -147,11 +137,10 @@ from .utils import cached_property from .utils import redirect from .wsgi import get_host -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from uuid import UUID - from werkzeug.wrappers.response import Response +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment _rule_re = re.compile( r""" @@ -187,7 +176,7 @@ _PYTHON_CONSTANTS = {"None": None, "True": True, "False": False} -def _pythonize(value: str) -> Optional[Union[bool, str, float, int]]: +def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: if value in _PYTHON_CONSTANTS: return _PYTHON_CONSTANTS[value] for convert in int, float: @@ -200,7 +189,7 @@ def _pythonize(value: str) -> Optional[Union[bool, str, float, int]]: return str(value) -def parse_converter_args(argstr: str) -> Any: +def parse_converter_args(argstr: str) -> t.Tuple[t.Tuple, t.Dict[str, t.Any]]: argstr += "," args = [] kwargs = {} @@ -219,9 +208,7 @@ def parse_converter_args(argstr: str) -> Any: return tuple(args), kwargs -def parse_rule( - rule: str, -) -> Iterator[Union[Tuple[Optional[str], Optional[str], Optional[str]]]]: +def parse_rule(rule: str) -> t.Iterator[t.Tuple[t.Optional[str], t.Optional[str], str]]: """Parse a rule and return it as generator. Each iteration yields tuples in the form ``(converter, arguments, variable)``. If the converter is `None` it's a static url part, otherwise it's a dynamic one. @@ -271,10 +258,10 @@ class RequestRedirect(HTTPException, RoutingException): code = 308 def __init__(self, new_url: str) -> None: - RoutingException.__init__(self, new_url) + super().__init__(new_url) self.new_url = new_url - def get_response(self, environ: Optional[WSGIEnvironment] = None,) -> "Response": + def get_response(self, environ=None): return redirect(self.new_url, self.code) @@ -284,13 +271,15 @@ class RequestPath(RoutingException): __slots__ = ("path_info",) def __init__(self, path_info: str) -> None: + super().__init__() self.path_info = path_info class RequestAliasRedirect(RoutingException): # noqa: B903 """This rule is an alias and wants to redirect to the canonical URL.""" - def __init__(self, matched_values: Dict[Any, Any]) -> None: + def __init__(self, matched_values: t.Mapping[str, t.Any]) -> None: + super().__init__() self.matched_values = matched_values @@ -302,22 +291,22 @@ class BuildError(RoutingException, LookupError): def __init__( self, endpoint: str, - values: Optional[Union[Dict[str, int], Dict[str, str]]], - method: Optional[str], - adapter: Optional["MapAdapter"] = None, + values: t.Mapping[str, t.Any], + method: t.Optional[str], + adapter: t.Optional["MapAdapter"] = None, ) -> None: - LookupError.__init__(self, endpoint, values, method) + super().__init__(endpoint, values, method) self.endpoint = endpoint self.values = values self.method = method self.adapter = adapter @cached_property - def suggested(self): + def suggested(self) -> t.Optional["Rule"]: return self.closest_rule(self.adapter) - def closest_rule(self, adapter: "MapAdapter") -> Optional["Rule"]: - def _score_rule(rule): + def closest_rule(self, adapter: t.Optional["MapAdapter"]) -> t.Optional["Rule"]: + def _score_rule(rule: "Rule") -> float: return sum( [ 0.98 @@ -378,7 +367,7 @@ class RuleFactory: be added by subclassing `RuleFactory` and overriding `get_rules`. """ - def get_rules(self, map): + def get_rules(self, map: "Map") -> t.Iterable["Rule"]: """Subclasses of `RuleFactory` have to override this method and return an iterable of rules.""" raise NotImplementedError() @@ -403,11 +392,11 @@ class Subdomain(RuleFactory): for the current request. """ - def __init__(self, subdomain: str, rules: List["Rule"]) -> None: + def __init__(self, subdomain: str, rules: t.Iterable["Rule"]) -> None: self.subdomain = subdomain self.rules = rules - def get_rules(self, map: "Map") -> Iterator["Rule"]: + def get_rules(self, map: "Map") -> t.Iterator["Rule"]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -429,11 +418,11 @@ class Submount(RuleFactory): Now the rule ``'blog/show'`` matches ``/blog/entry/``. """ - def __init__(self, path: str, rules: List["Rule"]) -> None: + def __init__(self, path: str, rules: t.Iterable["Rule"]) -> None: self.path = path.rstrip("/") self.rules = rules - def get_rules(self, map: "Map") -> Iterator["Rule"]: + def get_rules(self, map: "Map") -> t.Iterator["Rule"]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -454,11 +443,11 @@ class EndpointPrefix(RuleFactory): ]) """ - def __init__(self, prefix: str, rules: List["Rule"]) -> None: + def __init__(self, prefix: str, rules: t.Iterable["Rule"]) -> None: self.prefix = prefix self.rules = rules - def get_rules(self, map: "Map") -> Iterator["Rule"]: + def get_rules(self, map: "Map") -> t.Iterator["Rule"]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -485,7 +474,7 @@ class RuleTemplate: replace the placeholders in all the string parameters. """ - def __init__(self, rules: List[Union[Submount, EndpointPrefix, Subdomain]]) -> None: + def __init__(self, rules: t.Iterable["Rule"]) -> None: self.rules = list(rules) def __call__(self, *args, **kwargs) -> "RuleTemplateFactory": @@ -499,15 +488,11 @@ class RuleTemplateFactory(RuleFactory): :internal: """ - def __init__( - self, - rules: List[Union[Submount, EndpointPrefix, Subdomain]], - context: Dict[str, str], - ) -> None: + def __init__(self, rules: t.Iterable["Rule"], context: t.Dict[str, t.Any]) -> None: self.rules = rules self.context = context - def get_rules(self, map: "Map") -> Iterator["Rule"]: + def get_rules(self, map: "Map") -> t.Iterator["Rule"]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): new_defaults = subdomain = None @@ -533,7 +518,7 @@ def get_rules(self, map: "Map") -> Iterator["Rule"]: ) -def _prefix_names(src: str) -> Any: +def _prefix_names(src: str) -> ast.stmt: """ast parse and prefix names with `.` to avoid collision with user vars""" tree = ast.parse(src).body[0] if isinstance(tree, ast.Expr): @@ -680,25 +665,19 @@ def foo_with_slug(adapter, id): ``HEAD`` is added to ``methods`` if ``GET`` is present. """ - _trace: Optional[List[Any]] - _converters: Optional[Dict[Hashable, Any]] - _static_weights: Optional[List[Any]] - _argument_weights: Optional[List[Any]] - _regex: Optional[Pattern] - def __init__( self, string: str, - defaults: Optional[Any] = None, - subdomain: Optional[str] = None, - methods: Optional[Union[Iterable[str], str]] = None, + defaults: t.Optional[t.Mapping[str, t.Any]] = None, + subdomain: t.Optional[str] = None, + methods: t.Optional[t.Iterable[str]] = None, build_only: bool = False, - endpoint: Optional[str] = None, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - redirect_to: None = None, + endpoint: t.Optional[str] = None, + strict_slashes: t.Optional[bool] = None, + merge_slashes: t.Optional[bool] = None, + redirect_to: t.Optional[t.Union[str, t.Callable[..., str]]] = None, alias: bool = False, - host: Optional[str] = None, + host: t.Optional[str] = None, websocket: bool = False, ) -> None: if not string.startswith("/"): @@ -706,7 +685,7 @@ def __init__( self.rule = string self.is_leaf = not string.endswith("/") - self.map: Optional[Map] = None + self.map: "Map" = None # type: ignore self.strict_slashes = strict_slashes self.merge_slashes = merge_slashes self.subdomain = subdomain @@ -731,14 +710,13 @@ def __init__( ) self.methods = methods - self.endpoint = endpoint + self.endpoint: str = endpoint # type: ignore self.redirect_to = redirect_to if defaults: self.arguments = set(map(str, defaults)) else: self.arguments = set() - self._trace = self._converters = self._regex = self._argument_weights = None def empty(self) -> "Rule": """ @@ -748,11 +726,9 @@ def empty(self) -> "Rule": map. See ``get_empty_kwargs`` to override what keyword arguments are provided to the new copy. """ - return type(self)(self.rule, **self.get_empty_kwargs()) # type: ignore + return type(self)(self.rule, **self.get_empty_kwargs()) - def get_empty_kwargs( - self, - ) -> Dict[str, Optional[Union[Dict[str, str], str, Set[str], bool]]]: + def get_empty_kwargs(self) -> t.Mapping[str, t.Any]: """ Provides kwargs for instantiating empty copy with empty() @@ -779,10 +755,10 @@ def get_empty_kwargs( host=self.host, ) - def get_rules(self, map: "Map") -> Iterator["Rule"]: + def get_rules(self, map: "Map") -> t.Iterator["Rule"]: yield self - def refresh(self): + def refresh(self) -> None: """Rebinds and refreshes the URL. Call this if you modified the rule in place. @@ -811,8 +787,8 @@ def get_converter( self, variable_name: str, converter_name: str, - args: Tuple, - kwargs: Dict[str, bool], + args: t.Tuple, + kwargs: t.Mapping[str, t.Any], ) -> "BaseConverter": """Looks up the converter for the given parameter. @@ -822,9 +798,7 @@ def get_converter( raise LookupError(f"the converter {converter_name!r} does not exist") return self.map.converters[converter_name](self.map, *args, **kwargs) - def _encode_query_vars( - self, query_vars: Dict[str, Union[float, List[float], int]] - ) -> str: + def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]): return url_encode( query_vars, charset=self.map.charset, @@ -841,13 +815,13 @@ def compile(self) -> None: else: domain_rule = self.subdomain or "" - self._trace = [] - self._converters = {} - self._static_weights = [] - self._argument_weights = [] - regex_parts: List[Any] = [] + self._trace: t.List[t.Tuple[bool, str]] = [] + self._converters: t.Dict[str, "BaseConverter"] = {} + self._static_weights: t.List[t.Tuple[int, int]] = [] + self._argument_weights: t.List[int] = [] + regex_parts = [] - def _build_regex(rule): + def _build_regex(rule: str) -> None: index = 0 for converter, arguments, variable in parse_rule(rule): if converter is None: @@ -903,7 +877,9 @@ def _build_regex(rule): regex = f"^{''.join(regex_parts)}{tail}$" self._regex = re.compile(regex) - def match(self, path: str, method: Optional[str] = None) -> Optional[dict]: + def match( + self, path: str, method: t.Optional[str] = None + ) -> t.Optional[t.Mapping[str, t.Any]]: """Check if the rule matches a given path. Path is a string in the form ``"subdomain|/path"`` and is assembled by the map. If the map is doing host matching the subdomain part will be the host @@ -950,7 +926,7 @@ def match(self, path: str, method: Optional[str] = None) -> Optional[dict]: result.update(self.defaults) if self.merge_slashes: - new_path = "|".join(self.build(result, False)) + new_path = "|".join(self.build(result, False)) # type: ignore if path.endswith("/") and not new_path.endswith("/"): new_path += "/" if new_path.count("/") < path.count("/"): @@ -969,16 +945,18 @@ def match(self, path: str, method: Optional[str] = None) -> Optional[dict]: return None @staticmethod - def _get_func_code(code: Any, name: Any) -> Any: - globs: Dict[Any, Any] = {} - locs: Dict[Any, Any] = {} + def _get_func_code(code: CodeType, name: str) -> t.Callable[..., t.Tuple[str, str]]: + globs: t.Dict[str, t.Any] = {} + locs: t.Dict[str, t.Any] = {} exec(code, globs, locs) return locs[name] - def _compile_builder(self, append_unknown: bool = True) -> Callable: + def _compile_builder( + self, append_unknown: bool = True + ) -> t.Callable[..., t.Tuple[str, str]]: defaults = self.defaults or {} - dom_ops: List[Any] = [] - url_ops: List[Any] = [] + dom_ops: t.List[t.Tuple[bool, str]] = [] + url_ops: t.List[t.Tuple[bool, str]] = [] opl = dom_ops for is_dynamic, data in self._trace: @@ -993,7 +971,7 @@ def _compile_builder(self, append_unknown: bool = True) -> Callable: opl.append((False, data)) elif not is_dynamic: opl.append( - (False, url_quote(_to_bytes(data, self.map.charset), safe="/:|+"),) + (False, url_quote(_to_bytes(data, self.map.charset), safe="/:|+")) ) else: opl.append((True, data)) @@ -1042,7 +1020,7 @@ def _join(parts): ] kargs = [str(k) for k in defaults] - func_ast = _prefix_names("def _(): pass") + func_ast: ast.FunctionDef = _prefix_names("def _(): pass") # type: ignore func_ast.name = f"" func_ast.args.args.append(ast.arg(".self", None)) for arg in pargs + kargs: @@ -1070,8 +1048,8 @@ def _join(parts): return self._get_func_code(code, func_ast.name) def build( - self, values: Dict[str, Any], append_unknown: bool = True - ) -> Tuple[str, str]: + self, values: t.Mapping[str, t.Any], append_unknown: bool = True + ) -> t.Optional[t.Tuple[str, str]]: """Assembles the relative url for that rule and the subdomain. If building doesn't work for some reasons `None` is returned. @@ -1085,12 +1063,12 @@ def build( except ValidationError: return None - def provides_defaults_for(self, rule: "Rule") -> Optional[bool]: + def provides_defaults_for(self, rule: "Rule") -> bool: """Check if this rule has defaults for a given rule. :internal: """ - return ( + return bool( not self.build_only and self.defaults and self.endpoint == rule.endpoint @@ -1099,7 +1077,7 @@ def provides_defaults_for(self, rule: "Rule") -> Optional[bool]: ) def suitable_for( - self, values: Dict[str, Any], method: Optional[str] = None + self, values: t.Mapping[str, t.Any], method: t.Optional[str] = None ) -> bool: """Check if the dict of values has enough data for url generation. @@ -1133,12 +1111,7 @@ def suitable_for( def match_compare_key( self, - ) -> Union[ - Tuple[bool, int, List[Any], int, List[Any]], - Tuple[bool, int, List[Tuple[int, int]], int, List[int]], - Tuple[bool, int, List[Tuple[int, int]], int, List[Any]], - Tuple[bool, int, List[Any], int, List[int]], - ]: + ) -> t.Tuple[bool, int, t.Iterable[t.Tuple[int, int]], int, t.Iterable[int]]: """The match compare key for sorting. Current implementation: @@ -1164,26 +1137,19 @@ def match_compare_key( self._argument_weights, ) - def build_compare_key(self) -> Tuple[int, int, int]: + def build_compare_key(self) -> t.Tuple[int, int, int]: """The build compare key for sorting. :internal: """ - return ( - 1 if self.alias else 0, - -len(self.arguments), - -len(self.defaults or ()), - ) + return (1 if self.alias else 0, -len(self.arguments), -len(self.defaults or ())) def __eq__(self, other: object) -> bool: - return ( - self.__class__ is other.__class__ - and self._trace == other._trace # type: ignore - ) + return isinstance(other, type(self)) and self._trace == other._trace - __hash__ = None + __hash__ = None # type: ignore - def __str__(self): + def __str__(self) -> str: return self.rule def __repr__(self) -> str: @@ -1206,13 +1172,13 @@ class BaseConverter: regex = "[^/]+" weight = 100 - def __init__(self, map: "Map") -> None: + def __init__(self, map: "Map", *args, **kwargs) -> None: self.map = map - def to_python(self, value: str) -> str: + def to_python(self, value: str) -> t.Any: return value - def to_url(self, value: Union[str, int, bytes]) -> str: + def to_url(self, value: t.Any) -> str: if isinstance(value, (bytes, bytearray)): return _fast_url_quote(value) return _fast_url_quote(str(value).encode(self.map.charset)) @@ -1240,19 +1206,19 @@ def __init__( self, map: "Map", minlength: int = 1, - maxlength: Optional[int] = None, - length: Optional[int] = None, + maxlength: t.Optional[int] = None, + length: t.Optional[int] = None, ) -> None: - BaseConverter.__init__(self, map) + super().__init__(map) if length is not None: - length = f"{{{int(length)}}}" # type: ignore + length_regex = f"{{{int(length)}}}" else: if maxlength is None: - maxlength = "" # type: ignore + maxlength_value = "" else: - maxlength = int(maxlength) - length = f"{{{int(minlength)},{maxlength}}}" # type: ignore - self.regex = f"[^/]{length}" + maxlength_value = str(int(maxlength)) + length_regex = f"{{{int(minlength)},{maxlength_value}}}" + self.regex = f"[^/]{length_regex}" class AnyConverter(BaseConverter): @@ -1266,8 +1232,8 @@ class AnyConverter(BaseConverter): arguments. """ - def __init__(self, map: "Map", *items) -> None: - BaseConverter.__init__(self, map) + def __init__(self, map: "Map", *items: str) -> None: + super().__init__(map) self.regex = f"(?:{'|'.join([re.escape(x) for x in items])})" @@ -1292,38 +1258,39 @@ class NumberConverter(BaseConverter): """ weight = 50 + num_convert: t.Callable = int def __init__( self, map: "Map", fixed_digits: int = 0, - min: None = None, - max: None = None, + min: t.Optional[int] = None, + max: t.Optional[int] = None, signed: bool = False, ) -> None: if signed: self.regex = self.signed_regex - BaseConverter.__init__(self, map) + super().__init__(map) self.fixed_digits = fixed_digits self.min = min self.max = max self.signed = signed - def to_python(self, value: Any) -> Union[int, float]: # type: ignore + def to_python(self, value: str) -> t.Any: if self.fixed_digits and len(value) != self.fixed_digits: raise ValidationError() - value = self.num_convert(value) # type: ignore + value = self.num_convert(value) if (self.min is not None and value < self.min) or ( self.max is not None and value > self.max ): raise ValidationError() return value - def to_url(self, value: Any) -> str: - value = self.num_convert(value) # type: ignore + def to_url(self, value: t.Any) -> str: + value = str(self.num_convert(value)) if self.fixed_digits: - value = str(value).zfill(self.fixed_digits) - return str(value) + value = value.zfill(self.fixed_digits) + return value @property def signed_regex(self) -> str: @@ -1353,7 +1320,6 @@ class IntegerConverter(NumberConverter): """ regex = r"\d+" - num_convert = int class FloatConverter(NumberConverter): @@ -1378,10 +1344,8 @@ class FloatConverter(NumberConverter): regex = r"\d+\.\d+" num_convert = float - def __init__( - self, map: "Map", min: None = None, max: None = None, signed: bool = False, - ) -> None: - NumberConverter.__init__(self, map, min=min, max=max, signed=signed) + def __init__(self, map, min=None, max=None, signed=False): + super().__init__(map, min=min, max=max, signed=signed) class UUIDConverter(BaseConverter): @@ -1399,15 +1363,15 @@ class UUIDConverter(BaseConverter): r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" ) - def to_python(self, value: str) -> "UUID": # type: ignore + def to_python(self, value: str) -> uuid.UUID: return uuid.UUID(value) - def to_url(self, value: "UUID") -> str: # type: ignore + def to_url(self, value): return str(value) #: the default converter mapping for the map. -DEFAULT_CONVERTERS = { +DEFAULT_CONVERTERS: t.Mapping[str, t.Type[BaseConverter]] = { "default": UnicodeConverter, "string": UnicodeConverter, "any": AnyConverter, @@ -1473,20 +1437,20 @@ class Map: def __init__( self, - rules: Optional[Union[List[RuleTemplateFactory], List[Rule]]] = None, + rules: t.Optional[t.Iterable[RuleFactory]] = None, default_subdomain: str = "", charset: str = "utf-8", strict_slashes: bool = True, merge_slashes: bool = True, redirect_defaults: bool = True, - converters: None = None, + converters: t.Optional[t.Mapping[str, t.Type[BaseConverter]]] = None, sort_parameters: bool = False, - sort_key: Optional[Callable] = None, + sort_key: t.Optional[t.Callable[[t.Any], t.Any]] = None, encoding_errors: str = "replace", host_matching: bool = False, ) -> None: - self._rules: List[Any] = [] - self._rules_by_endpoint: Dict[Hashable, Any] = {} + self._rules: t.List[Rule] = [] + self._rules_by_endpoint: t.Dict[str, t.List[Rule]] = {} self._remap = True self._remap_lock = self.lock_class() @@ -1508,7 +1472,7 @@ def __init__( for rulefactory in rules or (): self.add(rulefactory) - def is_endpoint_expecting(self, endpoint, *arguments): + def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: """Iterate over all rules and check if the endpoint expects the arguments provided. This is for example useful if you have some URLs that expect a language code and others that do not and @@ -1528,7 +1492,7 @@ def is_endpoint_expecting(self, endpoint, *arguments): return True return False - def iter_rules(self, endpoint=None): + def iter_rules(self, endpoint: t.Optional[str] = None) -> t.Iterator[Rule]: """Iterate over all rules or the rules of an endpoint. :param endpoint: if provided only the rules for that endpoint @@ -1540,7 +1504,7 @@ def iter_rules(self, endpoint=None): return iter(self._rules_by_endpoint[endpoint]) return iter(self._rules) - def add(self, rulefactory: Union[Rule, RuleTemplateFactory]) -> None: + def add(self, rulefactory: RuleFactory) -> None: """Add a new rule or factory to the map and bind it. Requires that the rule is not bound to another map. @@ -1555,12 +1519,12 @@ def add(self, rulefactory: Union[Rule, RuleTemplateFactory]) -> None: def bind( self, server_name: str, - script_name: Optional[str] = None, - subdomain: Optional[str] = None, + script_name: t.Optional[str] = None, + subdomain: t.Optional[str] = None, url_scheme: str = "http", default_method: str = "GET", - path_info: Optional[str] = None, - query_args: Optional[str] = None, + path_info: t.Optional[str] = None, + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, ) -> "MapAdapter": """Return a new :class:`MapAdapter` with the details specified to the call. Note that `script_name` will default to ``'/'`` if not further @@ -1618,9 +1582,9 @@ def bind( def bind_to_environ( self, - environ: WSGIEnvironment, - server_name: Optional[str] = None, - subdomain: Optional[str] = None, + environ: "WSGIEnvironment", + server_name: t.Optional[str] = None, + subdomain: t.Optional[str] = None, ) -> "MapAdapter": """Like :meth:`bind` but you can pass it an WSGI environment and it will fetch the information from that dictionary. Note that because of @@ -1698,10 +1662,11 @@ def bind_to_environ( else: subdomain = ".".join(filter(None, cur_server_name[:offset])) - def _get_wsgi_string(name): + def _get_wsgi_string(name: str) -> t.Optional[str]: val = environ.get(name) if val is not None: return _wsgi_decoding_dance(val, self.charset) + return None script_name = _get_wsgi_string("SCRIPT_NAME") path_info = _get_wsgi_string("PATH_INFO") @@ -1747,14 +1712,14 @@ class MapAdapter: def __init__( self, map: Map, - server_name: Union[str, bytes], + server_name: str, script_name: str, - subdomain: Optional[str], + subdomain: t.Optional[str], url_scheme: str, path_info: str, default_method: str, - query_args: Optional[str] = None, - ) -> None: + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + ): self.map = map self.server_name = _to_str(server_name) script_name = _to_str(script_name) @@ -1770,11 +1735,11 @@ def __init__( def dispatch( self, - view_func: Callable, - path_info: Optional[str] = None, - method: Optional[str] = None, + view_func: t.Callable[[str, t.Mapping[str, t.Any]], "WSGIApplication"], + path_info: t.Optional[str] = None, + method: t.Optional[str] = None, catch_http_exceptions: bool = False, - ) -> Any: + ) -> "WSGIApplication": """Does the complete dispatching process. `view_func` is called with the endpoint and a dict with the values for the view. It should look up the view function, call it, and return a response object @@ -1828,14 +1793,36 @@ def application(environ, start_response): return e raise + @typing.overload + def match( # type: ignore + self, + path_info: t.Optional[str] = None, + method: t.Optional[str] = None, + return_rule: "t.Literal[False]" = False, + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + websocket: t.Optional[bool] = None, + ) -> t.Tuple[str, t.Mapping[str, t.Any]]: + ... + + @typing.overload + def match( + self, + path_info: t.Optional[str] = None, + method: t.Optional[str] = None, + return_rule: "t.Literal[True]" = True, + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + websocket: t.Optional[bool] = None, + ) -> t.Tuple[Rule, t.Mapping[str, t.Any]]: + ... + def match( self, - path_info: Optional[str] = None, - method: Optional[str] = None, - return_rule: bool = False, - query_args: Optional[Union[str, Dict[str, str]]] = None, - websocket: Optional[bool] = None, - ) -> Any: + path_info=None, + method=None, + return_rule=False, + query_args=None, + websocket=None, + ): """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -1957,11 +1944,7 @@ def match( except RequestAliasRedirect as e: raise RequestRedirect( self.make_alias_redirect_url( - path, - rule.endpoint, - e.matched_values, - method, - query_args, # type: ignore + path, rule.endpoint, e.matched_values, method, query_args ) ) if rv is None: @@ -1975,9 +1958,7 @@ def match( continue if self.map.redirect_defaults: - redirect_url = self.get_default_redirect( - rule, method, rv, query_args # type: ignore - ) + redirect_url = self.get_default_redirect(rule, method, rv, query_args) if redirect_url is not None: raise RequestRedirect(redirect_url) @@ -2007,7 +1988,7 @@ def _handle_match(match): if require_redirect: raise RequestRedirect( self.make_redirect_url( - url_quote(path_info, self.map.charset, safe="/:|+"), query_args, + url_quote(path_info, self.map.charset, safe="/:|+"), query_args ) ) @@ -2024,7 +2005,9 @@ def _handle_match(match): raise NotFound() - def test(self, path_info=None, method=None): + def test( + self, path_info: t.Optional[str] = None, method: t.Optional[str] = None + ) -> bool: """Test if a rule would match. Works like `match` but returns `True` if the URL matches, or `False` if it does not exist. @@ -2041,7 +2024,7 @@ def test(self, path_info=None, method=None): return False return True - def allowed_methods(self, path_info=None): + def allowed_methods(self, path_info: t.Optional[str] = None) -> t.Iterable[str]: """Returns the valid methods that match for a given path. .. versionadded:: 0.7 @@ -2049,12 +2032,12 @@ def allowed_methods(self, path_info=None): try: self.match(path_info, method="--") except MethodNotAllowed as e: - return e.valid_methods + return e.valid_methods # type: ignore except HTTPException: pass return [] - def get_host(self, domain_part: Optional[str]) -> str: + def get_host(self, domain_part: t.Optional[str]) -> str: """Figures out the full host name for the given domain part. The domain part is a subdomain in case host matching is disabled or a full host name. @@ -2078,9 +2061,9 @@ def get_default_redirect( self, rule: Rule, method: str, - values: Dict[str, Any], - query_args: Optional[str], - ) -> Optional[str]: + values: t.MutableMapping[str, t.Any], + query_args: t.Union[t.Mapping[str, t.Any], str], + ) -> t.Optional[str]: """A helper that returns the URL to redirect to if it finds one. This is used for default redirecting only. @@ -2094,21 +2077,21 @@ def get_default_redirect( if r is rule: break if r.provides_defaults_for(rule) and r.suitable_for(values, method): - values.update(r.defaults) - domain_part, path = r.build(values) + values.update(r.defaults) # type: ignore + domain_part, path = r.build(values) # type: ignore return self.make_redirect_url(path, query_args, domain_part=domain_part) return None - def encode_query_args(self, query_args: Union[str, Dict[str, str]]) -> str: + def encode_query_args(self, query_args: t.Union[t.Mapping[str, t.Any], str]) -> str: if not isinstance(query_args, str): - query_args = url_encode(query_args, self.map.charset) + return url_encode(query_args, self.map.charset) return query_args def make_redirect_url( self, path_info: str, - query_args: Optional[Union[str, Dict[str, str]]] = None, - domain_part: Optional[str] = None, + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + domain_part: t.Optional[str] = None, ) -> str: """Creates a redirect URL. @@ -2128,9 +2111,9 @@ def make_alias_redirect_url( self, path: str, endpoint: str, - values: Dict[str, int], + values: t.Mapping[str, t.Any], method: str, - query_args: Optional[Dict[str, str]], + query_args: t.Union[t.Mapping[str, t.Any], str], ) -> str: """Internally called to make an alias redirect URL.""" url = self.build( @@ -2144,10 +2127,10 @@ def make_alias_redirect_url( def _partial_build( self, endpoint: str, - values: Dict[str, Any], - method: Optional[str], + values: t.Mapping[str, t.Any], + method: t.Optional[str], append_unknown: bool, - ) -> Optional[Tuple[str, str, bool]]: + ) -> t.Optional[t.Tuple[str, str, bool]]: """Helper for :meth:`build`. Returns subdomain and path for the rule that accepts this endpoint, values and method. @@ -2168,10 +2151,10 @@ def _partial_build( for rule in self.map._rules_by_endpoint.get(endpoint, ()): if rule.suitable_for(values, method): - rv = rule.build(values, append_unknown) + build_rv = rule.build(values, append_unknown) - if rv is not None: - rv = (rv[0], rv[1], rule.websocket) + if build_rv is not None: + rv = (build_rv[0], build_rv[1], rule.websocket) if self.map.host_matching: if rv[0] == self.server_name: return rv @@ -2185,11 +2168,11 @@ def _partial_build( def build( self, endpoint: str, - values: Optional[Any] = None, - method: Optional[str] = None, + values: t.Optional[t.Mapping[str, t.Any]] = None, + method: t.Optional[str] = None, force_external: bool = False, append_unknown: bool = True, - url_scheme: Optional[str] = None, + url_scheme: t.Optional[str] = None, ) -> str: """Building URLs works pretty much the other way round. Instead of `match` you call `build` and pass it the endpoint and a dict of diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 8c95888ae..96579bd92 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -4,29 +4,27 @@ import os import posixpath import secrets -from hmac import HMAC -from typing import AnyStr -from typing import Callable -from typing import Optional -from typing import Tuple -from typing import Union +import typing as t from ._internal import _to_bytes +if t.TYPE_CHECKING: + pass + SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" DEFAULT_PBKDF2_ITERATIONS = 260000 -_os_alt_seps = list( - sep for sep in [os.path.sep, os.path.altsep] if sep not in (None, "/") +_os_alt_seps: t.List[str] = list( + sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/" ) def pbkdf2_hex( - data: AnyStr, - salt: str, + data: t.AnyStr, + salt: t.AnyStr, iterations: int = DEFAULT_PBKDF2_ITERATIONS, - keylen: Optional[int] = None, - hashfunc: Optional[str] = None, + keylen: t.Optional[int] = None, + hashfunc: t.Optional[str] = None, ) -> str: """Like :func:`pbkdf2_bin`, but returns a hex-encoded string. @@ -46,11 +44,11 @@ def pbkdf2_hex( def pbkdf2_bin( - data: AnyStr, - salt: str, + data: t.AnyStr, + salt: t.AnyStr, iterations: int = DEFAULT_PBKDF2_ITERATIONS, - keylen: Optional[int] = None, - hashfunc: Optional[str] = None, + keylen: t.Optional[int] = None, + hashfunc: t.Optional[str] = None, ) -> bytes: """Returns a binary digest for the PBKDF2 hash algorithm of `data` with the given `salt`. It iterates `iterations` times and produces a @@ -105,7 +103,7 @@ def gen_salt(length: int) -> str: return "".join(secrets.choice(SALT_CHARS) for _ in range(length)) -def _hash_internal(method: str, salt: str, password: str) -> Tuple[str, str]: +def _hash_internal(method: str, salt: str, password: str) -> t.Tuple[str, str]: """Internal password hash helper. Supports plaintext without salt, unsalted and salted passwords. In case salted passwords are used hmac is used. @@ -142,18 +140,14 @@ def _hash_internal(method: str, salt: str, password: str) -> Tuple[str, str]: return rv, actual_method -def _create_mac( - key: Union[bytes, bytearray], - msg: Union[bytes, bytearray], - method: Union[Callable, str], -) -> HMAC: +def _create_mac(key: bytes, msg: bytes, method: str) -> hmac.HMAC: if callable(method): - return hmac.HMAC(key, msg, method) # type: ignore + return hmac.new(key, msg, method) def hashfunc(d=b""): return hashlib.new(method, d) - return hmac.HMAC(key, msg, hashfunc) # type: ignore + return hmac.new(key, msg, hashfunc) def generate_password_hash( @@ -205,7 +199,7 @@ def check_password_hash(pwhash: str, password: str) -> bool: return safe_str_cmp(_hash_internal(method, salt, password)[0], hashval) -def safe_join(directory: str, *pathnames: str) -> Optional[str]: +def safe_join(directory: str, *pathnames: str) -> t.Optional[str]: """Safely join zero or more untrusted path components to a base directory to avoid escaping the base directory. diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index 451e663b5..3e9f90c11 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -18,24 +18,13 @@ import socket import socketserver import sys +import typing +import typing as t import warnings from datetime import datetime as dt from datetime import timedelta from http.server import BaseHTTPRequestHandler from http.server import HTTPServer -from typing import Any -from typing import Callable -from typing import List -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union - -if TYPE_CHECKING: - # these are used only in type checking and are quoted so flake8 can't pick them up - from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey # noqa: F401 - from cryptography.hazmat.backends.openssl.x509 import _Certificate # noqa: F401 -from werkzeug.types import WSGIEnvironment from ._internal import _log from ._internal import _wsgi_encoding_dance @@ -57,7 +46,7 @@ def __getattr__(self, name): try: import click except ImportError: - click = None + click = None # type: ignore can_fork = hasattr(os, "fork") @@ -72,16 +61,28 @@ class ForkingMixIn: # type: ignore try: af_unix = socket.AF_UNIX except AttributeError: - af_unix = None + af_unix = None # type: ignore LISTEN_QUEUE = 128 can_open_by_fd = not platform.system() == "Windows" and hasattr(socket, "fromfd") +_TSSLContextArg = t.Optional[ + t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], "t.Literal['adhoc']"] +] + +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKeyWithSerialization, + ) + from cryptography.x509 import Certificate + class DechunkedInput(io.RawIOBase): """An input stream that handles Transfer-Encoding 'chunked'""" - def __init__(self, rfile): + def __init__(self, rfile: t.BinaryIO) -> None: self._rfile = rfile self._done = False self._len = 0 @@ -99,7 +100,7 @@ def read_chunk_len(self) -> int: raise OSError("Negative chunk length not allowed") return _len - def readinto(self, buf) -> int: + def readinto(self, buf: bytearray) -> int: # type: ignore read = 0 while not self._done and read < len(buf): if self._len == 0: @@ -132,16 +133,17 @@ def readinto(self, buf) -> int: class WSGIRequestHandler(BaseHTTPRequestHandler): - """A request handler that implements WSGI dispatching.""" + server: "BaseWSGIServer" + @property def server_version(self) -> str: # type: ignore from . import __version__ return f"Werkzeug/{__version__}" - def make_environ(self) -> WSGIEnvironment: + def make_environ(self) -> "WSGIEnvironment": request_url = url_parse(self.path) def shutdown_server(): @@ -152,15 +154,12 @@ def shutdown_server(): ) self.server.shutdown_signal = True - url_scheme = ( - "http" if self.server.ssl_context is None else "https" # type: ignore - ) + url_scheme = "http" if self.server.ssl_context is None else "https" + if not self.client_address: - self.client_address = "" # type: ignore - if isinstance(self.client_address, str): + self.client_address = ("", 0) + elif isinstance(self.client_address, str): self.client_address = (self.client_address, 0) - else: - pass # If there was no scheme but the path started with two slashes, # the first segment may have been incorrectly parsed as the @@ -170,15 +169,15 @@ def shutdown_server(): else: path_info = request_url.path - path_info = url_unquote(path_info) # type: ignore + path_info = url_unquote(path_info) - environ = { + environ: "WSGIEnvironment" = { "wsgi.version": (1, 0), "wsgi.url_scheme": url_scheme, "wsgi.input": self.rfile, "wsgi.errors": sys.stderr, - "wsgi.multithread": self.server.multithread, # type: ignore - "wsgi.multiprocess": self.server.multiprocess, # type: ignore + "wsgi.multithread": self.server.multithread, + "wsgi.multiprocess": self.server.multiprocess, "wsgi.run_once": False, "werkzeug.server.shutdown": shutdown_server, "SERVER_SOFTWARE": self.server_version, @@ -224,9 +223,7 @@ def shutdown_server(): environ["SSL_CLIENT_CERT"] = ssl.DER_cert_to_PEM_cert(peer_cert) except ValueError: # SSL handshake hasn't finished. - self.server.log( # type: ignore - "error", "Cannot fetch SSL peer certificate info" - ) + self.server.log("error", "Cannot fetch SSL peer certificate info") except AttributeError: # Not using TLS, the socket will not have getpeercert(). pass @@ -238,21 +235,26 @@ def run_wsgi(self) -> None: self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.environ = environ = self.make_environ() - headers_set: List[Any] = [] - headers_sent: List[Any] = [] - - def write(data): - assert headers_set, "write() before start_response" - if not headers_sent: - status, response_headers = headers_sent[:] = headers_set + status_set: t.Optional[str] = None + headers_set: t.Optional[t.List[t.Tuple[str, str]]] = None + status_sent: t.Optional[str] = None + headers_sent: t.Optional[t.List[t.Tuple[str, str]]] = None + + def write(data: bytes) -> None: + nonlocal status_sent, headers_sent + assert status_set is not None, "write() before start_response" + assert headers_set is not None, "write() before start_response" + if status_sent is None: + status_sent = status_set + headers_sent = headers_set try: - code, msg = status.split(None, 1) + code_str, msg = status_sent.split(None, 1) except ValueError: - code, msg = status, "" - code = int(code) + code_str, msg = status_sent, "" + code = int(code_str) self.send_response(code, msg) header_keys = set() - for key, value in response_headers: + for key, value in headers_sent: self.send_header(key, value) key = key.lower() header_keys.add(key) @@ -274,7 +276,8 @@ def write(data): self.wfile.write(data) self.wfile.flush() - def start_response(status, response_headers, exc_info=None): + def start_response(status, headers, exc_info=None): + nonlocal status_set, headers_set if exc_info: try: if headers_sent: @@ -283,10 +286,11 @@ def start_response(status, response_headers, exc_info=None): exc_info = None elif headers_set: raise AssertionError("Headers already set") - headers_set[:] = [status, response_headers] + status_set = status + headers_set = headers return write - def execute(app): + def execute(app: "WSGIApplication") -> None: application_iter = app(environ, start_response) try: for data in application_iter: @@ -295,14 +299,14 @@ def execute(app): write(b"") finally: if hasattr(application_iter, "close"): - application_iter.close() + application_iter.close() # type: ignore try: - execute(self.server.app) # type: ignore + execute(self.server.app) except (ConnectionError, socket.timeout) as e: self.connection_dropped(e, environ) except Exception: - if self.server.passthrough_errors: # type: ignore + if self.server.passthrough_errors: raise from .debug.tbtools import get_current_traceback @@ -310,14 +314,13 @@ def execute(app): try: # if we haven't yet sent the headers but they are set # we roll back to be able to set them again. - if not headers_sent: - del headers_set[:] + if status_sent is None: + status_set = None + headers_set = None execute(InternalServerError()) except Exception: pass - self.server.log( # type: ignore - "error", "Error on request:\n%s", traceback.plaintext - ) + self.server.log("error", "Error on request:\n%s", traceback.plaintext) def handle(self) -> None: """Handles a request ignoring dropped connections.""" @@ -326,9 +329,9 @@ def handle(self) -> None: except (ConnectionError, socket.timeout) as e: self.connection_dropped(e) except Exception as e: - if self.server.ssl_context is None or not is_ssl_error(e): # type: ignore + if self.server.ssl_context is None or not is_ssl_error(e): raise - if self.server.shutdown_signal: # type: ignore + if self.server.shutdown_signal: self.initiate_shutdown() def initiate_shutdown(self) -> None: @@ -340,8 +343,8 @@ def initiate_shutdown(self) -> None: self.server._BaseServer__shutdown_request = True # type: ignore def connection_dropped( - self, error, environ: Optional[WSGIEnvironment] = None - ) -> Any: + self, error: BaseException, environ: t.Optional["WSGIEnvironment"] = None + ) -> None: """Called if the connection was closed by the client. By default nothing happens. """ @@ -350,11 +353,11 @@ def handle_one_request(self) -> None: """Handle a single HTTP request.""" self.raw_requestline = self.rfile.readline() if not self.raw_requestline: - self.close_connection = 1 # type: ignore + self.close_connection = True elif self.parse_request(): - return self.run_wsgi() + self.run_wsgi() - def send_response(self, code, message=None) -> None: + def send_response(self, code: int, message: t.Optional[str] = None) -> None: """Send the response header and log the response code.""" self.log_request(code) if message is None: @@ -364,22 +367,21 @@ def send_response(self, code, message=None) -> None: self.wfile.write(hdr.encode("ascii")) def version_string(self) -> str: - return BaseHTTPRequestHandler.version_string(self).strip() + return super().version_string().strip() def address_string(self) -> str: if getattr(self, "environ", None): return self.environ["REMOTE_ADDR"] - elif not self.client_address: + + if not self.client_address: return "" - elif isinstance(self.client_address, str): - return self.client_address - else: - return self.client_address[0] + + return self.client_address[0] def port_integer(self) -> int: return self.client_address[1] - def log_request(self, code="-", size="-") -> None: + def log_request(self, code: t.Union[int, str] = "-", size: t.Union[int, str] = "-"): try: path = uri_to_iri(self.path) msg = f"{self.command} {path} {self.request_version}" @@ -412,10 +414,10 @@ def log_request(self, code="-", size="-") -> None: def log_error(self, *args) -> None: self.log("error", *args) - def log_message(self, format, *args) -> None: + def log_message(self, format: str, *args) -> None: self.log("info", format, *args) - def log(self, type, message, *args) -> None: + def log(self, type: str, message: str, *args) -> None: _log( type, f"{self.address_string()} - - [{self.log_date_time_string()}] {message}\n", @@ -423,13 +425,9 @@ def log(self, type, message, *args) -> None: ) -#: backwards compatible name if someone is subclassing it -BaseRequestHandler = WSGIRequestHandler - - def generate_adhoc_ssl_pair( - cn: Optional[str] = None, -) -> Tuple["_Certificate", "_RSAPrivateKey"]: + cn: t.Optional[str] = None, +) -> t.Tuple["Certificate", "RSAPrivateKeyWithSerialization"]: try: from cryptography import x509 from cryptography.x509.oid import NameOID @@ -469,10 +467,8 @@ def generate_adhoc_ssl_pair( def make_ssl_devcert( - base_path: Union[str, "os.PathLike"], - host: Optional[str] = None, - cn: Optional[str] = None, -) -> Tuple[str, str]: + base_path: str, host: t.Optional[str] = None, cn: t.Optional[str] = None +) -> t.Tuple[str, str]: """Creates an SSL key for development. This should be used instead of the ``'adhoc'`` key which generates a new cert on each server start. It accepts a path for where it should store the key and cert and @@ -514,7 +510,7 @@ def make_ssl_devcert( return cert_file, pkey_file -def generate_adhoc_ssl_context() -> ssl.SSLContext: +def generate_adhoc_ssl_context() -> "ssl.SSLContext": """Generates an adhoc SSL context for the development server.""" import tempfile import atexit @@ -544,7 +540,9 @@ def generate_adhoc_ssl_context() -> ssl.SSLContext: return ctx -def load_ssl_context(cert_file, pkey_file=None, protocol=None) -> ssl.SSLContext: +def load_ssl_context( + cert_file: str, pkey_file: t.Optional[str] = None, protocol: t.Optional[int] = None +) -> "ssl.SSLContext": """Loads SSL context from cert/private key files and optional protocol. Many parameters are directly taken from the API of :py:class:`ssl.SSLContext`. @@ -570,7 +568,7 @@ def is_ssl_error(error=None): return isinstance(error, ssl.SSLError) -def select_address_family(host: str, port: Union[str, int]): +def select_address_family(host: str, port: int) -> socket.AddressFamily: """Return ``AF_INET4``, ``AF_INET6``, or ``AF_UNIX`` depending on the host and port.""" if host.startswith("unix://"): @@ -580,7 +578,9 @@ def select_address_family(host: str, port: Union[str, int]): return socket.AF_INET -def get_sockaddr(host, port, family): +def get_sockaddr( + host: str, port: int, family: socket.AddressFamily +) -> t.Union[t.Tuple[str, int], str]: """Return a fully qualified socket address that can be passed to :func:`socket.bind`.""" if family == af_unix: @@ -591,27 +591,27 @@ def get_sockaddr(host, port, family): ) except socket.gaierror: return host, port - return res[0][4] + return res[0][4] # type: ignore class BaseWSGIServer(HTTPServer): """Simple single-threaded, single-process WSGI server.""" - multithread: bool = False - multiprocess: bool = False - request_queue_size: int = LISTEN_QUEUE + multithread = False + multiprocess = False + request_queue_size = LISTEN_QUEUE def __init__( self, host: str, - port: Union[str, int], - app: Optional[Callable], - handler: Optional[Callable] = None, - passthrough_errors=False, - ssl_context: Union[str, Tuple[Any, ...], ssl.SSLContext] = None, - fd=None, - ): + port: int, + app: "WSGIApplication", + handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + passthrough_errors: bool = False, + ssl_context: t.Optional[_TSSLContextArg] = None, + fd: t.Optional[int] = None, + ) -> None: if handler is None: handler = WSGIRequestHandler @@ -624,9 +624,13 @@ def __init__( server_address = get_sockaddr(host, int(port), self.address_family) # remove socket file if it already exists - if self.address_family == af_unix and os.path.exists(server_address): - os.unlink(server_address) - HTTPServer.__init__(self, server_address, handler) + if self.address_family == af_unix: + server_address = t.cast(str, server_address) + + if os.path.exists(server_address): + os.unlink(server_address) + + super().__init__(server_address, handler) # type: ignore self.app = app self.passthrough_errors = passthrough_errors @@ -646,61 +650,55 @@ def __init__( if ssl_context == "adhoc": ssl_context = generate_adhoc_ssl_context() - self.socket = ssl_context.wrap_socket( # type: ignore - self.socket, server_side=True - ) - self.ssl_context = ssl_context + self.socket = ssl_context.wrap_socket(self.socket, server_side=True) + self.ssl_context: t.Optional["ssl.SSLContext"] = ssl_context else: self.ssl_context = None - def log(self, type, message, *args) -> None: + def log(self, type: str, message: str, *args) -> None: _log(type, message, *args) - def serve_forever(self) -> None: # type: ignore + def serve_forever(self, poll_interval=0.5) -> None: self.shutdown_signal = False try: - HTTPServer.serve_forever(self) + super().serve_forever(poll_interval=poll_interval) except KeyboardInterrupt: pass finally: self.server_close() - def handle_error(self, request, client_address) -> None: + def handle_error(self, request: t.Any, client_address: t.Tuple[str, int]) -> None: if self.passthrough_errors: raise - return HTTPServer.handle_error(self, request, client_address) - - def get_request(self) -> Tuple[ssl.SSLSocket, Any]: # type: ignore - con, info = self.socket.accept() - return con, info + return super().handle_error(request, client_address) class ThreadedWSGIServer(socketserver.ThreadingMixIn, BaseWSGIServer): """A WSGI server that does threading.""" - multithread: bool = True - daemon_threads: bool = True + multithread = True + daemon_threads = True class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer): """A WSGI server that does forking.""" - multiprocess: bool = True + multiprocess = True def __init__( self, - host, - port, - app, - processes=40, - handler=None, - passthrough_errors=False, - ssl_context=None, - fd=None, - ): + host: str, + port: int, + app: "WSGIApplication", + processes: int = 40, + handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + passthrough_errors: bool = False, + ssl_context: t.Optional[_TSSLContextArg] = None, + fd: t.Optional[int] = None, + ) -> None: if not can_fork: raise ValueError("Your platform does not support forking.") BaseWSGIServer.__init__( @@ -710,15 +708,15 @@ def __init__( def make_server( - host=None, - port=None, - app=None, - threaded=False, - processes=1, - request_handler=None, - passthrough_errors=False, - ssl_context=None, - fd=None, + host: str, + port: int, + app: "WSGIApplication", + threaded: bool = False, + processes: int = 1, + request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + passthrough_errors: bool = False, + ssl_context: t.Optional[_TSSLContextArg] = None, + fd: t.Optional[int] = None, ) -> BaseWSGIServer: """Create a new server instance that is either threaded, or forks or just processes one request after another. @@ -727,7 +725,7 @@ def make_server( raise ValueError("cannot have a multithreaded and multi process server.") elif threaded: return ThreadedWSGIServer( - host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd, + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd ) elif processes > 1: return ForkingWSGIServer( @@ -742,7 +740,7 @@ def make_server( ) else: return BaseWSGIServer( - host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd, + host, port, app, request_handler, passthrough_errors, ssl_context, fd=fd ) @@ -756,23 +754,23 @@ def is_running_from_reloader() -> bool: def run_simple( - hostname, - port, - application, - use_reloader=False, - use_debugger=False, - use_evalex=True, - extra_files=None, - exclude_patterns=None, - reloader_interval=1, - reloader_type="auto", - threaded=False, - processes=1, - request_handler=None, - static_files=None, - passthrough_errors=False, - ssl_context=None, -): + hostname: str, + port: int, + application: "WSGIApplication", + use_reloader: bool = False, + use_debugger: bool = False, + use_evalex: bool = True, + extra_files: t.Optional[t.Iterable[str]] = None, + exclude_patterns: t.Optional[t.Iterable[str]] = None, + reloader_interval: int = 1, + reloader_type: str = "auto", + threaded: bool = False, + processes: int = 1, + request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + static_files: t.Optional[t.Dict[str, t.Union[str, t.Tuple[str, str]]]] = None, + passthrough_errors: bool = False, + ssl_context: t.Optional[_TSSLContextArg] = None, +) -> None: """Start a WSGI application. Optional features include a reloader, multithreading and fork support. @@ -928,12 +926,13 @@ def inner(): else: s.close() if address_family == af_unix: + server_address = t.cast(str, server_address) _log("info", "Unlinking %s", server_address) os.unlink(server_address) - from ._reloader import run_with_reloader + from ._reloader import run_with_reloader as _rwr - run_with_reloader( + _rwr( inner, extra_files=extra_files, exclude_patterns=exclude_patterns, @@ -944,14 +943,14 @@ def inner(): inner() -def run_with_reloader(*args, **kwargs): +def run_with_reloader(*args, **kwargs) -> None: """Run a process with the reloader. This is not a public API, do not use this function. .. deprecated:: 2.0 This function will be removed in version 2.1. """ - from ._reloader import run_with_reloader + from ._reloader import run_with_reloader as _rwr warnings.warn( ( @@ -961,10 +960,10 @@ def run_with_reloader(*args, **kwargs): DeprecationWarning, stacklevel=2, ) - return run_with_reloader(*args, **kwargs) + _rwr(*args, **kwargs) -def main(): +def main() -> None: """A simple command-line interface for :py:func:`run_simple`.""" import argparse from .utils import import_string @@ -994,7 +993,7 @@ def main(): help="Reload the process if modules change.", ) parser.add_argument( - "application", help="Application to import and serve, in the form module:app.", + "application", help="Application to import and serve, in the form module:app." ) args = parser.parse_args() hostname, port = None, None diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index 93bc50526..231b7843d 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -1,29 +1,16 @@ import mimetypes import sys +import typing as t import warnings from collections import defaultdict +from datetime import datetime +from datetime import timedelta from http.cookiejar import CookieJar from io import BytesIO from itertools import chain from random import random from tempfile import TemporaryFile from time import time -from typing import Any -from typing import BinaryIO -from typing import Callable -from typing import cast -from typing import Dict -from typing import Generic -from typing import Hashable -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Tuple -from typing import Type -from typing import Union from urllib.request import Request as _UrllibRequest from ._internal import _get_environ @@ -31,7 +18,6 @@ from ._internal import _to_bytes from ._internal import _wsgi_decoding_dance from ._internal import _wsgi_encoding_dance -from .datastructures import AnyHeaders from .datastructures import Authorization from .datastructures import CallbackDict from .datastructures import CombinedMultiDict @@ -42,7 +28,6 @@ from .http import dump_cookie from .http import dump_options_header from .http import parse_options_header -from .types import WSGIEnvironment from .urls import iri_to_uri from .urls import url_encode from .urls import url_fix @@ -58,90 +43,107 @@ from .wsgi import ClosingIterator from .wsgi import get_current_url +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment + def stream_encode_multipart( - values: Mapping, + data: t.Mapping[str, t.Any], use_tempfile: bool = True, threshold: int = 1024 * 500, - boundary: Optional[str] = None, + boundary: t.Optional[str] = None, charset: str = "utf-8", -) -> Tuple[BinaryIO, int, str]: +) -> t.Tuple[t.BinaryIO, int, str]: """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. """ if boundary is None: boundary = f"---------------WerkzeugFormPart_{time()}{random()}" - _closure = [BytesIO(), 0, False] + + stream = BytesIO() + total_length = 0 + on_disk = False if use_tempfile: def write_binary(string): - stream, total_length, on_disk = _closure + nonlocal stream, total_length, on_disk + if on_disk: stream.write(string) else: length = len(string) - if length + _closure[1] <= threshold: + + if length + total_length <= threshold: stream.write(string) else: new_stream = TemporaryFile("wb+") new_stream.write(stream.getvalue()) new_stream.write(string) - _closure[0] = new_stream - _closure[2] = True - _closure[1] = total_length + length + stream = new_stream + on_disk = True + + total_length += length else: - write_binary = _closure[0].write # type: ignore + write_binary = stream.write def write(string): write_binary(string.encode(charset)) - if not isinstance(values, MultiDict): - values = MultiDict(values) - - for key, values in values.lists(): # type: ignore - for value in values: - write(f'--{boundary}\r\nContent-Disposition: form-data; name="{key}"') - reader = getattr(value, "read", None) - if reader is not None: - filename = getattr(value, "filename", getattr(value, "name", None)) - content_type = getattr(value, "content_type", None) - if content_type is None: - content_type = ( - filename - and mimetypes.guess_type(filename)[0] - or "application/octet-stream" - ) - if filename is not None: - write(f'; filename="{filename}"\r\n') - else: - write("\r\n") - write(f"Content-Type: {content_type}\r\n\r\n") - while 1: - chunk = reader(16384) - if not chunk: - break - write_binary(chunk) + for key, value in _iter_data(data): + write(f'--{boundary}\r\nContent-Disposition: form-data; name="{key}"') + reader = getattr(value, "read", None) + + if reader is not None: + filename = getattr(value, "filename", getattr(value, "name", None)) + content_type = getattr(value, "content_type", None) + + if content_type is None: + content_type = ( + filename + and mimetypes.guess_type(filename)[0] + or "application/octet-stream" + ) + + if filename is not None: + write(f'; filename="{filename}"\r\n') else: - if not isinstance(value, str): - value = str(value) + write("\r\n") + + write(f"Content-Type: {content_type}\r\n\r\n") + + while True: + chunk = reader(16384) + + if not chunk: + break + + write_binary(chunk) + else: + if not isinstance(value, str): + value = str(value) + + value = _to_bytes(value, charset) + write("\r\n\r\n") + write_binary(value) + + write("\r\n") - value = _to_bytes(value, charset) - write("\r\n\r\n") - write_binary(value) - write("\r\n") write(f"--{boundary}--\r\n") - length = int(_closure[0].tell()) # type: ignore - _closure[0].seek(0) # type: ignore - return _closure[0], length, boundary # type: ignore + length = stream.tell() + stream.seek(0) + return stream, length, boundary def encode_multipart( - values: Mapping, boundary: Optional[str] = None, charset: str = "utf-8" -) -> Tuple[str, bytes]: + values: t.Mapping[str, t.Any], + boundary: t.Optional[str] = None, + charset: str = "utf-8", +) -> t.Tuple[str, bytes]: """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is bytes. """ @@ -151,17 +153,13 @@ def encode_multipart( return boundary, stream.read() -class _TestCookieHeaders(Generic[AnyHeaders]): - - """A headers adapter for cookielib - """ +class _TestCookieHeaders: + """A headers adapter for cookielib""" - headers: AnyHeaders - - def __init__(self, headers: AnyHeaders) -> None: + def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: self.headers = headers - def getheaders(self, name): + def getheaders(self, name: str) -> t.Iterable[str]: headers = [] name = name.lower() for k, v in self.headers: @@ -169,23 +167,23 @@ def getheaders(self, name): headers.append(v) return headers - def get_all(self, name: str, default: Optional[List[Any]] = None) -> List[str]: - rv = [] - for k, v in self.headers: - if k.lower() == name.lower(): - rv.append(v) - return rv or default or [] + def get_all( + self, name: str, default: t.Optional[t.Iterable[str]] = None + ) -> t.Iterable[str]: + headers = self.getheaders(name) + if not headers: + return default # type: ignore + + return headers -class _TestCookieResponse: +class _TestCookieResponse: """Something that looks like a httplib.HTTPResponse, but is actually just an adapter for our test responses to make them available for cookielib. """ - headers: _TestCookieHeaders - - def __init__(self, headers: AnyHeaders) -> None: + def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: self.headers = _TestCookieHeaders(headers) def info(self) -> _TestCookieHeaders: @@ -193,12 +191,11 @@ def info(self) -> _TestCookieHeaders: class _TestCookieJar(CookieJar): - """A cookielib.CookieJar modified to inject and read cookie headers from and to wsgi environments, and wsgi application responses. """ - def inject_wsgi(self, environ: WSGIEnvironment) -> None: + def inject_wsgi(self, environ: "WSGIEnvironment") -> None: """Inject the cookies as client headers into the server's wsgi environment. """ @@ -209,7 +206,11 @@ def inject_wsgi(self, environ: WSGIEnvironment) -> None: else: environ.pop("HTTP_COOKIE", None) - def extract_wsgi(self, environ: WSGIEnvironment, headers: AnyHeaders) -> None: + def extract_wsgi( + self, + environ: "WSGIEnvironment", + headers: t.Union[Headers, t.List[t.Tuple[str, str]]], + ) -> None: """Extract the server's set-cookie headers as cookies into the cookie jar. """ @@ -219,23 +220,23 @@ def extract_wsgi(self, environ: WSGIEnvironment, headers: AnyHeaders) -> None: ) -def _iter_data(data: Any,) -> Iterator[Tuple[Hashable, Any]]: - """Iterates over a `dict` or :class:`MultiDict` yielding all keys and - values. - This is used to iterate over the data passed to the - :class:`EnvironBuilder`. +def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[t.Tuple[str, t.Any]]: + """Iterate over a mapping that might have a list of values, yielding + all key, value pairs. Almost like iter_multi_items but only allows + lists, not tuples, of values so tuples can be used for files. """ if isinstance(data, MultiDict): - for key, values in data.lists(): - for value in values: - yield key, value + yield from data.items(multi=True) else: - for key, values in data.items(): - if isinstance(values, list): - for value in values: - yield key, value + for key, value in data.items(): + if isinstance(value, list): + for v in value: + yield key, v else: - yield key, values + yield key, value + + +_TAnyMultiDict = t.TypeVar("_TAnyMultiDict", bound=MultiDict) class EnvironBuilder: @@ -346,34 +347,33 @@ class EnvironBuilder: json_dumps = staticmethod(json.dumps) del json + _args: t.Optional[MultiDict] + _query_string: t.Optional[str] + _input_stream: t.Optional[t.BinaryIO] + _form: t.Optional[MultiDict] + _files: t.Optional[FileMultiDict] + def __init__( self, path: str = "/", - base_url: Optional[str] = None, - query_string: Optional[Union[str, Dict[str, str]]] = None, + base_url: t.Optional[str] = None, + query_string: t.Optional[t.Union[t.Mapping[str, str], str]] = None, method: str = "GET", - input_stream: Optional[BytesIO] = None, - content_type: Optional[str] = None, - content_length: Optional[Union[str, int]] = None, - errors_stream: Optional[IO] = None, + input_stream: t.Optional[t.BinaryIO] = None, + content_type: t.Optional[str] = None, + content_length: t.Optional[int] = None, + errors_stream: t.Optional[t.TextIO] = None, multithread: bool = False, multiprocess: bool = False, run_once: bool = False, - headers: Optional[Union[Headers, Dict[str, str]]] = None, - data: Optional[Any] = None, - environ_base: Optional[Dict[str, str]] = None, - environ_overrides: Optional[ - Union[ - Dict[str, int], - Dict[str, Tuple[int, int]], - Dict[str, Union[str, Tuple[int, int], BytesIO, bool]], - Dict[str, str], - ] - ] = None, + headers: t.Optional[t.Union[Headers, t.Iterable[t.Tuple[str, str]]]] = None, + data: t.Optional[t.Union[t.BinaryIO, str, bytes, t.Mapping[str, t.Any]]] = None, + environ_base: t.Optional[t.Mapping[str, t.Any]] = None, + environ_overrides: t.Optional[t.Mapping[str, t.Any]] = None, charset: str = "utf-8", - mimetype: Optional[str] = None, - json: Optional[Union[List[int], Dict[str, str]]] = None, - auth: Union[Authorization, Tuple[str, str]] = None, + mimetype: t.Optional[str] = None, + json: t.Optional[t.Mapping[str, t.Any]] = None, + auth: t.Optional[t.Union[Authorization, t.Tuple[str, str]]] = None, ) -> None: path_s = _make_encode_wrapper(path) if query_string is not None and path_s("?") in path: @@ -386,7 +386,7 @@ def __init__( self.request_uri = path if base_url is not None: base_url = url_fix(iri_to_uri(base_url, charset), charset) - self.base_url = base_url + self.base_url = base_url # type: ignore if isinstance(query_string, (bytes, str)): self.query_string = query_string else: @@ -436,7 +436,7 @@ def __init__( if input_stream is not None: raise TypeError("can't provide input stream and data") if hasattr(data, "read"): - data = data.read() + data = data.read() # type: ignore if isinstance(data, str): data = data.encode(self.charset) if isinstance(data, bytes): @@ -444,7 +444,7 @@ def __init__( if self.content_length is None: self.content_length = len(data) else: - for key, value in _iter_data(data): + for key, value in _iter_data(data): # type: ignore if isinstance(value, (tuple, dict)) or hasattr(value, "read"): self._add_file_from_data(key, value) else: @@ -454,7 +454,7 @@ def __init__( self.mimetype = mimetype @classmethod - def from_environ(cls, environ: WSGIEnvironment, **kwargs) -> "EnvironBuilder": + def from_environ(cls, environ: "WSGIEnvironment", **kwargs) -> "EnvironBuilder": """Turn an environ dict back into a builder. Any extra kwargs override the args extracted from the environ. @@ -487,7 +487,11 @@ def from_environ(cls, environ: WSGIEnvironment, **kwargs) -> "EnvironBuilder": return cls(**out) def _add_file_from_data( - self, key: Hashable, value: Union[BytesIO, Tuple[BytesIO], Tuple[BytesIO, str]], + self, + key: str, + value: t.Union[ + t.BinaryIO, t.Tuple[t.BinaryIO, str], t.Tuple[t.BinaryIO, str, str] + ], ) -> None: """Called in the EnvironBuilder to add files from the data dict.""" if isinstance(value, tuple): @@ -500,14 +504,14 @@ def _make_base_url(scheme: str, host: str, script_root: str) -> str: return url_unparse((scheme, host, script_root, "", "")).rstrip("/") + "/" @property - def base_url(self): + def base_url(self) -> str: """The base URL is used to extract the URL scheme, host name, port, and root path. """ return self._make_base_url(self.url_scheme, self.host, self.script_root) @base_url.setter - def base_url(self, value): + def base_url(self, value: t.Optional[str]) -> None: if value is None: scheme = "http" netloc = "localhost" @@ -521,7 +525,7 @@ def base_url(self, value): self.url_scheme = scheme @property - def content_type(self): + def content_type(self) -> t.Optional[str]: """The content type for the request. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -536,14 +540,14 @@ def content_type(self): return ct @content_type.setter - def content_type(self, value): + def content_type(self, value: t.Optional[str]) -> None: if value is None: self.headers.pop("Content-Type", None) else: self.headers["Content-Type"] = value @property - def mimetype(self): + def mimetype(self) -> t.Optional[str]: """The mimetype (content type without charset etc.) .. versionadded:: 0.14 @@ -552,12 +556,12 @@ def mimetype(self): return ct.split(";")[0].strip() if ct else None @mimetype.setter - def mimetype(self, value): + def mimetype(self, value: str) -> None: self.content_type = get_content_type(value, self.charset) @property - def mimetype_params(self): - """ The mimetype parameters as dict. For example if the + def mimetype_params(self) -> t.Mapping[str, str]: + """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -571,7 +575,7 @@ def on_update(d): return CallbackDict(d, on_update) @property - def content_length(self): + def content_length(self) -> t.Optional[int]: """The content length as integer. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -579,15 +583,13 @@ def content_length(self): return self.headers.get("Content-Length", type=int) @content_length.setter - def content_length(self, value): + def content_length(self, value: t.Optional[int]) -> None: if value is None: self.headers.pop("Content-Length", None) else: self.headers["Content-Length"] = str(value) - def _get_form( - self, name: str, storage: Union[Type[FileMultiDict], Type[MultiDict]] - ) -> Union[MultiDict, FileMultiDict]: + def _get_form(self, name: str, storage: t.Type[_TAnyMultiDict]) -> _TAnyMultiDict: """Common behavior for getting the :attr:`form` and :attr:`files` properties. @@ -605,7 +607,7 @@ def _get_form( return rv - def _set_form(self, name, value): + def _set_form(self, name: str, value: MultiDict) -> None: """Common behavior for setting the :attr:`form` and :attr:`files` properties. @@ -616,27 +618,27 @@ def _set_form(self, name, value): setattr(self, name, value) @property - def form(self): + def form(self) -> MultiDict: """A :class:`MultiDict` of form values.""" return self._get_form("_form", MultiDict) @form.setter - def form(self, value): + def form(self, value: MultiDict) -> None: self._set_form("_form", value) @property - def files(self): + def files(self) -> FileMultiDict: """A :class:`FileMultiDict` of uploaded files. Use :meth:`~FileMultiDict.add_file` to add new files. """ return self._get_form("_files", FileMultiDict) @files.setter - def files(self, value): + def files(self, value: FileMultiDict) -> None: self._set_form("_files", value) @property - def input_stream(self): + def input_stream(self) -> t.Optional[t.BinaryIO]: """An optional input stream. This is mutually exclusive with setting :attr:`form` and :attr:`files`, setting it will clear those. Do not provide this if the method is not ``POST`` or @@ -645,13 +647,13 @@ def input_stream(self): return self._input_stream @input_stream.setter - def input_stream(self, value): + def input_stream(self, value: t.Optional[t.BinaryIO]) -> None: self._input_stream = value self._form = None self._files = None @property - def query_string(self): + def query_string(self) -> str: """The query string. If you set this to a string :attr:`args` will no longer be available. """ @@ -662,12 +664,12 @@ def query_string(self): return self._query_string @query_string.setter - def query_string(self, value): + def query_string(self, value: t.Optional[str]) -> None: self._query_string = value self._args = None @property - def args(self): + def args(self) -> MultiDict: """The URL arguments as :class:`MultiDict`.""" if self._query_string is not None: raise AttributeError("a query string is defined") @@ -676,7 +678,7 @@ def args(self): return self._args @args.setter - def args(self, value): + def args(self, value: t.Optional[MultiDict]) -> None: self._query_string = None self._args = value @@ -711,7 +713,7 @@ def close(self) -> None: try: files = self.files.values() except AttributeError: - files = () + files = () # type: ignore for f in files: try: f.close() @@ -719,7 +721,7 @@ def close(self) -> None: pass self.closed = True - def get_environ(self) -> WSGIEnvironment: + def get_environ(self) -> "WSGIEnvironment": """Return the built environ. .. versionchanged:: 0.15 @@ -740,24 +742,22 @@ def get_environ(self) -> WSGIEnvironment: input_stream.seek(start_pos) content_length = end_pos - start_pos elif mimetype == "multipart/form-data": - values = CombinedMultiDict([self.form, self.files]) input_stream, content_length, boundary = stream_encode_multipart( - values, charset=self.charset + CombinedMultiDict([self.form, self.files]), charset=self.charset ) content_type = f'{mimetype}; boundary="{boundary}"' elif mimetype == "application/x-www-form-urlencoded": - values = url_encode(self.form, charset=self.charset) # type: ignore - values = values.encode("ascii") # type: ignore - content_length = len(values) - input_stream = BytesIO(values) # type: ignore + form_encoded = url_encode(self.form, charset=self.charset).encode("ascii") + content_length = len(form_encoded) + input_stream = BytesIO(form_encoded) else: input_stream = BytesIO() - result = {} + result: "WSGIEnvironment" = {} if self.environ_base: result.update(self.environ_base) - def _path_encode(x): + def _path_encode(x: str) -> str: return _wsgi_encoding_dance(url_unquote(x, self.charset), self.charset) raw_uri = _wsgi_encoding_dance(self.request_uri, self.charset) @@ -775,13 +775,13 @@ def _path_encode(x): "SERVER_PORT": str(self.server_port), "HTTP_HOST": self.host, "SERVER_PROTOCOL": self.server_protocol, - "wsgi.version": self.wsgi_version, # type: ignore + "wsgi.version": self.wsgi_version, "wsgi.url_scheme": self.url_scheme, "wsgi.input": input_stream, - "wsgi.errors": self.errors_stream, # type: ignore - "wsgi.multithread": self.multithread, # type: ignore - "wsgi.multiprocess": self.multiprocess, # type: ignore - "wsgi.run_once": self.run_once, # type: ignore + "wsgi.errors": self.errors_stream, + "wsgi.multithread": self.multithread, + "wsgi.multiprocess": self.multiprocess, + "wsgi.run_once": self.run_once, } ) @@ -800,15 +800,15 @@ def _path_encode(x): for key, value in headers.to_wsgi_list(): combined_headers[f"HTTP_{key.upper().replace('-', '_')}"].append(value) - for key, values in combined_headers.items(): # type: ignore + for key, values in combined_headers.items(): result[key] = ", ".join(values) if self.environ_overrides: - result.update(self.environ_overrides) # type: ignore + result.update(self.environ_overrides) return result - def get_request(self, cls=None): + def get_request(self, cls: t.Optional[t.Type[Request]] = None) -> Request: """Returns a request with the data. If the request class is not specified :attr:`request_class` is used. @@ -858,8 +858,8 @@ class ClientResponse(BaseResponse): def __init__( self, - application: Callable[[WSGIEnvironment, Callable], Iterable[bytes]], - response_wrapper: Optional[Union[Type[Response], Type[BaseResponse]]] = None, + application: "WSGIApplication", + response_wrapper: t.Optional[t.Type["Response"]] = None, use_cookies: bool = True, allow_subdomain_redirects: bool = False, ) -> None: @@ -869,15 +869,15 @@ def __init__( response_wrapper = TestResponse elif not isinstance(response_wrapper, TestResponse): response_wrapper = type( - "WrapperTestResponse", (TestResponse, response_wrapper), {} + "WrapperTestResponse", + (TestResponse, response_wrapper), # type: ignore + {}, ) - self.response_wrapper: Type[TestResponse] = cast( - Type[TestResponse], response_wrapper - ) + self.response_wrapper = t.cast(t.Type["TestResponse"], response_wrapper) if use_cookies: - self.cookie_jar = _TestCookieJar() + self.cookie_jar: t.Optional[_TestCookieJar] = _TestCookieJar() else: self.cookie_jar = None @@ -888,13 +888,13 @@ def set_cookie( server_name: str, key: str, value: str = "", - max_age: Optional[int] = None, - expires: Optional[int] = None, + max_age: t.Optional[t.Union[timedelta, int]] = None, + expires: t.Optional[t.Union[str, datetime, int, float]] = None, path: str = "/", - domain: Optional[str] = None, + domain: t.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: Optional[str] = None, + samesite: t.Optional[str] = None, charset: str = "utf-8", ) -> None: """Sets a cookie in the client's cookie jar. The server name @@ -923,10 +923,10 @@ def delete_cookie( server_name: str, key: str, path: str = "/", - domain: Optional[str] = None, + domain: t.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: Optional[str] = None, + samesite: t.Optional[str] = None, ) -> None: """Deletes a cookie in the test client.""" self.set_cookie( @@ -942,8 +942,8 @@ def delete_cookie( ) def run_wsgi_app( - self, environ: WSGIEnvironment, buffered: bool = False - ) -> Tuple[Iterable[bytes], str, Headers]: + self, environ: "WSGIEnvironment", buffered: bool = False + ) -> t.Tuple[t.Iterable[bytes], str, Headers]: """Runs the wrapped WSGI app with the given environment. :meta private: @@ -1051,7 +1051,7 @@ def open( .. versionchanged:: 0.5 Added the ``follow_redirects`` parameter. """ - request = None + request: t.Optional["Request"] = None if not kwargs and len(args) == 1: arg = args[0] @@ -1061,7 +1061,7 @@ def open( elif isinstance(arg, dict): request = EnvironBuilder.from_environ(arg).get_request() elif isinstance(arg, BaseRequest): - request = arg + request = t.cast(Request, arg) if request is None: builder = EnvironBuilder(*args, **kwargs) @@ -1075,7 +1075,7 @@ def open( response = self.response_wrapper(*response, request=request) redirects = set() - history: List[TestResponse] = [] + history: t.List["TestResponse"] = [] while follow_redirects and response.status_code in { 301, @@ -1156,11 +1156,11 @@ def trace(self, *args, **kw) -> "TestResponse": kw["method"] = "TRACE" return self.open(*args, **kw) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {self.application!r}>" -def create_environ(*args, **kwargs) -> WSGIEnvironment: +def create_environ(*args, **kwargs) -> "WSGIEnvironment": """Create a new WSGI environ dict based on the values passed. The first parameter should be the path of the request which defaults to '/'. The second one can either be an absolute path (in that case the host is @@ -1184,10 +1184,8 @@ def create_environ(*args, **kwargs) -> WSGIEnvironment: def run_wsgi_app( - app: Callable[[WSGIEnvironment, Callable], Iterable[bytes]], - environ: WSGIEnvironment, - buffered: bool = False, -) -> Tuple[Iterable[bytes], str, Headers]: + app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False +) -> t.Tuple[t.Iterable[bytes], str, Headers]: """Return a tuple in the form (app_iter, status, headers) of the application output. This works best if you pass it an application that returns an iterator all the time. @@ -1207,8 +1205,9 @@ def run_wsgi_app( # Copy environ to ensure any mutations by the app (ProxyFix, for # example) don't affect subsequent requests (such as redirects). environ = _get_environ(environ).copy() - response: Optional[Tuple[str, List[Tuple[str, str]]]] = None - buffer: List[bytes] = [] + status: str + response: t.Optional[t.Tuple[str, t.List[t.Tuple[str, str]]]] = None + buffer: t.List[bytes] = [] def start_response(status, headers, exc_info=None): nonlocal response @@ -1224,13 +1223,13 @@ def start_response(status, headers, exc_info=None): app_rv = app(environ, start_response) close_func = getattr(app_rv, "close", None) - app_iter = iter(app_rv) + app_iter: t.Iterable[bytes] = iter(app_rv) # when buffering we emit the close call early and convert the # application iterator into a regular list if buffered: try: - app_iter = list(app_iter) # type: ignore + app_iter = list(app_iter) finally: if close_func is not None: close_func() @@ -1252,7 +1251,8 @@ def start_response(status, headers, exc_info=None): if close_func is not None and app_iter is not app_rv: app_iter = ClosingIterator(app_iter, close_func) - return app_iter, response[0], Headers(response[1]) + status, headers = response # type: ignore + return app_iter, status, Headers(headers) class TestResponse(JSONMixin, Response): # type: ignore @@ -1269,26 +1269,26 @@ class TestResponse(JSONMixin, Response): # type: ignore resulted in this response. """ - history: Tuple["TestResponse", ...] + history: t.Tuple["TestResponse", ...] """A list of intermediate responses. Populated when the test request is made with ``follow_redirects`` enabled. """ def __init__( self, - response: Iterable[bytes], + response: t.Iterable[bytes], status: str, headers: Headers, request: Request, - history: Tuple["TestResponse", ...] = (), + history: t.Tuple["TestResponse"] = (), # type: ignore **kwargs, - ): + ) -> None: super().__init__(response, status, headers, **kwargs) self.request = request self.history = history self._compat_tuple = response, status, headers - def __iter__(self): + def __iter__(self) -> t.Iterator: warnings.warn( ( "The test client no longer returns a tuple, it returns" @@ -1301,7 +1301,7 @@ def __iter__(self): ) return iter(self._compat_tuple) - def __getitem__(self, item): + def __getitem__(self, item: int) -> t.Any: warnings.warn( ( "The test client no longer returns a tuple, it returns" diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 26eb4e780..480fcd0c5 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -4,16 +4,20 @@ import base64 import os import sys +import typing as t from html import escape from textwrap import wrap -from typing import Any from . import __version__ as _werkzeug_version from .wrappers import BaseRequest as Request from .wrappers import BaseResponse as Response -from werkzeug.types import WSGIEnvironment -logo: Any = Response( +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIEnvironment + + +logo = Response( base64.b64decode( """ R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP///////// @@ -55,7 +59,7 @@ ) -TEMPLATE: Any = """\ +TEMPLATE = """\ WSGI Information @@ -134,7 +138,7 @@ """ -def iter_sys_path(): +def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: if os.name == "posix": def strip(x): @@ -154,11 +158,11 @@ def strip(x): yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item -def render_testapp(req): +def render_testapp(req: Request) -> bytes: try: import pkg_resources except ImportError: - eggs = () + eggs: t.Iterable[t.Any] = () else: eggs = sorted(pkg_resources.working_set, key=lambda x: x.project_name.lower()) python_eggs = [] @@ -203,7 +207,9 @@ def render_testapp(req): ).encode("utf-8") -def test_app(environ: WSGIEnvironment, start_response): +def test_app( + environ: "WSGIEnvironment", start_response: "StartResponse" +) -> t.Iterable[bytes]: """Simple test application that dumps the environment. You can use it to check if Werkzeug is working properly: diff --git a/src/werkzeug/types.py b/src/werkzeug/types.py deleted file mode 100644 index 307be7155..000000000 --- a/src/werkzeug/types.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Types which Werkzeug uses. - -Do not depend on these externally -""" -from typing import Any -from typing import Dict -from typing import TypeVar -from typing import Union - -BytesOrStr = Union[bytes, str] -# A value which can be encoded using Unicode. -UnicodeEncodable = Union[bytes, str, int] - -# a generic type parameter used in many functions -T = TypeVar("T") -# a number (either floating point or an integer) -Number = TypeVar("Number", int, float) - -# A WSGI environment -# TODO: At some point it may be possible to replace this with a -# `TypedDict` from the `typing` module. At present (21.06.2020) -# this isn't possible because `TypedDict` is only available on -# Python 3.8+ -WSGIEnvironment = Dict[str, Any] diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index 64b4b7081..771ecfc0a 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -6,32 +6,17 @@ import codecs import os import re +import typing as t import warnings -from io import StringIO -from typing import Any -from typing import AnyStr -from typing import BinaryIO -from typing import Callable -from typing import Dict -from typing import FrozenSet -from typing import Iterator -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import Union - -if TYPE_CHECKING: - from werkzeug.datastructures import MultiDict # noqa: F401 from ._internal import _check_str_tuple from ._internal import _decode_idna from ._internal import _encode_idna from ._internal import _make_encode_wrapper from ._internal import _to_str -from werkzeug.types import T + +if t.TYPE_CHECKING: + from . import datastructures as ds # A regular expression for what a valid schema looks like _scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") @@ -55,29 +40,33 @@ _bytetohex = [f"%{char:02X}".encode("ascii") for char in range(256)] -URLPartsTuple = Tuple[AnyStr, AnyStr, AnyStr, AnyStr, AnyStr] - - -class _URLTuple(NamedTuple): - scheme: Any - netloc: Any - path: Any - query: Any - fragment: Any +class _URLTuple(t.NamedTuple): + scheme: str + netloc: str + path: str + query: str + fragment: str class BaseURL(_URLTuple): """Superclass of :py:class:`URL` and :py:class:`BytesURL`.""" __slots__ = () + _at: str + _colon: str + _lbracket: str + _rbracket: str - def replace(self, **kwargs) -> "URL": + def __str__(self) -> str: + return self.to_url() + + def replace(self, **kwargs) -> "BaseURL": """Return an URL with the same values, except for those parameters given new values by whichever keyword arguments are specified.""" - return self._replace(**kwargs) # type: ignore + return self._replace(**kwargs) @property - def host(self) -> Optional[Union[str, bytes]]: + def host(self) -> t.Optional[str]: """The host part of the URL if available, otherwise `None`. The host is either the hostname or the IP address mentioned in the URL. It will not contain the port. @@ -85,7 +74,7 @@ def host(self) -> Optional[Union[str, bytes]]: return self._split_host()[0] @property - def ascii_host(self) -> Optional[str]: + def ascii_host(self) -> t.Optional[str]: """Works exactly like :attr:`host` but will return a result that is restricted to ASCII. If it finds a netloc that is not ASCII it will attempt to idna decode it. This is useful for socket @@ -94,13 +83,13 @@ def ascii_host(self) -> Optional[str]: rv = self.host if rv is not None and isinstance(rv, str): try: - rv = _encode_idna(rv) + rv = _encode_idna(rv) # type: ignore except UnicodeError: rv = rv.encode("ascii", "ignore") # type: ignore return _to_str(rv, "ascii", "ignore") @property - def port(self) -> Optional[int]: + def port(self) -> t.Optional[int]: """The port in the URL as an integer if it was present, `None` otherwise. This does not fill in default ports. """ @@ -113,94 +102,111 @@ def port(self) -> Optional[int]: return None @property - def auth(self) -> Optional[str]: + def auth(self) -> t.Optional[str]: """The authentication part in the URL if available, `None` otherwise. """ return self._split_netloc()[0] @property - def username(self) -> Optional[str]: + def username(self) -> t.Optional[str]: """The username if it was part of the URL, `None` otherwise. This undergoes URL decoding and will always be a string. """ rv = self._split_auth()[0] if rv is not None: - return _url_unquote_legacy(rv) # type: ignore + return _url_unquote_legacy(rv) return None @property - def raw_username(self) -> Optional[Union[str, bytes]]: + def raw_username(self) -> t.Optional[str]: """The username if it was part of the URL, `None` otherwise. Unlike :attr:`username` this one is not being decoded. """ return self._split_auth()[0] @property - def password(self) -> Optional[str]: + def password(self) -> t.Optional[str]: """The password if it was part of the URL, `None` otherwise. This undergoes URL decoding and will always be a string. """ rv = self._split_auth()[1] if rv is not None: - return _url_unquote_legacy(rv) # type: ignore + return _url_unquote_legacy(rv) return None @property - def raw_password(self) -> Optional[Union[str, bytes]]: + def raw_password(self) -> t.Optional[str]: """The password if it was part of the URL, `None` otherwise. Unlike :attr:`password` this one is not being decoded. """ return self._split_auth()[1] - def decode_query(self, *args, **kwargs): + def decode_query(self, *args, **kwargs) -> "ds.MultiDict[str, str]": """Decodes the query part of the URL. Ths is a shortcut for calling :func:`url_decode` on the query argument. The arguments and keyword arguments are forwarded to :func:`url_decode` unchanged. """ return url_decode(self.query, *args, **kwargs) - def join(self, *args, **kwargs): + def join(self, *args, **kwargs) -> "BaseURL": """Joins this URL with another one. This is just a convenience function for calling into :meth:`url_join` and then parsing the return value again. """ return url_parse(url_join(self, *args, **kwargs)) - def to_url(self): + def to_url(self) -> str: """Returns a URL string or bytes depending on the type of the information stored. This is just a convenience function for calling :meth:`url_unparse` for this URL. """ return url_unparse(self) + def encode_netloc(self) -> str: + """Encodes the netloc part to an ASCII safe URL as bytes.""" + rv = self.ascii_host or "" + if ":" in rv: + rv = f"[{rv}]" + port = self.port + if port is not None: + rv = f"{rv}:{port}" + auth = ":".join( + filter( + None, + [ + url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), + url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), + ], + ) + ) + if auth: + rv = f"{auth}@{rv}" + return rv + def decode_netloc(self) -> str: """Decodes the netloc part into a string.""" rv = _decode_idna(self.host or "") if ":" in rv: - rv = f"[{rv}]" # type: ignore + rv = f"[{rv}]" port = self.port if port is not None: - rv = f"{rv}:{port}" # type: ignore + rv = f"{rv}:{port}" auth = ":".join( filter( None, [ - _url_unquote_legacy( # type: ignore - self.raw_username or "", "/:%@" - ), - _url_unquote_legacy( # type: ignore - self.raw_password or "", "/:%@" - ), + _url_unquote_legacy(self.raw_username or "", "/:%@"), + _url_unquote_legacy(self.raw_password or "", "/:%@"), ], ) ) if auth: - rv = f"{auth}@{rv}" # type: ignore - return rv # type: ignore + rv = f"{auth}@{rv}" + return rv - def to_uri_tuple(self) -> "BytesURL": + def to_uri_tuple(self) -> "BaseURL": """Returns a :class:`BytesURL` tuple that holds a URI. This will encode all the information in the URL properly to ASCII using the rules a web browser would follow. @@ -208,9 +214,9 @@ def to_uri_tuple(self) -> "BytesURL": It's usually more interesting to directly call :meth:`iri_to_uri` which will return a string. """ - return url_parse(iri_to_uri(self).encode("ascii")) # type: ignore + return url_parse(iri_to_uri(self)) - def to_iri_tuple(self) -> "URL": + def to_iri_tuple(self) -> "BaseURL": """Returns a :class:`URL` tuple that holds a IRI. This will try to decode as much information as possible in the URL without losing information similar to how a web browser does it for the @@ -219,11 +225,11 @@ def to_iri_tuple(self) -> "URL": It's usually more interesting to directly call :meth:`uri_to_iri` which will return a string. """ - return url_parse(uri_to_iri(self)) # type: ignore + return url_parse(uri_to_iri(self)) def get_file_location( - self, pathformat: Optional[str] = None - ) -> Tuple[Optional[Union[bytes, str]], Optional[Union[bytes, str]]]: + self, pathformat: t.Optional[str] = None + ) -> t.Tuple[t.Optional[str], t.Optional[str]]: """Returns a tuple with the location of the file in the form ``(server, location)``. If the netloc is empty in the URL or points to localhost, it's represented as ``None``. @@ -254,12 +260,8 @@ def get_file_location( pathformat = "posix" if pathformat == "windows": - if ( - path[:1] == "/" - and path[1:2].isalpha() - and path[2:3] in "|:" # type: ignore - ): - path = f"{path[1:2]}:{path[3:]}" # type: ignore + if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": + path = f"{path[1:2]}:{path[3:]}" windows_share = path[:3] in ("\\" * 3, "/" * 3) import ntpath @@ -269,7 +271,7 @@ def get_file_location( # path like ``///host/directory``. We need to special-case this # because the path contains the hostname. if windows_share and host is None: - parts = path.lstrip("\\").split("\\", 1) # type: ignore + parts = path.lstrip("\\").split("\\", 1) if len(parts) == 2: host, path = parts else: @@ -287,40 +289,40 @@ def get_file_location( return host, path - def _split_netloc(self,) -> Union[List[Optional[AnyStr]], Tuple[None, AnyStr]]: - if self._at in self.netloc: # type: ignore - return self.netloc.split(self._at, 1) # type: ignore + def _split_netloc(self) -> t.Tuple[t.Optional[str], str]: + if self._at in self.netloc: + auth, _, netloc = self.netloc.partition(self._at) + return auth, netloc return None, self.netloc - def _split_auth(self) -> Union[List[AnyStr], Tuple[Optional[AnyStr], None]]: + def _split_auth(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: auth = self._split_netloc()[0] if not auth: return None, None - if self._colon not in auth: # type: ignore - return auth, None # type: ignore - return auth.split(self._colon, 1) # type: ignore - - def _split_host( - self, - ) -> Union[ - Tuple[Optional[Union[bytes, str]], Optional[Union[bytes, str]]], List[str], - ]: + if self._colon not in auth: + return auth, None + + username, _, password = auth.partition(self._colon) + return username, password + + def _split_host(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: rv = self._split_netloc()[1] if not rv: return None, None - if not rv.startswith(self._lbracket): # type: ignore - if self._colon in rv: # type: ignore - return rv.split(self._colon, 1) # type: ignore + if not rv.startswith(self._lbracket): + if self._colon in rv: + host, _, port = rv.partition(self._colon) + return host, port return rv, None - idx = rv.find(self._rbracket) # type: ignore + idx = rv.find(self._rbracket) if idx < 0: return rv, None host = rv[1:idx] rest = rv[idx + 1 :] - if rest.startswith(self._colon): # type: ignore + if rest.startswith(self._colon): return host, rest[1:] return host, None @@ -337,40 +339,16 @@ class URL(BaseURL): _lbracket = "[" _rbracket = "]" - def __str__(self): - return self.to_url() - - def encode_netloc(self) -> str: - """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or "" - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), - url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def encode(self, charset="utf-8", errors="replace"): + def encode(self, charset="utf-8", errors="replace") -> "BytesURL": """Encodes the URL to a tuple made out of bytes. The charset is only being used for the path, query and fragment. """ return BytesURL( - self.scheme.encode("ascii"), + self.scheme.encode("ascii"), # type: ignore self.encode_netloc(), - self.path.encode(charset, errors), - self.query.encode(charset, errors), - self.fragment.encode(charset, errors), + self.path.encode(charset, errors), # type: ignore + self.query.encode(charset, errors), # type: ignore + self.fragment.encode(charset, errors), # type: ignore ) @@ -378,42 +356,44 @@ class BytesURL(BaseURL): """Represents a parsed URL in bytes.""" __slots__ = () - _at = b"@" - _colon = b":" - _lbracket = b"[" - _rbracket = b"]" + _at = b"@" # type: ignore + _colon = b":" # type: ignore + _lbracket = b"[" # type: ignore + _rbracket = b"]" # type: ignore - def __str__(self): - return self.to_url().decode("utf-8", "replace") + def __str__(self) -> str: + return self.to_url().decode("utf-8", "replace") # type: ignore - def encode_netloc(self) -> bytes: + def encode_netloc(self) -> bytes: # type: ignore """Returns the netloc unchanged as bytes.""" - return self.netloc + return self.netloc # type: ignore - def decode(self, charset="utf-8", errors="replace") -> URL: + def decode(self, charset="utf-8", errors="replace") -> "URL": """Decodes the URL to a tuple made out of strings. The charset is only being used for the path, query and fragment. """ return URL( - self.scheme.decode("ascii"), + self.scheme.decode("ascii"), # type: ignore self.decode_netloc(), - self.path.decode(charset, errors), - self.query.decode(charset, errors), - self.fragment.decode(charset, errors), + self.path.decode(charset, errors), # type: ignore + self.query.decode(charset, errors), # type: ignore + self.fragment.decode(charset, errors), # type: ignore ) -_unquote_maps: Dict[FrozenSet, Dict[bytes, int]] = {frozenset(): _hextobyte} +_unquote_maps: t.Dict[t.FrozenSet[int], t.Dict[bytes, int]] = {frozenset(): _hextobyte} -def _unquote_to_bytes(string: Union[str, bytes], unsafe: str = "") -> bytes: +def _unquote_to_bytes( + string: t.Union[str, bytes], unsafe: t.Union[str, bytes] = "" +) -> bytes: if isinstance(string, str): string = string.encode("utf-8") if isinstance(unsafe, str): - unsafe = unsafe.encode("utf-8") # type: ignore + unsafe = unsafe.encode("utf-8") - unsafe = frozenset(bytearray(unsafe)) # type: ignore + unsafe = frozenset(bytearray(unsafe)) groups = iter(string.split(b"%")) result = bytearray(next(groups, b"")) @@ -438,26 +418,36 @@ def _unquote_to_bytes(string: Union[str, bytes], unsafe: str = "") -> bytes: def _url_encode_impl( - obj: Any, charset: str, sort: bool, key: Optional[Callable] -) -> Iterator[str]: + obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], + charset: str, + sort: bool, + key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]], +): from .datastructures import iter_multi_items - iterable = iter_multi_items(obj) + iterable: t.Iterable[t.Tuple[str, str]] = iter_multi_items(obj) + if sort: - iterable = sorted(iterable, key=key) # type: ignore - for key, value in iterable: - if value is None: + iterable = sorted(iterable, key=key) + + for key_str, value_str in iterable: + if value_str is None: continue - if not isinstance(key, bytes): - key = str(key).encode(charset) - if not isinstance(value, bytes): - value = str(value).encode(charset) - yield f"{_fast_url_quote_plus(key)}={_fast_url_quote_plus(value)}" + + if not isinstance(key_str, bytes): + key_bytes = str(key_str).encode(charset) + else: + key_bytes = key_str + + if not isinstance(value_str, bytes): + value_bytes = str(value_str).encode(charset) + else: + value_bytes = value_str + + yield f"{_fast_url_quote_plus(key_bytes)}={_fast_url_quote_plus(value_bytes)}" -def _url_unquote_legacy( - value: Union[str, bytes], unsafe: str = "" -) -> Union[str, bytes]: +def _url_unquote_legacy(value: str, unsafe: str = "") -> str: try: return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) except UnicodeError: @@ -465,8 +455,8 @@ def _url_unquote_legacy( def url_parse( - url: AnyStr, scheme: Optional[str] = None, allow_fragments: bool = True -) -> Union[BytesURL, URL]: + url: str, scheme: t.Optional[str] = None, allow_fragments: bool = True +) -> BaseURL: """Parses a URL from a string into a :class:`URL` tuple. If the URL is lacking a scheme it can be provided as second argument. Otherwise, it is ignored. Optionally fragments can be stripped from the URL @@ -483,7 +473,7 @@ def url_parse( is_text_based = isinstance(url, str) if scheme is None: - scheme = s("") # type: ignore + scheme = s("") netloc = query = fragment = s("") i = url.find(s(":")) if i > 0 and _scheme_re.match(_to_str(url[:i], errors="replace")): @@ -492,7 +482,7 @@ def url_parse( rest = url[i + 1 :] if not rest or any(c not in s("0123456789") for c in rest): # not a port number - scheme, url = url[:i].lower(), rest # type: ignore + scheme, url = url[:i].lower(), rest if url[:2] == s("//"): delim = len(url) @@ -512,10 +502,15 @@ def url_parse( url, query = url.split(s("?"), 1) result_type = URL if is_text_based else BytesURL - return result_type(scheme, netloc, url, query, fragment) # type: ignore + return result_type(scheme, netloc, url, query, fragment) -def _make_fast_url_quote(charset="utf-8", errors="strict", safe="/:", unsafe=""): +def _make_fast_url_quote( + charset: str = "utf-8", + errors: str = "strict", + safe: t.Union[str, bytes] = "/:", + unsafe: t.Union[str, bytes] = "", +) -> t.Callable[[bytes], str]: """Precompile the translation table for a URL encoding function. Unlike :func:`url_quote`, the generated function only takes the @@ -535,7 +530,7 @@ def _make_fast_url_quote(charset="utf-8", errors="strict", safe="/:", unsafe="") safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) table = [chr(c) if c in safe else f"%{c:02X}" for c in range(256)] - def quote(string): + def quote(string: bytes) -> str: return "".join([table[c] for c in string]) return quote @@ -550,11 +545,11 @@ def _fast_url_quote_plus(string: bytes) -> str: def url_quote( - string: Union[str, int, bytes], + string: t.Union[str, bytes], charset: str = "utf-8", errors: str = "strict", - safe: str = "/:", - unsafe: str = "", + safe: t.Union[str, bytes] = "/:", + unsafe: t.Union[str, bytes] = "", ) -> str: """URL encode a single string with a given encoding. @@ -571,12 +566,10 @@ def url_quote( if isinstance(string, str): string = string.encode(charset, errors) if isinstance(safe, str): - safe = safe.encode(charset, errors) # type: ignore + safe = safe.encode(charset, errors) if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) # type: ignore - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset( # type: ignore - bytearray(unsafe) # type: ignore - ) + unsafe = unsafe.encode(charset, errors) + safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) rv = bytearray() for char in bytearray(string): if char in safe: @@ -587,10 +580,7 @@ def url_quote( def url_quote_plus( - string: Union[str, int], - charset: str = "utf-8", - errors: str = "strict", - safe: str = "", + string: str, charset: str = "utf-8", errors: str = "strict", safe: str = "" ) -> str: """URL encode a single string with the given encoding and convert whitespace to "+". @@ -602,7 +592,7 @@ def url_quote_plus( return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") -def url_unparse(components: URLPartsTuple) -> AnyStr: +def url_unparse(components: t.Tuple[str, str, str, str, str]) -> str: """The reverse operation to :meth:`url_parse`. This accepts arbitrary as well as :class:`URL` tuples and returns a URL as a string. @@ -634,11 +624,11 @@ def url_unparse(components: URLPartsTuple) -> AnyStr: def url_unquote( - string: Union[bytes, str], - charset: Optional[str] = "utf-8", + s: t.Union[str, bytes], + charset: str = "utf-8", errors: str = "replace", unsafe: str = "", -) -> Union[bytes, str]: +) -> str: """URL decode a single string with a given encoding. If the charset is set to `None` no decoding is performed and raw bytes are returned. @@ -648,15 +638,15 @@ def url_unquote( no decoding will take place. :param errors: the error handling for the charset decoding. """ - rv = _unquote_to_bytes(string, unsafe) - if charset is not None: - rv = rv.decode(charset, errors) # type: ignore - return rv + rv = _unquote_to_bytes(s, unsafe) + if charset is None: + return rv + return rv.decode(charset, errors) def url_unquote_plus( - s: Union[str, bytes], charset: Optional[str] = "utf-8", errors: str = "replace", -) -> Union[str, bytes]: + s: t.Union[str, bytes], charset: str = "utf-8", errors: str = "replace" +) -> str: """URL decode a single string with the given `charset` and decode "+" to whitespace. @@ -709,21 +699,23 @@ def url_fix(s: str, charset: str = "utf-8") -> str: _to_iri_unsafe = "".join([chr(c) for c in range(128) if c not in _always_safe]) -def _codec_error_url_quote(e: UnicodeError) -> Tuple[Union[str, bytes], int]: +def _codec_error_url_quote(e): """Used in :func:`uri_to_iri` after unquoting to re-quote any invalid bytes. """ # the docs state that `UnicodeError` does have these attributes, # but mypy isn't picking them up? - out = _fast_url_quote(e.object[e.start : e.end]) # type: ignore - return out, e.end # type: ignore + out = _fast_url_quote(e.object[e.start : e.end]) + return out, e.end codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) def uri_to_iri( - uri: Union[str, bytes], charset: str = "utf-8", errors: str = "werkzeug.url_quote", + uri: t.Union[str, t.Tuple[str, str, str, str, str]], + charset: str = "utf-8", + errors: str = "werkzeug.url_quote", ) -> str: """Convert a URI to an IRI. All valid UTF-8 characters are unquoted, leaving all reserved and invalid characters quoted. If the URL has @@ -759,7 +751,7 @@ def uri_to_iri( def iri_to_uri( - iri: Union[bytes, str, URLPartsTuple], + iri: t.Union[str, t.Tuple[str, str, str, str, str]], charset: str = "utf-8", errors: str = "strict", safe_conversion: bool = False, @@ -825,14 +817,14 @@ def iri_to_uri( def url_decode( - s: bytes, - charset: Optional[str] = "utf-8", + s: str, + charset: str = "utf-8", decode_keys: None = None, include_empty: bool = True, errors: str = "replace", - separator: Union[str, bytes] = "&", - cls: Optional[T] = None, -) -> Type[T]: + separator: str = "&", + cls: t.Optional[t.Type["ds.MultiDict"]] = None, +) -> "ds.MultiDict[str, str]": """Parse a query string and return it as a :class:`MultiDict`. :param s: The query string to parse. @@ -864,27 +856,25 @@ def url_decode( if cls is None: from .datastructures import MultiDict # noqa: F811 - cls = MultiDict # type: ignore + cls = MultiDict if isinstance(s, str) and not isinstance(separator, str): separator = separator.decode(charset or "ascii") elif isinstance(s, bytes) and not isinstance(separator, bytes): separator = separator.encode(charset or "ascii") - return cls( # type: ignore - _url_decode_impl(s.split(separator), charset, include_empty, errors) - ) + return cls(_url_decode_impl(s.split(separator), charset, include_empty, errors)) def url_decode_stream( - stream: BinaryIO, - charset: str = "utf-8", + stream: t.BinaryIO, + charset="utf-8", decode_keys: None = None, include_empty: bool = True, errors: str = "replace", - separator: str = "&", - cls: Optional[Union[Type[T], Type["MultiDict"]]] = None, - limit: Optional[int] = None, + separator: bytes = b"&", + cls: t.Optional[t.Type["ds.MultiDict"]] = None, + limit: t.Optional[int] = None, return_iterator: bool = False, -) -> Union[T, Iterator[Tuple[Any, Any]], "MultiDict"]: +) -> "ds.MultiDict[str, str]": """Works like :func:`url_decode` but decodes a stream. The behavior of stream and limit follows functions like :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is @@ -924,24 +914,22 @@ def url_decode_stream( ) pair_iter = make_chunk_iter(stream, separator, limit) - decoder = _url_decode_impl( - pair_iter, charset, include_empty, errors # type: ignore - ) + decoder = _url_decode_impl(pair_iter, charset, include_empty, errors) if return_iterator: - return decoder + return decoder # type: ignore if cls is None: from .datastructures import MultiDict # noqa: F811 cls = MultiDict - return cls(decoder) # type: ignore + return cls(decoder) def _url_decode_impl( - pair_iter: List[bytes], charset: Optional[str], include_empty: bool, errors: str, -) -> Iterator[Tuple[AnyStr, AnyStr]]: + pair_iter: t.Iterable[t.AnyStr], charset: str, include_empty: bool, errors: str +) -> t.Iterator[t.Tuple[str, str]]: for pair in pair_iter: if not pair: continue @@ -954,16 +942,18 @@ def _url_decode_impl( continue key = pair value = s("") - key = url_unquote_plus(key, charset, errors) # type: ignore - yield key, url_unquote_plus(value, charset, errors) # type: ignore + yield ( + url_unquote_plus(key, charset, errors), + url_unquote_plus(value, charset, errors), + ) def url_encode( - obj: object, + obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], charset: str = "utf-8", encode_keys: None = None, sort: bool = False, - key: Optional[Callable] = None, + key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, separator: str = "&", ) -> str: """URL encode a dict/`MultiDict`. If a value is `None` it will not appear @@ -995,14 +985,14 @@ def url_encode( def url_encode_stream( - obj: object, - stream: Optional[StringIO] = None, + obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], + stream: t.Optional[t.TextIO] = None, charset: str = "utf-8", encode_keys: None = None, sort: bool = False, - key: None = None, + key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, separator: str = "&", -) -> Optional[Iterator[str]]: +) -> None: """Like :meth:`url_encode` but writes the results to a stream object. If the stream is `None` a generator over all encoded pairs is returned. @@ -1041,10 +1031,10 @@ def url_encode_stream( def url_join( - base: Union[str, URLPartsTuple], - url: Union[str, URLPartsTuple], + base: t.Union[str, t.Tuple[str, str, str, str, str]], + url: t.Union[str, t.Tuple[str, str, str, str, str]], allow_fragments: bool = True, -) -> str: +): """Join a base URL and a possibly relative URL to form an absolute interpretation of the latter. @@ -1091,11 +1081,11 @@ def url_join( # Resolve ".." and "." segments = [segment for segment in segments if segment != s(".")] - while 1: + while True: i = 1 n = len(segments) - 1 while i < n: - if segments[i] == s("..") and segments[i - 1] not in (s(""), s(".."),): + if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): del segments[i - 1 : i + 1] break i += 1 @@ -1162,13 +1152,7 @@ class Href: `sort` and `key` were added. """ - def __init__( - self, - base: str = "./", - charset: str = "utf-8", - sort: bool = False, - key: None = None, - ) -> None: + def __init__(self, base="./", charset="utf-8", sort=False, key=None): if not base: base = "./" self.base = base @@ -1176,7 +1160,7 @@ def __init__( self.sort = sort self.key = key - def __getattr__(self, name: str) -> "Href": + def __getattr__(self, name): if name[:2] == "__": raise AttributeError(name) base = self.base @@ -1184,7 +1168,7 @@ def __getattr__(self, name: str) -> "Href": base += "/" return Href(url_join(base, name), self.charset, self.sort, self.key) - def __call__(self, *path, **query) -> str: + def __call__(self, *path, **query): if path and isinstance(path[-1], dict): if query: raise TypeError("keyword arguments and query-dicts can't be combined") @@ -1205,6 +1189,6 @@ def __call__(self, *path, **query) -> str: rv = url_join(rv, f"./{path}") if query: rv += "?" + _to_str( - url_encode(query, self.charset, sort=self.sort, key=self.key), "ascii", + url_encode(query, self.charset, sort=self.sort, key=self.key), "ascii" ) return rv diff --git a/src/werkzeug/useragents.py b/src/werkzeug/useragents.py index b06c0eb37..ef9f79708 100644 --- a/src/werkzeug/useragents.py +++ b/src/werkzeug/useragents.py @@ -1,14 +1,14 @@ import re -from typing import Any -from typing import Optional -from typing import Tuple -from typing import Union +import typing as t + +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment class UserAgentParser: """A simple user agent parser. Used by the `UserAgent`.""" - platforms: Any = ( + platform_rules: t.ClassVar[t.Iterable[t.Tuple[str, str]]] = ( (" cros ", "chromeos"), ("iphone|ios", "iphone"), ("ipad", "ipad"), @@ -31,7 +31,7 @@ class UserAgentParser: ("blackberry|playbook", "blackberry"), ("symbian", "symbian"), ) - browsers: Any = ( + browser_rules: t.ClassVar[t.Iterable[t.Tuple[str, str]]] = ( ("googlebot", "google"), ("msnbot", "msn"), ("yahoo", "yahoo"), @@ -64,17 +64,20 @@ class UserAgentParser: ) def __init__(self) -> None: - self.platforms = [(b, re.compile(a, re.I)) for a, b in self.platforms] + self.platforms = [(b, re.compile(a, re.I)) for a, b in self.platform_rules] self.browsers = [ (b, re.compile(self._browser_version_re.format(pattern=a), re.I)) - for a, b in self.browsers + for a, b in self.browser_rules ] def __call__( self, user_agent: str - ) -> Union[ - Tuple[Optional[str], Optional[str], Optional[str], Optional[str]], - ]: + ) -> t.Tuple[t.Optional[str], t.Optional[str], t.Optional[str], t.Optional[str]]: + platform: t.Optional[str] + browser: t.Optional[str] + version: t.Optional[str] + language: t.Optional[str] + for platform, regex in self.platforms: # noqa: B007 match = regex.search(user_agent) if match is not None: @@ -172,15 +175,16 @@ class UserAgent: the language of the browser. ``None`` if not recognized. """ - string: Any _parser = UserAgentParser() - def __init__(self, environ_or_string: Any) -> None: + def __init__(self, environ_or_string: t.Union["WSGIEnvironment", str]) -> None: if isinstance(environ_or_string, dict): - environ_or_string = environ_or_string.get("HTTP_USER_AGENT", "") - self.string = environ_or_string + self.string = environ_or_string.get("HTTP_USER_AGENT", "") + else: + self.string = environ_or_string + self.platform, self.browser, self.version, self.language = self._parser( - environ_or_string + self.string ) def to_header(self) -> str: @@ -189,10 +193,8 @@ def to_header(self) -> str: def __str__(self) -> str: return self.string - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: return bool(self.browser) - __bool__: Any = __nonzero__ - def __repr__(self): return f"<{type(self).__name__} {self.browser!r}/{self.version}>" diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index cbeec31ed..f5dbfad1f 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -6,34 +6,29 @@ import pkgutil import re import sys +import typing as t import unicodedata import warnings +from datetime import datetime from html.entities import name2codepoint +from time import struct_time from time import time -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterator -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from zlib import adler32 from ._internal import _DictAccessorProperty from ._internal import _missing from ._internal import _parse_signature +from ._internal import _TAccessorValue from .datastructures import Headers from .exceptions import NotFound from .exceptions import RequestedRangeNotSatisfiable from .security import safe_join from .urls import url_quote from .wsgi import wrap_file -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from werkzeug.wrappers.base_request import BaseRequest # noqa: F401 - from werkzeug.wrappers.request import Request # noqa: F401 - from werkzeug.wrappers.response import Response +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment + from .wrappers import Response _entity_re = re.compile(r"&([^;]+);") _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") @@ -69,35 +64,30 @@ def foo(self): work. """ - # implementation detail: A subclass of python's builtin property - # decorator, we override __get__ to check for a cached value. If one - # chooses to invoke __get__ by hand the property will still work as - # expected because the lookup logic is replicated in __get__ for - # manual invocation. - def __init__( - self, func: Callable, name: Optional[str] = None, doc: None = None + self, + fget: t.Callable[[t.Any], t.Any], + name: t.Optional[str] = None, + doc: t.Optional[str] = None, ) -> None: - super().__init__() - self.__name__ = name or func.__name__ - self.__module__ = func.__module__ - self.__doc__ = doc or func.__doc__ - self.func = func + super().__init__(fget, doc=doc) + self.__name__ = name or fget.__name__ + self.__module__ = fget.__module__ - def __set__(self, obj, value): + def __set__(self, obj: object, value: t.Any) -> None: obj.__dict__[self.__name__] = value - def __get__(self, obj: Any, type: Optional[Any] = None) -> Any: + def __get__(self, obj: object, type: type = None) -> t.Any: # type: ignore if obj is None: return self value = obj.__dict__.get(self.__name__, _missing) if value is _missing: - value = self.func(obj) + value = self.fget(obj) # type: ignore obj.__dict__[self.__name__] = value return value -def invalidate_cached_property(obj, name): +def invalidate_cached_property(obj: object, name: str) -> None: """Invalidates the cache for a :class:`cached_property`: >>> class Test(object): @@ -124,10 +114,10 @@ def invalidate_cached_property(obj, name): f"Attribute {name!r} of object {obj} is not a" " cached_property, cannot be invalidated." ) - obj.__dict__[name] = _missing + del obj.__dict__[name] -class environ_property(_DictAccessorProperty): +class environ_property(_DictAccessorProperty[_TAccessorValue]): """Maps request attributes to environment variables. This works not only for the Werkzeug request object, but also any other class with an environ attribute: @@ -150,14 +140,14 @@ class environ_property(_DictAccessorProperty): read_only = True - def lookup(self, obj: Any) -> Dict[str, Any]: + def lookup(self, obj: t.Any) -> "WSGIEnvironment": return obj.environ -class header_property(_DictAccessorProperty): +class header_property(_DictAccessorProperty[_TAccessorValue]): """Like `environ_property` but for headers.""" - def lookup(self, obj: "Response") -> "Headers": + def lookup(self, obj: t.Any) -> Headers: return obj.headers @@ -336,7 +326,7 @@ def get_content_type(mimetype: str, charset: str) -> str: return mimetype -def detect_utf_encoding(data): +def detect_utf_encoding(data: bytes) -> str: """Detect which UTF encoding was used to encode the given bytes. The latest JSON standard (:rfc:`8259`) suggests that only UTF-8 is @@ -432,10 +422,9 @@ def secure_filename(filename: str) -> str: :param filename: the filename to secure """ - if isinstance(filename, str): - from unicodedata import normalize + filename = unicodedata.normalize("NFKD", filename) + filename = filename.encode("ascii", "ignore").decode("ascii") - filename = normalize("NFKD", filename).encode("ascii", "ignore").decode("ascii") for sep in os.path.sep, os.path.altsep: if sep: filename = filename.replace(sep, " ") @@ -503,7 +492,9 @@ def unescape(s): return html.unescape(s) -def redirect(location: str, code: int = 302, Response: None = None) -> "Response": +def redirect( + location: str, code: int = 302, Response: t.Optional[t.Type["Response"]] = None +) -> "Response": """Returns a response object (a WSGI application) that, if called, redirects the client to the target location. Supported codes are 301, 302, 303, 305, 307, and 308. 300 is not supported because @@ -549,7 +540,7 @@ def redirect(location: str, code: int = 302, Response: None = None) -> "Response return response -def append_slash_redirect(environ: WSGIEnvironment, code: int = 301,) -> "Response": +def append_slash_redirect(environ: "WSGIEnvironment", code: int = 301) -> "Response": """Redirects to the same URL but with a slash appended. The behavior of this function is undefined if the path ends with a slash already. @@ -565,18 +556,20 @@ def append_slash_redirect(environ: WSGIEnvironment, code: int = 301,) -> "Respon def send_file( - path_or_file, - environ, - mimetype=None, - as_attachment=False, - download_name=None, - conditional=True, - add_etags=True, - last_modified=None, - max_age=None, - use_x_sendfile=False, - response_class=None, - _root_path=None, + path_or_file: t.Union[os.PathLike, str, t.BinaryIO], + environ: "WSGIEnvironment", + mimetype: t.Optional[str] = None, + as_attachment: bool = False, + download_name: t.Optional[str] = None, + conditional: bool = True, + add_etags: bool = True, + last_modified: t.Optional[t.Union[datetime, int, float, struct_time]] = None, + max_age: t.Optional[ + t.Union[int, t.Callable[[t.Optional[t.Union[os.PathLike, str]]], int]] + ] = None, + use_x_sendfile: bool = False, + response_class: t.Optional[t.Type["Response"]] = None, + _root_path: t.Optional[t.Union[os.PathLike, str]] = None, ): """Send the contents of a file to the client. @@ -639,11 +632,20 @@ def send_file( default. """ if response_class is None: - from .wrappers import Response as response_class + from .wrappers import Response + + response_class = Response + + path: t.Optional[pathlib.Path] = None + file: t.Optional[t.BinaryIO] = None + size: t.Optional[int] = None + mtime: t.Optional[float] = None - if isinstance(path_or_file, (str, os.PathLike)) or hasattr( + if isinstance(path_or_file, (os.PathLike, str)) or hasattr( # type: ignore path_or_file, "__fspath__" ): + path_or_file = t.cast(t.Union[os.PathLike, str], path_or_file) + # Flask will pass app.root_path, allowing its send_file wrapper # to not have to deal with paths. if _root_path is not None: @@ -654,9 +656,7 @@ def send_file( stat = path.stat() size = stat.st_size mtime = stat.st_mtime - file = None else: - path = size = mtime = None file = path_or_file if download_name is None and path is not None: @@ -676,10 +676,10 @@ def send_file( if download_name is not None: try: - download_name = download_name.encode("ascii") + download_name.encode("ascii") except UnicodeEncodeError: simple = unicodedata.normalize("NFKD", download_name) - simple = simple.encode("ascii", "ignore") + simple = simple.encode("ascii", "ignore").decode("ascii") quoted = url_quote(download_name, safe="") names = {"filename": simple, "filename*": f"UTF-8''{quoted}"} else: @@ -698,7 +698,7 @@ def send_file( data = None else: if file is None: - file = path.open("rb") + file = path.open("rb") # type: ignore elif isinstance(file, io.BytesIO): size = file.getbuffer().nbytes elif isinstance(file, io.TextIOBase): @@ -714,9 +714,9 @@ def send_file( rv.content_length = size if last_modified is not None: - rv.last_modified = last_modified + rv.last_modified = last_modified # type: ignore elif mtime is not None: - rv.last_modified = mtime + rv.last_modified = mtime # type: ignore rv.cache_control.no_cache = True @@ -731,7 +731,7 @@ def send_file( rv.cache_control.public = True rv.cache_control.max_age = max_age - rv.expires = int(time() + max_age) + rv.expires = int(time() + max_age) # type: ignore if add_etags and path is not None: check = adler32(str(path).encode("utf-8")) & 0xFFFFFFFF @@ -754,7 +754,12 @@ def send_file( return rv -def send_from_directory(directory, path, environ, **kwargs): +def send_from_directory( + directory: t.Union[os.PathLike, str], + path: t.Union[os.PathLike, str], + environ: "WSGIEnvironment", + **kwargs, +) -> "Response": """Send a file from within a directory using :func:`send_file`. This is a secure way to serve files from a folder, such as static @@ -794,7 +799,7 @@ def send_from_directory(directory, path, environ, **kwargs): return send_file(path, environ, **kwargs) -def import_string(import_name: str, silent: bool = False) -> Any: +def import_string(import_name: str, silent: bool = False) -> t.Any: """Imports an object based on a string. This is useful if you want to use import paths as endpoints or something similar. An import path can be specified either in dotted notation (``xml.sax.saxutils.escape``) @@ -833,7 +838,7 @@ def import_string(import_name: str, silent: bool = False) -> Any: def find_modules( import_path: str, include_packages: bool = False, recursive: bool = False -) -> Iterator[str]: +) -> t.Iterator[str]: """Finds all the modules below a package. This can be useful to automatically import all views / controllers so that their metaclasses / function decorators have a chance to register themselves on the @@ -987,11 +992,10 @@ def __init__(self, missing=None, extra=None, extra_positional=None): self.missing = set(missing or ()) self.extra = extra or {} self.extra_positional = extra_positional or [] - ValueError.__init__( - self, + super().__init__( "function arguments invalid." f" ({len(self.missing)} missing," - f" {len(self.extra) + len(self.extra_positional)} additional)", + f" {len(self.extra) + len(self.extra_positional)} additional)" ) @@ -999,13 +1003,11 @@ class ImportStringError(ImportError): """Provides information about a failed :func:`import_string` attempt.""" #: String in dotted notation that failed to be imported. - import_name = None + import_name: str #: Wrapped exception. - exception = None + exception: BaseException - def __init__( - self, import_name: str, exception: Union[ImportError, ModuleNotFoundError], - ) -> None: + def __init__(self, import_name, exception): self.import_name = import_name self.exception = exception msg = import_name @@ -1033,7 +1035,7 @@ def __init__( ) break - ImportError.__init__(self, msg) + super().__init__(msg) def __repr__(self): return f"<{type(self).__name__}({self.import_name!r}, {self.exception!r})>" diff --git a/src/werkzeug/wrappers/accept.py b/src/werkzeug/wrappers/accept.py index c80ee3259..d1c3c5632 100644 --- a/src/werkzeug/wrappers/accept.py +++ b/src/werkzeug/wrappers/accept.py @@ -1,4 +1,6 @@ +from ..datastructures import Accept from ..datastructures import CharsetAccept +from ..datastructures import EnvironHeaders from ..datastructures import LanguageAccept from ..datastructures import MIMEAccept from ..http import parse_accept_header @@ -12,22 +14,24 @@ class AcceptMixin: thereof). """ + headers: EnvironHeaders + @cached_property - def accept_mimetypes(self): + def accept_mimetypes(self) -> MIMEAccept: """List of mimetypes this client supports as :class:`~werkzeug.datastructures.MIMEAccept` object. """ return parse_accept_header(self.headers.get("Accept"), MIMEAccept) @cached_property - def accept_charsets(self): + def accept_charsets(self) -> CharsetAccept: """List of charsets this client supports as :class:`~werkzeug.datastructures.CharsetAccept` object. """ return parse_accept_header(self.headers.get("Accept-Charset"), CharsetAccept) @cached_property - def accept_encodings(self): + def accept_encodings(self) -> Accept: """List of encodings this client accepts. Encodings in a HTTP term are compression encodings such as gzip. For charsets have a look at :attr:`accept_charset`. @@ -35,7 +39,7 @@ def accept_encodings(self): return parse_accept_header(self.headers.get("Accept-Encoding")) @cached_property - def accept_languages(self): + def accept_languages(self) -> LanguageAccept: """List of languages this client accepts as :class:`~werkzeug.datastructures.LanguageAccept` object. diff --git a/src/werkzeug/wrappers/auth.py b/src/werkzeug/wrappers/auth.py index 6b9612a8f..6db7a7416 100644 --- a/src/werkzeug/wrappers/auth.py +++ b/src/werkzeug/wrappers/auth.py @@ -1,12 +1,12 @@ -from typing import TYPE_CHECKING +import typing as t +from ..datastructures import Authorization +from ..datastructures import EnvironHeaders +from ..datastructures import Headers +from ..datastructures import WWWAuthenticate from ..http import parse_authorization_header from ..http import parse_www_authenticate_header from ..utils import cached_property -from werkzeug.types import WSGIEnvironment - -if TYPE_CHECKING: - from werkzeug.datastructures import WWWAuthenticate, Authorization class AuthorizationMixin: @@ -15,29 +15,28 @@ class AuthorizationMixin: :class:`~werkzeug.datastructures.Authorization` object. """ - environ: WSGIEnvironment + headers: EnvironHeaders @cached_property - def authorization(self) -> "Authorization": + def authorization(self) -> t.Optional[Authorization]: """The `Authorization` object in parsed form.""" - header = self.headers.get("Authorization") # type: ignore - return parse_authorization_header(header) + return parse_authorization_header(self.headers.get("Authorization")) class WWWAuthenticateMixin: """Adds a :attr:`www_authenticate` property to a response object.""" - environ: WSGIEnvironment + headers: Headers @property - def www_authenticate(self) -> "WWWAuthenticate": + def www_authenticate(self) -> WWWAuthenticate: """The `WWW-Authenticate` header in a parsed form.""" - def on_update(www_auth): + def on_update(www_auth: WWWAuthenticate) -> None: if not www_auth and "www-authenticate" in self.headers: del self.headers["www-authenticate"] elif www_auth: self.headers["WWW-Authenticate"] = www_auth.to_header() - header = self.headers.get("www-authenticate") # type: ignore + header = self.headers.get("www-authenticate") return parse_www_authenticate_header(header, on_update) diff --git a/src/werkzeug/wrappers/base_request.py b/src/werkzeug/wrappers/base_request.py index aed1ca8c7..6cb3334aa 100644 --- a/src/werkzeug/wrappers/base_request.py +++ b/src/werkzeug/wrappers/base_request.py @@ -1,15 +1,12 @@ +import typing as t from functools import update_wrapper from io import BytesIO -from typing import BinaryIO -from typing import Callable -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from .._internal import _to_str from .._internal import _wsgi_decoding_dance from ..datastructures import CombinedMultiDict from ..datastructures import EnvironHeaders +from ..datastructures import FileStorage from ..datastructures import ImmutableList from ..datastructures import ImmutableMultiDict from ..datastructures import iter_multi_items @@ -26,11 +23,10 @@ from ..wsgi import get_current_url from ..wsgi import get_host from ..wsgi import get_input_stream -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from werkzeug.wrappers.request import PlainRequest, Request # noqa: F401 - from werkzeug.wsgi import LimitedStream # noqa: F401 +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment class BaseRequest: @@ -79,10 +75,10 @@ class Request(BaseRequest, ETagRequestMixin): """ #: the charset for the request, defaults to utf-8 - charset = "utf-8" + charset: str = "utf-8" #: the error handling procedure for errors, defaults to 'replace' - encoding_errors = "replace" + encoding_errors: str = "replace" #: the maximum content length. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -93,7 +89,7 @@ class Request(BaseRequest, ETagRequestMixin): #: Have a look at :doc:`/request_data` for more details. #: #: .. versionadded:: 0.5 - max_content_length = None + max_content_length: t.Optional[int] = None #: the maximum form field size. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -104,7 +100,7 @@ class Request(BaseRequest, ETagRequestMixin): #: Have a look at :doc:`/request_data` for more details. #: #: .. versionadded:: 0.5 - max_form_memory_size = None + max_form_memory_size: t.Optional[int] = None #: the class to use for `args` and `form`. The default is an #: :class:`~werkzeug.datastructures.ImmutableMultiDict` which supports @@ -115,14 +111,14 @@ class Request(BaseRequest, ETagRequestMixin): #: possible to use mutable structures, but this is not recommended. #: #: .. versionadded:: 0.6 - parameter_storage_class = ImmutableMultiDict + parameter_storage_class: t.Type[MultiDict] = ImmutableMultiDict #: the type to be used for list values from the incoming WSGI environment. #: By default an :class:`~werkzeug.datastructures.ImmutableList` is used #: (for example for :attr:`access_list`). #: #: .. versionadded:: 0.6 - list_storage_class = ImmutableList + list_storage_class: t.Type[t.List] = ImmutableList #: The type to be used for dict values from the incoming WSGI #: environment. (For example for :attr:`cookies`.) By default an @@ -132,11 +128,11 @@ class Request(BaseRequest, ETagRequestMixin): #: Changed to ``ImmutableMultiDict`` to support multiple values. #: #: .. versionadded:: 0.6 - dict_storage_class = ImmutableMultiDict + dict_storage_class: t.Type[MultiDict] = ImmutableMultiDict #: The form data parser that shoud be used. Can be replaced to customize #: the form date parsing. - form_data_parser_class = FormDataParser + form_data_parser_class: t.Type[FormDataParser] = FormDataParser #: Optionally a list of hosts that is trusted by this request. By default #: all hosts are trusted which means that whatever the client sends the @@ -148,17 +144,17 @@ class Request(BaseRequest, ETagRequestMixin): #: behind one). #: #: .. versionadded:: 0.9 - trusted_hosts = None + trusted_hosts: t.Optional[t.List[str]] = None #: Indicates whether the data descriptor should be allowed to read and #: buffer up the input stream. By default it's enabled. #: #: .. versionadded:: 0.9 - disable_data_descriptor = False + disable_data_descriptor: bool = False def __init__( self, - environ: WSGIEnvironment, + environ: "WSGIEnvironment", populate_request: bool = True, shallow: bool = False, ) -> None: @@ -190,7 +186,7 @@ def url_charset(self) -> str: return self.charset @classmethod - def from_values(cls, *args, **kwargs): + def from_values(cls, *args, **kwargs) -> "BaseRequest": """Create a new request object based on the values provided. If environ is given missing values are filled from there. This method is useful for small scripts when you need to simulate a request from an URL. @@ -214,12 +210,14 @@ def from_values(cls, *args, **kwargs): kwargs["charset"] = charset builder = EnvironBuilder(*args, **kwargs) try: - return builder.get_request(cls) + return builder.get_request(cls) # type: ignore finally: builder.close() @classmethod - def application(cls, f: Callable) -> Callable: + def application( + cls, f: t.Callable[["BaseRequest"], "WSGIApplication"] + ) -> "WSGIApplication": """Decorate a function as responder that accepts the request as the last argument. This works like the :func:`responder` decorator but the function is passed the request object as the @@ -258,10 +256,10 @@ def application(*args): def _get_file_stream( self, total_content_length: int, - content_type: Optional[str], - filename: Optional[str] = None, - content_length: Optional[int] = None, - ) -> BinaryIO: + content_type: t.Optional[str], + filename: t.Optional[str] = None, + content_length: t.Optional[int] = None, + ): """Called to get a stream for the file upload. This must provide a file-like class with `read()`, `readline()` @@ -333,11 +331,11 @@ def _load_form_data(self) -> None: mimetype, options = parse_options_header(content_type) parser = self.make_form_data_parser() data = parser.parse( - self._get_stream_for_parsing(), mimetype, content_length, options, + self._get_stream_for_parsing(), mimetype, content_length, options ) else: data = ( - self.stream, # type: ignore + self.stream, self.parameter_storage_class(), self.parameter_storage_class(), ) @@ -347,7 +345,7 @@ def _load_form_data(self) -> None: d = self.__dict__ d["stream"], d["form"], d["files"] = data - def _get_stream_for_parsing(self) -> Union[BytesIO, "LimitedStream"]: + def _get_stream_for_parsing(self) -> t.BinaryIO: """This is the same as accessing :attr:`stream` with the difference that if it finds cached data from calling :meth:`get_data` first it will create a new stream out of the cached data. @@ -370,14 +368,14 @@ def close(self) -> None: for _key, value in iter_multi_items(files or ()): value.close() - def __enter__(self) -> "Request": - return self # type: ignore + def __enter__(self) -> "BaseRequest": + return self - def __exit__(self, exc_type: None, exc_value: None, tb: None) -> None: + def __exit__(self, exc_type, exc_value, tb) -> None: self.close() @cached_property - def stream(self): + def stream(self) -> t.BinaryIO: """ If the incoming form data was not encoded with a known mimetype the data is stored unmodified in this stream for consumption. Most @@ -407,7 +405,7 @@ def stream(self): ) @cached_property - def args(self): + def args(self) -> "MultiDict[str, str]": """The parsed URL parameters (the part in the URL after the question mark). @@ -425,7 +423,7 @@ def args(self): ) @cached_property - def data(self): + def data(self) -> bytes: """ Contains the incoming request data as string in case it came with a mimetype Werkzeug does not handle. @@ -443,8 +441,8 @@ def data(self): return self.get_data(parse_form_data=True) def get_data( - self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False, - ) -> Union[str, bytes]: + self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False + ) -> bytes: """This reads the buffered incoming data from the client into one bytes object. By default this is cached but that behavior can be changed by setting `cache` to `False`. @@ -481,7 +479,7 @@ def get_data( return rv @cached_property - def form(self): + def form(self) -> "ImmutableMultiDict[str, str]": """The form parameters. By default an :class:`~werkzeug.datastructures.ImmutableMultiDict` is returned from this function. This can be changed by setting @@ -500,7 +498,7 @@ def form(self): return self.form @cached_property - def values(self): + def values(self) -> "CombinedMultiDict[str, str]": """A :class:`werkzeug.datastructures.CombinedMultiDict` that combines :attr:`args` and :attr:`form`.""" args = [] @@ -511,7 +509,7 @@ def values(self): return CombinedMultiDict(args) @cached_property - def files(self): + def files(self) -> "ImmutableMultiDict[str, FileStorage]": """:class:`~werkzeug.datastructures.MultiDict` object containing all uploaded files. Each key in :attr:`files` is the name from the ````. Each value in :attr:`files` is a @@ -534,10 +532,10 @@ def files(self): return self.files @cached_property - def cookies(self): + def cookies(self) -> "ImmutableMultiDict[str, str]": """A :class:`dict` with the contents of all cookies transmitted with the request.""" - return parse_cookie( + return parse_cookie( # type: ignore self.environ, self.charset, self.encoding_errors, @@ -545,7 +543,7 @@ def cookies(self): ) @cached_property - def headers(self): + def headers(self) -> EnvironHeaders: """The headers from the WSGI environ as immutable :class:`~werkzeug.datastructures.EnvironHeaders`. """ @@ -558,7 +556,7 @@ def path(self) -> str: even if the URL root is accessed. """ raw_path = _wsgi_decoding_dance( - self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors, + self.environ.get("PATH_INFO") or "", self.charset, self.encoding_errors ) return "/" + raw_path.lstrip("/") @@ -571,12 +569,12 @@ def full_path(self) -> str: def script_root(self) -> str: """The root path of the script without the trailing slash.""" raw_path = _wsgi_decoding_dance( - self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors, + self.environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors ) return raw_path.rstrip("/") @cached_property - def url(self): + def url(self) -> str: """The reconstructed current URL as IRI. See also: :attr:`trusted_hosts`. """ @@ -588,7 +586,7 @@ def base_url(self) -> str: See also: :attr:`trusted_hosts`. """ return get_current_url( - self.environ, strip_querystring=True, trusted_hosts=self.trusted_hosts, + self.environ, strip_querystring=True, trusted_hosts=self.trusted_hosts ) @cached_property @@ -615,23 +613,21 @@ def host(self) -> str: """ return get_host(self.environ, trusted_hosts=self.trusted_hosts) - query_string = environ_property( + query_string = environ_property[bytes]( "QUERY_STRING", - "", - read_only=True, + b"", load_func=lambda x: x.encode("latin1"), doc="The URL parameters as raw bytes.", ) method = environ_property( "REQUEST_METHOD", "GET", - read_only=True, load_func=lambda x: x.upper(), doc="The request method. (For example ``'GET'`` or ``'POST'``).", ) @cached_property - def access_route(self): + def access_route(self) -> t.List[str]: """If a forwarded header exists this is a list of all ip addresses from the client ip to the last proxy server. """ @@ -644,17 +640,17 @@ def access_route(self): return self.list_storage_class() @property - def remote_addr(self) -> str: + def remote_addr(self) -> t.Optional[str]: """The remote address of the client.""" return self.environ.get("REMOTE_ADDR") - remote_user = environ_property( + remote_user = environ_property[str]( "REMOTE_USER", doc="""If the server supports user authentication, and the script is protected, this attribute contains the username the user has authenticated as.""", ) - scheme = environ_property( + scheme = environ_property[str]( "wsgi.url_scheme", doc=""" URL scheme (http or https). @@ -665,17 +661,17 @@ def remote_addr(self) -> str: lambda self: self.environ["wsgi.url_scheme"] == "https", doc="`True` if the request is secure.", ) - is_multithread = environ_property( + is_multithread = environ_property[bool]( "wsgi.multithread", doc="""boolean that is `True` if the application is served by a multithreaded WSGI server.""", ) - is_multiprocess = environ_property( + is_multiprocess = environ_property[bool]( "wsgi.multiprocess", doc="""boolean that is `True` if the application is served by a WSGI server that spawns multiple processes.""", ) - is_run_once = environ_property( + is_run_once = environ_property[bool]( "wsgi.run_once", doc="""boolean that is `True` if the application will be executed only once in a process lifetime. This is the case for @@ -684,9 +680,7 @@ def remote_addr(self) -> str: ) -def _assert_not_shallow( - request: Union["Request", "BaseRequest", "PlainRequest"] -) -> None: +def _assert_not_shallow(request: BaseRequest) -> None: if request.shallow: raise RuntimeError( "A shallow request tried to consume form data. If you really" diff --git a/src/werkzeug/wrappers/base_response.py b/src/werkzeug/wrappers/base_response.py index 81a7df52d..4e13dbba3 100644 --- a/src/werkzeug/wrappers/base_response.py +++ b/src/werkzeug/wrappers/base_response.py @@ -1,14 +1,8 @@ +import typing +import typing as t import warnings -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import Union +from datetime import datetime +from datetime import timedelta from .._internal import _to_bytes from .._internal import _to_str @@ -21,16 +15,14 @@ from ..utils import get_content_type from ..wsgi import ClosingIterator from ..wsgi import get_current_url -from werkzeug.types import T -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from werkzeug.middleware.proxy_fix import ProxyFix # noqa: F401 - from werkzeug.wrappers.request import Request # noqa: F401 - from werkzeug.wrappers.response import Response # noqa: F401 +if t.TYPE_CHECKING: + from wsgiref.types import StartResponse + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment -def _warn_if_string(iterable: Any) -> None: +def _warn_if_string(iterable: t.Iterable) -> None: """Helper for the response objects to check if the iterable returned to the WSGI server is not a string. """ @@ -45,7 +37,9 @@ def _warn_if_string(iterable: Any) -> None: ) -def _iter_encoded(iterable: Any, charset: str) -> Iterator[bytes]: +def _iter_encoded( + iterable: t.Iterable[t.Union[str, bytes]], charset: str +) -> t.Iterator[bytes]: for item in iterable: if isinstance(item, str): yield item.encode(charset) @@ -53,7 +47,7 @@ def _iter_encoded(iterable: Any, charset: str) -> Iterator[bytes]: yield item -def _clean_accept_ranges(accept_ranges: bool) -> str: +def _clean_accept_ranges(accept_ranges: t.Union[bool, str]) -> str: if accept_ranges is True: return "bytes" elif accept_ranges is False: @@ -134,13 +128,13 @@ def application(environ, start_response): """ #: the charset of the response. - charset = "utf-8" + charset: str = "utf-8" #: the default status if none is provided. - default_status = 200 + default_status: int = 200 #: the default mimetype if none is provided. - default_mimetype = "text/plain" + default_mimetype: str = "text/plain" #: if set to `False` accessing properties on the response object will #: not try to consume the response iterator and convert it into a list. @@ -150,19 +144,19 @@ def application(environ, start_response): #: That attribute was previously called `implicit_seqence_conversion`. #: (Notice the typo). If you did use this feature, you have to adapt #: your code to the name change. - implicit_sequence_conversion = True + implicit_sequence_conversion: bool = True #: Should this response object correct the location header to be RFC #: conformant? This is true by default. #: #: .. versionadded:: 0.8 - autocorrect_location_header = True + autocorrect_location_header: bool = True #: Should this response object automatically set the content-length #: header if possible? This is true by default. #: #: .. versionadded:: 0.8 - automatically_set_content_length = True + automatically_set_content_length: bool = True #: Warn if a cookie header exceeds this size. The default, 4093, should be #: safely `supported by most browsers `_. A cookie larger than @@ -172,15 +166,22 @@ def application(environ, start_response): #: .. versionadded:: 0.13 #: #: .. _`cookie`: http://browsercookielimits.squawky.net/ - max_cookie_size = 4093 + max_cookie_size: int = 4093 def __init__( self, - response: Optional[Any] = None, - status: Optional[Union[Tuple, int, str]] = None, - headers: Optional[Union[Headers, List[Tuple[str, str]]]] = None, - mimetype: Optional[str] = None, - content_type: Optional[str] = None, + response: t.Optional[ + t.Union[t.Iterable[bytes], bytes, t.Iterable[str], str] + ] = None, + status: t.Optional[t.Union[int, str]] = None, + headers: t.Optional[ + t.Union[ + t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]], + t.Iterable[t.Tuple[str, t.Union[str, int]]], + ] + ] = None, + mimetype: t.Optional[str] = None, + content_type: t.Optional[str] = None, direct_passthrough: bool = False, ) -> None: if isinstance(headers, Headers): @@ -200,21 +201,21 @@ def __init__( self.headers["Content-Type"] = content_type if status is None: status = self.default_status - self.status = status + self.status = status # type: ignore self.direct_passthrough = direct_passthrough - self._on_close: List[Callable] = [] + self._on_close: t.List[t.Callable[[], t.Any]] = [] # we set the response after the headers so that if a class changes # the charset attribute, the data is set in the correct charset. if response is None: - self.response = [] + self.response: t.Union[t.Iterable[bytes], t.Iterable[str]] = [] elif isinstance(response, (str, bytes, bytearray)): self.set_data(response) else: self.response = response - def call_on_close(self, func: Callable) -> Callable: + def call_on_close(self, func: t.Callable[[], t.Any]) -> t.Callable[[], t.Any]: """Adds a function to the internal list of functions that should be called as part of closing down the response. Since 0.7 this function also returns the function that was passed so that this @@ -233,7 +234,9 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {body_info} [{self.status}]>" @classmethod - def force_type(cls: Type[T], response: Any, environ: WSGIEnvironment = None) -> T: + def force_type( + cls, response: "BaseResponse", environ: t.Optional["WSGIEnvironment"] = None + ) -> "BaseResponse": """Enforce that the WSGI response is a response object of the current type. Werkzeug will use the :class:`BaseResponse` internally in many situations like the exceptions. If you call :meth:`get_response` on an @@ -277,11 +280,8 @@ def force_type(cls: Type[T], response: Any, environ: WSGIEnvironment = None) -> @classmethod def from_app( - cls, - app: Union["Response", Callable, "ProxyFix"], - environ: WSGIEnvironment, - buffered: bool = False, - ) -> "Response": + cls, app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False + ) -> "BaseResponse": """Create a new response object from an application output. This works best if you pass it an application that returns a generator all the time. Sometimes applications may use the `write()` callable @@ -296,30 +296,30 @@ def from_app( """ from ..test import run_wsgi_app - return cls(*run_wsgi_app(app, environ, buffered)) # type: ignore + return cls(*run_wsgi_app(app, environ, buffered)) @property - def status_code(self): + def status_code(self) -> int: """The HTTP status code as a number.""" return self._status_code @status_code.setter - def status_code(self, code): - self.status = code + def status_code(self, code: int) -> None: + self.status = code # type: ignore @property - def status(self): + def status(self) -> str: """The HTTP status code as a string.""" return self._status @status.setter - def status(self, value): + def status(self, value: t.Union[str, int]) -> None: if not isinstance(value, (str, bytes, int)): raise TypeError("Invalid status argument") self._status, self._status_code = self._clean_status(value) - def _clean_status(self, value: Union[str, int]) -> Tuple[str, int]: + def _clean_status(self, value: t.Union[str, int]) -> t.Tuple[str, int]: status = _to_str(value, self.charset) split_status = status.split(None, 1) @@ -348,7 +348,15 @@ def _clean_status(self, value: Union[str, int]) -> Tuple[str, int]: # one-word message return f"0 {status}", 0 - def get_data(self, as_text: bool = False) -> Union[str, bytes]: + @typing.overload + def get_data(self, as_text: "t.Literal[False]" = False) -> str: + ... + + @typing.overload + def get_data(self, as_text: "t.Literal[True]") -> bytes: + ... + + def get_data(self, as_text=False): """The string representation of the response body. Whenever you call this property the response iterable is encoded and flattened. This can lead to unwanted behavior if you stream big data. @@ -364,10 +372,10 @@ def get_data(self, as_text: bool = False) -> Union[str, bytes]: self._ensure_sequence() rv = b"".join(self.iter_encoded()) if as_text: - rv = rv.decode(self.charset) # type: ignore + rv = rv.decode(self.charset) return rv - def set_data(self, value: Union[str, bytes]) -> None: + def set_data(self, value: t.Union[bytes, str]) -> None: """Sets a new string as response. The value must be a string or bytes. If a string is set it's encoded to the charset of the response (utf-8 by default). @@ -390,7 +398,7 @@ def set_data(self, value: Union[str, bytes]) -> None: doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.", ) - def calculate_content_length(self) -> int: + def calculate_content_length(self) -> t.Optional[int]: """Returns the content length if available or `None` otherwise.""" try: self._ensure_sequence() @@ -408,7 +416,7 @@ def _ensure_sequence(self, mutable: bool = False) -> None: if self.is_sequence: # if we need a mutable object, we ensure it's a list. if mutable and not isinstance(self.response, list): - self.response = list(self.response) + self.response = list(self.response) # type: ignore return if self.direct_passthrough: raise RuntimeError( @@ -440,7 +448,7 @@ def make_sequence(self) -> None: if close is not None: self.call_on_close(close) - def iter_encoded(self) -> Iterator[Any]: + def iter_encoded(self) -> t.Iterator[bytes]: """Iter the response encoded with the encoding of the response. If the response object is invoked as WSGI application the return value of this method is used as application iterator unless @@ -456,14 +464,14 @@ def iter_encoded(self) -> Iterator[Any]: def set_cookie( self, key: str, - value: Union[str, bytes] = "", - max_age: Optional[int] = None, - expires: Optional[int] = None, - path: str = "/", - domain: Optional[str] = None, + value: str = "", + max_age: t.Optional[t.Union[timedelta, int]] = None, + expires: t.Optional[t.Union[str, datetime, int, float]] = None, + path: t.Optional[str] = "/", + domain: t.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: Optional[str] = None, + samesite: t.Optional[str] = None, ) -> None: """Sets a cookie. @@ -510,11 +518,11 @@ def delete_cookie( self, key: str, path: str = "/", - domain: Optional[str] = None, + domain: t.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: Optional[str] = None, - ) -> None: + samesite: t.Optional[str] = None, + ): """Delete a cookie. Fails silently if key doesn't exist. :param key: the key (name) of the cookie to be deleted. @@ -550,7 +558,7 @@ def is_streamed(self) -> bool: filtering that should not take place for streamed responses. """ try: - len(self.response) + len(self.response) # type: ignore except (TypeError, AttributeError): return True return False @@ -580,7 +588,7 @@ def close(self) -> None: def __enter__(self) -> "BaseResponse": return self - def __exit__(self, exc_type: None, exc_value: None, tb: None) -> None: + def __exit__(self, exc_type, exc_value, tb): self.close() def freeze(self) -> None: @@ -596,7 +604,7 @@ def freeze(self) -> None: self.response = list(self.iter_encoded()) self.headers["Content-Length"] = str(sum(map(len, self.response))) - def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: + def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: """This is automatically called right before the response is started and returns headers modified for the given environment. It returns a copy of the headers from the response with some modifications applied @@ -621,9 +629,9 @@ def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: object. """ headers = Headers(self.headers) - location = None - content_location = None - content_length = None + location: t.Optional[str] = None + content_location: t.Optional[str] = None + content_length: t.Optional[t.Union[str, int]] = None status = self.status_code # iterate over the headers to find all values in one go. Because @@ -652,7 +660,7 @@ def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: current_url = iri_to_uri(current_url) location = url_join(current_url, location) if location != old_location: - headers["Location"] = location + headers["Location"] = location # type: ignore # make sure the content location is a URL if content_location is not None and isinstance(content_location, str): @@ -689,7 +697,7 @@ def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: return headers - def get_app_iter(self, environ: WSGIEnvironment) -> ClosingIterator: + def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]: """Returns the application iterator for the given environ. Depending on the request method and the current status code the return value might be an empty response rather than the one from the response. @@ -709,7 +717,7 @@ def get_app_iter(self, environ: WSGIEnvironment) -> ClosingIterator: or 100 <= status < 200 or status in (204, 304) ): - iterable: Iterable[Any] = () + iterable: t.Iterable[bytes] = () elif self.direct_passthrough: if __debug__: _warn_if_string(self.response) @@ -719,8 +727,8 @@ def get_app_iter(self, environ: WSGIEnvironment) -> ClosingIterator: return ClosingIterator(iterable, self.close) def get_wsgi_response( - self, environ: WSGIEnvironment - ) -> Tuple[ClosingIterator, str, List[Tuple[str, str]]]: + self, environ: "WSGIEnvironment" + ) -> t.Tuple[t.Iterable[bytes], str, t.List[t.Tuple[str, str]]]: """Returns the final WSGI response as tuple. The first item in the tuple is the application iterator, the second the status and the third the list of headers. The response returned is created @@ -738,8 +746,8 @@ def get_wsgi_response( return app_iter, self.status, headers.to_wsgi_list() def __call__( - self, environ: WSGIEnvironment, start_response: Callable - ) -> ClosingIterator: + self, environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: """Process this response as WSGI application. :param environ: the WSGI environment. diff --git a/src/werkzeug/wrappers/common_descriptors.py b/src/werkzeug/wrappers/common_descriptors.py index 678f3261c..c8b573709 100644 --- a/src/werkzeug/wrappers/common_descriptors.py +++ b/src/werkzeug/wrappers/common_descriptors.py @@ -1,9 +1,11 @@ +import typing as t from datetime import datetime from datetime import timedelta -from typing import Dict -from typing import Optional from ..datastructures import CallbackDict +from ..datastructures import EnvironHeaders +from ..datastructures import Headers +from ..datastructures import HeaderSet from ..http import dump_age from ..http import dump_csp_header from ..http import dump_header @@ -18,7 +20,9 @@ from ..utils import get_content_type from ..utils import header_property from ..wsgi import get_content_length -from werkzeug.types import WSGIEnvironment + +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment class CommonRequestDescriptorsMixin: @@ -29,9 +33,9 @@ class CommonRequestDescriptorsMixin: .. versionadded:: 0.5 """ - environ: WSGIEnvironment + headers: Headers - content_type = header_property( + content_type = header_property[str]( "Content-Type", doc="""The Content-Type entity-header field indicates the media type of the entity-body sent to the recipient or, in the case of @@ -41,7 +45,7 @@ class CommonRequestDescriptorsMixin: ) @cached_property - def content_length(self): + def content_length(self) -> t.Optional[int]: """The Content-Length entity-header field indicates the size of the entity-body in bytes or, in the case of the HEAD method, the size of the entity-body that would have been sent had the request been a @@ -49,7 +53,7 @@ def content_length(self): """ return get_content_length(self.headers) - content_encoding = header_property( + content_encoding = header_property[str]( "Content-Encoding", doc="""The Content-Encoding entity-header field is used as a modifier to the media-type. When present, its value indicates @@ -61,7 +65,7 @@ def content_length(self): .. versionadded:: 0.9""", read_only=True, ) - content_md5 = header_property( + content_md5 = header_property[str]( "Content-MD5", doc="""The Content-MD5 entity-header field, as defined in RFC 1864, is an MD5 digest of the entity-body for the purpose of @@ -73,7 +77,7 @@ def content_length(self): .. versionadded:: 0.9""", read_only=True, ) - referrer = header_property( + referrer = header_property[str]( "Referer", doc="""The Referer[sic] request-header field allows the client to specify, for the server's benefit, the address (URI) of the @@ -104,7 +108,7 @@ def content_length(self): def _parse_content_type(self) -> None: if not hasattr(self, "_parsed_content_type"): self._parsed_content_type = parse_options_header( - self.headers.get("Content-Type", "") # type: ignore + self.headers.get("Content-Type", "") ) @property @@ -118,7 +122,7 @@ def mimetype(self) -> str: return self._parsed_content_type[0].lower() @property - def mimetype_params(self) -> Dict[str, str]: + def mimetype_params(self) -> t.Dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -127,7 +131,7 @@ def mimetype_params(self) -> Dict[str, str]: return self._parsed_content_type[1] @cached_property - def pragma(self): + def pragma(self) -> HeaderSet: """The Pragma general-header field is used to include implementation-specific directives that might apply to any recipient along the request/response chain. All pragma directives specify @@ -137,27 +141,53 @@ def pragma(self): return parse_set_header(self.headers.get("Pragma", "")) +def _set_property(name: str, doc: t.Optional[str] = None) -> property: + def fget(self): + def on_update(header_set): + if not header_set and name in self.headers: + del self.headers[name] + elif header_set: + self.headers[name] = header_set.to_header() + + return parse_set_header(self.headers.get(name), on_update) + + def fset(self, value): + if not value: + del self.headers[name] + elif isinstance(value, str): + self.headers[name] = value + else: + self.headers[name] = dump_header(value) + + return property(fget, fset, doc=doc) + + class CommonResponseDescriptorsMixin: """A mixin for :class:`BaseResponse` subclasses. Response objects that mix this class in will automatically get descriptors for a couple of HTTP headers with automatic type conversion. """ - environ = WSGIEnvironment + charset: str + environ: "WSGIEnvironment" + headers: EnvironHeaders @property - def mimetype(self): + def mimetype(self) -> t.Optional[str]: """The mimetype (content type without charset etc.)""" ct = self.headers.get("content-type") + if ct: return ct.split(";")[0].strip() + else: + return None @mimetype.setter - def mimetype(self, value): + def mimetype(self, value: str) -> None: self.headers["Content-Type"] = get_content_type(value, self.charset) @property - def mimetype_params(self) -> CallbackDict: + def mimetype_params(self) -> t.Dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -168,12 +198,10 @@ def mimetype_params(self) -> CallbackDict: def on_update(d): self.headers["Content-Type"] = dump_options_header(self.mimetype, d) - d = parse_options_header(self.headers.get("content-type", ""))[ # type: ignore - 1 - ] + d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) - location = header_property( + location = header_property[str]( "Location", doc="""The Location response-header field is used to redirect the recipient to a location other than the Request-URI for @@ -184,7 +212,7 @@ def on_update(d): "Age", None, parse_age, - dump_age, + dump_age, # type: ignore doc="""The Age response-header field conveys the sender's estimate of the amount of time since the response (or its revalidation) was generated at the origin server. @@ -192,7 +220,7 @@ def on_update(d): Age values are non-negative decimal integers, representing time in seconds.""", ) - content_type = header_property( + content_type = header_property[str]( "Content-Type", doc="""The Content-Type entity-header field indicates the media type of the entity-body sent to the recipient or, in the case of @@ -210,14 +238,14 @@ def on_update(d): entity-body that would have been sent had the request been a GET.""", ) - content_location = header_property( + content_location = header_property[str]( "Content-Location", doc="""The Content-Location entity-header field MAY be used to supply the resource location for the entity enclosed in the message when that entity is accessible from a location separate from the requested resource's URI.""", ) - content_encoding = header_property( + content_encoding = header_property[str]( "Content-Encoding", doc="""The Content-Encoding entity-header field is used as a modifier to the media-type. When present, its value indicates @@ -226,7 +254,7 @@ def on_update(d): in order to obtain the media-type referenced by the Content-Type header field.""", ) - content_md5 = header_property( + content_md5 = header_property[str]( "Content-MD5", doc="""The Content-MD5 entity-header field, as defined in RFC 1864, is an MD5 digest of the entity-body for the purpose of @@ -238,7 +266,7 @@ def on_update(d): content_security_policy = header_property( "Content-Security-Policy", None, - parse_csp_header, + parse_csp_header, # type: ignore dump_csp_header, doc="""The Content-Security-Policy header adds an additional layer of security to help detect and mitigate certain types of attacks.""", @@ -246,7 +274,7 @@ def on_update(d): content_security_policy_report_only = header_property( "Content-Security-Policy-Report-Only", None, - parse_csp_header, + parse_csp_header, # type: ignore dump_csp_header, doc="""The Content-Security-Policy-Report-Only header adds a csp policy that is not enforced but is reported thereby helping detect @@ -281,7 +309,7 @@ def on_update(d): ) @property - def retry_after(self): + def retry_after(self) -> t.Optional[datetime]: """The Retry-After response-header field can be used with a 503 (Service Unavailable) response to indicate how long the service is expected to be unavailable to the requesting client. @@ -290,13 +318,13 @@ def retry_after(self): """ value = self.headers.get("retry-after") if value is None: - return + return None elif value.isdigit(): return datetime.utcnow() + timedelta(seconds=int(value)) return parse_date(value) @retry_after.setter - def retry_after(self, value): + def retry_after(self, value: t.Optional[t.Union[datetime, int, str]]) -> None: if value is None: if "retry-after" in self.headers: del self.headers["retry-after"] @@ -307,28 +335,6 @@ def retry_after(self, value): value = str(value) self.headers["Retry-After"] = value - def _set_property( # type: ignore - name: str, doc: Optional[str] = None # noqa: B902 - ): - def fget(self): - def on_update(header_set): - if not header_set and name in self.headers: - del self.headers[name] - elif header_set: - self.headers[name] = header_set.to_header() - - return parse_set_header(self.headers.get(name), on_update) - - def fset(self, value): - if not value: - del self.headers[name] - elif isinstance(value, str): - self.headers[name] = value - else: - self.headers[name] = dump_header(value) - - return property(fget, fset, doc=doc) - vary = _set_property( "Vary", doc="""The Vary field value indicates the set of request-header @@ -352,5 +358,3 @@ def fset(self, value): field MUST be present in a 405 (Method Not Allowed) response.""", ) - - del _set_property diff --git a/src/werkzeug/wrappers/cors.py b/src/werkzeug/wrappers/cors.py index 210c4a629..a8f68cdf6 100644 --- a/src/werkzeug/wrappers/cors.py +++ b/src/werkzeug/wrappers/cors.py @@ -1,3 +1,6 @@ +import typing as t + +from ..datastructures import Headers from ..http import dump_header from ..http import parse_set_header from ..utils import header_property @@ -11,7 +14,7 @@ class CORSRequestMixin: .. versionadded:: 1.0 """ - origin = header_property( + origin = header_property[str]( "Origin", doc=( "The host that the request originated from. Set" @@ -33,7 +36,7 @@ class CORSRequestMixin: read_only=True, ) - access_control_request_method = header_property( + access_control_request_method = header_property[str]( "Access-Control-Request-Method", doc=( "Sent with a preflight request to indicate which method" @@ -53,8 +56,10 @@ class CORSResponseMixin: .. versionadded:: 1.0 """ + headers: Headers + @property - def access_control_allow_credentials(self): + def access_control_allow_credentials(self) -> bool: """Whether credentials can be shared by the browser to JavaScript code. As part of the preflight request it indicates whether credentials can be used on the cross origin request. @@ -62,7 +67,7 @@ def access_control_allow_credentials(self): return "Access-Control-Allow-Credentials" in self.headers @access_control_allow_credentials.setter - def access_control_allow_credentials(self, value): + def access_control_allow_credentials(self, value: t.Optional[bool]) -> None: if value is True: self.headers["Access-Control-Allow-Credentials"] = "true" else: @@ -82,7 +87,7 @@ def access_control_allow_credentials(self, value): doc="Which methods can be used for the cross origin request.", ) - access_control_allow_origin = header_property( + access_control_allow_origin = header_property[str]( "Access-Control-Allow-Origin", doc="The origin or '*' for any origin that may make cross origin requests.", ) diff --git a/src/werkzeug/wrappers/etag.py b/src/werkzeug/wrappers/etag.py index 5d8c0a100..f290e3dd5 100644 --- a/src/werkzeug/wrappers/etag.py +++ b/src/werkzeug/wrappers/etag.py @@ -1,12 +1,15 @@ -from typing import Any -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union +import typing as t +from datetime import datetime from .._internal import _get_environ from ..datastructures import ContentRange +from ..datastructures import EnvironHeaders +from ..datastructures import ETags +from ..datastructures import Headers +from ..datastructures import IfRange +from ..datastructures import Range from ..datastructures import RequestCacheControl +from ..datastructures import ResponseCacheControl from ..http import generate_etag from ..http import http_date from ..http import is_resource_modified @@ -22,11 +25,9 @@ from ..utils import header_property from ..wrappers.base_response import _clean_accept_ranges from ..wsgi import _RangeWrapper -from werkzeug.datastructures import ResponseCacheControl -from werkzeug.types import WSGIEnvironment -if TYPE_CHECKING: - from werkzeug.wrappers.response import Response +if t.TYPE_CHECKING: + from wsgiref.types import WSGIEnvironment class ETagRequestMixin: @@ -35,8 +36,10 @@ class ETagRequestMixin: only provides access to etags but also to the cache control header. """ + headers: "EnvironHeaders" + @cached_property - def cache_control(self): + def cache_control(self) -> RequestCacheControl: """A :class:`~werkzeug.datastructures.RequestCacheControl` object for the incoming cache control headers. """ @@ -44,7 +47,7 @@ def cache_control(self): return parse_cache_control_header(cache_control, None, RequestCacheControl) @cached_property - def if_match(self): + def if_match(self) -> ETags: """An object containing all the etags in the `If-Match` header. :rtype: :class:`~werkzeug.datastructures.ETags` @@ -52,7 +55,7 @@ def if_match(self): return parse_etags(self.headers.get("If-Match")) @cached_property - def if_none_match(self): + def if_none_match(self) -> ETags: """An object containing all the etags in the `If-None-Match` header. :rtype: :class:`~werkzeug.datastructures.ETags` @@ -60,17 +63,17 @@ def if_none_match(self): return parse_etags(self.headers.get("If-None-Match")) @cached_property - def if_modified_since(self): + def if_modified_since(self) -> t.Optional[datetime]: """The parsed `If-Modified-Since` header as datetime object.""" return parse_date(self.headers.get("If-Modified-Since")) @cached_property - def if_unmodified_since(self): + def if_unmodified_since(self) -> t.Optional[datetime]: """The parsed `If-Unmodified-Since` header as datetime object.""" return parse_date(self.headers.get("If-Unmodified-Since")) @cached_property - def if_range(self): + def if_range(self) -> IfRange: """The parsed `If-Range` header. .. versionadded:: 0.7 @@ -80,7 +83,7 @@ def if_range(self): return parse_if_range_header(self.headers.get("If-Range")) @cached_property - def range(self): + def range(self) -> t.Optional[Range]: """The parsed `Range` header. .. versionadded:: 0.7 @@ -101,9 +104,9 @@ class ETagResponseMixin: response class does not do that. """ - headers: dict status_code: int - response: Any + headers: Headers + response: t.Iterable[bytes] @property def cache_control(self) -> "ResponseCacheControl": @@ -127,7 +130,7 @@ def _wrap_response(self, start: int, length: int) -> None: if self.status_code == 206: self.response = _RangeWrapper(self.response, start, length) - def _is_range_request_processable(self, environ: WSGIEnvironment,) -> bool: + def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool: """Return ``True`` if `Range` header is present and if underlying resource is considered unchanged when compared with `If-Range` header. """ @@ -144,9 +147,9 @@ def _is_range_request_processable(self, environ: WSGIEnvironment,) -> bool: def _process_range_request( self, - environ: WSGIEnvironment, - complete_length: Optional[int] = None, - accept_ranges: Optional[Union[str, bool]] = None, + environ: "WSGIEnvironment", + complete_length: t.Optional[int] = None, + accept_ranges: t.Optional[t.Union[bool, str]] = None, ) -> bool: """Handle Range Request related headers (RFC7233). If `Accept-Ranges` header is valid, and Range Request is processable, we set the headers @@ -181,17 +184,17 @@ def _process_range_request( content_length = range_tuple[1] - range_tuple[0] self.headers["Content-Length"] = content_length self.headers["Accept-Ranges"] = accept_ranges - self.content_range = content_range_header + self.content_range = content_range_header # type: ignore self.status_code = 206 self._wrap_response(range_tuple[0], content_length) return True def make_conditional( self, - request_or_environ: WSGIEnvironment, - accept_ranges: bool = False, - complete_length: Optional[int] = None, - ) -> "Response": + request_or_environ: "WSGIEnvironment", + accept_ranges: t.Union[bool, str] = False, + complete_length: t.Optional[int] = None, + ): """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -236,7 +239,7 @@ def make_conditional( # wsgiref. if "date" not in self.headers: self.headers["Date"] = http_date() - accept_ranges = _clean_accept_ranges(accept_ranges) # type: ignore + accept_ranges = _clean_accept_ranges(accept_ranges) is206 = self._process_range_request(environ, complete_length, accept_ranges) if not is206 and not is_resource_modified( environ, @@ -255,7 +258,7 @@ def make_conditional( length = self.calculate_content_length() # type: ignore if length is not None: self.headers["Content-Length"] = length - return self # type: ignore + return self def add_etag(self, overwrite: bool = False, weak: bool = False) -> None: """Add an etag for the current response if there is none yet.""" @@ -266,22 +269,22 @@ def set_etag(self, etag: str, weak: bool = False) -> None: """Set the etag, and override the old one if there was one.""" self.headers["ETag"] = quote_etag(etag, weak) - def get_etag(self) -> Union[Tuple[str, bool], Tuple[None, None]]: + def get_etag(self) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: """Return a tuple in the form ``(etag, is_weak)``. If there is no ETag the return value is ``(None, None)``. """ return unquote_etag(self.headers.get("ETag")) - def freeze(self, no_etag=False): + def freeze(self, no_etag: bool = False) -> None: """Call this method if you want to make your response object ready for pickeling. This buffers the generator if there is one. This also sets the etag unless `no_etag` is set to `True`. """ if not no_etag: self.add_etag() - super().freeze() + super().freeze() # type: ignore - accept_ranges = header_property( + accept_ranges = header_property[str]( "Accept-Ranges", doc="""The `Accept-Ranges` header. Even though the name would indicate that multiple values are supported, it must be one @@ -293,7 +296,7 @@ def freeze(self, no_etag=False): ) @property - def content_range(self): + def content_range(self) -> ContentRange: """The ``Content-Range`` header as a :class:`~werkzeug.datastructures.ContentRange` object. Available even if the header is not set. @@ -301,7 +304,7 @@ def content_range(self): .. versionadded:: 0.7 """ - def on_update(rng): + def on_update(rng: ContentRange) -> None: if not rng: del self.headers["content-range"] else: @@ -316,7 +319,7 @@ def on_update(rng): return rv @content_range.setter - def content_range(self, value): + def content_range(self, value: t.Optional[t.Union[ContentRange, str]]) -> None: if not value: del self.headers["content-range"] elif isinstance(value, str): diff --git a/src/werkzeug/wrappers/json.py b/src/werkzeug/wrappers/json.py index e36496f44..df465c908 100644 --- a/src/werkzeug/wrappers/json.py +++ b/src/werkzeug/wrappers/json.py @@ -1,16 +1,14 @@ import datetime import json +import typing as t import uuid -from typing import Dict -from typing import List -from typing import Union from ..exceptions import BadRequest class _JSONModule: @staticmethod - def _default(o): + def _default(o: t.Any) -> t.Any: if isinstance(o, datetime.date): return o.isoformat() @@ -23,14 +21,14 @@ def _default(o): raise TypeError() @classmethod - def dumps(cls, obj, **kw): + def dumps(cls, obj: t.Any, **kw) -> str: kw.setdefault("separators", (",", ":")) kw.setdefault("default", cls._default) kw.setdefault("sort_keys", True) return json.dumps(obj, **kw) @staticmethod - def loads(s: bytes, **kw) -> Union[List[int], Dict[str, str]]: + def loads(s: t.Union[str, bytes], **kw) -> t.Any: return json.loads(s, **kw) @@ -45,7 +43,7 @@ class JSONMixin: json_module = _JSONModule @property - def json(self): + def json(self) -> t.Optional[t.Any]: """The parsed JSON data if :attr:`mimetype` indicates JSON (:mimetype:`application/json`, see :meth:`is_json`). @@ -54,29 +52,31 @@ def json(self): return self.get_json() @property - def is_json(self): + def is_json(self) -> bool: """Check if the mimetype indicates JSON data, either :mimetype:`application/json` or :mimetype:`application/*+json`. """ - mt = self.mimetype + mt = self.mimetype # type: ignore return ( mt == "application/json" or mt.startswith("application/") and mt.endswith("+json") ) - def _get_data_for_json(self, cache): + def _get_data_for_json(self, cache: bool) -> bytes: try: - return self.get_data(cache=cache) + return self.get_data(cache=cache) # type: ignore except TypeError: # Response doesn't have cache param. - return self.get_data() + return self.get_data() # type: ignore # Cached values for ``(silent=False, silent=True)``. Initialized # with sentinel values. - _cached_json = (Ellipsis, Ellipsis) + _cached_json: t.Tuple[t.Any, t.Any] = (Ellipsis, Ellipsis) - def get_json(self, force=False, silent=False, cache=True): + def get_json( + self, force: bool = False, silent: bool = False, cache: bool = True + ) -> t.Optional[t.Any]: """Parse :attr:`data` as JSON. If the mimetype does not indicate JSON @@ -121,7 +121,7 @@ def get_json(self, force=False, silent=False, cache=True): return rv - def on_json_loading_failed(self, e): + def on_json_loading_failed(self, e: ValueError) -> t.Any: """Called if :meth:`get_json` parsing fails and isn't silenced. If this method returns a value, it is used as the return value for :meth:`get_json`. The default implementation raises diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 95502980d..0922922a9 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -7,7 +7,7 @@ from .user_agent import UserAgentMixin -class Request( +class Request( # type: ignore BaseRequest, AcceptMixin, ETagRequestMixin, @@ -42,7 +42,7 @@ class StreamOnlyMixin: want_form_data_parsed = False -class PlainRequest(StreamOnlyMixin, Request): +class PlainRequest(StreamOnlyMixin, Request): # type: ignore """A request object without special form parsing capabilities. .. versionadded:: 0.9 diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index ff8bca5b3..f34261a8b 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -1,5 +1,4 @@ -from typing import List -from typing import Union +import typing as t from ..utils import cached_property from .auth import WWWAuthenticateMixin @@ -15,32 +14,32 @@ class ResponseStream: iterable of the response object. """ - mode = "wb+" + mode: str = "wb+" - def __init__(self, response: "Response") -> None: + def __init__(self, response: BaseResponse): self.response = response self.closed = False - def write(self, value: Union[str, bytes]) -> int: + def write(self, value: bytes) -> int: if self.closed: raise ValueError("I/O operation on closed file") self.response._ensure_sequence(mutable=True) - self.response.response.append(value) + self.response.response.append(value) # type: ignore self.response.headers.pop("Content-Length", None) return len(value) - def writelines(self, seq: List[str]) -> None: + def writelines(self, seq: t.Iterable[bytes]) -> None: for item in seq: self.write(item) - def close(self): + def close(self) -> None: self.closed = True def flush(self) -> None: if self.closed: raise ValueError("I/O operation on closed file") - def isatty(self): + def isatty(self) -> bool: if self.closed: raise ValueError("I/O operation on closed file") return False @@ -50,7 +49,7 @@ def tell(self) -> int: return sum(map(len, self.response.response)) @property - def encoding(self): + def encoding(self) -> str: return self.response.charset @@ -61,9 +60,9 @@ class ResponseStreamMixin: """ @cached_property - def stream(self): + def stream(self) -> ResponseStream: """The response iterable as write-only stream.""" - return ResponseStream(self) + return ResponseStream(self) # type: ignore class Response( # type: ignore diff --git a/src/werkzeug/wrappers/user_agent.py b/src/werkzeug/wrappers/user_agent.py index 5c1efa276..a402d4260 100644 --- a/src/werkzeug/wrappers/user_agent.py +++ b/src/werkzeug/wrappers/user_agent.py @@ -1,3 +1,4 @@ +from ..datastructures import EnvironHeaders from ..useragents import UserAgent from ..utils import cached_property @@ -8,7 +9,9 @@ class UserAgentMixin: request as a :class:`~werkzeug.useragents.UserAgent` object. """ + headers: EnvironHeaders + @cached_property - def user_agent(self): + def user_agent(self) -> UserAgent: """The current user agent.""" - return UserAgent(self.headers.get("User-Agent", "")) + return UserAgent(self.headers.get("User-Agent", "")) # type: ignore diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index f3738a730..9a39b80f5 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -1,36 +1,26 @@ import io import re +import typing as t from functools import partial from functools import update_wrapper -from io import BufferedReader -from io import BytesIO -from io import FileIO from itertools import chain -from typing import Any -from typing import AnyStr -from typing import Callable -from typing import Dict -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union from ._internal import _encode_idna from ._internal import _make_encode_wrapper from ._internal import _to_bytes from ._internal import _to_str +from .urls import _URLTuple from .urls import uri_to_iri from .urls import url_join from .urls import url_parse from .urls import url_quote -from werkzeug.types import BytesOrStr -from werkzeug.types import WSGIEnvironment +if t.TYPE_CHECKING: + from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment -def responder(f: Callable) -> Callable: + +def responder(f: t.Callable[..., "WSGIApplication"]) -> "WSGIApplication": """Marks a function as responder. Decorate a function with it and it will automatically call the return value as WSGI application. @@ -44,11 +34,11 @@ def application(environ, start_response): def get_current_url( - environ: WSGIEnvironment, + environ: "WSGIEnvironment", root_only: bool = False, strip_querystring: bool = False, host_only: bool = False, - trusted_hosts: Optional[List[str]] = None, + trusted_hosts: t.Optional[t.Iterable[str]] = None, ) -> str: """A handy helper function that recreates the full URL as IRI for the current request or parts of it. Here's an example: @@ -99,7 +89,7 @@ def get_current_url( return uri_to_iri("".join(tmp)) -def host_is_trusted(hostname: str, trusted_list: List[str]) -> bool: +def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: """Checks if a host is trusted against a list. This also takes care of port normalization. @@ -116,13 +106,13 @@ def host_is_trusted(hostname: str, trusted_list: List[str]) -> bool: if isinstance(trusted_list, str): trusted_list = [trusted_list] - def _normalize(hostname): + def _normalize(hostname: str) -> bytes: if ":" in hostname: hostname = hostname.rsplit(":", 1)[0] return _encode_idna(hostname) try: - hostname = _normalize(hostname) + hostname_bytes = _normalize(hostname) except UnicodeError: return False for ref in trusted_list: @@ -132,18 +122,18 @@ def _normalize(hostname): else: suffix_match = False try: - ref = _normalize(ref) + ref_bytes = _normalize(ref) except UnicodeError: return False - if ref == hostname: + if ref_bytes == hostname_bytes: return True - if suffix_match and hostname.endswith(b"." + ref): # type: ignore + if suffix_match and hostname_bytes.endswith(b"." + ref_bytes): return True return False def get_host( - environ: WSGIEnvironment, trusted_hosts: Optional[List[str]] = None + environ: "WSGIEnvironment", trusted_hosts: t.Optional[t.Iterable[str]] = None ) -> str: """Return the host for the given WSGI environment. This first checks the ``Host`` header. If it's not present, then ``SERVER_NAME`` and @@ -181,7 +171,7 @@ def get_host( return rv -def get_content_length(environ: WSGIEnvironment) -> Optional[int]: +def get_content_length(environ: "WSGIEnvironment") -> t.Optional[int]: """Returns the content length from the WSGI environment as integer. If it's not available or chunked transfer encoding is used, ``None`` is returned. @@ -203,8 +193,8 @@ def get_content_length(environ: WSGIEnvironment) -> Optional[int]: def get_input_stream( - environ: WSGIEnvironment, safe_fallback: bool = True -) -> Union[BytesIO, "LimitedStream"]: + environ: "WSGIEnvironment", safe_fallback: bool = True +) -> t.BinaryIO: """Returns the input stream from the WSGI environment and wraps it in the most sensible way possible. The stream returned is not the raw WSGI stream in most cases but one that is safe to read from @@ -237,10 +227,10 @@ def get_input_stream( return io.BytesIO() if safe_fallback else stream # Otherwise limit the stream to the content length - return LimitedStream(stream, content_length) + return t.cast(t.BinaryIO, LimitedStream(stream, content_length)) -def get_query_string(environ: WSGIEnvironment) -> str: +def get_query_string(environ: "WSGIEnvironment") -> str: """Returns the ``QUERY_STRING`` from the WSGI environment. This also takes care of the WSGI decoding dance. The string returned will be restricted to ASCII characters. @@ -257,8 +247,8 @@ def get_query_string(environ: WSGIEnvironment) -> str: def get_path_info( - environ: WSGIEnvironment, charset: Optional[str] = "utf-8", errors: str = "replace", -) -> Union[str, bytes]: + environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" +) -> str: """Return the ``PATH_INFO`` from the WSGI environment and decode it unless ``charset`` is ``None``. @@ -274,8 +264,8 @@ def get_path_info( def get_script_name( - environ: WSGIEnvironment, charset: Optional[str] = "utf-8", errors: str = "replace", -) -> Union[str, bytes]: + environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" +) -> str: """Return the ``SCRIPT_NAME`` from the WSGI environment and decode it unless `charset` is set to ``None``. @@ -291,8 +281,8 @@ def get_script_name( def pop_path_info( - environ: WSGIEnvironment, charset: str = "utf-8", errors: str = "replace", -) -> Optional[str]: + environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" +) -> t.Optional[str]: """Removes and returns the next segment of `PATH_INFO`, pushing it onto `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. @@ -349,8 +339,8 @@ def pop_path_info( def peek_path_info( - environ: WSGIEnvironment, charset: Optional[str] = "utf-8", errors: str = "replace", -) -> Optional[Union[str, bytes]]: + environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" +) -> t.Optional[str]: """Returns the next segment on the `PATH_INFO` or `None` if there is none. Works like :func:`pop_path_info` without modifying the environment: @@ -374,20 +364,18 @@ def peek_path_info( segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) if segments: return _to_str( - segments[0].encode("latin1"), charset, errors, allow_none_charset=True, + segments[0].encode("latin1"), charset, errors, allow_none_charset=True ) return None def extract_path_info( - environ_or_baseurl: Union[ - str, Dict[str, Union[str, Tuple[int, int], BytesIO, bool]] - ], - path_or_url: str, + environ_or_baseurl: t.Union[str, "WSGIEnvironment"], + path_or_url: t.Union[str, _URLTuple], charset: str = "utf-8", errors: str = "werkzeug.url_quote", collapse_http_schemes: bool = True, -) -> Optional[str]: +) -> t.Optional[str]: """Extracts the path info from the given URL (or WSGI environment) and path. The path info returned is a string. The URLs might also be IRIs. @@ -448,7 +436,7 @@ def _normalize_netloc(scheme, netloc): environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) base_iri = uri_to_iri(environ_or_baseurl, charset, errors) base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path, = url_parse(url_join(base_iri, path))[:3] + cur_scheme, cur_netloc, cur_path = url_parse(url_join(base_iri, path))[:3] # normalize the network location base_netloc = _normalize_netloc(base_scheme, base_netloc) @@ -499,11 +487,13 @@ class ClosingIterator: def __init__( self, - iterable: Iterable, - callbacks: Optional[Union[Callable, List[Callable]]] = None, + iterable: t.Iterable[bytes], + callbacks: t.Optional[ + t.Union[t.Callable[[], None], t.Iterable[t.Callable[[], None]]] + ] = None, ) -> None: iterator = iter(iterable) - self._next = partial(next, iterator) + self._next = t.cast(t.Callable[[], bytes], partial(next, iterator)) if callbacks is None: callbacks = [] elif callable(callbacks): @@ -518,7 +508,7 @@ def __init__( def __iter__(self) -> "ClosingIterator": return self - def __next__(self) -> Any: + def __next__(self) -> bytes: return self._next() def close(self) -> None: @@ -527,10 +517,8 @@ def close(self) -> None: def wrap_file( - environ: WSGIEnvironment, - file: Union[FileIO, BufferedReader], - buffer_size: int = 8192, -) -> "FileWrapper": + environ: "WSGIEnvironment", file: t.BinaryIO, buffer_size: int = 8192 +) -> t.Iterable[bytes]: """Wraps a file. This uses the WSGI server's file wrapper if available or otherwise the generic :class:`FileWrapper`. @@ -567,9 +555,7 @@ class FileWrapper: :param buffer_size: number of bytes for one iteration. """ - def __init__( - self, file: Union[FileIO, BufferedReader], buffer_size: int = 8192 - ) -> None: + def __init__(self, file: t.BinaryIO, buffer_size: int = 8192) -> None: self.file = file self.buffer_size = buffer_size @@ -588,7 +574,7 @@ def seek(self, *args) -> None: if hasattr(self.file, "seek"): self.file.seek(*args) - def tell(self) -> int: + def tell(self) -> t.Optional[int]: if hasattr(self.file, "tell"): return self.file.tell() return None @@ -621,14 +607,19 @@ class _RangeWrapper: """ def __init__( - self, iterable: Iterable, start_byte: int = 0, byte_range: Optional[int] = None, - ) -> None: + self, + iterable: t.Union[t.Iterable[bytes], t.BinaryIO], + start_byte: int = 0, + byte_range: t.Optional[int] = None, + ): self.iterable = iter(iterable) self.byte_range = byte_range self.start_byte = start_byte self.end_byte = None + if byte_range is not None: - self.end_byte = self.start_byte + self.byte_range + self.end_byte = start_byte + byte_range + self.read_length = 0 self.seekable = ( hasattr(iterable, "seekable") and iterable.seekable() # type: ignore @@ -647,7 +638,7 @@ def _next_chunk(self) -> bytes: self.end_reached = True raise - def _first_iteration(self) -> Tuple[Optional[bytes], int]: + def _first_iteration(self) -> t.Tuple[t.Optional[bytes], int]: chunk = None if self.seekable: self.iterable.seek(self.start_byte) # type: ignore @@ -682,14 +673,16 @@ def __next__(self) -> bytes: self.end_reached = True raise StopIteration() - def close(self): + def close(self) -> None: if hasattr(self.iterable, "close"): - self.iterable.close() + self.iterable.close() # type: ignore def _make_chunk_iter( - stream: Union[IO[AnyStr], Iterator[AnyStr]], limit: Optional[int], buffer_size: int, -) -> Iterator[AnyStr]: + stream: t.Union[t.Iterable[bytes], t.BinaryIO], + limit: t.Optional[int], + buffer_size: int, +) -> t.Iterator[bytes]: """Helper for the line and chunk iter functions.""" if isinstance(stream, (bytes, bytearray, str)): raise TypeError( @@ -700,10 +693,11 @@ def _make_chunk_iter( if item: yield item return + stream = t.cast(t.BinaryIO, stream) if not isinstance(stream, LimitedStream) and limit is not None: - stream = LimitedStream(stream, limit) # type: ignore - _read = stream.read # type: ignore - while 1: + stream = t.cast(t.BinaryIO, LimitedStream(stream, limit)) + _read = stream.read + while True: item = _read(buffer_size) if not item: break @@ -711,11 +705,11 @@ def _make_chunk_iter( def make_line_iter( - stream: Union[Iterator[AnyStr], IO], - limit: Optional[int] = None, + stream: t.Union[t.Iterable[bytes], t.BinaryIO], + limit: t.Optional[int] = None, buffer_size: int = 10 * 1024, cap_at_buffer: bool = False, -) -> Iterator[AnyStr]: +) -> t.Iterator[bytes]: """Safely iterates line-based over an input stream. If the input stream is not a :class:`LimitedStream` the `limit` parameter is mandatory. @@ -747,30 +741,32 @@ def make_line_iter( that the buffer size might be exhausted by a factor of two however. """ - _iter: Iterator[AnyStr] = _make_chunk_iter(stream, limit, buffer_size) + _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item: AnyStr = next(_iter, "") # type: ignore + first_item = next(_iter, "") if not first_item: return s = _make_encode_wrapper(first_item) - empty = s("") - cr = s("\r") - lf = s("\n") - crlf = s("\r\n") + empty = t.cast(bytes, s("")) + cr = t.cast(bytes, s("\r")) + lf = t.cast(bytes, s("\n")) + crlf = t.cast(bytes, s("\r\n")) - _iter = chain((first_item,), _iter) + _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - def _iter_basic_lines() -> Iterator[AnyStr]: + def _iter_basic_lines() -> t.Iterator[bytes]: _join = empty.join - buffer: List[Any] = [] - while 1: + buffer: t.List[bytes] = [] + while True: new_data = next(_iter, "") if not new_data: break - new_buf = [] + new_buf: t.List[bytes] = [] buf_size = 0 - for item in chain(buffer, new_data.splitlines(True)): + for item in t.cast( + t.Iterator[bytes], chain(buffer, new_data.splitlines(True)) + ): new_buf.append(item) buf_size += len(item) if item and item[-1:] in crlf: @@ -801,12 +797,12 @@ def _iter_basic_lines() -> Iterator[AnyStr]: def make_chunk_iter( - stream: Union[Iterator[AnyStr], IO], - separator: str, - limit: Optional[int] = None, + stream: t.Union[t.Iterable[bytes], t.BinaryIO], + separator: bytes, + limit: t.Optional[int] = None, buffer_size: int = 10 * 1024, cap_at_buffer: bool = False, -) -> Iterator[str]: +) -> t.Iterator[bytes]: """Works like :func:`make_line_iter` but accepts a separator which divides chunks. If you want newline based processing you should use :func:`make_line_iter` instead as it @@ -833,27 +829,27 @@ def make_chunk_iter( """ _iter = _make_chunk_iter(stream, limit, buffer_size) - first_item = next(_iter, "") + first_item = next(_iter, b"") if not first_item: return - _iter = chain((first_item,), _iter) + _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) if isinstance(first_item, str): separator = _to_str(separator) _split = re.compile(f"({re.escape(separator)})").split - _join: Callable = "".join + _join = "".join else: - separator = _to_bytes(separator) # type: ignore - _split = re.compile(b"(" + re.escape(separator) + b")").split # type: ignore + separator = _to_bytes(separator) + _split = re.compile(b"(" + re.escape(separator) + b")").split _join = b"".join - buffer: List[Any] = [] - while 1: - new_data = next(_iter, "") + buffer: t.List[bytes] = [] + while True: + new_data = next(_iter, b"") if not new_data: break chunks = _split(new_data) - new_buf: List[Any] = [] + new_buf: t.List[bytes] = [] buf_size = 0 for item in chain(buffer, chunks): if item == separator: @@ -911,21 +907,21 @@ class LimitedStream(io.IOBase): end with `EOF` (like `wsgi.input`) """ - def __init__(self, stream: Union[IO], limit: int) -> None: + def __init__(self, stream: t.BinaryIO, limit: int) -> None: self._read = stream.read self._readline = stream.readline self._pos = 0 self.limit = limit - def __iter__(self) -> "LimitedStream": # type: ignore + def __iter__(self) -> "LimitedStream": return self @property - def is_exhausted(self): + def is_exhausted(self) -> bool: """If the stream is exhausted this attribute is `True`.""" return self._pos >= self.limit - def on_exhausted(self) -> Union[str, bytes]: + def on_exhausted(self) -> bytes: """This is called when the stream tries to read past the limit. The return value of this function is returned from the reading function. @@ -934,7 +930,7 @@ def on_exhausted(self) -> Union[str, bytes]: # correct end of stream marker. return self._read(0) - def on_disconnect(self): + def on_disconnect(self) -> bytes: """What should happen if a disconnect is detected? The return value of this function is returned from read functions in case the client went away. By default a @@ -959,7 +955,7 @@ def exhaust(self, chunk_size: int = 1024 * 64) -> None: self.read(chunk) to_read -= chunk - def read(self, size: Optional[int] = None) -> Union[str, bytes]: + def read(self, size: t.Optional[int] = None) -> bytes: """Read `size` bytes or if size is not provided everything is read. :param size: the number of bytes read. @@ -978,7 +974,7 @@ def read(self, size: Optional[int] = None) -> Union[str, bytes]: self._pos += len(read) return read - def readline(self, size: Optional[int] = None) -> BytesOrStr: # type: ignore + def readline(self, size: t.Optional[int] = None) -> bytes: """Reads one line from the stream.""" if self._pos >= self.limit: return self.on_exhausted() @@ -995,7 +991,7 @@ def readline(self, size: Optional[int] = None) -> BytesOrStr: # type: ignore self._pos += len(line) return line - def readlines(self, size: Optional[int] = None) -> List[str]: # type: ignore + def readlines(self, size: t.Optional[int] = None) -> t.List[bytes]: """Reads a file into a list of strings. It calls :meth:`readline` until the file is read to the end. It does support the optional `size` argument if the underlying stream supports it for @@ -1007,7 +1003,7 @@ def readlines(self, size: Optional[int] = None) -> List[str]: # type: ignore end = min(self.limit, last_pos + size) else: end = self.limit - while 1: + while True: if size is not None: size -= last_pos - self._pos if self._pos >= end: @@ -1015,7 +1011,7 @@ def readlines(self, size: Optional[int] = None) -> List[str]: # type: ignore result.append(self.readline(size)) if size is not None: last_pos = self._pos - return result # type: ignore + return result def tell(self) -> int: """Returns the position of the stream. @@ -1024,11 +1020,11 @@ def tell(self) -> int: """ return self._pos - def __next__(self) -> str: # type: ignore + def __next__(self) -> bytes: line = self.readline() if not line: raise StopIteration() - return line # type: ignore + return line def readable(self) -> bool: return True diff --git a/tests/middleware/test_lint.py b/tests/middleware/test_lint.py index fc321ac1f..71f2ebec3 100644 --- a/tests/middleware/test_lint.py +++ b/tests/middleware/test_lint.py @@ -9,7 +9,7 @@ def dummy_application(environ, start_response): start_response("200 OK", [("Content-Type", "text/plain")]) - return ["Foo"] + return [b"Foo"] def test_lint_middleware(): @@ -42,7 +42,7 @@ def test_lint_middleware_check_environ(key, value, message): def test_lint_middleware_invalid_status(): def my_dummy_application(environ, start_response): start_response("20 OK", [("Content-Type", "text/plain")]) - return ["Foo"] + return [b"Foo"] app = LintMiddleware(my_dummy_application) @@ -57,15 +57,15 @@ def my_dummy_application(environ, start_response): @pytest.mark.parametrize( "headers, message", [ - (tuple([("Content-Type", "text/plain")]), "header list is not a list"), - (["fo"], "Headers must tuple 2-item tuples"), - ([("status", "foo")], "The status header is not supported"), + (tuple([("Content-Type", "text/plain")]), "Header list is not a list."), + (["fo"], "Header items must be 2-item tuples."), + ([("status", "foo")], "The status header is not supported."), ], ) def test_lint_middleware_http_headers(headers, message): def my_dummy_application(environ, start_response): start_response("200 OK", headers) - return ["Foo"] + return [b"Foo"] app = LintMiddleware(my_dummy_application) @@ -77,10 +77,10 @@ def my_dummy_application(environ, start_response): def test_lint_middleware_invalid_location(): def my_dummy_application(environ, start_response): start_response("200 OK", [("location", "foo")]) - return ["Foo"] + return [b"Foo"] app = LintMiddleware(my_dummy_application) environ = create_environ("/test") - with pytest.warns(HTTPWarning, match="absolute URLs required for location header"): + with pytest.warns(HTTPWarning, match="Absolute URLs required for location header."): run_wsgi_app(app, environ, buffered=True) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 7e94146ab..553bbed59 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,17 +1,15 @@ import io import pickle import tempfile +import typing as t from contextlib import contextmanager from copy import copy from copy import deepcopy import pytest -from werkzeug import datastructures +from werkzeug import datastructures as ds from werkzeug import http -from werkzeug.datastructures import LanguageAccept -from werkzeug.datastructures import MIMEAccept -from werkzeug.datastructures import Range from werkzeug.exceptions import BadRequestKeyError @@ -44,7 +42,7 @@ def items(self, multi=1): class _MutableMultiDictTests: - storage_class = None + storage_class: t.Type["ds.MultiDict"] def test_pickle(self): cls = self.storage_class @@ -254,7 +252,7 @@ def test_basic_interface(self): class _ImmutableDictTests: - storage_class = None + storage_class: t.Type[dict] def test_follows_dict_interface(self): cls = self.storage_class @@ -301,11 +299,11 @@ def test_dict_is_hashable(self): class TestImmutableTypeConversionDict(_ImmutableDictTests): - storage_class = datastructures.ImmutableTypeConversionDict # type: ignore + storage_class = ds.ImmutableTypeConversionDict class TestImmutableMultiDict(_ImmutableDictTests): - storage_class = datastructures.ImmutableMultiDict # type: ignore + storage_class = ds.ImmutableMultiDict def test_multidict_is_hashable(self): cls = self.storage_class @@ -326,11 +324,11 @@ def test_multidict_is_hashable(self): class TestImmutableDict(_ImmutableDictTests): - storage_class = datastructures.ImmutableDict # type: ignore + storage_class = ds.ImmutableDict class TestImmutableOrderedMultiDict(_ImmutableDictTests): - storage_class = datastructures.ImmutableOrderedMultiDict # type: ignore + storage_class = ds.ImmutableOrderedMultiDict def test_ordered_multidict_is_hashable(self): a = self.storage_class([("a", 1), ("b", 1), ("a", 2)]) @@ -339,7 +337,7 @@ def test_ordered_multidict_is_hashable(self): class TestMultiDict(_MutableMultiDictTests): - storage_class = datastructures.MultiDict # type: ignore + storage_class = ds.MultiDict def test_multidict_pop(self): def make_d(): @@ -409,7 +407,7 @@ def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self): class TestOrderedMultiDict(_MutableMultiDictTests): - storage_class = datastructures.OrderedMultiDict # type: ignore + storage_class = ds.OrderedMultiDict def test_ordered_interface(self): cls = self.storage_class @@ -461,7 +459,7 @@ def test_ordered_interface(self): d.add("foo", 23) d.add("bar", 2) d.add("foo", 42) - assert d == datastructures.MultiDict(d) + assert d == ds.MultiDict(d) id = self.storage_class(d) assert d == id d.add("foo", 2) @@ -506,9 +504,9 @@ def test_ordered_interface(self): pytest.raises(TypeError, hash, d) def test_iterables(self): - a = datastructures.MultiDict((("key_a", "value_a"),)) - b = datastructures.MultiDict((("key_b", "value_b"),)) - ab = datastructures.CombinedMultiDict((a, b)) + a = ds.MultiDict((("key_a", "value_a"),)) + b = ds.MultiDict((("key_b", "value_b"),)) + ab = ds.CombinedMultiDict((a, b)) assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])] assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]] @@ -519,7 +517,7 @@ def test_iterables(self): assert sorted(ab.keys()) == ["key_a", "key_b"] def test_get_description(self): - data = datastructures.OrderedMultiDict() + data = ds.OrderedMultiDict() with pytest.raises(BadRequestKeyError) as exc_info: data["baz"] @@ -538,7 +536,7 @@ def test_get_description(self): class TestTypeConversionDict: - storage_class = datastructures.TypeConversionDict + storage_class = ds.TypeConversionDict def test_value_conversion(self): d = self.storage_class(foo="1") @@ -556,11 +554,11 @@ def test_propagate_exceptions_in_conversion(self): class TestCombinedMultiDict: - storage_class = datastructures.CombinedMultiDict + storage_class = ds.CombinedMultiDict def test_basic_interface(self): - d1 = datastructures.MultiDict([("foo", "1")]) - d2 = datastructures.MultiDict([("bar", "2"), ("bar", "3")]) + d1 = ds.MultiDict([("foo", "1")]) + d2 = ds.MultiDict([("bar", "2"), ("bar", "3")]) d = self.storage_class([d1, d2]) # lookup @@ -590,14 +588,14 @@ def test_basic_interface(self): d["foo"] = "blub" # make sure lists merges - md1 = datastructures.MultiDict((("foo", "bar"),)) - md2 = datastructures.MultiDict((("foo", "blafasel"),)) + md1 = ds.MultiDict((("foo", "bar"),)) + md2 = ds.MultiDict((("foo", "blafasel"),)) x = self.storage_class((md1, md2)) assert list(x.lists()) == [("foo", ["bar", "blafasel"])] def test_length(self): - d1 = datastructures.MultiDict([("foo", "1")]) - d2 = datastructures.MultiDict([("bar", "2")]) + d1 = ds.MultiDict([("foo", "1")]) + d2 = ds.MultiDict([("bar", "2")]) assert len(d1) == len(d2) == 1 d = self.storage_class([d1, d2]) assert len(d) == 2 @@ -607,7 +605,7 @@ def test_length(self): class TestHeaders: - storage_class = datastructures.Headers + storage_class = ds.Headers def test_basic_interface(self): headers = self.storage_class() @@ -721,7 +719,7 @@ def test_bytes_operations(self): def test_extend(self): h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) - h.extend(datastructures.Headers([("a", "3"), ("a", "4")])) + h.extend(ds.Headers([("a", "3"), ("a", "4")])) assert h.getlist("a") == ["0", "3", "4"] h.extend(b=["5", "6"]) assert h.getlist("b") == ["1", "5", "6"] @@ -734,7 +732,7 @@ def test_extend(self): def test_update(self): h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) - h.update(datastructures.Headers([("a", "3"), ("a", "4")])) + h.update(ds.Headers([("a", "3"), ("a", "4")])) assert h.getlist("a") == ["3", "4"] h.update(b=["5", "6"]) assert h.getlist("b") == ["5", "6"] @@ -792,7 +790,7 @@ def test_equality(self): class TestEnvironHeaders: - storage_class = datastructures.EnvironHeaders + storage_class = ds.EnvironHeaders def test_basic_interface(self): # this happens in multiple WSGI servers because they @@ -840,7 +838,7 @@ def test_bytes_operations(self): class TestHeaderSet: - storage_class = datastructures.HeaderSet + storage_class = ds.HeaderSet def test_basic_interface(self): hs = self.storage_class() @@ -865,7 +863,7 @@ def test_basic_interface(self): class TestImmutableList: - storage_class = datastructures.ImmutableList + storage_class = ds.ImmutableList def test_list_hashable(self): data = (1, 2, 3, 4) @@ -902,7 +900,7 @@ def wrapped(*args, **kwargs): class TestCallbackDict: - storage_class = datastructures.CallbackDict + storage_class = ds.CallbackDict def test_callback_dict_reads(self): assert_calls, func = make_call_asserter() @@ -943,11 +941,11 @@ def test_callback_dict_writes(self): class TestCacheControl: def test_repr(self): - cc = datastructures.RequestCacheControl([("max-age", "0"), ("private", "True")]) + cc = ds.RequestCacheControl([("max-age", "0"), ("private", "True")]) assert repr(cc) == "" def test_set_none(self): - cc = datastructures.ResponseCacheControl([("max-age", "0")]) + cc = ds.ResponseCacheControl([("max-age", "0")]) assert cc.no_cache is None cc.no_cache = None assert cc.no_cache is None @@ -955,9 +953,7 @@ def test_set_none(self): class TestContentSecurityPolicy: def test_construct(self): - csp = datastructures.ContentSecurityPolicy( - [("font-src", "'self'"), ("media-src", "*")] - ) + csp = ds.ContentSecurityPolicy([("font-src", "'self'"), ("media-src", "*")]) assert csp.font_src == "'self'" assert csp.media_src == "*" policies = [policy.strip() for policy in csp.to_header().split(";")] @@ -965,7 +961,7 @@ def test_construct(self): assert "media-src *" in policies def test_properties(self): - csp = datastructures.ContentSecurityPolicy() + csp = ds.ContentSecurityPolicy() csp.default_src = "* 'self' quart.com" csp.img_src = "'none'" policies = [policy.strip() for policy in csp.to_header().split(";")] @@ -974,7 +970,7 @@ def test_properties(self): class TestAccept: - storage_class = datastructures.Accept + storage_class = ds.Accept def test_accept_basic(self): accept = self.storage_class( @@ -1094,7 +1090,7 @@ class TestMIMEAccept: ], ) def test_mime_accept(self, values, matches, default, expect): - accept = MIMEAccept(values) + accept = ds.MIMEAccept(values) match = accept.best_match(matches, default=default) assert match == expect @@ -1118,13 +1114,13 @@ class TestLanguageAccept: ), ) def test_best_match_fallback(self, values, matches, default, expect): - accept = LanguageAccept(values) + accept = ds.LanguageAccept(values) best = accept.best_match(matches, default=default) assert best == expect class TestFileStorage: - storage_class = datastructures.FileStorage + storage_class = ds.FileStorage def test_mimetype_always_lowercase(self): file_storage = self.storage_class(content_type="APPLICATION/JSON") @@ -1180,7 +1176,7 @@ def test_save_to_file(self, tmp_path): @pytest.mark.parametrize("ranges", ([(0, 1), (-5, None)], [(5, None)])) def test_range_to_header(ranges): - header = Range("byes", ranges).to_header() + header = ds.Range("byes", ranges).to_header() r = http.parse_range_header(header) assert r.ranges == ranges @@ -1190,4 +1186,4 @@ def test_range_to_header(ranges): ) def test_range_validates_ranges(ranges): with pytest.raises(ValueError): - datastructures.Range("bytes", ranges) + ds.Range("bytes", ranges) diff --git a/tests/test_debug.py b/tests/test_debug.py index a0cdcc892..e1ca01b63 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -287,7 +287,7 @@ def test_filename_encoding(self, tmpdir, monkeypatch): moduledir.join("bar.py").write("def foo():\n 1/0\n") monkeypatch.syspath_prepend(str(moduledir)) - import bar + import bar # type: ignore try: bar.foo() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1936c169b..12ea8b920 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -132,7 +132,7 @@ def test_retry_after_mixin(cls, value, expect): "cls", sorted( (e for e in HTTPException.__subclasses__() if e.code and e.code >= 400), - key=lambda e: e.code, + key=lambda e: e.code, # type: ignore ), ) def test_passing_response(cls): diff --git a/tests/test_http.py b/tests/test_http.py index 5f4842569..a9af04b9f 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -695,4 +695,6 @@ def test_best_match_works(self): ], ) def test_authorization_to_header(value: str) -> None: - assert http.parse_authorization_header(value).to_header() == value + parsed = http.parse_authorization_header(value) + assert parsed is not None + assert parsed.to_header() == value diff --git a/tests/test_send_file.py b/tests/test_send_file.py index 54cf1e125..367b117da 100644 --- a/tests/test_send_file.py +++ b/tests/test_send_file.py @@ -42,7 +42,7 @@ def test_last_modified(): @pytest.mark.parametrize( - "file_factory", [lambda: txt_path.open("rb"), lambda: io.BytesIO(b"test")], + "file_factory", [lambda: txt_path.open("rb"), lambda: io.BytesIO(b"test")] ) def test_object(file_factory): rv = send_file(file_factory(), environ, mimetype="text/plain", use_x_sendfile=True) @@ -65,7 +65,7 @@ def test_object_mimetype_from_name(): @pytest.mark.parametrize( - "file_factory", [lambda: txt_path.open(), lambda: io.StringIO("test")], + "file_factory", [lambda: txt_path.open(), lambda: io.StringIO("test")] ) def test_text_mode_fails(file_factory): with file_factory() as f, pytest.raises(ValueError, match="binary mode"): @@ -84,13 +84,13 @@ def test_disposition_name(as_attachment, value): def test_object_attachment_requires_name(): with pytest.raises(TypeError, match="attachment"): send_file( - io.BytesIO(b"test"), environ, mimetype="text/plain", as_attachment=True, + io.BytesIO(b"test"), environ, mimetype="text/plain", as_attachment=True ) rv = send_file( - io.BytesIO(b"test"), environ, as_attachment=True, download_name="test.txt", + io.BytesIO(b"test"), environ, as_attachment=True, download_name="test.txt" ) - assert rv.headers["Content-Disposition"] == f"attachment; filename=test.txt" + assert rv.headers["Content-Disposition"] == "attachment; filename=test.txt" rv.close() diff --git a/tests/test_serving.py b/tests/test_serving.py index 9c1886c66..e9d2fdf9d 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -82,7 +82,7 @@ def test_ssl_object(dev_server): @pytest.mark.parametrize("reloader_type", ["stat", "watchdog"]) @pytest.mark.skipif( - os.name == "nt" and "CI" in os.environ, reason="unreliable on Windows during CI", + os.name == "nt" and "CI" in os.environ, reason="unreliable on Windows during CI" ) def test_reloader_sys_path(tmp_path, dev_server, reloader_type): """This tests the general behavior of the reloader. It also tests diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 9f0cb46e6..27d449dad 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1379,7 +1379,7 @@ def test_modified_url_encoding(): class ModifiedRequest(wrappers.Request): url_charset = "euc-kr" - req = ModifiedRequest.from_values("/?foo=정상처리".encode("euc-kr")) + req = ModifiedRequest.from_values(query_string={"foo": "정상처리"}, charset="euc-kr") assert req.args["foo"] == "정상처리" @@ -1502,7 +1502,7 @@ def test_samesite(self): class TestJSONMixin: - class Request(JSONMixin, wrappers.Request): + class Request(JSONMixin, wrappers.Request): # type: ignore pass class Response(JSONMixin, wrappers.Response): # type: ignore