From 6511dcb42bf17edb2113c8074355dac29147a1bc Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Thu, 19 Jan 2023 19:51:49 -0500 Subject: [PATCH 01/27] test: reconcile main with mypy changes --- .gitignore | 1 + dev-requirements.txt | 6 ++ invoke/__init__.py | 32 ++++--- invoke/completion/complete.py | 18 +++- invoke/config.py | 148 +++++++++++++++------------- invoke/context.py | 40 ++++---- invoke/env.py | 24 +++-- invoke/exceptions.py | 43 +++++---- invoke/parser/__init__.py | 2 +- invoke/parser/parser.py | 95 +++++++++--------- invoke/py.typed | 0 invoke/runners.py | 176 +++++++++++++++++++--------------- invoke/terminals.py | 17 ++-- invoke/util.py | 31 +++--- mypy.ini | 4 + tox.ini | 2 +- 16 files changed, 361 insertions(+), 278 deletions(-) create mode 100644 invoke/py.typed create mode 100644 mypy.ini diff --git a/.gitignore b/.gitignore index 4e47c3533..2413cddbc 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ src/ htmlcov coverage.xml .cache +.mypy_cache/ diff --git a/dev-requirements.txt b/dev-requirements.txt index 7ff385a9a..3d3fb0555 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -19,3 +19,9 @@ black>=22.8,<22.9 setuptools>56 # Debuggery icecream>=2.1 +# typing +mypy>=0.942 +mypy-extensions>=0.4.3 +typed-ast>=1.4.3 +types-mock>=0.1.3 +types-PyYAML>=5.4.3 diff --git a/invoke/__init__.py b/invoke/__init__.py index acb963705..1e37d0fd6 100644 --- a/invoke/__init__.py +++ b/invoke/__init__.py @@ -1,8 +1,10 @@ -from ._version import __version_info__, __version__ # noqa -from .collection import Collection # noqa -from .config import Config # noqa -from .context import Context, MockContext # noqa -from .exceptions import ( # noqa +from typing import Any + +from invoke._version import __version_info__, __version__ # noqa +from invoke.collection import Collection # noqa +from invoke.config import Config # noqa +from invoke.context import Context, MockContext # noqa +from invoke.exceptions import ( # noqa AmbiguousEnvVar, AuthFailure, CollectionNotFound, @@ -19,17 +21,17 @@ WatcherError, CommandTimedOut, ) -from .executor import Executor # noqa -from .loader import FilesystemLoader # noqa -from .parser import Argument, Parser, ParserContext, ParseResult # noqa -from .program import Program # noqa -from .runners import Runner, Local, Failure, Result, Promise # noqa -from .tasks import task, call, Call, Task # noqa -from .terminals import pty_size # noqa -from .watchers import FailingResponder, Responder, StreamWatcher # noqa +from invoke.executor import Executor # noqa +from invoke.loader import FilesystemLoader # noqa +from invoke.parser import Argument, Parser, ParserContext, ParseResult # noqa +from invoke.program import Program # noqa +from invoke.runners import Runner, Local, Failure, Result, Promise # noqa +from invoke.tasks import task, call, Call, Task # noqa +from invoke.terminals import pty_size # noqa +from invoke.watchers import FailingResponder, Responder, StreamWatcher # noqa -def run(command, **kwargs): +def run(command: str, **kwargs: Any) -> Any: """ Run ``command`` in a subprocess and return a `.Result` object. @@ -48,7 +50,7 @@ def run(command, **kwargs): return Context().run(command, **kwargs) -def sudo(command, **kwargs): +def sudo(command: str, **kwargs: Any) -> Any: """ Run ``command`` in a ``sudo`` subprocess and return a `.Result` object. diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index eab996b26..a88c33bce 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -2,16 +2,28 @@ Command-line completion mechanisms, executed by the core ``--complete`` flag. """ +from typing import List import glob import os import re import shlex +from typing import TYPE_CHECKING from ..exceptions import Exit, ParseError from ..util import debug, task_name_sort_key +if TYPE_CHECKING: + from ..collection import Collection + from ..parser import Parser, Context -def complete(names, core, initial_context, collection, parser): + +def complete( + names: List[str], + core, + initial_context: "Context", + collection: "Collection", + parser: "Parser", +): # Strip out program name (scripts give us full command line) # TODO: this may not handle path/to/script though? invocation = re.sub(r"^({}) ".format("|".join(names)), "", core.remainder) @@ -80,7 +92,7 @@ def complete(names, core, initial_context, collection, parser): raise Exit -def print_task_names(collection): +def print_task_names(collection: Collection) -> None: for name in sorted(collection.task_names, key=task_name_sort_key): print(name) # Just stick aliases after the thing they're aliased to. Sorting isn't @@ -89,7 +101,7 @@ def print_task_names(collection): print(alias) -def print_completion_script(shell, names): +def print_completion_script(shell: str, names: List[str]) -> None: # Grab all .completion files in invoke/completion/. (These used to have no # suffix, but surprise, that's super fragile. completions = { diff --git a/invoke/config.py b/invoke/config.py index ee995160d..d1abd310e 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -2,7 +2,9 @@ import json import os import types +from os import PathLike from os.path import join, splitext, expanduser +from typing import Any, Dict, Iterator, Optional, Tuple from .env import Environment from .exceptions import UnknownFileType, UnpicklableConfigMember @@ -64,7 +66,9 @@ class DataProxy: ) @classmethod - def from_data(cls, data, root=None, keypath=tuple()): + def from_data( + cls, data: Dict[str, Any], root: Optional[str] = None, keypath=tuple() + ): """ Alternate constructor for 'baby' DataProxies used as sub-dict values. @@ -93,7 +97,7 @@ def from_data(cls, data, root=None, keypath=tuple()): obj._set(_keypath=keypath) return obj - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: # NOTE: due to default Python attribute-lookup semantics, "real" # attributes will always be yielded on attribute access and this method # is skipped. That behavior is good for us (it's more intuitive than @@ -113,7 +117,7 @@ def __getattr__(self, key): err += "\n\nValid real attributes: {!r}".format(attrs) raise AttributeError(err) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: str) -> None: # Turn attribute-sets into config updates anytime we don't have a real # attribute with the given name/key. has_real_attr = key in dir(self) @@ -124,12 +128,12 @@ def __setattr__(self, key, value): else: super().__setattr__(key, value) - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, Any]]: # For some reason Python is ignoring our __hasattr__ when determining # whether we support __iter__. BOO return iter(self._config) - def __eq__(self, other): + def __eq__(self, other) -> bool: # NOTE: Can't proxy __eq__ because the RHS will always be an obj of the # current class, not the proxied-to class, and that causes # NotImplemented. @@ -140,19 +144,19 @@ def __eq__(self, other): # itself just a dict. if isinstance(other, dict): other_val = other - return self._config == other_val + return bool(self._config == other_val) - def __len__(self): + def __len__(self) -> int: return len(self._config) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str) -> None: self._config[key] = value self._track_modification_of(key, value) - def __getitem__(self, key): + def __getitem__(self, key: str): return self._get(key) - def _get(self, key): + def _get(self, key: str) -> Any: # Short-circuit if pickling/copying mechanisms are asking if we've got # __setstate__ etc; they'll ask this w/o calling our __init__ first, so # we'd be in a RecursionError-causing catch-22 otherwise. @@ -172,7 +176,7 @@ def _get(self, key): value = DataProxy.from_data(data=value, root=root, keypath=keypath) return value - def _set(self, *args, **kwargs): + def _set(self, *args: Any, **kwargs: Any) -> None: """ Convenience workaround of default 'attrs are config keys' behavior. @@ -189,21 +193,21 @@ def _set(self, *args, **kwargs): for key, value in kwargs.items(): object.__setattr__(self, key, value) - def __repr__(self): + def __repr__(self) -> str: return "<{}: {}>".format(self.__class__.__name__, self._config) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self._config @property - def _is_leaf(self): + def _is_leaf(self) -> bool: return hasattr(self, "_root") @property - def _is_root(self): + def _is_root(self) -> bool: return hasattr(self, "_modify") - def _track_removal_of(self, key): + def _track_removal_of(self, key: str): # Grab the root object responsible for tracking removals; either the # referenced root (if we're a leaf) or ourselves (if we're not). # (Intermediate nodes never have anything but __getitem__ called on @@ -216,7 +220,7 @@ def _track_removal_of(self, key): if target is not None: target._remove(getattr(self, "_keypath", tuple()), key) - def _track_modification_of(self, key, value): + def _track_modification_of(self, key: str, value: str) -> None: target = None if self._is_leaf: target = self._root @@ -225,11 +229,11 @@ def _track_modification_of(self, key, value): if target is not None: target._modify(getattr(self, "_keypath", tuple()), key, value) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._config[key] self._track_removal_of(key) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: # Make sure we don't screw up true attribute deletion for the # situations that actually want it. (Uncommon, but not rare.) if name in self: @@ -237,12 +241,12 @@ def __delattr__(self, name): else: object.__delattr__(self, name) - def clear(self): + def clear(self) -> None: keys = list(self.keys()) for key in keys: del self[key] - def pop(self, *args): + def pop(self, *args: Any) -> Any: # Must test this up front before (possibly) mutating self._config key_existed = args and args[0] in self._config # We always have a _config (whether it's a real dict or a cache of @@ -259,12 +263,12 @@ def pop(self, *args): # In all cases, return the popped value. return ret - def popitem(self): + def popitem(self) -> Any: ret = self._config.popitem() self._track_removal_of(ret[0]) return ret - def setdefault(self, *args): + def setdefault(self, *args: Any) -> Any: # Must test up front whether the key existed beforehand key_existed = args and args[0] in self._config # Run locally @@ -279,7 +283,7 @@ def setdefault(self, *args): self._track_modification_of(key, default) return ret - def update(self, *args, **kwargs): + def update(self, *args: Any, **kwargs: Dict[str, Any]) -> None: if kwargs: for key, value in kwargs.items(): self[key] = value @@ -412,7 +416,7 @@ class Config(DataProxy): env_prefix = None @staticmethod - def global_defaults(): + def global_defaults() -> Dict[str, Any]: """ Return the core default settings for Invoke. @@ -496,13 +500,13 @@ def global_defaults(): def __init__( self, - overrides=None, - defaults=None, - system_prefix=None, - user_prefix=None, - project_location=None, - runtime_path=None, - lazy=False, + overrides: Optional[Dict[str, Any]] = None, + defaults: Optional[Dict[str, Any]] = None, + system_prefix: Optional[str] = None, + user_prefix: Optional[str] = None, + project_location: Optional[PathLike] = None, + runtime_path: Optional[PathLike] = None, + lazy: bool = False, ): """ Creates a new config object. @@ -639,12 +643,12 @@ def __init__( # a subroutine does so. self.merge() - def load_base_conf_files(self): + def load_base_conf_files(self) -> None: # Just a refactor of something done in unlazy init or in clone() self.load_system(merge=False) self.load_user(merge=False) - def load_defaults(self, data, merge=True): + def load_defaults(self, data: Dict[str, Any], merge: bool = True) -> None: """ Set or replace the 'defaults' configuration level, from ``data``. @@ -662,7 +666,7 @@ def load_defaults(self, data, merge=True): if merge: self.merge() - def load_overrides(self, data, merge=True): + def load_overrides(self, data: Dict[str, Any], merge: bool = True) -> None: """ Set or replace the 'overrides' configuration level, from ``data``. @@ -680,7 +684,7 @@ def load_overrides(self, data, merge=True): if merge: self.merge() - def load_system(self, merge=True): + def load_system(self, merge: bool = True) -> None: """ Load a system-level config file, if possible. @@ -697,7 +701,7 @@ def load_system(self, merge=True): """ self._load_file(prefix="system", merge=merge) - def load_user(self, merge=True): + def load_user(self, merge: bool = True) -> None: """ Load a user-level config file, if possible. @@ -714,7 +718,7 @@ def load_user(self, merge=True): """ self._load_file(prefix="user", merge=merge) - def load_project(self, merge=True): + def load_project(self, merge: bool = True) -> None: """ Load a project-level config file, if possible. @@ -736,7 +740,7 @@ def load_project(self, merge=True): """ self._load_file(prefix="project", merge=merge) - def set_runtime_path(self, path): + def set_runtime_path(self, path: Optional[PathLike]) -> None: """ Set the runtime config file path. @@ -750,7 +754,7 @@ def set_runtime_path(self, path): # if no loading has been attempted yet.) self._set(_runtime_found=None) - def load_runtime(self, merge=True): + def load_runtime(self, merge: bool = True) -> None: """ Load a runtime-level config file, if one was specified. @@ -768,7 +772,7 @@ def load_runtime(self, merge=True): """ self._load_file(prefix="runtime", absolute=True, merge=merge) - def load_shell_env(self): + def load_shell_env(self) -> None: """ Load values from the shell environment. @@ -793,7 +797,9 @@ def load_shell_env(self): debug("Loaded shell environment, triggering final merge") self.merge() - def load_collection(self, data, merge=True): + def load_collection( + self, data: Dict[str, Any], merge: bool = True + ) -> None: """ Update collection-driven config data. @@ -808,7 +814,7 @@ def load_collection(self, data, merge=True): if merge: self.merge() - def set_project_location(self, path): + def set_project_location(self, path: PathLike) -> None: """ Set the directory path where a project-level config file may be found. @@ -830,7 +836,9 @@ def set_project_location(self, path): # Data loaded from the per-project config file. self._set(_project={}) - def _load_file(self, prefix, absolute=False, merge=True): + def _load_file( + self, prefix: str, absolute: bool = False, merge: bool = True + ) -> None: # Setup found = "_{}_found".format(prefix) path = "_{}_path".format(prefix) @@ -891,18 +899,18 @@ def _load_file(self, prefix, absolute=False, merge=True): elif merge: self.merge() - def _load_yaml(self, path): + def _load_yaml(self, path: PathLike) -> Any: with open(path) as fd: return yaml.safe_load(fd) - def _load_yml(self, path): + def _load_yml(self, path: PathLike) -> Any: return self._load_yaml(path) - def _load_json(self, path): + def _load_json(self, path: PathLike) -> Any: with open(path) as fd: return json.load(fd) - def _load_py(self, path): + def _load_py(self, path: PathLike) -> Dict[str, Any]: data = {} for key, value in (load_source("mod", path)).items(): # Strip special members, as these are always going to be builtins @@ -920,7 +928,7 @@ def _load_py(self, path): data[key] = value return data - def merge(self): + def merge(self) -> None: """ Merge all config sources, in order. @@ -945,7 +953,7 @@ def merge(self): debug("Deletions: {!r}".format(self._deletions)) obliterate(self._config, self._deletions) - def _merge_file(self, name, desc): + def _merge_file(self, name: str, desc: str) -> None: # Setup desc += " config file" # yup found = getattr(self, "_{}_found".format(name)) @@ -1060,7 +1068,7 @@ def clone(self, into=None): new.merge() return new - def _clone_init_kwargs(self, into=None): + def _clone_init_kwargs(self, into=None) -> Dict[str, Any]: """ Supply kwargs suitable for initializing a new clone of this object. @@ -1087,7 +1095,7 @@ def _clone_init_kwargs(self, into=None): lazy=True, ) - def _modify(self, keypath, key, value): + def _modify(self, keypath: Tuple[str, ...], key: str, value: str) -> None: """ Update our user-modifications config level with new data. @@ -1106,9 +1114,9 @@ def _modify(self, keypath, key, value): excise(self._deletions, keypath + (key,)) # Now we can add it to the modifications structure. data = self._modifications - keypath = list(keypath) - while keypath: - subkey = keypath.pop(0) + keypath_list = list(keypath) + while keypath_list: + subkey = keypath_list.pop(0) # TODO: could use defaultdict here, but...meh? if subkey not in data: # TODO: generify this and the subsequent 3 lines... @@ -1117,7 +1125,7 @@ def _modify(self, keypath, key, value): data[key] = value self.merge() - def _remove(self, keypath, key): + def _remove(self, keypath: Tuple[str, ...], key: str) -> None: """ Like `._modify`, but for removal. """ @@ -1126,9 +1134,9 @@ def _remove(self, keypath, key): # inverse - remove from _deletions on modification. # TODO: may be sane to push this step up to callers? data = self._deletions - keypath = list(keypath) - while keypath: - subkey = keypath.pop(0) + keypath_list = list(keypath) + while keypath_list: + subkey = keypath_list.pop(0) if subkey in data: data = data[subkey] # If we encounter None, it means something higher up than our @@ -1153,7 +1161,9 @@ class AmbiguousMergeError(ValueError): pass -def merge_dicts(base, updates): +def merge_dicts( + base: Dict[str, Any], updates: Dict[str, Any] +) -> Dict[str, Any]: """ Recursively merge dict ``updates`` into dict ``base`` (mutating ``base``.) @@ -1210,7 +1220,7 @@ def merge_dicts(base, updates): return base -def _merge_error(orig, new_): +def _merge_error(orig: str, new_: Any) -> AmbiguousMergeError: return AmbiguousMergeError( "Can't cleanly merge {} with {}".format( _format_mismatch(orig), _format_mismatch(new_) @@ -1218,11 +1228,11 @@ def _merge_error(orig, new_): ) -def _format_mismatch(x): +def _format_mismatch(x: Any) -> str: return "{} ({!r})".format(type(x), x) -def copy_dict(source): +def copy_dict(source: Dict[str, Any]) -> Dict[str, Any]: """ Return a fresh copy of ``source`` with as little shared state as possible. @@ -1234,17 +1244,17 @@ def copy_dict(source): return merge_dicts({}, source) -def excise(dict_, keypath): +def excise(dict_: Dict[str, Any], keypath: Tuple[str, ...]) -> None: """ Remove key pointed at by ``keypath`` from nested dict ``dict_``, if exists. .. versionadded:: 1.0 """ data = dict_ - keypath = list(keypath) - leaf_key = keypath.pop() - while keypath: - key = keypath.pop(0) + keypath_list = list(keypath) + leaf_key = keypath_list.pop() + while keypath_list: + key = keypath_list.pop(0) if key not in data: # Not there, nothing to excise return @@ -1253,7 +1263,7 @@ def excise(dict_, keypath): del data[leaf_key] -def obliterate(base, deletions): +def obliterate(base: Dict[str, Any], deletions: Tuple[str, ...]) -> None: """ Remove all (nested) keys mentioned in ``deletions``, from ``base``. diff --git a/invoke/context.py b/invoke/context.py index 1a1607827..a06c4b7e1 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -2,6 +2,8 @@ import re from contextlib import contextmanager from itertools import cycle +from os import PathLike +from typing import Any, Iterator, List, Optional, Union from unittest.mock import Mock from .config import Config, DataProxy @@ -30,7 +32,7 @@ class Context(DataProxy): .. versionadded:: 1.0 """ - def __init__(self, config=None): + def __init__(self, config: Optional[Config] = None) -> None: """ :param config: `.Config` object to use as the base configuration. @@ -51,22 +53,22 @@ def __init__(self, config=None): #: A list of commands to run (via "&&") before the main argument to any #: `run` or `sudo` calls. Note that the primary API for manipulating #: this list is `prefix`; see its docs for details. - command_prefixes = list() + command_prefixes: List[str] = list() self._set(command_prefixes=command_prefixes) #: A list of directories to 'cd' into before running commands with #: `run` or `sudo`; intended for management via `cd`, please see its #: docs for details. - command_cwds = list() + command_cwds: List[str] = list() self._set(command_cwds=command_cwds) @property - def config(self): + def config(self) -> Any: # Allows Context to expose a .config attribute even though DataProxy # otherwise considers it a config key. return self._config @config.setter - def config(self, value): + def config(self, value: Any) -> None: # NOTE: mostly used by client libraries needing to tweak a Context's # config at execution time; i.e. a Context subclass that bears its own # unique data may want to be stood up when parameterizing/expanding a @@ -74,7 +76,7 @@ def config(self, value): # runtime. self._set(_config=value) - def run(self, command, **kwargs): + def run(self, command: str, **kwargs: Any): """ Execute a local shell command, honoring config options. @@ -93,11 +95,11 @@ def run(self, command, **kwargs): # NOTE: broken out of run() to allow for runner class injection in # Fabric/etc, which needs to juggle multiple runner class types (local and # remote). - def _run(self, runner, command, **kwargs): + def _run(self, runner, command: str, **kwargs: Any): command = self._prefix_commands(command) return runner.run(command, **kwargs) - def sudo(self, command, **kwargs): + def sudo(self, command: str, **kwargs: Any): """ Execute a shell command via ``sudo`` with password auto-response. @@ -170,7 +172,7 @@ def sudo(self, command, **kwargs): return self._sudo(runner, command, **kwargs) # NOTE: this is for runner injection; see NOTE above _run(). - def _sudo(self, runner, command, **kwargs): + def _sudo(self, runner, command: str, **kwargs: Any): prompt = self.config.sudo.prompt password = kwargs.pop("password", self.config.sudo.password) user = kwargs.pop("user", self.config.sudo.user) @@ -232,7 +234,7 @@ def _sudo(self, runner, command, **kwargs): # TODO: wonder if it makes sense to move this part of things inside Runner, # which would grow a `prefixes` and `cwd` init kwargs or similar. The less # that's stuffed into Context, probably the better. - def _prefix_commands(self, command): + def _prefix_commands(self, command: str) -> str: """ Prefixes ``command`` with all prefixes found in ``command_prefixes``. @@ -247,7 +249,7 @@ def _prefix_commands(self, command): return " && ".join(prefixes + [command]) @contextmanager - def prefix(self, command): + def prefix(self, command: str): """ Prefix all nested `run`/`sudo` commands with given command plus ``&&``. @@ -303,7 +305,7 @@ def prefix(self, command): self.command_prefixes.pop() @property - def cwd(self): + def cwd(self) -> Union[PathLike, str]: """ Return the current working directory, accounting for uses of `cd`. @@ -326,7 +328,7 @@ def cwd(self): return os.path.join(*paths) @contextmanager - def cd(self, path): + def cd(self, path: PathLike): """ Context manager that keeps directory state when executing commands. @@ -401,7 +403,7 @@ class MockContext(Context): Added ``Mock`` wrapping of ``run`` and ``sudo``. """ - def __init__(self, config=None, **kwargs): + def __init__(self, config: Optional[Config] = None, **kwargs: Any) -> None: """ Create a ``Context``-like object whose methods yield `.Result` objects. @@ -481,7 +483,7 @@ def __init__(self, config=None, **kwargs): # Wrap the method in a Mock self._set(method, Mock(wraps=getattr(self, method))) - def _normalize(self, value): + def _normalize(self, value: Any) -> Iterator[Any]: # First turn everything into an iterable if not hasattr(value, "__iter__") or isinstance(value, str): value = [value] @@ -501,7 +503,7 @@ def _normalize(self, value): # worth. Maybe in situations where Context grows a _lot_ of methods (e.g. # in Fabric 2; though Fabric could do its own sub-subclass in that case...) - def _yield_result(self, attname, command): + def _yield_result(self, attname: str, command: str): try: obj = getattr(self, attname) # Dicts need to try direct lookup or regex matching @@ -531,21 +533,21 @@ def _yield_result(self, attname, command): # raise_from(NotImplementedError(command), None) raise NotImplementedError(command) - def run(self, command, *args, **kwargs): + def run(self, command: str, *args: Any, **kwargs: Any): # TODO: perform more convenience stuff associating args/kwargs with the # result? E.g. filling in .command, etc? Possibly useful for debugging # if one hits unexpected-order problems with what they passed in to # __init__. return self._yield_result("__run", command) - def sudo(self, command, *args, **kwargs): + def sudo(self, command: str, *args: Any, **kwargs: Any): # TODO: this completely nukes the top-level behavior of sudo(), which # could be good or bad, depending. Most of the time I think it's good. # No need to supply dummy password config, etc. # TODO: see the TODO from run() re: injecting arg/kwarg values return self._yield_result("__sudo", command) - def set_result_for(self, attname, command, result): + def set_result_for(self, attname: str, command: str, result) -> None: """ Modify the stored mock results for given ``attname`` (e.g. ``run``). diff --git a/invoke/env.py b/invoke/env.py index f523347d2..c11377332 100644 --- a/invoke/env.py +++ b/invoke/env.py @@ -9,18 +9,22 @@ """ import os +from typing import TYPE_CHECKING, Any, Dict, List from .exceptions import UncastableEnvVar, AmbiguousEnvVar from .util import debug +if TYPE_CHECKING: + from .config import Config + class Environment: - def __init__(self, config, prefix): + def __init__(self, config: 'Config', prefix: str) -> None: self._config = config self._prefix = prefix - self.data = {} # Accumulator + self.data: Dict[str, Any] = {} # Accumulator - def load(self): + def load(self) -> Dict[str, Any]: """ Return a nested dict containing values from `os.environ`. @@ -41,7 +45,9 @@ def load(self): debug("Obtained env var config: {!r}".format(self.data)) return self.data - def _crawl(self, key_path, env_vars): + def _crawl( + self, key_path: List[str], env_vars: Dict[str, Any] + ) -> Dict[str, Any]: """ Examine config at location ``key_path`` & return potential env vars. @@ -55,7 +61,7 @@ def _crawl(self, key_path, env_vars): Returns another dictionary of new keypairs as per above. """ - new_vars = {} + new_vars: Dict[str, Any] = {} obj = self._path_get(key_path) # Sub-dict -> recurse if ( @@ -79,10 +85,10 @@ def _crawl(self, key_path, env_vars): new_vars[self._to_env_var(key_path)] = key_path return new_vars - def _to_env_var(self, key_path): + def _to_env_var(self, key_path: List[str]) -> str: return "_".join(key_path).upper() - def _path_get(self, key_path): + def _path_get(self, key_path: List[str]): # -> Config: # Gets are from self._config because that's what determines valid env # vars and/or values for typecasting. obj = self._config @@ -90,7 +96,7 @@ def _path_get(self, key_path): obj = obj[key] return obj - def _path_set(self, key_path, value): + def _path_set(self, key_path: List[str], value: str) -> None: # Sets are to self.data since that's what we are presenting to the # outer config object and debugging. obj = self.data @@ -102,7 +108,7 @@ def _path_set(self, key_path, value): new_ = self._cast(old, value) obj[key_path[-1]] = new_ - def _cast(self, old, new_): + def _cast(self, old: Any, new_: Any) -> Any: if isinstance(old, bool): return new_ not in ("0", "") elif isinstance(old, str): diff --git a/invoke/exceptions.py b/invoke/exceptions.py index 76f766ab1..94ac99627 100644 --- a/invoke/exceptions.py +++ b/invoke/exceptions.py @@ -6,12 +6,17 @@ condition in a way easily told apart from other, truly unexpected errors". """ -from traceback import format_exception from pprint import pformat +from traceback import format_exception +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +if TYPE_CHECKING: + from invoke.context import Context + from invoke.runners import Result class CollectionNotFound(Exception): - def __init__(self, name, start): + def __init__(self, name: str, start: str) -> None: self.name = name self.start = start @@ -41,11 +46,11 @@ class Failure(Exception): .. versionadded:: 1.0 """ - def __init__(self, result, reason=None): + def __init__(self, result: "Result", reason: Optional[str] = None) -> None: self.result = result self.reason = reason - def streams_for_display(self): + def streams_for_display(self) -> Tuple[str, str]: """ Return stdout/err streams as necessary for error display. @@ -75,10 +80,10 @@ def streams_for_display(self): stderr = self.result.tail("stderr") return stdout, stderr - def __repr__(self): + def __repr__(self) -> str: return self._repr() - def _repr(self, **kwargs): + def _repr(self, **kwargs: Any) -> str: """ Return ``__repr__``-like value from inner result + any kwargs. """ @@ -110,7 +115,7 @@ class UnexpectedExit(Failure): .. versionadded:: 1.0 """ - def __str__(self): + def __str__(self) -> str: stdout, stderr = self.streams_for_display() command = self.result.command exited = self.result.exited @@ -127,7 +132,7 @@ def __str__(self): """ return template.format(command, exited, stdout, stderr) - def _repr(self, **kwargs): + def _repr(self, **kwargs: Any) -> str: kwargs.setdefault("exited", self.result.exited) return super()._repr(**kwargs) @@ -137,14 +142,14 @@ class CommandTimedOut(Failure): Raised when a subprocess did not exit within a desired timeframe. """ - def __init__(self, result, timeout): + def __init__(self, result: "Result", timeout: int) -> None: super().__init__(result) self.timeout = timeout - def __repr__(self): + def __repr__(self) -> str: return self._repr(timeout=self.timeout) - def __str__(self): + def __str__(self) -> str: stdout, stderr = self.streams_for_display() command = self.result.command template = """Command did not complete within {} seconds! @@ -171,11 +176,11 @@ class AuthFailure(Failure): .. versionadded:: 1.0 """ - def __init__(self, result, prompt): + def __init__(self, result, prompt: str) -> None: self.result = result self.prompt = prompt - def __str__(self): + def __str__(self) -> str: err = "The password submitted to prompt {!r} was rejected." return err.format(self.prompt) @@ -189,7 +194,7 @@ class ParseError(Exception): .. versionadded:: 1.0 """ - def __init__(self, msg, context=None): + def __init__(self, msg: str, context: Optional["Context"] = None) -> None: super().__init__(msg) self.context = context @@ -215,12 +220,14 @@ class Exit(Exception): .. versionadded:: 1.0 """ - def __init__(self, message=None, code=None): + def __init__( + self, message: Optional[str] = None, code: Optional[int] = None + ) -> None: self.message = message self._code = code @property - def code(self): + def code(self) -> int: if self._code is not None: return self._code return 1 if self.message else 0 @@ -289,7 +296,7 @@ class UnpicklableConfigMember(Exception): pass -def _printable_kwargs(kwargs): +def _printable_kwargs(kwargs: Any) -> Dict[Any, Any]: """ Return print-friendly version of a thread-related ``kwargs`` dict. @@ -337,7 +344,7 @@ class ThreadException(Exception): #: Thread kwargs which appear to be very long (e.g. IO #: buffers) will be truncated when printed, to avoid huge #: unreadable error display. - exceptions = tuple() + exceptions: Tuple[str, ...] = tuple() def __init__(self, exceptions): self.exceptions = tuple(exceptions) diff --git a/invoke/parser/__init__.py b/invoke/parser/__init__.py index 02aa02622..b4620877e 100644 --- a/invoke/parser/__init__.py +++ b/invoke/parser/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa -from .parser import * +from .parser import * # type: ignore from .context import ParserContext from .context import ParserContext as Context, to_flag, translate_underscores from .argument import Argument diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index 273fbe906..16adf62b0 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -1,22 +1,41 @@ import copy +from typing import Any, List try: - from ..vendor.lexicon import Lexicon - from ..vendor.fluidity import StateMachine, state, transition + from invoke.vendor.lexicon import Lexicon + from invoke.vendor.fluidity import StateMachine, state, transition except ImportError: - from lexicon import Lexicon - from fluidity import StateMachine, state, transition + from lexicon import Lexicon # type: ignore + from fluidity import StateMachine, state, transition # type: ignore -from ..util import debug -from ..exceptions import ParseError +# from invoke.parser import Context +from invoke.exceptions import ParseError +from invoke.util import debug # type: ignore -def is_flag(value): - return value.startswith("-") +def is_flag(value: str) -> bool: + return bool(value.startswith("-")) -def is_long_flag(value): - return value.startswith("--") +def is_long_flag(value: str) -> bool: + return bool(value.startswith("--")) + + +class ParseResult(list): + """ + List-like object with some extra parse-related attributes. + + Specifically, a ``.remainder`` attribute, which is the string found after a + ``--`` in any parsed argv list; and an ``.unparsed`` attribute, a list of + tokens that were unable to be parsed. + + .. versionadded:: 1.0 + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(ParseResult, self).__init__(*args, **kwargs) + self.remainder = "" + self.unparsed: List[str] = [] class Parser: @@ -40,7 +59,12 @@ class Parser: .. versionadded:: 1.0 """ - def __init__(self, contexts=(), initial=None, ignore_unknown=False): + def __init__( + self, + contexts=(), # : Tuple[Context, ...] = (), + initial=None, #: Optional[Context] = None, + ignore_unknown: bool = False, + ) -> None: self.initial = initial self.contexts = Lexicon() self.ignore_unknown = ignore_unknown @@ -57,7 +81,7 @@ def __init__(self, contexts=(), initial=None, ignore_unknown=False): raise ValueError(exists.format(alias)) self.contexts.alias(alias, to=context.name) - def parse_argv(self, argv): + def parse_argv(self, argv: List[str]) -> ParseResult: """ Parse an argv-style token list ``argv``. @@ -192,10 +216,10 @@ class ParseMachine(StateMachine): to="unknown", ) - def changing_state(self, from_, to): + def changing_state(self, from_: str, to: str) -> None: debug("ParseMachine: {!r} => {!r}".format(from_, to)) - def __init__(self, initial, contexts, ignore_unknown): + def __init__(self, initial, contexts, ignore_unknown) -> None: # Initialize self.ignore_unknown = ignore_unknown self.initial = self.context = copy.deepcopy(initial) @@ -209,7 +233,7 @@ def __init__(self, initial, contexts, ignore_unknown): super().__init__() @property - def waiting_for_flag_value(self): + def waiting_for_flag_value(self) -> bool: # Do we have a current flag, and does it expect a value (vs being a # bool/toggle)? takes_value = self.flag and self.flag.takes_value @@ -233,7 +257,7 @@ def waiting_for_flag_value(self): # Argument that can be queried, e.g. "arg.is_iterable"?) return not has_value - def handle(self, token): + def handle(self, token: str) -> None: debug("Handling token: {!r}".format(token)) # Handle unknown state at the top: we don't care about even # possibly-valid input if we've encountered unknown input. @@ -291,12 +315,12 @@ def handle(self, token): debug("Bottom-of-handle() see_unknown({!r})".format(token)) self.see_unknown(token) - def store_only(self, token): + def store_only(self, token: str) -> None: # Start off the unparsed list debug("Storing unknown token {!r}".format(token)) self.result.unparsed.append(token) - def complete_context(self): + def complete_context(self) -> None: debug( "Wrapping up context {!r}".format( self.context.name if self.context else self.context @@ -313,14 +337,14 @@ def complete_context(self): if self.context and self.context not in self.result: self.result.append(self.context) - def switch_to_context(self, name): + def switch_to_context(self, name: str) -> None: self.context = copy.deepcopy(self.contexts[name]) debug("Moving to context {!r}".format(name)) debug("Context args: {!r}".format(self.context.args)) debug("Context flags: {!r}".format(self.context.flags)) debug("Context inverse_flags: {!r}".format(self.context.inverse_flags)) - def complete_flag(self): + def complete_flag(self) -> None: if self.flag: msg = "Completing current flag {} before moving on" debug(msg.format(self.flag)) @@ -342,7 +366,7 @@ def complete_flag(self): # Skip casting so the bool gets preserved self.flag.set_value(True, cast=False) - def check_ambiguity(self, value): + def check_ambiguity(self, value: Any) -> bool: """ Guard against ambiguity when current flag takes an optional value. @@ -367,7 +391,7 @@ def check_ambiguity(self, value): msg = "{!r} is ambiguous when given after an optional-value flag" raise ParseError(msg.format(value)) - def switch_to_flag(self, flag, inverse=False): + def switch_to_flag(self, flag, inverse: bool = False) -> None: # Sanity check for ambiguity w/ prior optional-value flag self.check_ambiguity(flag) # Also tie it off, in case prior had optional value or etc. Seems to be @@ -395,42 +419,25 @@ def switch_to_flag(self, flag, inverse=False): # insufficient) self.flag_got_value = False # Handle boolean flags (which can immediately be updated) - if not self.flag.takes_value: + if self.flag and not self.flag.takes_value: val = not inverse debug("Marking seen flag {!r} as {}".format(self.flag, val)) self.flag.value = val - def see_value(self, value): + def see_value(self, value: Any) -> None: self.check_ambiguity(value) - if self.flag.takes_value: + if self.flag and self.flag.takes_value: debug("Setting flag {!r} to value {!r}".format(self.flag, value)) self.flag.value = value self.flag_got_value = True else: self.error("Flag {!r} doesn't take any value!".format(self.flag)) - def see_positional_arg(self, value): + def see_positional_arg(self, value) -> None: for arg in self.context.positional_args: if arg.value is None: arg.value = value break - def error(self, msg): + def error(self, msg: str) -> None: raise ParseError(msg, self.context) - - -class ParseResult(list): - """ - List-like object with some extra parse-related attributes. - - Specifically, a ``.remainder`` attribute, which is the string found after a - ``--`` in any parsed argv list; and an ``.unparsed`` attribute, a list of - tokens that were unable to be parsed. - - .. versionadded:: 1.0 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.remainder = "" - self.unparsed = [] diff --git a/invoke/py.typed b/invoke/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/invoke/runners.py b/invoke/runners.py index adde96af7..b87e8be7e 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -2,11 +2,13 @@ import locale import os import struct -from subprocess import Popen, PIPE import sys import threading import time import signal +from subprocess import Popen, PIPE +from types import TracebackType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple # Import some platform-specific things at top level so they can be mocked for # tests. @@ -23,7 +25,7 @@ except ImportError: termios = None -from .exceptions import ( +from invoke.exceptions import ( UnexpectedExit, Failure, ThreadException, @@ -31,7 +33,7 @@ SubprocessPipeError, CommandTimedOut, ) -from .terminals import ( +from invoke.terminals import ( WINDOWS, pty_size, character_buffered, @@ -40,6 +42,9 @@ ) from .util import has_fileno, isatty, ExceptionHandlingThread +if TYPE_CHECKING: + from .context import Context + class Runner: """ @@ -55,7 +60,7 @@ class Runner: read_chunk_size = 1000 input_sleep = 0.01 - def __init__(self, context): + def __init__(self, context: "Context") -> None: """ Create a new runner with a handle on some `.Context`. @@ -95,7 +100,7 @@ def __init__(self, context): self.warned_about_pty_fallback = False #: A list of `.StreamWatcher` instances for use by `respond`. Is filled #: in at runtime by `run`. - self.watchers = [] + self.watchers: List[str] = [] # Optional timeout timer placeholder self._timer = None # Async flags (initialized for 'finally' referencing in case something @@ -103,7 +108,7 @@ def __init__(self, context): self._asynchronous = False self._disowned = False - def run(self, command, **kwargs): + def run(self, command: str, **kwargs: Any) -> Any: """ Execute ``command``, returning an instance of `Result` once complete. @@ -378,10 +383,10 @@ def run(self, command, **kwargs): if not (self._asynchronous or self._disowned): self.stop() - def echo(self, command): + def echo(self, command: str) -> None: print(self.opts["echo_format"].format(command=command)) - def _setup(self, command, kwargs): + def _setup(self, command: str, kwargs: Any) -> None: """ Prepare data on ``self`` so we're ready to start running. """ @@ -409,7 +414,7 @@ def _setup(self, command, kwargs): encoding=self.encoding, ) - def _run_body(self, command, **kwargs): + def _run_body(self, command: str, **kwargs: Any) -> Any: # Prepare all the bits n bobs. self._setup(command, kwargs) # If dry-run, stop here. @@ -431,7 +436,7 @@ def _run_body(self, command, **kwargs): # Wrap up or promise that we will, depending return self.make_promise() if self._asynchronous else self._finish() - def make_promise(self): + def make_promise(self) -> "Promise": """ Return a `Promise` allowing async control of the rest of lifecycle. @@ -439,7 +444,7 @@ def make_promise(self): """ return Promise(self) - def _finish(self): + def _finish(self) -> Any: # Wait for subprocess to run, forwarding signals as we get them. try: while True: @@ -499,7 +504,7 @@ def _finish(self): raise UnexpectedExit(result) return result - def _unify_kwargs_with_config(self, kwargs): + def _unify_kwargs_with_config(self, kwargs: Any) -> None: """ Unify `run` kwargs with config options to arrive at local options. @@ -559,7 +564,7 @@ def _unify_kwargs_with_config(self, kwargs): self.opts = opts self.streams = {"out": out_stream, "err": err_stream, "in": in_stream} - def _collate_result(self, watcher_errors): + def _collate_result(self, watcher_errors) -> Any: # At this point, we had enough success that we want to be returning or # raising detailed info about our execution; so we generate a Result. stdout = "".join(self.stdout) @@ -587,7 +592,7 @@ def _collate_result(self, watcher_errors): ) return result - def _thread_join_timeout(self, target): + def _thread_join_timeout(self, target) -> Optional[int]: # Add a timeout to out/err thread joins when it looks like they're not # dead but their counterpart is dead; this indicates issue #351 (fixed # by #432) where the subproc may hang because its stdout (or stderr) is @@ -603,7 +608,9 @@ def _thread_join_timeout(self, target): return 1 return None - def create_io_threads(self): + def create_io_threads( + self, + ) -> Tuple[Dict[Any, ExceptionHandlingThread], List[Any], List[Any]]: """ Create and return a dictionary of IO thread worker objects. @@ -642,7 +649,7 @@ def create_io_threads(self): threads[target] = t return threads, stdout, stderr - def generate_result(self, **kwargs): + def generate_result(self, **kwargs: Any) -> "Result": """ Create & return a suitable `Result` instance from the given ``kwargs``. @@ -687,7 +694,7 @@ def read_proc_output(self, reader): break yield self.decode(data) - def write_our_output(self, stream, string): + def write_our_output(self, stream, string: str) -> None: """ Write ``string`` to ``stream``. @@ -746,7 +753,7 @@ def handle_stdout(self, buffer_, hide, output): buffer_, hide, output, reader=self.read_proc_stdout ) - def handle_stderr(self, buffer_, hide, output): + def handle_stderr(self, buffer_, hide, output) -> None: """ Read process' stderr, storing into a buffer & printing/parsing. @@ -799,7 +806,7 @@ def read_our_stdin(self, input_): bytes_ = self.decode(bytes_) return bytes_ - def handle_stdin(self, input_, output, echo): + def handle_stdin(self, input_, output, echo) -> None: """ Read local stdin, copying into process' stdin as necessary. @@ -874,7 +881,7 @@ def should_echo_stdin(self, input_, output): """ return (not self.using_pty) and isatty(input_) - def respond(self, buffer_): + def respond(self, buffer_) -> None: """ Write to the program's stdin in response to patterns in ``buffer_``. @@ -901,7 +908,9 @@ def respond(self, buffer_): for response in watcher.submit(stream): self.write_proc_stdin(response) - def generate_env(self, env, replace_env): + def generate_env( + self, env: Dict[str, Any], replace_env: bool + ) -> Dict[str, Any]: """ Return a suitable environment dict based on user input & behavior. @@ -916,7 +925,7 @@ def generate_env(self, env, replace_env): """ return env if replace_env else dict(os.environ, **env) - def should_use_pty(self, pty, fallback): + def should_use_pty(self, pty: bool, fallback: bool) -> bool: """ Should execution attempt to use a pseudo-terminal? @@ -932,7 +941,7 @@ def should_use_pty(self, pty, fallback): return pty @property - def has_dead_threads(self): + def has_dead_threads(self) -> bool: """ Detect whether any IO threads appear to have terminated unexpectedly. @@ -948,7 +957,7 @@ def has_dead_threads(self): """ return any(x.is_dead for x in self.threads.values()) - def wait(self): + def wait(self) -> None: """ Block until the running command appears to have exited. @@ -963,7 +972,7 @@ def wait(self): break time.sleep(self.input_sleep) - def write_proc_stdin(self, data): + def write_proc_stdin(self, data: str) -> None: """ Write encoded ``data`` to the running process' stdin. @@ -977,7 +986,7 @@ def write_proc_stdin(self, data): # actual write to subprocess' stdin. self._write_proc_stdin(data.encode(self.encoding)) - def decode(self, data): + def decode(self, data: bytes) -> str: """ Decode some ``data`` bytes, returning Unicode. @@ -988,7 +997,7 @@ def decode(self, data): return data.decode(self.encoding, "replace") @property - def process_is_finished(self): + def process_is_finished(self) -> bool: """ Determine whether our subprocess has terminated. @@ -1004,7 +1013,7 @@ def process_is_finished(self): """ raise NotImplementedError - def start(self, command, shell, env): + def start(self, command: str, shell: str, env: Dict[str, Any]): """ Initiate execution of ``command`` (via ``shell``, with ``env``). @@ -1027,7 +1036,7 @@ def start(self, command, shell, env): """ raise NotImplementedError - def start_timer(self, timeout): + def start_timer(self, timeout: int) -> None: """ Start a timer to `kill` our subprocess after ``timeout`` seconds. """ @@ -1035,7 +1044,7 @@ def start_timer(self, timeout): self._timer = threading.Timer(timeout, self.kill) self._timer.start() - def read_proc_stdout(self, num_bytes): + def read_proc_stdout(self, num_bytes -> int) -> Union[bytes, str]: """ Read ``num_bytes`` from the running process' stdout stream. @@ -1047,7 +1056,7 @@ def read_proc_stdout(self, num_bytes): """ raise NotImplementedError - def read_proc_stderr(self, num_bytes): + def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str]: """ Read ``num_bytes`` from the running process' stderr stream. @@ -1059,7 +1068,7 @@ def read_proc_stderr(self, num_bytes): """ raise NotImplementedError - def _write_proc_stdin(self, data): + def _write_proc_stdin(self, data -> str) -> None: """ Write ``data`` to running process' stdin. @@ -1074,7 +1083,7 @@ def _write_proc_stdin(self, data): """ raise NotImplementedError - def close_proc_stdin(self): + def close_proc_stdin(self) -> None: """ Close running process' stdin. @@ -1084,7 +1093,7 @@ def close_proc_stdin(self): """ raise NotImplementedError - def default_encoding(self): + def default_encoding(self) -> str: """ Return a string naming the expected encoding of subprocess streams. @@ -1096,7 +1105,7 @@ def default_encoding(self): # subprocess. For now, good enough to assume both are the same. return default_encoding() - def send_interrupt(self, interrupt): + def send_interrupt(self, interrupt) -> None: """ Submit an interrupt signal to the running subprocess. @@ -1114,7 +1123,7 @@ def send_interrupt(self, interrupt): """ self.write_proc_stdin("\x03") - def returncode(self): + def returncode(self) > int: """ Return the numeric return/exit code resulting from command execution. @@ -1124,7 +1133,7 @@ def returncode(self): """ raise NotImplementedError - def stop(self): + def stop(self) -> None: """ Perform final cleanup, if necessary. @@ -1139,7 +1148,7 @@ def stop(self): if self._timer: self._timer.cancel() - def kill(self): + def kill(self) -> None: """ Forcibly terminate the subprocess. @@ -1152,7 +1161,7 @@ def kill(self): raise NotImplementedError @property - def timed_out(self): + def timed_out(self) -> bool: """ Returns ``True`` if the subprocess stopped because it timed out. @@ -1180,12 +1189,12 @@ class Local(Runner): .. versionadded:: 1.0 """ - def __init__(self, context): + def __init__(self, context: "Context") -> None: super().__init__(context) # Bookkeeping var for pty use case self.status = None - def should_use_pty(self, pty=False, fallback=True): + def should_use_pty(self, pty: bool = False, fallback: bool = True) -> bool: use_pty = False if pty: use_pty = True @@ -1198,7 +1207,7 @@ def should_use_pty(self, pty=False, fallback=True): use_pty = False return use_pty - def read_proc_stdout(self, num_bytes): + def read_proc_stdout(self, num_bytes: int): # Obtain useful read-some-bytes function if self.using_pty: # Need to handle spurious OSErrors on some Linux platforms. @@ -1223,12 +1232,12 @@ def read_proc_stdout(self, num_bytes): data = os.read(self.process.stdout.fileno(), num_bytes) return data - def read_proc_stderr(self, num_bytes): + def read_proc_stderr(self, num_bytes: int): # NOTE: when using a pty, this will never be called. # TODO: do we ever get those OSErrors on stderr? Feels like we could? return os.read(self.process.stderr.fileno(), num_bytes) - def _write_proc_stdin(self, data): + def _write_proc_stdin(self, data) -> int: # NOTE: parent_fd from os.fork() is a read/write pipe attached to our # forked process' stdout/stdin, respectively. fd = self.parent_fd if self.using_pty else self.process.stdin.fileno() @@ -1241,14 +1250,14 @@ def _write_proc_stdin(self, data): if "Broken pipe" not in str(e): raise - def close_proc_stdin(self): + def close_proc_stdin(self) -> None: if self.using_pty: # there is no working scenario to tell the process that stdin # closed when using pty raise SubprocessPipeError("Cannot close stdin when pty=True") self.process.stdin.close() - def start(self, command, shell, env): + def start(self, command: str, shell: str, env: Dict[str, Any]) -> None: if self.using_pty: if pty is None: # Encountered ImportError err = "You indicated pty=True, but your platform doesn't support the 'pty' module!" # noqa @@ -1287,12 +1296,12 @@ def start(self, command, shell, env): stdin=PIPE, ) - def kill(self): + def kill(self) -> None: pid = self.pid if self.using_pty else self.process.pid os.kill(pid, signal.SIGKILL) @property - def process_is_finished(self): + def process_is_finished(self) -> bool: if self.using_pty: # NOTE: # https://github.com/pexpect/ptyprocess/blob/4058faa05e2940662ab6da1330aa0586c6f9cd9c/ptyprocess/ptyprocess.py#L680-L687 @@ -1306,7 +1315,7 @@ def process_is_finished(self): else: return self.process.poll() is not None - def returncode(self): + def returncode(self) -> int: if self.using_pty: # No subprocess.returncode available; use WIFEXITED/WIFSIGNALED to # determine whch of WEXITSTATUS / WTERMSIG to use. @@ -1327,7 +1336,7 @@ def returncode(self): else: return self.process.returncode - def stop(self): + def stop(self) -> None: # If we opened a PTY for child communications, make sure to close() it, # otherwise long-running Invoke-using processes exhaust their file # descriptors eventually. @@ -1409,15 +1418,15 @@ class Result: # TODO: inherit from namedtuple instead? heh (or: use attrs from pypi) def __init__( self, - stdout="", - stderr="", - encoding=None, - command="", - shell="", - env=None, - exited=0, - pty=False, - hide=tuple(), + stdout: str = "", + stderr: str = "", + encoding: Optional[str] = None, + command: str = "", + shell: str = "", + env: Optional[Dict[str, Any]] = None, + exited: int = 0, + pty: bool = False, + hide: Tuple[str, ...] = tuple(), ): self.stdout = stdout self.stderr = stderr @@ -1432,7 +1441,7 @@ def __init__( self.hide = hide @property - def return_code(self): + def return_code(self) -> Any: """ An alias for ``.exited``. @@ -1440,10 +1449,17 @@ def return_code(self): """ return self.exited - def __bool__(self): + def __nonzero__(self) -> Any: + # NOTE: This is the method that (under Python 2) determines Boolean + # behavior for objects. return self.ok - def __str__(self): + def __bool__(self) -> Any: + # NOTE: And this is the Python 3 equivalent of __nonzero__. Much better + # name... + return self.__nonzero__() + + def __str__(self) -> str: if self.exited is not None: desc = "Command exited with status {}.".format(self.exited) else: @@ -1452,17 +1468,17 @@ def __str__(self): for x in ("stdout", "stderr"): val = getattr(self, x) ret.append( - """=== {} === + u"""=== {} === {} """.format( x, val.rstrip() ) if val - else "(no {})".format(x) + else u"(no {})".format(x) ) - return "\n".join(ret) + return u"\n".join(ret) - def __repr__(self): + def __repr__(self) -> str: # TODO: more? e.g. len of stdout/err? (how to represent cleanly in a # 'x=y' format like this? e.g. '4b' is ambiguous as to what it # represents @@ -1470,16 +1486,16 @@ def __repr__(self): return template.format(self.command, self.exited) @property - def ok(self): + def ok(self) -> bool: """ A boolean equivalent to ``exited == 0``. .. versionadded:: 1.0 """ - return self.exited == 0 + return bool(self.exited == 0) @property - def failed(self): + def failed(self) -> bool: """ The inverse of ``ok``. @@ -1490,7 +1506,7 @@ def failed(self): """ return not self.ok - def tail(self, stream, count=10): + def tail(self, stream: str, count: int = 10) -> str: """ Return the last ``count`` lines of ``stream``, plus leading whitespace. @@ -1523,7 +1539,7 @@ class Promise(Result): .. versionadded:: 1.4 """ - def __init__(self, runner): + def __init__(self, runner: "Runner") -> None: """ Create a new promise. @@ -1538,7 +1554,7 @@ def __init__(self, runner): for key, value in self.runner.result_kwargs.items(): setattr(self, key, value) - def join(self): + def join(self) -> Any: """ Block until associated subprocess exits, returning/raising the result. @@ -1558,16 +1574,22 @@ def join(self): try: return self.runner._finish() finally: - self.runner.stop() + self.runner._stop_everything() - def __enter__(self): + def __enter__(self) -> "Promise": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, exc_type, exc_value, exc_tb: Optional[TracebackType] + ) -> None: self.join() -def normalize_hide(val, out_stream=None, err_stream=None): +def normalize_hide( + val: Any, + out_stream: Optional[str] = None, + err_stream: Optional[str] = None, +) -> Tuple[str, ...]: # Normalize to list-of-stream-names hide_vals = (None, False, "out", "stdout", "err", "stderr", "both", True) if val not in hide_vals: @@ -1591,7 +1613,7 @@ def normalize_hide(val, out_stream=None, err_stream=None): return tuple(hide) -def default_encoding(): +def default_encoding() -> str: """ Obtain apparent interpreter-local default text encoding. diff --git a/invoke/terminals.py b/invoke/terminals.py index c793a4c2b..ae1efc81e 100644 --- a/invoke/terminals.py +++ b/invoke/terminals.py @@ -8,6 +8,7 @@ """ from contextlib import contextmanager +from typing import Any, Optional, Tuple import os import select import sys @@ -38,7 +39,7 @@ import tty -def pty_size(): +def pty_size() -> Tuple[int, int]: """ Determine current local pseudoterminal dimensions. @@ -50,10 +51,10 @@ def pty_size(): """ cols, rows = _pty_size() if not WINDOWS else _win_pty_size() # TODO: make defaults configurable? - return ((cols or 80), (rows or 24)) + return (int(cols or 80), int(rows or 24)) -def _pty_size(): +def _pty_size() -> Tuple[Optional[int], Optional[int]]: """ Suitable for most POSIX platforms. @@ -85,7 +86,7 @@ def _pty_size(): return size -def _win_pty_size(): +def _win_pty_size() -> Tuple[Optional[str], Optional[str]]: class CONSOLE_SCREEN_BUFFER_INFO(Structure): _fields_ = [ ("dwSize", _COORD), @@ -115,7 +116,7 @@ class CONSOLE_SCREEN_BUFFER_INFO(Structure): return (None, None) -def stdin_is_foregrounded_tty(stream): +def stdin_is_foregrounded_tty(stream) -> bool: """ Detect if given stdin ``stream`` seems to be in the foreground of a TTY. @@ -139,7 +140,7 @@ def stdin_is_foregrounded_tty(stream): return os.getpgrp() == os.tcgetpgrp(stream.fileno()) -def cbreak_already_set(stream): +def cbreak_already_set(stream) -> bool: # Explicitly not docstringed to remain private, for now. Eh. # Checks whether tty.setcbreak appears to have already been run against # ``stream`` (or if it would otherwise just not do anything). @@ -186,7 +187,7 @@ def character_buffered(stream): termios.tcsetattr(stream, termios.TCSADRAIN, old_settings) -def ready_for_reading(input_): +def ready_for_reading(input_) -> bool: """ Test ``input_`` to determine whether a read action will succeed. @@ -209,7 +210,7 @@ def ready_for_reading(input_): return bool(reads and reads[0] is input_) -def bytes_to_read(input_): +def bytes_to_read(input_) -> int: """ Query stream ``input_`` to see how many bytes may be readable. diff --git a/invoke/util.py b/invoke/util.py index 74043c4fc..39b726127 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -1,5 +1,8 @@ from collections import namedtuple from contextlib import contextmanager + +# from os import PathLike +from typing import Any, List, Optional, Tuple, Union import io import logging import os @@ -19,14 +22,14 @@ from .vendor.lexicon import Lexicon # noqa from .vendor import yaml # noqa except ImportError: - from lexicon import Lexicon # noqa - import yaml # noqa + from lexicon import Lexicon # type: ignore # noqa + import yaml # type: ignore # noqa LOG_FORMAT = "%(name)s.%(module)s.%(funcName)s: %(message)s" -def enable_logging(): +def enable_logging() -> None: logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) @@ -41,7 +44,7 @@ def enable_logging(): globals()[x] = getattr(log, x) -def task_name_sort_key(name): +def task_name_sort_key(name: str) -> Tuple[List[str], str]: """ Return key tuple for use sorting dotted task names, via e.g. `sorted`. @@ -60,7 +63,7 @@ def task_name_sort_key(name): # TODO: Make part of public API sometime @contextmanager -def cd(where): +def cd(where: str): cwd = os.getcwd() os.chdir(where) try: @@ -69,7 +72,7 @@ def cd(where): os.chdir(cwd) -def has_fileno(stream): +def has_fileno(stream) -> bool: """ Cleanly determine whether ``stream`` has a useful ``.fileno()``. @@ -93,7 +96,7 @@ def has_fileno(stream): return False -def isatty(stream): +def isatty(stream) -> Union[bool, Any]: """ Cleanly determine whether ``stream`` is a TTY. @@ -126,7 +129,7 @@ def isatty(stream): return False -def helpline(obj): +def helpline(obj: object) -> Optional[str]: """ Yield an object's first docstring line, or None if there was no docstring. @@ -161,7 +164,7 @@ class ExceptionHandlingThread(threading.Thread): .. versionadded:: 1.0 """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """ Create a new exception-handling thread instance. @@ -177,7 +180,7 @@ def __init__(self, **kwargs): self.kwargs = kwargs self.exc_info = None - def run(self): + def run(self) -> None: try: # Allow subclasses implemented using the "override run()'s body" # approach to work, by using _run() instead of run(). If that @@ -214,7 +217,7 @@ def run(self): name = self.kwargs["target"].__name__ debug(msg.format(self.exc_info[1], name)) # noqa - def exception(self): + def exception(self) -> Optional['ExceptionWrapper']: """ If an exception occurred, return an `.ExceptionWrapper` around it. @@ -230,7 +233,7 @@ def exception(self): return ExceptionWrapper(self.kwargs, *self.exc_info) @property - def is_dead(self): + def is_dead(self) -> bool: """ Returns ``True`` if not alive and has a stored exception. @@ -243,9 +246,9 @@ def is_dead(self): # be thorough? return (not self.is_alive()) and self.exc_info is not None - def __repr__(self): + def __repr__(self) -> str: # TODO: beef this up more - return self.kwargs["target"].__name__ + return str(self.kwargs["target"].__name__) # NOTE: ExceptionWrapper defined here, not in exceptions.py, to avoid circular diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..e104bb452 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +exclude = (integration|invoke/vendor|site|tests|tasks) diff --git a/tox.ini b/tox.ini index 3739fcf81..0cdda0531 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py36, py37, py38, py39 +envlist = py36, py37, py38, py39, 'py310', py311 [testenv] commands = From f2c6cc9de92f1db6c004842c8d78ca61076f2540 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Thu, 19 Jan 2023 20:38:21 -0500 Subject: [PATCH 02/27] test: fix tests with mypy runners --- invoke/completion/complete.py | 2 +- invoke/loader.py | 16 +++++++++------- invoke/runners.py | 35 ++++++++++++++--------------------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index a88c33bce..c61ec4a7a 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -92,7 +92,7 @@ def complete( raise Exit -def print_task_names(collection: Collection) -> None: +def print_task_names(collection: "Collection") -> None: for name in sorted(collection.task_names, key=task_name_sort_key): print(name) # Just stick aliases after the thing they're aliased to. Sorting isn't diff --git a/invoke/loader.py b/invoke/loader.py index 0748edf77..41706f44b 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -1,10 +1,12 @@ import os import sys import imp +from types import ModuleType +from typing import Any, IO, Optional, Tuple from . import Config from .exceptions import CollectionNotFound -from .util import debug +from .util import debug # type: ignore class Loader: @@ -14,7 +16,7 @@ class Loader: .. versionadded:: 1.0 """ - def __init__(self, config=None): + def __init__(self, config: Optional["Config"] = None) -> None: """ Set up a new loader with some `.Config`. @@ -27,7 +29,7 @@ def __init__(self, config=None): config = Config() self.config = config - def find(self, name): + def find(self, name: str) -> Tuple[str, str, str]: """ Implementation-specific finder method seeking collection ``name``. @@ -42,7 +44,7 @@ def find(self, name): """ raise NotImplementedError - def load(self, name=None): + def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]: """ Load and return collection module identified by ``name``. @@ -99,18 +101,18 @@ class FilesystemLoader(Loader): # TODO: otherwise Loader has to know about specific bits to transmit, such # as auto-dashes, and has to grow one of those for every bit Collection # ever needs to know - def __init__(self, start=None, **kwargs): + def __init__(self, start: Optional[str] = None, **kwargs: Any) -> None: super().__init__(**kwargs) if start is None: start = self.config.tasks.search_root self._start = start @property - def start(self): + def start(self) -> str: # Lazily determine default CWD if configured value is falsey return self._start or os.getcwd() - def find(self, name): + def find(self, name: str) -> Tuple[IO[Any], str, Tuple[str, str, int]]: # Accumulate all parent directories start = self.start debug("FilesystemLoader find starting at {!r}".format(start)) diff --git a/invoke/runners.py b/invoke/runners.py index b87e8be7e..9f733b80b 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -8,24 +8,24 @@ import signal from subprocess import Popen, PIPE from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union # Import some platform-specific things at top level so they can be mocked for # tests. try: import pty except ImportError: - pty = None + pty = None # type: ignore try: import fcntl except ImportError: - fcntl = None + fcntl = None # type: ignore try: import termios except ImportError: - termios = None + termios = None # type: ignore -from invoke.exceptions import ( +from .exceptions import ( UnexpectedExit, Failure, ThreadException, @@ -33,7 +33,7 @@ SubprocessPipeError, CommandTimedOut, ) -from invoke.terminals import ( +from .terminals import ( WINDOWS, pty_size, character_buffered, @@ -1044,7 +1044,7 @@ def start_timer(self, timeout: int) -> None: self._timer = threading.Timer(timeout, self.kill) self._timer.start() - def read_proc_stdout(self, num_bytes -> int) -> Union[bytes, str]: + def read_proc_stdout(self, num_bytes: int) -> Union[bytes, str]: """ Read ``num_bytes`` from the running process' stdout stream. @@ -1068,7 +1068,7 @@ def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str]: """ raise NotImplementedError - def _write_proc_stdin(self, data -> str) -> None: + def _write_proc_stdin(self, data: str) -> None: """ Write ``data`` to running process' stdin. @@ -1123,7 +1123,7 @@ def send_interrupt(self, interrupt) -> None: """ self.write_proc_stdin("\x03") - def returncode(self) > int: + def returncode(self) -> int: """ Return the numeric return/exit code resulting from command execution. @@ -1449,16 +1449,9 @@ def return_code(self) -> Any: """ return self.exited - def __nonzero__(self) -> Any: - # NOTE: This is the method that (under Python 2) determines Boolean - # behavior for objects. + def __bool__(self): return self.ok - def __bool__(self) -> Any: - # NOTE: And this is the Python 3 equivalent of __nonzero__. Much better - # name... - return self.__nonzero__() - def __str__(self) -> str: if self.exited is not None: desc = "Command exited with status {}.".format(self.exited) @@ -1468,15 +1461,15 @@ def __str__(self) -> str: for x in ("stdout", "stderr"): val = getattr(self, x) ret.append( - u"""=== {} === + """=== {} === {} """.format( x, val.rstrip() ) if val - else u"(no {})".format(x) + else "(no {})".format(x) ) - return u"\n".join(ret) + return "\n".join(ret) def __repr__(self) -> str: # TODO: more? e.g. len of stdout/err? (how to represent cleanly in a @@ -1574,7 +1567,7 @@ def join(self) -> Any: try: return self.runner._finish() finally: - self.runner._stop_everything() + self.runner.stop() def __enter__(self) -> "Promise": return self From dcaa09383dee298dd0773c505f03e380dd273778 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Thu, 19 Jan 2023 20:52:49 -0500 Subject: [PATCH 03/27] test: update mypy for env and config --- invoke/config.py | 28 +++++++++++++++++----------- invoke/env.py | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/invoke/config.py b/invoke/config.py index d1abd310e..de1d3a02f 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -2,7 +2,6 @@ import json import os import types -from os import PathLike from os.path import join, splitext, expanduser from typing import Any, Dict, Iterator, Optional, Tuple @@ -10,16 +9,18 @@ from .exceptions import UnknownFileType, UnpicklableConfigMember from .runners import Local from .terminals import WINDOWS -from .util import debug, yaml +from .util import debug, yaml # type: ignore try: from importlib.machinery import SourceFileLoader except ImportError: # PyPy3 - from importlib._bootstrap import _SourceFileLoader as SourceFileLoader + from importlib._bootstrap import ( # type: ignore + _SourceFileLoader as SourceFileLoader, + ) -def load_source(name, path): +def load_source(name: str, path: str) -> Dict[str, Any]: if not os.path.exists(path): return {} return vars(SourceFileLoader("mod", path).load_module()) @@ -67,7 +68,10 @@ class DataProxy: @classmethod def from_data( - cls, data: Dict[str, Any], root: Optional[str] = None, keypath=tuple() + cls, + data: Dict[str, Any], + root: Optional[str] = None, + keypath: Tuple[str, ...] = tuple(), ): """ Alternate constructor for 'baby' DataProxies used as sub-dict values. @@ -133,7 +137,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: # whether we support __iter__. BOO return iter(self._config) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: # NOTE: Can't proxy __eq__ because the RHS will always be an obj of the # current class, not the proxied-to class, and that causes # NotImplemented. @@ -910,7 +914,7 @@ def _load_json(self, path: PathLike) -> Any: with open(path) as fd: return json.load(fd) - def _load_py(self, path: PathLike) -> Dict[str, Any]: + def _load_py(self, path: str) -> Dict[str, Any]: data = {} for key, value in (load_source("mod", path)).items(): # Strip special members, as these are always going to be builtins @@ -972,7 +976,7 @@ def _merge_file(self, name: str, desc: str) -> None: # the negative? Just a branch here based on 'name'? debug("{} not found, skipping".format(desc)) - def clone(self, into=None): + def clone(self, into: Optional["Config"] = None) -> "Config": """ Return a copy of this configuration object. @@ -1025,7 +1029,7 @@ def clone(self, into=None): # instantiation" and "I want cloning to not trigger certain things like # external data source loading". # NOTE: this will include lazy=True, see end of method - new = klass(**self._clone_init_kwargs(into=into)) + new = klass(**self._clone_init_kwargs(into=into)) # type: ignore # Copy/merge/etc all 'private' data sources and attributes for name in """ collection @@ -1068,7 +1072,9 @@ def clone(self, into=None): new.merge() return new - def _clone_init_kwargs(self, into=None) -> Dict[str, Any]: + def _clone_init_kwargs( + self, into: Optional["Config"] = None + ) -> Dict[str, Any]: """ Supply kwargs suitable for initializing a new clone of this object. @@ -1263,7 +1269,7 @@ def excise(dict_: Dict[str, Any], keypath: Tuple[str, ...]) -> None: del data[leaf_key] -def obliterate(base: Dict[str, Any], deletions: Tuple[str, ...]) -> None: +def obliterate(base: Dict[str, Any], deletions: Dict[str, Any]) -> None: """ Remove all (nested) keys mentioned in ``deletions``, from ``base``. diff --git a/invoke/env.py b/invoke/env.py index c11377332..6e6ade7bf 100644 --- a/invoke/env.py +++ b/invoke/env.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Dict, List from .exceptions import UncastableEnvVar, AmbiguousEnvVar -from .util import debug +from .util import debug # type: ignore if TYPE_CHECKING: from .config import Config From 504525f569a968991b67725c3444f30f39941b15 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Thu, 19 Jan 2023 21:12:16 -0500 Subject: [PATCH 04/27] test: update mypy for env and config --- invoke/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invoke/config.py b/invoke/config.py index de1d3a02f..b2760f092 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -2,6 +2,7 @@ import json import os import types +from os import PathLike from os.path import join, splitext, expanduser from typing import Any, Dict, Iterator, Optional, Tuple From 6e2ec918e2b32ff71bd68efe895370158ebc6312 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Wed, 25 Jan 2023 23:34:17 -0500 Subject: [PATCH 05/27] test: add type-hints --- invoke/__init__.py | 26 ++++---- invoke/completion/complete.py | 6 +- invoke/config.py | 10 +-- invoke/context.py | 36 ++++++++--- invoke/env.py | 4 +- invoke/exceptions.py | 18 +++--- invoke/executor.py | 33 ++++++++-- invoke/loader.py | 2 +- invoke/parser/parser.py | 29 +++++---- invoke/program.py | 117 ++++++++++++++++++++-------------- invoke/runners.py | 70 ++++++++++++++------ 11 files changed, 224 insertions(+), 127 deletions(-) diff --git a/invoke/__init__.py b/invoke/__init__.py index 1e37d0fd6..e7fa1208d 100644 --- a/invoke/__init__.py +++ b/invoke/__init__.py @@ -1,10 +1,10 @@ from typing import Any -from invoke._version import __version_info__, __version__ # noqa -from invoke.collection import Collection # noqa -from invoke.config import Config # noqa -from invoke.context import Context, MockContext # noqa -from invoke.exceptions import ( # noqa +from ._version import __version_info__, __version__ # noqa +from .collection import Collection # noqa +from .config import Config # noqa +from .context import Context, MockContext # noqa +from .exceptions import ( # noqa AmbiguousEnvVar, AuthFailure, CollectionNotFound, @@ -21,14 +21,14 @@ WatcherError, CommandTimedOut, ) -from invoke.executor import Executor # noqa -from invoke.loader import FilesystemLoader # noqa -from invoke.parser import Argument, Parser, ParserContext, ParseResult # noqa -from invoke.program import Program # noqa -from invoke.runners import Runner, Local, Failure, Result, Promise # noqa -from invoke.tasks import task, call, Call, Task # noqa -from invoke.terminals import pty_size # noqa -from invoke.watchers import FailingResponder, Responder, StreamWatcher # noqa +from .executor import Executor # noqa +from .loader import FilesystemLoader # noqa +from .parser import Argument, Parser, ParserContext, ParseResult # noqa +from .program import Program # noqa +from .runners import Runner, Local, Failure, Result, Promise # noqa +from .tasks import task, call, Call, Task # noqa +from .terminals import pty_size # noqa +from .watchers import FailingResponder, Responder, StreamWatcher # noqa def run(command: str, **kwargs: Any) -> Any: diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index c61ec4a7a..3fc6777fa 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -10,16 +10,16 @@ from typing import TYPE_CHECKING from ..exceptions import Exit, ParseError -from ..util import debug, task_name_sort_key +from ..util import debug, task_name_sort_key # type: ignore if TYPE_CHECKING: from ..collection import Collection - from ..parser import Parser, Context + from ..parser import Parser, ParseResult, Context def complete( names: List[str], - core, + core: "ParseResult", initial_context: "Context", collection: "Collection", parser: "Parser", diff --git a/invoke/config.py b/invoke/config.py index b2760f092..449a228c7 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -4,7 +4,7 @@ import types from os import PathLike from os.path import join, splitext, expanduser -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple, Type from .env import Environment from .exceptions import UnknownFileType, UnpicklableConfigMember @@ -71,7 +71,7 @@ class DataProxy: def from_data( cls, data: Dict[str, Any], - root: Optional[str] = None, + root: Optional['DataProxy'] = None, keypath: Tuple[str, ...] = tuple(), ): """ @@ -288,7 +288,7 @@ def setdefault(self, *args: Any) -> Any: self._track_modification_of(key, default) return ret - def update(self, *args: Any, **kwargs: Dict[str, Any]) -> None: + def update(self, *args: Any, **kwargs: Any) -> None: if kwargs: for key, value in kwargs.items(): self[key] = value @@ -819,7 +819,7 @@ def load_collection( if merge: self.merge() - def set_project_location(self, path: PathLike) -> None: + def set_project_location(self, path: Optional[PathLike]) -> None: """ Set the directory path where a project-level config file may be found. @@ -977,7 +977,7 @@ def _merge_file(self, name: str, desc: str) -> None: # the negative? Just a branch here based on 'name'? debug("{} not found, skipping".format(desc)) - def clone(self, into: Optional["Config"] = None) -> "Config": + def clone(self, into: Optional[Type["Config"]] = None) -> "Config": """ Return a copy of this configuration object. diff --git a/invoke/context.py b/invoke/context.py index a06c4b7e1..b05ebe3c1 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -3,7 +3,16 @@ from contextlib import contextmanager from itertools import cycle from os import PathLike -from typing import Any, Iterator, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Generator, + Iterator, + List, + Optional, + Type, + Union, +) from unittest.mock import Mock from .config import Config, DataProxy @@ -11,6 +20,9 @@ from .runners import Result from .watchers import FailingResponder +if TYPE_CHECKING: + from invoke.runners import Runner + class Context(DataProxy): """ @@ -76,7 +88,7 @@ def config(self, value: Any) -> None: # runtime. self._set(_config=value) - def run(self, command: str, **kwargs: Any): + def run(self, command: str, **kwargs: Any) -> Result: """ Execute a local shell command, honoring config options. @@ -95,11 +107,13 @@ def run(self, command: str, **kwargs: Any): # NOTE: broken out of run() to allow for runner class injection in # Fabric/etc, which needs to juggle multiple runner class types (local and # remote). - def _run(self, runner, command: str, **kwargs: Any): + def _run( + self, runner: Type["Runner"], command: str, **kwargs: Any + ) -> Result: command = self._prefix_commands(command) return runner.run(command, **kwargs) - def sudo(self, command: str, **kwargs: Any): + def sudo(self, command: str, **kwargs: Any) -> Result: """ Execute a shell command via ``sudo`` with password auto-response. @@ -172,7 +186,9 @@ def sudo(self, command: str, **kwargs: Any): return self._sudo(runner, command, **kwargs) # NOTE: this is for runner injection; see NOTE above _run(). - def _sudo(self, runner, command: str, **kwargs: Any): + def _sudo( + self, runner: Type["Runner"], command: str, **kwargs: Any + ) -> Result: prompt = self.config.sudo.prompt password = kwargs.pop("password", self.config.sudo.password) user = kwargs.pop("user", self.config.sudo.user) @@ -249,7 +265,7 @@ def _prefix_commands(self, command: str) -> str: return " && ".join(prefixes + [command]) @contextmanager - def prefix(self, command: str): + def prefix(self, command: str) -> Generator[None, None, None]: """ Prefix all nested `run`/`sudo` commands with given command plus ``&&``. @@ -328,7 +344,7 @@ def cwd(self) -> Union[PathLike, str]: return os.path.join(*paths) @contextmanager - def cd(self, path: PathLike): + def cd(self, path: PathLike) -> Generator[None, None, None]: """ Context manager that keeps directory state when executing commands. @@ -503,7 +519,7 @@ def _normalize(self, value: Any) -> Iterator[Any]: # worth. Maybe in situations where Context grows a _lot_ of methods (e.g. # in Fabric 2; though Fabric could do its own sub-subclass in that case...) - def _yield_result(self, attname: str, command: str): + def _yield_result(self, attname: str, command: str) -> Result: try: obj = getattr(self, attname) # Dicts need to try direct lookup or regex matching @@ -533,14 +549,14 @@ def _yield_result(self, attname: str, command: str): # raise_from(NotImplementedError(command), None) raise NotImplementedError(command) - def run(self, command: str, *args: Any, **kwargs: Any): + def run(self, command: str, *args: Any, **kwargs: Any) -> Result: # TODO: perform more convenience stuff associating args/kwargs with the # result? E.g. filling in .command, etc? Possibly useful for debugging # if one hits unexpected-order problems with what they passed in to # __init__. return self._yield_result("__run", command) - def sudo(self, command: str, *args: Any, **kwargs: Any): + def sudo(self, command: str, *args: Any, **kwargs: Any) -> Result: # TODO: this completely nukes the top-level behavior of sudo(), which # could be good or bad, depending. Most of the time I think it's good. # No need to supply dummy password config, etc. diff --git a/invoke/env.py b/invoke/env.py index 6e6ade7bf..f772fc6ce 100644 --- a/invoke/env.py +++ b/invoke/env.py @@ -19,7 +19,7 @@ class Environment: - def __init__(self, config: 'Config', prefix: str) -> None: + def __init__(self, config: "Config", prefix: str) -> None: self._config = config self._prefix = prefix self.data: Dict[str, Any] = {} # Accumulator @@ -88,7 +88,7 @@ def _crawl( def _to_env_var(self, key_path: List[str]) -> str: return "_".join(key_path).upper() - def _path_get(self, key_path: List[str]): # -> Config: + def _path_get(self, key_path: List[str]) -> "Config": # Gets are from self._config because that's what determines valid env # vars and/or values for typecasting. obj = self._config diff --git a/invoke/exceptions.py b/invoke/exceptions.py index 94ac99627..cee6dbeae 100644 --- a/invoke/exceptions.py +++ b/invoke/exceptions.py @@ -8,11 +8,11 @@ from pprint import pformat from traceback import format_exception -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple if TYPE_CHECKING: - from invoke.context import Context - from invoke.runners import Result + from .parser import ParserContext + from .runners import Result class CollectionNotFound(Exception): @@ -176,7 +176,7 @@ class AuthFailure(Failure): .. versionadded:: 1.0 """ - def __init__(self, result, prompt: str) -> None: + def __init__(self, result: "Result", prompt: str) -> None: self.result = result self.prompt = prompt @@ -194,7 +194,9 @@ class ParseError(Exception): .. versionadded:: 1.0 """ - def __init__(self, msg: str, context: Optional["Context"] = None) -> None: + def __init__( + self, msg: str, context: Optional["ParserContext"] = None + ) -> None: super().__init__(msg) self.context = context @@ -344,12 +346,12 @@ class ThreadException(Exception): #: Thread kwargs which appear to be very long (e.g. IO #: buffers) will be truncated when printed, to avoid huge #: unreadable error display. - exceptions: Tuple[str, ...] = tuple() + exceptions: Tuple[Exception, ...] = tuple() - def __init__(self, exceptions): + def __init__(self, exceptions: List[Exception]) -> None: self.exceptions = tuple(exceptions) - def __str__(self): + def __str__(self) -> str: details = [] for x in self.exceptions: # Build useful display diff --git a/invoke/executor.py b/invoke/executor.py index 8876e19bc..80853ad14 100644 --- a/invoke/executor.py +++ b/invoke/executor.py @@ -1,8 +1,15 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + from .config import Config from .parser import ParserContext -from .util import debug +from .util import debug # type: ignore from .tasks import Call, Task +if TYPE_CHECKING: + from .collection import Collection + from .runners import Result + from .parser import ParseResult + class Executor: """ @@ -14,7 +21,12 @@ class Executor: .. versionadded:: 1.0 """ - def __init__(self, collection, config=None, core=None): + def __init__( + self, + collection: "Collection", + config: Optional["Config"] = None, + core: Optional["ParseResult"] = None, + ) -> None: """ Initialize executor with handles to necessary data structures. @@ -34,7 +46,9 @@ def __init__(self, collection, config=None, core=None): self.config = config if config is not None else Config() self.core = core - def execute(self, *tasks): + def execute( + self, *tasks: Union[str, Tuple[str, Dict[str, Any]], ParserContext] + ) -> Dict["Task", "Result"]: """ Execute one or more ``tasks`` in sequence. @@ -132,7 +146,12 @@ def execute(self, *tasks): results[call.task] = result return results - def normalize(self, tasks): + def normalize( + self, + tasks: Tuple[ + Union[str, Tuple[str, Dict[str, Any]], ParserContext], ... + ], + ) -> List["Call"]: """ Transform arbitrary task list w/ various types, into `.Call` objects. @@ -142,9 +161,9 @@ def normalize(self, tasks): """ calls = [] for task in tasks: - name, kwargs = None, {} if isinstance(task, str): name = task + kwargs = {} elif isinstance(task, ParserContext): name = task.name kwargs = task.as_kwargs @@ -156,7 +175,7 @@ def normalize(self, tasks): calls = [Call(task=self.collection[self.collection.default])] return calls - def dedupe(self, calls): + def dedupe(self, calls: List["Call"]) -> List["Call"]: """ Deduplicate a list of `tasks <.Call>`. @@ -176,7 +195,7 @@ def dedupe(self, calls): debug("{!r}: found in list already, skipping".format(call)) return deduped - def expand_calls(self, calls): + def expand_calls(self, calls: List["Call"]) -> List["Call"]: """ Expand a list of `.Call` objects into a near-final list of same. diff --git a/invoke/loader.py b/invoke/loader.py index 41706f44b..356e0fe28 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -29,7 +29,7 @@ def __init__(self, config: Optional["Config"] = None) -> None: config = Config() self.config = config - def find(self, name: str) -> Tuple[str, str, str]: + def find(self, name: str) -> Tuple[IO[Any], str, Tuple[str, str, int]]: """ Implementation-specific finder method seeking collection ``name``. diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index 16adf62b0..85a1641e0 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -1,5 +1,5 @@ import copy -from typing import Any, List +from typing import TYPE_CHECKING, Any, List, Optional, Tuple try: from invoke.vendor.lexicon import Lexicon @@ -8,10 +8,12 @@ from lexicon import Lexicon # type: ignore from fluidity import StateMachine, state, transition # type: ignore -# from invoke.parser import Context from invoke.exceptions import ParseError from invoke.util import debug # type: ignore +if TYPE_CHECKING: + from .context import ParserContext + def is_flag(value: str) -> bool: return bool(value.startswith("-")) @@ -33,7 +35,7 @@ class ParseResult(list): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super(ParseResult, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.remainder = "" self.unparsed: List[str] = [] @@ -61,8 +63,8 @@ class Parser: def __init__( self, - contexts=(), # : Tuple[Context, ...] = (), - initial=None, #: Optional[Context] = None, + contexts: Tuple["Lexicon", ...] = (), + initial: Optional["ParserContext"] = None, ignore_unknown: bool = False, ) -> None: self.initial = initial @@ -159,10 +161,10 @@ def parse_argv(self, argv: List[str]) -> ParseResult: debug(msg.format(token, rest)) mutations.append((index + 1, rest)) else: - rest = ["-{}".format(x) for x in rest] + _rest = ["-{}".format(x) for x in rest] msg = "Splitting multi-flag glob {!r} into {!r} and {!r}" # noqa - debug(msg.format(orig, token, rest)) - for item in reversed(rest): + debug(msg.format(orig, token, _rest)) + for item in reversed(_rest): mutations.append((index + 1, item)) # Here, we've got some possible mutations queued up, and 'token' # may have been overwritten as well. Whether we apply those and @@ -219,7 +221,12 @@ class ParseMachine(StateMachine): def changing_state(self, from_: str, to: str) -> None: debug("ParseMachine: {!r} => {!r}".format(from_, to)) - def __init__(self, initial, contexts, ignore_unknown) -> None: + def __init__( + self, + initial: "ParserContext", + contexts: Lexicon, + ignore_unknown: bool, + ) -> None: # Initialize self.ignore_unknown = ignore_unknown self.initial = self.context = copy.deepcopy(initial) @@ -391,7 +398,7 @@ def check_ambiguity(self, value: Any) -> bool: msg = "{!r} is ambiguous when given after an optional-value flag" raise ParseError(msg.format(value)) - def switch_to_flag(self, flag, inverse: bool = False) -> None: + def switch_to_flag(self, flag: str, inverse: bool = False) -> None: # Sanity check for ambiguity w/ prior optional-value flag self.check_ambiguity(flag) # Also tie it off, in case prior had optional value or etc. Seems to be @@ -433,7 +440,7 @@ def see_value(self, value: Any) -> None: else: self.error("Flag {!r} doesn't take any value!".format(self.flag)) - def see_positional_arg(self, value) -> None: + def see_positional_arg(self, value: Any) -> None: for arg in self.context.positional_args: if arg.value is None: arg.value = value diff --git a/invoke/program.py b/invoke/program.py index d0c1cb239..564f8deae 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -5,13 +5,20 @@ import sys import textwrap from importlib import import_module # buffalo buffalo +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from . import Collection, Config, Executor, FilesystemLoader from .completion.complete import complete, print_completion_script from .parser import Parser, ParserContext, Argument from .exceptions import UnexpectedExit, CollectionNotFound, ParseError, Exit from .terminals import pty_size -from .util import debug, enable_logging, helpline +from .util import debug, enable_logging, helpline # type: ignore + +if TYPE_CHECKING: + from .context import Context + from .loader import Loader + from .parser import ParseResult + from .util import Lexicon class Program: @@ -28,7 +35,9 @@ class Program: .. versionadded:: 1.0 """ - def core_args(self): + core: "ParseResult" + + def core_args(self) -> List["Argument"]: """ Return default core `.Argument` objects, as a list. @@ -132,7 +141,7 @@ def core_args(self): ), ] - def task_args(self): + def task_args(self) -> List["Argument"]: """ Return default task-related `.Argument` objects, as a list. @@ -171,15 +180,15 @@ def task_args(self): def __init__( self, - version=None, - namespace=None, - name=None, - binary=None, - loader_class=None, - executor_class=None, - config_class=None, - binary_names=None, - ): + version: Optional[str] = None, + namespace: Optional["Collection"] = None, + name: Optional[str] = None, + binary: Optional[str] = None, + loader_class: Optional["Loader"] = None, + executor_class: Optional["Executor"] = None, + config_class: Optional["Config"] = None, + binary_names: Optional[List[str]] = None, + ) -> None: """ Create a new, parameterized `.Program` instance. @@ -261,12 +270,12 @@ def __init__( # code to autogenerate it from binary_names.) self._binary = binary self._binary_names = binary_names - self.argv = None + self.argv: Optional[List[str]] = None self.loader_class = loader_class or FilesystemLoader self.executor_class = executor_class or Executor self.config_class = config_class or Config - def create_config(self): + def create_config(self) -> None: """ Instantiate a `.Config` (or subclass, depending) for use in task exec. @@ -279,9 +288,9 @@ def create_config(self): .. versionadded:: 1.0 """ - self.config = self.config_class() + self.config = self.config_class() # type: ignore - def update_config(self, merge=True): + def update_config(self, merge: bool = True) -> None: """ Update the previously instantiated `.Config` with parsed data. @@ -334,7 +343,9 @@ def update_config(self, merge=True): if merge: self.config.merge() - def run(self, argv=None, exit=True): + def run( + self, argv: Optional[Union[List[str], str]] = None, exit: bool = True + ) -> None: """ Execute main CLI logic, based on ``argv``. @@ -403,7 +414,7 @@ def run(self, argv=None, exit=True): except KeyboardInterrupt: sys.exit(1) # Same behavior as Python itself outside of REPL - def parse_core(self, argv): + def parse_core(self, argv: Optional[List[str]]) -> None: debug("argv given to Program.run: {!r}".format(argv)) self.normalize_argv(argv) @@ -433,7 +444,7 @@ def parse_core(self, argv): ) raise Exit - def parse_collection(self): + def parse_collection(self) -> None: """ Load a tasks collection & project-level config. @@ -469,7 +480,7 @@ def parse_collection(self): # TODO: load project conf, if possible, gracefully - def parse_cleanup(self): + def parse_cleanup(self) -> None: """ Post-parsing, pre-execution steps such as --help, --list, etc. @@ -531,14 +542,14 @@ def parse_cleanup(self): if not self.tasks and not self.collection.default: self.no_tasks_given() - def no_tasks_given(self): + def no_tasks_given(self) -> None: debug( "No tasks specified for execution and no default task; printing global help as fallback" # noqa ) self.print_help() raise Exit - def execute(self): + def execute(self) -> None: """ Hand off data & tasks-to-execute specification to an `.Executor`. @@ -561,10 +572,12 @@ def execute(self): # "normal" but also its own possible source of bugs/confusion... module = import_module(module_path) klass = getattr(module, class_name) - executor = klass(self.collection, self.config, self.core) + executor = klass( # type: ignore + self.collection, self.config, self.core + ) executor.execute(*self.tasks) - def normalize_argv(self, argv): + def normalize_argv(self, argv: Optional[List[str]]) -> None: """ Massages ``argv`` into a useful list of strings. @@ -588,7 +601,7 @@ def normalize_argv(self, argv): self.argv = argv @property - def name(self): + def name(self) -> str: """ Derive program's human-readable name based on `.binary`. @@ -597,7 +610,7 @@ def name(self): return self._name or self.binary.capitalize() @property - def called_as(self): + def called_as(self) -> str: """ Returns the program name we were actually called as. @@ -609,7 +622,7 @@ def called_as(self): return os.path.basename(self.argv[0]) @property - def binary(self): + def binary(self) -> str: """ Derive program's help-oriented binary name(s) from init args & argv. @@ -618,7 +631,7 @@ def binary(self): return self._binary or self.called_as @property - def binary_names(self): + def binary_names(self) -> List[str]: """ Derive program's completion-oriented binary name(s) from args & argv. @@ -628,7 +641,7 @@ def binary_names(self): # TODO 3.0: ugh rename this or core_args, they are too confusing @property - def args(self): + def args(self) -> "Lexicon": """ Obtain core program args from ``self.core`` parse result. @@ -637,7 +650,7 @@ def args(self): return self.core[0].args @property - def initial_context(self): + def initial_context(self) -> ParserContext: """ The initial parser context, aka core program flags. @@ -651,10 +664,10 @@ def initial_context(self): args += self.task_args() return ParserContext(args=args) - def print_version(self): + def print_version(self) -> None: print("{} {}".format(self.name, self.version or "unknown")) - def print_help(self): + def print_help(self) -> None: usage_suffix = "task1 [--task1-opts] ... taskN [--taskN-opts]" if self.namespace is not None: usage_suffix = " [--subcommand-opts] ..." @@ -666,7 +679,7 @@ def print_help(self): if self.namespace is not None: self.list_tasks() - def parse_core_args(self): + def parse_core_args(self) -> None: """ Filter out core args, leaving any tasks or their args for later. @@ -680,7 +693,7 @@ def parse_core_args(self): msg = "Core-args parse result: {!r} & unparsed: {!r}" debug(msg.format(self.core, self.core.unparsed)) - def load_collection(self): + def load_collection(self) -> None: """ Load a task collection based on parsed core args, or die trying. @@ -689,7 +702,9 @@ def load_collection(self): # NOTE: start, coll_name both fall back to configuration values within # Loader (which may, however, get them from our config.) start = self.args["search-root"].value - loader = self.loader_class(config=self.config, start=start) + loader = self.loader_class( # type: ignore + config=self.config, start=start + ) coll_name = self.args.collection.value try: module, parent = loader.load(coll_name) @@ -707,7 +722,9 @@ def load_collection(self): except CollectionNotFound as e: raise Exit("Can't find any collection named {!r}!".format(e.name)) - def _update_core_context(self, context, new_args): + def _update_core_context( + self, context: "Context", new_args: Dict[str, Any] + ) -> None: # Update core context w/ core_via_task args, if and only if the # via-task version of the arg was truly given a value. # TODO: push this into an Argument-aware Lexicon subclass and @@ -716,7 +733,7 @@ def _update_core_context(self, context, new_args): if arg.got_value: context.args[key]._value = arg._value - def _make_parser(self): + def _make_parser(self) -> Parser: return Parser( initial=self.initial_context, contexts=self.collection.to_contexts( @@ -724,7 +741,7 @@ def _make_parser(self): ), ) - def parse_tasks(self): + def parse_tasks(self) -> None: """ Parse leftover args, which are typically tasks & per-task args. @@ -748,7 +765,7 @@ def parse_tasks(self): self.tasks = result debug("Resulting task contexts: {!r}".format(self.tasks)) - def print_task_help(self, name): + def print_task_help(self, name: str) -> None: """ Print help for a specific task, e.g. ``inv --help ``. @@ -781,7 +798,7 @@ def print_task_help(self, name): print(self.leading_indent + "none") print("") - def list_tasks(self): + def list_tasks(self) -> None: # Short circuit if no tasks to show (Collection now implements bool) focus = self.scoped_collection if not focus: @@ -791,16 +808,20 @@ def list_tasks(self): # this a bit? getattr(self, "list_{}".format(self.list_format))() - def list_flat(self): + def list_flat(self) -> None: pairs = self._make_pairs(self.scoped_collection) self.display_with_columns(pairs=pairs) - def list_nested(self): + def list_nested(self) -> None: pairs = self._make_pairs(self.scoped_collection) extra = "'*' denotes collection defaults" self.display_with_columns(pairs=pairs, extra=extra) - def _make_pairs(self, coll, ancestors=None): + def _make_pairs( + self, + coll: "Collection", + ancestors: Optional[List[str]] = None, + ) -> Tuple[str, Optional[str]]: if ancestors is None: ancestors = [] pairs = [] @@ -865,7 +886,7 @@ def _make_pairs(self, coll, ancestors=None): pairs.extend(recursed_pairs) return pairs - def list_json(self): + def list_json(self) -> None: # Sanity: we can't cleanly honor the --list-depth argument without # changing the data schema or otherwise acting strangely; and it also # doesn't make a ton of sense to limit depth when the output is for a @@ -881,7 +902,7 @@ def list_json(self): data = coll.serialized() print(json.dumps(data)) - def task_list_opener(self, extra=""): + def task_list_opener(self, extra: str = "") -> str: root = self.list_root depth = self.list_depth specifier = " '{}'".format(root) if root else "" @@ -897,7 +918,9 @@ def task_list_opener(self, extra=""): text = "Subcommands" return text - def display_with_columns(self, pairs, extra=""): + def display_with_columns( + self, pairs: Tuple[str, Optional[str]], extra: str = "" + ) -> None: root = self.list_root print("{}:\n".format(self.task_list_opener(extra=extra))) self.print_columns(pairs) @@ -912,7 +935,7 @@ def display_with_columns(self, pairs, extra=""): # TODO: trim/prefix dots print("Default{} task: {}\n".format(specific, default)) - def print_columns(self, tuples): + def print_columns(self, tuples: Tuple[str, Optional[str]]) -> None: """ Print tabbed columns from (name, help) ``tuples``. diff --git a/invoke/runners.py b/invoke/runners.py index 9f733b80b..39d015627 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -8,7 +8,17 @@ import signal from subprocess import Popen, PIPE from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) # Import some platform-specific things at top level so they can be mocked for # tests. @@ -43,6 +53,7 @@ from .util import has_fileno, isatty, ExceptionHandlingThread if TYPE_CHECKING: + from io import StringIO, TextIOWrapper from .context import Context @@ -57,6 +68,8 @@ class Runner: .. versionadded:: 1.0 """ + opts: Dict[str, Any] + using_pty: bool read_chunk_size = 1000 input_sleep = 0.01 @@ -108,7 +121,7 @@ def __init__(self, context: "Context") -> None: self._asynchronous = False self._disowned = False - def run(self, command: str, **kwargs: Any) -> Any: + def run(self, command: str, **kwargs: Any) -> Optional["Result"]: """ Execute ``command``, returning an instance of `Result` once complete. @@ -414,7 +427,7 @@ def _setup(self, command: str, kwargs: Any) -> None: encoding=self.encoding, ) - def _run_body(self, command: str, **kwargs: Any) -> Any: + def _run_body(self, command: str, **kwargs: Any) -> Optional["Result"]: # Prepare all the bits n bobs. self._setup(command, kwargs) # If dry-run, stop here. @@ -427,7 +440,7 @@ def _run_body(self, command: str, **kwargs: Any) -> Any: # If disowned, we just stop here - no threads, no timer, no error # checking, nada. if self._disowned: - return + return None # Stand up & kick off IO, timer threads self.start_timer(self.opts["timeout"]) self.threads, self.stdout, self.stderr = self.create_io_threads() @@ -564,7 +577,7 @@ def _unify_kwargs_with_config(self, kwargs: Any) -> None: self.opts = opts self.streams = {"out": out_stream, "err": err_stream, "in": in_stream} - def _collate_result(self, watcher_errors) -> Any: + def _collate_result(self, watcher_errors: List[WatcherError]) -> "Result": # At this point, we had enough success that we want to be returning or # raising detailed info about our execution; so we generate a Result. stdout = "".join(self.stdout) @@ -592,7 +605,7 @@ def _collate_result(self, watcher_errors) -> Any: ) return result - def _thread_join_timeout(self, target) -> Optional[int]: + def _thread_join_timeout(self, target: Callable) -> Optional[int]: # Add a timeout to out/err thread joins when it looks like they're not # dead but their counterpart is dead; this indicates issue #351 (fixed # by #432) where the subproc may hang because its stdout (or stderr) is @@ -661,7 +674,7 @@ def generate_result(self, **kwargs: Any) -> "Result": """ return Result(**kwargs) - def read_proc_output(self, reader): + def read_proc_output(self, reader: Callable) -> Generator[str, None, None]: """ Iteratively read & decode bytes from a subprocess' out/err stream. @@ -694,7 +707,7 @@ def read_proc_output(self, reader): break yield self.decode(data) - def write_our_output(self, stream, string: str) -> None: + def write_our_output(self, stream: "TextIOWrapper", string: str) -> None: """ Write ``string`` to ``stream``. @@ -714,7 +727,13 @@ def write_our_output(self, stream, string: str) -> None: stream.write(string) stream.flush() - def _handle_output(self, buffer_, hide, output, reader): + def _handle_output( + self, + buffer_: List["StringIO"], + hide: bool, + output: "TextIOWrapper", + reader: Callable, + ) -> None: # TODO: store un-decoded/raw bytes somewhere as well... for data in self.read_proc_output(reader): # Echo to local stdout if necessary @@ -732,7 +751,9 @@ def _handle_output(self, buffer_, hide, output, reader): # Run our specific buffer through the autoresponder framework self.respond(buffer_) - def handle_stdout(self, buffer_, hide, output): + def handle_stdout( + self, buffer_: List["StringIO"], hide: bool, output: "TextIOWrapper" + ) -> None: """ Read process' stdout, storing into a buffer & printing/parsing. @@ -753,7 +774,9 @@ def handle_stdout(self, buffer_, hide, output): buffer_, hide, output, reader=self.read_proc_stdout ) - def handle_stderr(self, buffer_, hide, output) -> None: + def handle_stderr( + self, buffer_: List["StringIO"], hide: bool, output: "TextIOWrapper" + ) -> None: """ Read process' stderr, storing into a buffer & printing/parsing. @@ -766,7 +789,7 @@ def handle_stderr(self, buffer_, hide, output) -> None: buffer_, hide, output, reader=self.read_proc_stderr ) - def read_our_stdin(self, input_): + def read_our_stdin(self, input_: "TextIOWrapper") -> Optional[str]: """ Read & decode bytes from a local stdin stream. @@ -806,7 +829,12 @@ def read_our_stdin(self, input_): bytes_ = self.decode(bytes_) return bytes_ - def handle_stdin(self, input_, output, echo) -> None: + def handle_stdin( + self, + input_: "TextIOWrapper", + output: "TextIOWrapper", + echo: bool = False, + ) -> None: """ Read local stdin, copying into process' stdin as necessary. @@ -867,7 +895,9 @@ def handle_stdin(self, input_, output, echo) -> None: # Take a nap so we're not chewing CPU. time.sleep(self.input_sleep) - def should_echo_stdin(self, input_, output): + def should_echo_stdin( + self, input_: "StringIO", output: "TextIOWrapper" + ) -> bool: """ Determine whether data read from ``input_`` should echo to ``output``. @@ -881,7 +911,7 @@ def should_echo_stdin(self, input_, output): """ return (not self.using_pty) and isatty(input_) - def respond(self, buffer_) -> None: + def respond(self, buffer_: List[str]) -> None: """ Write to the program's stdin in response to patterns in ``buffer_``. @@ -1068,7 +1098,7 @@ def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str]: """ raise NotImplementedError - def _write_proc_stdin(self, data: str) -> None: + def _write_proc_stdin(self, data: bytes) -> None: """ Write ``data`` to running process' stdin. @@ -1105,7 +1135,7 @@ def default_encoding(self) -> str: # subprocess. For now, good enough to assume both are the same. return default_encoding() - def send_interrupt(self, interrupt) -> None: + def send_interrupt(self, interrupt: "KeyboardInterrupt") -> None: """ Submit an interrupt signal to the running subprocess. @@ -1237,7 +1267,7 @@ def read_proc_stderr(self, num_bytes: int): # TODO: do we ever get those OSErrors on stderr? Feels like we could? return os.read(self.process.stderr.fileno(), num_bytes) - def _write_proc_stdin(self, data) -> int: + def _write_proc_stdin(self, data: bytes) -> int: # NOTE: parent_fd from os.fork() is a read/write pipe attached to our # forked process' stdout/stdin, respectively. fd = self.parent_fd if self.using_pty else self.process.stdin.fileno() @@ -1323,7 +1353,7 @@ def returncode(self) -> int: # return whichever one of them is nondefault"? Probably not? # NOTE: doing this in an arbitrary order should be safe since only # one of the WIF* methods ought to ever return True. - code = None + code = 0 if os.WIFEXITED(self.status): code = os.WEXITSTATUS(self.status) elif os.WIFSIGNALED(self.status): @@ -1449,7 +1479,7 @@ def return_code(self) -> Any: """ return self.exited - def __bool__(self): + def __bool__(self) -> bool: return self.ok def __str__(self) -> str: From 82b2649b4e41f7114231e80a063edfcd9c95cb0d Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sat, 28 Jan 2023 08:42:52 -0500 Subject: [PATCH 06/27] test: Update dev-requirements.txt Co-authored-by: Sam Bull --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 3d3fb0555..3c8b36ea8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,7 +20,7 @@ setuptools>56 # Debuggery icecream>=2.1 # typing -mypy>=0.942 +mypy==0.991 mypy-extensions>=0.4.3 typed-ast>=1.4.3 types-mock>=0.1.3 From 72b337a695bbba256c56c33236b3303a08d41be0 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 09:40:16 -0500 Subject: [PATCH 07/27] test: implement additional typing --- .flake8 | 1 - dev-requirements.txt | 2 - invoke/collection.py | 125 ++++++++++++++++++++-------------- invoke/completion/complete.py | 10 +-- invoke/config.py | 10 +-- invoke/context.py | 28 ++++---- invoke/exceptions.py | 7 +- invoke/executor.py | 6 +- invoke/loader.py | 3 +- invoke/parser/__init__.py | 2 +- invoke/parser/argument.py | 57 +++++++++------- invoke/parser/context.py | 46 +++++++------ invoke/parser/parser.py | 2 +- invoke/program.py | 10 ++- invoke/runners.py | 8 ++- invoke/tasks.py | 120 ++++++++++++++++++++------------ invoke/terminals.py | 41 +++++++---- invoke/util.py | 26 ++++--- invoke/watchers.py | 17 +++-- mypy.ini | 4 -- 20 files changed, 317 insertions(+), 208 deletions(-) delete mode 100644 mypy.ini diff --git a/.flake8 b/.flake8 index 0fc73851b..fc6169a21 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,3 @@ exclude = invoke/vendor,sites,.git,build,dist,alt_env,appveyor ignore = E124,E125,E128,E261,E301,E302,E303,E306,W503,E731 max-line-length = 79 - diff --git a/dev-requirements.txt b/dev-requirements.txt index 3c8b36ea8..0d7c15288 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -21,7 +21,5 @@ setuptools>56 icecream>=2.1 # typing mypy==0.991 -mypy-extensions>=0.4.3 typed-ast>=1.4.3 -types-mock>=0.1.3 types-PyYAML>=5.4.3 diff --git a/invoke/collection.py b/invoke/collection.py index 5866ef614..835bc7e22 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -1,5 +1,6 @@ import copy -import types +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, Tuple from .util import Lexicon, helpline @@ -15,7 +16,7 @@ class Collection: .. versionadded:: 1.0 """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Create a new task collection/namespace. @@ -92,9 +93,9 @@ def __init__(self, *args, **kwargs): # Initialize self.tasks = Lexicon() self.collections = Lexicon() - self.default = None + self.default: Optional[str] = None self.name = None - self._configuration = {} + self._configuration: Dict[str, Any] = {} # Specific kwargs if applicable self.loaded_from = kwargs.pop("loaded_from", None) self.auto_dash_names = kwargs.pop("auto_dash_names", None) @@ -102,57 +103,62 @@ def __init__(self, *args, **kwargs): if self.auto_dash_names is None: self.auto_dash_names = True # Name if applicable - args = list(args) - if args and isinstance(args[0], str): - self.name = self.transform(args.pop(0)) + _args = list(args) + if _args and isinstance(args[0], str): + self.name = self.transform(_args.pop(0)) # Dispatch args/kwargs - for arg in args: + for arg in _args: self._add_object(arg) # Dispatch kwargs for name, obj in kwargs.items(): self._add_object(obj, name) - def _add_object(self, obj, name=None): + def _add_object( + self, obj: Any, name: Optional[str] = None + ) -> Callable[..., Any]: + method: Callable[..., Any] if isinstance(obj, Task): method = self.add_task - elif isinstance(obj, (Collection, types.ModuleType)): + elif isinstance(obj, (Collection, ModuleType)): method = self.add_collection else: raise TypeError("No idea how to insert {!r}!".format(type(obj))) return method(obj, name=name) - def __repr__(self): + def __repr__(self) -> str: task_names = list(self.tasks.keys()) collections = ["{}...".format(x) for x in self.collections.keys()] return "".format( self.name, ", ".join(sorted(task_names) + sorted(collections)) ) - def __eq__(self, other): - return ( - self.name == other.name - and self.tasks == other.tasks - and self.collections == other.collections - ) + def __eq__(self, other: object) -> bool: + if isinstance(other, Collection): + return ( + self.name == other.name + and self.tasks == other.tasks + and self.collections == other.collections + ) + return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.__bool__() - def __bool__(self): + def __bool__(self) -> bool: return bool(self.task_names) @classmethod def from_module( cls, - module, - name=None, - config=None, - loaded_from=None, - auto_dash_names=None, - ): + module: ModuleType, + name: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + loaded_from: Optional[str] = None, + auto_dash_names: Optional[bool] = None, + ) -> "Collection": """ Return a new `.Collection` created from ``module``. @@ -198,7 +204,7 @@ def from_module( """ module_name = module.__name__.split(".")[-1] - def instantiate(obj_name=None): + def instantiate(obj_name: Optional[str] = None) -> "Collection": # Explicitly given name wins over root ns name (if applicable), # which wins over actual module name. args = [name or obj_name or module_name] @@ -218,7 +224,9 @@ def instantiate(obj_name=None): ret = instantiate(obj_name=obj.name) ret.tasks = ret._transform_lexicon(obj.tasks) ret.collections = ret._transform_lexicon(obj.collections) - ret.default = ret.transform(obj.default) + ret.default = ( + ret.transform(obj.default) if obj.default else None + ) # Explicitly given config wins over root ns config obj_config = copy_dict(obj._configuration) if config: @@ -235,7 +243,13 @@ def instantiate(obj_name=None): collection.configure(config) return collection - def add_task(self, task, name=None, aliases=None, default=None): + def add_task( + self, + task: "Task", + name: Optional[str] = None, + aliases: Optional[Tuple[str, ...]] = None, + default: Optional[bool] = None + ) -> None: """ Add `.Task` ``task`` to this collection. @@ -275,7 +289,12 @@ def add_task(self, task, name=None, aliases=None, default=None): self._check_default_collision(name) self.default = name - def add_collection(self, coll, name=None, default=None): + def add_collection( + self, + coll: "Collection", + name: Optional[str] = None, + default: Optional[bool] = None + ) -> None: """ Add `.Collection` ``coll`` as a sub-collection of this one. @@ -294,7 +313,7 @@ def add_collection(self, coll, name=None, default=None): Added the ``default`` parameter. """ # Handle module-as-collection - if isinstance(coll, types.ModuleType): + if isinstance(coll, ModuleType): coll = Collection.from_module(coll) # Ensure we have a name, or die trying name = name or coll.name @@ -311,12 +330,12 @@ def add_collection(self, coll, name=None, default=None): self._check_default_collision(name) self.default = name - def _check_default_collision(self, name): + def _check_default_collision(self, name: str) -> None: if self.default: msg = "'{}' cannot be the default because '{}' already is!" raise ValueError(msg.format(name, self.default)) - def _split_path(self, path): + def _split_path(self, path: str) -> Tuple[str, str]: """ Obtain first collection + remainder, of a task path. @@ -331,7 +350,7 @@ def _split_path(self, path): rest = ".".join(parts) return coll, rest - def subcollection_from_path(self, path): + def subcollection_from_path(self, path: str) -> "Collection": """ Given a ``path`` to a subcollection, return that subcollection. @@ -343,7 +362,7 @@ def subcollection_from_path(self, path): collection = collection.collections[parts.pop(0)] return collection - def __getitem__(self, name=None): + def __getitem__(self, name: Optional[str] = None) -> Any: """ Returns task named ``name``. Honors aliases and subcollections. @@ -359,11 +378,15 @@ def __getitem__(self, name=None): """ return self.task_with_config(name)[0] - def _task_with_merged_config(self, coll, rest, ours): + def _task_with_merged_config( + self, coll: str, rest: str, ours: Dict[str, Any] + ) -> Tuple[str, Dict[str, Any]]: task, config = self.collections[coll].task_with_config(rest) return task, dict(config, **ours) - def task_with_config(self, name): + def task_with_config( + self, name: Optional[str] + ) -> Tuple[str, Dict[str, Any]]: """ Return task named ``name`` plus its configuration dict. @@ -397,14 +420,16 @@ def task_with_config(self, name): # Regular task lookup return self.tasks[name], ours - def __contains__(self, name): + def __contains__(self, name: str) -> bool: try: self[name] return True except KeyError: return False - def to_contexts(self, ignore_unknown_help=None): + def to_contexts( + self, ignore_unknown_help: Optional[bool] = None + ) -> List[ParserContext]: """ Returns all contained tasks and subtasks as a list of parser contexts. @@ -430,12 +455,12 @@ def to_contexts(self, ignore_unknown_help=None): ) return result - def subtask_name(self, collection_name, task_name): + def subtask_name(self, collection_name: str, task_name: str) -> str: return ".".join( [self.transform(collection_name), self.transform(task_name)] ) - def transform(self, name): + def transform(self, name: str) -> str: """ Transform ``name`` with the configured auto-dashes behavior. @@ -474,25 +499,25 @@ def transform(self, name): replaced.append(char) return "".join(replaced) - def _transform_lexicon(self, old): + def _transform_lexicon(self, old: Lexicon) -> Lexicon: """ Take a Lexicon and apply `.transform` to its keys and aliases. :returns: A new Lexicon. """ - new_ = Lexicon() + new = Lexicon() # Lexicons exhibit only their real keys in most places, so this will # only grab those, not aliases. for key, value in old.items(): # Deepcopy the value so we're not just copying a reference - new_[self.transform(key)] = copy.deepcopy(value) + new[self.transform(key)] = copy.deepcopy(value) # Also copy all aliases, which are string-to-string key mappings for key, value in old.aliases.items(): - new_.alias(from_=self.transform(key), to=self.transform(value)) - return new_ + new.alias(from_=self.transform(key), to=self.transform(value)) + return new @property - def task_names(self): + def task_names(self) -> Dict[str, Any]: """ Return all task identifiers for this collection as a one-level dict. @@ -523,7 +548,7 @@ def task_names(self): ret[self.subtask_name(coll_name, task_name)] = aliases return ret - def configuration(self, taskpath=None): + def configuration(self, taskpath: Optional[str] = None) -> Dict[str, Any]: """ Obtain merged configuration values from collection & children. @@ -541,7 +566,7 @@ def configuration(self, taskpath=None): return copy_dict(self._configuration) return self.task_with_config(taskpath)[1] - def configure(self, options): + def configure(self, options: Dict[str, Any]) -> None: """ (Recursively) merge ``options`` into the current `.configuration`. @@ -560,7 +585,7 @@ def configure(self, options): """ merge_dicts(self._configuration, options) - def serialized(self): + def serialized(self) -> Dict[str, Any]: """ Return an appropriate-for-serialization version of this object. diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index 3fc6777fa..d65901f57 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -14,16 +14,16 @@ if TYPE_CHECKING: from ..collection import Collection - from ..parser import Parser, ParseResult, Context + from ..parser import Parser, ParseResult, ParserContext def complete( names: List[str], core: "ParseResult", - initial_context: "Context", + initial_context: "ParserContext", collection: "Collection", parser: "Parser", -): +) -> Exit: # Strip out program name (scripts give us full command line) # TODO: this may not handle path/to/script though? invocation = re.sub(r"^({}) ".format("|".join(names)), "", core.remainder) @@ -37,13 +37,15 @@ def complete( # Gently parse invocation to obtain 'current' context. # Use last seen context in case of failure (required for # otherwise-invalid partial invocations being completed). + + # contexts: List[ParserContext, ParseResult] try: debug("Seeking context name in tokens: {!r}".format(tokens)) contexts = parser.parse_argv(tokens) except ParseError as e: msg = "Got parser error ({!r}), grabbing its last-seen context {!r}" # noqa debug(msg.format(e, e.context)) - contexts = [e.context] + contexts = [e.context] if e.context is not None else [] # Fall back to core context if no context seen. debug("Parsed invocation, contexts: {!r}".format(contexts)) if not contexts or not contexts[-1]: diff --git a/invoke/config.py b/invoke/config.py index 449a228c7..5301cdfb9 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -71,9 +71,9 @@ class DataProxy: def from_data( cls, data: Dict[str, Any], - root: Optional['DataProxy'] = None, + root: Optional["DataProxy"] = None, keypath: Tuple[str, ...] = tuple(), - ): + ) -> "DataProxy": """ Alternate constructor for 'baby' DataProxies used as sub-dict values. @@ -122,7 +122,7 @@ def __getattr__(self, key: str) -> Any: err += "\n\nValid real attributes: {!r}".format(attrs) raise AttributeError(err) - def __setattr__(self, key: str, value: str) -> None: + def __setattr__(self, key: str, value: Any) -> None: # Turn attribute-sets into config updates anytime we don't have a real # attribute with the given name/key. has_real_attr = key in dir(self) @@ -158,7 +158,7 @@ def __setitem__(self, key: str, value: str) -> None: self._config[key] = value self._track_modification_of(key, value) - def __getitem__(self, key: str): + def __getitem__(self, key: str) -> Any: return self._get(key) def _get(self, key: str) -> Any: @@ -212,7 +212,7 @@ def _is_leaf(self) -> bool: def _is_root(self) -> bool: return hasattr(self, "_modify") - def _track_removal_of(self, key: str): + def _track_removal_of(self, key: str) -> None: # Grab the root object responsible for tracking removals; either the # referenced root (if we're a leaf) or ourselves (if we're not). # (Intermediate nodes never have anything but __getitem__ called on diff --git a/invoke/context.py b/invoke/context.py index b05ebe3c1..69bdcbae6 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -10,7 +10,6 @@ Iterator, List, Optional, - Type, Union, ) from unittest.mock import Mock @@ -88,7 +87,7 @@ def config(self, value: Any) -> None: # runtime. self._set(_config=value) - def run(self, command: str, **kwargs: Any) -> Result: + def run(self, command: str, **kwargs: Any) -> Optional[Result]: """ Execute a local shell command, honoring config options. @@ -108,12 +107,12 @@ def run(self, command: str, **kwargs: Any) -> Result: # Fabric/etc, which needs to juggle multiple runner class types (local and # remote). def _run( - self, runner: Type["Runner"], command: str, **kwargs: Any - ) -> Result: + self, runner: "Runner", command: str, **kwargs: Any + ) -> Optional[Result]: command = self._prefix_commands(command) return runner.run(command, **kwargs) - def sudo(self, command: str, **kwargs: Any) -> Result: + def sudo(self, command: str, **kwargs: Any) -> Optional[Result]: """ Execute a shell command via ``sudo`` with password auto-response. @@ -187,8 +186,8 @@ def sudo(self, command: str, **kwargs: Any) -> Result: # NOTE: this is for runner injection; see NOTE above _run(). def _sudo( - self, runner: Type["Runner"], command: str, **kwargs: Any - ) -> Result: + self, runner: "Runner", command: str, **kwargs: Any + ) -> Optional[Result]: prompt = self.config.sudo.prompt password = kwargs.pop("password", self.config.sudo.password) user = kwargs.pop("user", self.config.sudo.user) @@ -215,8 +214,9 @@ def _sudo( cmd_str = "sudo -S -p '{}' {}{}{}".format( prompt, env_flags, user_flags, command ) + # FIXME pattern should be raw string prompt.encode('unicode_escape') watcher = FailingResponder( - pattern=re.escape(prompt), + pattern=re.escape(prompt), # type: ignore response="{}\n".format(password), sentinel="Sorry, try again.\n", ) @@ -321,7 +321,7 @@ def prefix(self, command: str) -> Generator[None, None, None]: self.command_prefixes.pop() @property - def cwd(self) -> Union[PathLike, str]: + def cwd(self) -> str: """ Return the current working directory, accounting for uses of `cd`. @@ -341,10 +341,10 @@ def cwd(self) -> Union[PathLike, str]: # TODO: see if there's a stronger "escape this path" function somewhere # we can reuse. e.g., escaping tildes or slashes in filenames. paths = [path.replace(" ", r"\ ") for path in self.command_cwds[i:]] - return os.path.join(*paths) + return str(os.path.join(*paths)) @contextmanager - def cd(self, path: PathLike) -> Generator[None, None, None]: + def cd(self, path: Union[PathLike, str]) -> Generator[None, None, None]: """ Context manager that keeps directory state when executing commands. @@ -539,7 +539,7 @@ def _yield_result(self, attname: str, command: str) -> Result: # Here, the value was either never a dict or has been extracted # from one, so we can assume it's an iterable of Result objects due # to work done by __init__. - result = next(obj) + result: Result = next(obj) # Populate Result's command string with what matched unless # explicitly given if not result.command: @@ -563,7 +563,9 @@ def sudo(self, command: str, *args: Any, **kwargs: Any) -> Result: # TODO: see the TODO from run() re: injecting arg/kwarg values return self._yield_result("__sudo", command) - def set_result_for(self, attname: str, command: str, result) -> None: + def set_result_for( + self, attname: str, command: str, result: Result + ) -> None: """ Modify the stored mock results for given ``attname`` (e.g. ``run``). diff --git a/invoke/exceptions.py b/invoke/exceptions.py index cee6dbeae..e1683c5c8 100644 --- a/invoke/exceptions.py +++ b/invoke/exceptions.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from .parser import ParserContext from .runners import Result + from .util import ExceptionWrapper class CollectionNotFound(Exception): @@ -298,7 +299,7 @@ class UnpicklableConfigMember(Exception): pass -def _printable_kwargs(kwargs: Any) -> Dict[Any, Any]: +def _printable_kwargs(kwargs: Any) -> Dict[str, Any]: """ Return print-friendly version of a thread-related ``kwargs`` dict. @@ -346,9 +347,9 @@ class ThreadException(Exception): #: Thread kwargs which appear to be very long (e.g. IO #: buffers) will be truncated when printed, to avoid huge #: unreadable error display. - exceptions: Tuple[Exception, ...] = tuple() + exceptions: Tuple["ExceptionWrapper", ...] = tuple() - def __init__(self, exceptions: List[Exception]) -> None: + def __init__(self, exceptions: List["ExceptionWrapper"]) -> None: self.exceptions = tuple(exceptions) def __str__(self) -> str: diff --git a/invoke/executor.py b/invoke/executor.py index 80853ad14..db7cc6f24 100644 --- a/invoke/executor.py +++ b/invoke/executor.py @@ -120,7 +120,6 @@ def execute( # moment... for call in calls: autoprint = call in direct and call.autoprint - args = call.args debug("Executing {!r}".format(call)) # Hand in reference to our config, which will preserve user # modifications across the lifetime of the session. @@ -137,7 +136,7 @@ def execute( # an appropriate one; e.g. subclasses might use extra data from # being parameterized), handing in this config for use there. context = call.make_context(config) - args = (context,) + args + args = (context,) + call.args result = call.task(*args, **call.kwargs) if autoprint: print(result) @@ -165,7 +164,8 @@ def normalize( name = task kwargs = {} elif isinstance(task, ParserContext): - name = task.name + # FIXME: task.name can be none here + name = task.name # type: ignore kwargs = task.as_kwargs else: name, kwargs = task diff --git a/invoke/loader.py b/invoke/loader.py index 356e0fe28..db6d582a6 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -74,8 +74,9 @@ def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]: parent = os.path.dirname(path) if parent not in sys.path: sys.path.insert(0, parent) + # FIXME: deprecated capability that needs replacement # Actual import - module = imp.load_module(name, fd, path, desc) + module = imp.load_module(name, fd, path, desc) # type: ignore # Return module + path. # TODO: is there a reason we're not simply having clients refer to # os.path.dirname(module.__file__)? diff --git a/invoke/parser/__init__.py b/invoke/parser/__init__.py index b4620877e..02aa02622 100644 --- a/invoke/parser/__init__.py +++ b/invoke/parser/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa -from .parser import * # type: ignore +from .parser import * from .context import ParserContext from .context import ParserContext as Context, to_flag, translate_underscores from .argument import Argument diff --git a/invoke/parser/argument.py b/invoke/parser/argument.py index ceb199453..43603dce9 100644 --- a/invoke/parser/argument.py +++ b/invoke/parser/argument.py @@ -1,3 +1,9 @@ +from typing import Any, Iterable, Optional, Tuple + +# TODO: dynamic map kind +# T = TypeVar('T') + + class Argument: """ A command-line argument/flag. @@ -35,24 +41,27 @@ class Argument: def __init__( self, - name=None, - names=(), - kind=str, - default=None, - help=None, - positional=False, - optional=False, - incrementable=False, - attr_name=None, - ): + name: Optional[str] = None, + names: Iterable[str] = (), + kind: Any = str, + default: Optional[Any] = None, + help: Optional[str] = None, + positional: bool = False, + optional: bool = False, + incrementable: bool = False, + attr_name: Optional[str] = None, + ) -> None: if name and names: msg = "Cannot give both 'name' and 'names' arguments! Pick one." raise TypeError(msg) if not (name or names): raise TypeError("An Argument must have at least one name.") - self.names = tuple(names if names else (name,)) + if names: + self.names = tuple(names) + elif name and not names: + self.names = (name,) self.kind = kind - initial_value = None + initial_value: Optional[Any] = None # Special case: list-type args start out as empty list, not None. if kind is list: initial_value = [] @@ -67,7 +76,7 @@ def __init__( self.incrementable = incrementable self.attr_name = attr_name - def __repr__(self): + def __repr__(self) -> str: nicks = "" if self.nicknames: nicks = " ({})".format(", ".join(self.nicknames)) @@ -88,7 +97,7 @@ def __repr__(self): ) @property - def name(self): + def name(self) -> Optional[str]: """ The canonical attribute-friendly name for this argument. @@ -100,11 +109,11 @@ def name(self): return self.attr_name or self.names[0] @property - def nicknames(self): + def nicknames(self) -> Tuple[str, ...]: return self.names[1:] @property - def takes_value(self): + def takes_value(self) -> bool: if self.kind is bool: return False if self.incrementable: @@ -112,14 +121,15 @@ def takes_value(self): return True @property - def value(self): + def value(self) -> Any: + # TODO: should probably be optional instead return self._value if self._value is not None else self.default @value.setter - def value(self, arg): + def value(self, arg: str) -> None: self.set_value(arg, cast=True) - def set_value(self, value, cast=True): + def set_value(self, value: Any, cast: bool = True) -> None: """ Actual explicit value-setting API call. @@ -143,15 +153,16 @@ def set_value(self, value, cast=True): func = self.kind # If self.kind is a list, append instead of using cast func. if self.kind is list: - func = lambda x: self._value + [x] + func = lambda x: self.value + [x] # If incrementable, just increment. if self.incrementable: - # TODO: explode nicely if self._value was not an int to start with - func = lambda x: self._value + 1 + # TODO: explode nicely if self.value was not an int to start + # with + func = lambda x: self.value + 1 self._value = func(value) @property - def got_value(self): + def got_value(self) -> bool: """ Returns whether the argument was ever given a (non-default) value. diff --git a/invoke/parser/context.py b/invoke/parser/context.py index a15f48d59..583b05b3f 100644 --- a/invoke/parser/context.py +++ b/invoke/parser/context.py @@ -1,41 +1,42 @@ import itertools +from typing import Any, Dict, List, Iterable, Optional, Tuple, Union try: from ..vendor.lexicon import Lexicon except ImportError: - from lexicon import Lexicon + from lexicon import Lexicon # type: ignore from .argument import Argument -def translate_underscores(name): +def translate_underscores(name: str) -> str: return name.lstrip("_").rstrip("_").replace("_", "-") -def to_flag(name): +def to_flag(name: str) -> str: name = translate_underscores(name) if len(name) == 1: return "-" + name return "--" + name -def sort_candidate(arg): +def sort_candidate(arg: Argument) -> str: names = arg.names # TODO: is there no "split into two buckets on predicate" builtin? shorts = {x for x in names if len(x.strip("-")) == 1} longs = {x for x in names if x not in shorts} - return sorted(shorts if shorts else longs)[0] + return str(sorted(shorts if shorts else longs)[0]) -def flag_key(x): +def flag_key(arg: Argument) -> List[Union[int, str]]: """ Obtain useful key list-of-ints for sorting CLI flags. .. versionadded:: 1.0 """ # Setup - ret = [] - x = sort_candidate(x) + ret: List[Union[int, str]] = [] + x = sort_candidate(arg) # Long-style flags win over short-style ones, so the first item of # comparison is simply whether the flag is a single character long (with # non-length-1 flags coming "first" [lower number]) @@ -67,7 +68,12 @@ class ParserContext: .. versionadded:: 1.0 """ - def __init__(self, name=None, aliases=(), args=()): + def __init__( + self, + name: Optional[str] = None, + aliases: Iterable[str] = (), + args: Iterable[Argument] = (), + ) -> None: """ Create a new ``ParserContext`` named ``name``, with ``aliases``. @@ -83,15 +89,15 @@ def __init__(self, name=None, aliases=(), args=()): ``for arg in args: self.add_arg(arg)`` after initialization. """ self.args = Lexicon() - self.positional_args = [] + self.positional_args: List[Argument] = [] self.flags = Lexicon() - self.inverse_flags = {} # No need for Lexicon here + self.inverse_flags: Dict[str, str] = {} # No need for Lexicone self.name = name self.aliases = aliases for arg in args: self.add_arg(arg) - def __repr__(self): + def __repr__(self) -> str: aliases = "" if self.aliases: aliases = " ({})".format(", ".join(self.aliases)) @@ -99,7 +105,7 @@ def __repr__(self): args = (": {!r}".format(self.args)) if self.args else "" return "".format(name, args) - def add_arg(self, *args, **kwargs): + def add_arg(self, *args: Any, **kwargs: Any) -> None: """ Adds given ``Argument`` (or constructor args for one) to this context. @@ -149,11 +155,11 @@ def add_arg(self, *args, **kwargs): self.inverse_flags[inverse_name] = to_flag(main) @property - def missing_positional_args(self): + def missing_positional_args(self) -> List[Argument]: return [x for x in self.positional_args if x.value is None] @property - def as_kwargs(self): + def as_kwargs(self) -> Dict[str, Any]: """ This context's arguments' values keyed by their ``.name`` attribute. @@ -167,11 +173,11 @@ def as_kwargs(self): ret[arg.name] = arg.value return ret - def names_for(self, flag): + def names_for(self, flag: str) -> List[str]: # TODO: should probably be a method on Lexicon/AliasDict return list(set([flag] + self.flags.aliases_of(flag))) - def help_for(self, flag): + def help_for(self, flag: str) -> Tuple[str, str]: """ Return 2-tuple of ``(flag-spec, help-string)`` for given ``flag``. @@ -210,7 +216,7 @@ def help_for(self, flag): helpstr = arg.help or "" return namestr, helpstr - def help_tuples(self): + def help_tuples(self) -> List[Tuple[str, Optional[str]]]: """ Return sorted iterable of help tuples for all member Arguments. @@ -244,7 +250,7 @@ def help_tuples(self): ) ) - def flag_names(self): + def flag_names(self) -> Tuple[str, ...]: """ Similar to `help_tuples` but returns flag names only, no helpstrs. @@ -256,5 +262,5 @@ def flag_names(self): flags = sorted(self.flags.values(), key=flag_key) names = [self.names_for(to_flag(x.name)) for x in flags] # Inverse flag names sold separately - names.append(self.inverse_flags.keys()) + names.append(list(self.inverse_flags.keys())) return tuple(itertools.chain.from_iterable(names)) diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index 85a1641e0..f4fb343ba 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -108,7 +108,7 @@ def parse_argv(self, argv: List[str]) -> ParseResult: .. versionadded:: 1.0 """ machine = ParseMachine( - initial=self.initial, + initial=self.initial, # type: ignore # FIXME: should not be none contexts=self.contexts, ignore_unknown=self.ignore_unknown, ) diff --git a/invoke/program.py b/invoke/program.py index 564f8deae..8e9084dd6 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -5,7 +5,7 @@ import sys import textwrap from importlib import import_module # buffalo buffalo -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from . import Collection, Config, Executor, FilesystemLoader from .completion.complete import complete, print_completion_script @@ -343,9 +343,7 @@ def update_config(self, merge: bool = True) -> None: if merge: self.config.merge() - def run( - self, argv: Optional[Union[List[str], str]] = None, exit: bool = True - ) -> None: + def run(self, argv: Optional[List[str]] = None, exit: bool = True) -> None: """ Execute main CLI logic, based on ``argv``. @@ -619,7 +617,7 @@ def called_as(self) -> str: .. versionadded:: 1.2 """ - return os.path.basename(self.argv[0]) + return os.path.basename(self.argv[0]) if self.argv else 'invoke' @property def binary(self) -> str: @@ -935,7 +933,7 @@ def display_with_columns( # TODO: trim/prefix dots print("Default{} task: {}\n".format(specific, default)) - def print_columns(self, tuples: Tuple[str, Optional[str]]) -> None: + def print_columns(self, tuples: List[Tuple[str, Optional[str]]]) -> None: """ Print tabbed columns from (name, help) ``tuples``. diff --git a/invoke/runners.py b/invoke/runners.py index 39d015627..dc0220130 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -55,6 +55,7 @@ if TYPE_CHECKING: from io import StringIO, TextIOWrapper from .context import Context + from .watchers import StreamWatcher class Runner: @@ -113,7 +114,7 @@ def __init__(self, context: "Context") -> None: self.warned_about_pty_fallback = False #: A list of `.StreamWatcher` instances for use by `respond`. Is filled #: in at runtime by `run`. - self.watchers: List[str] = [] + self.watchers: List["StreamWatcher"] = [] # Optional timeout timer placeholder self._timer = None # Async flags (initialized for 'finally' referencing in case something @@ -630,9 +631,10 @@ def create_io_threads( Caller is expected to handle persisting and/or starting the wrapped threads. """ - stdout, stderr = [], [] + stdout: List[str] = [] + stderr: List[str] = [] # Set up IO thread parameters (format - body_func: {kwargs}) - thread_args = { + thread_args: Dict[Callable[..., Any], Any] = { self.handle_stdout: { "buffer_": stdout, "hide": "stdout" in self.opts["hide"], diff --git a/invoke/tasks.py b/invoke/tasks.py index af8ca0781..5209aa3a5 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -6,10 +6,27 @@ from copy import deepcopy import inspect import types +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Iterable, + Optional, + Set, + Tuple, + Type, + Union, +) from .context import Context from .parser import Argument, translate_underscores +if TYPE_CHECKING: + from inspect import Signature + from .config import Config + class Task: """ @@ -36,20 +53,20 @@ class Task: # except a debug shell whose frame is exactly inside this class. def __init__( self, - body, - name=None, - aliases=(), - positional=None, - optional=(), - default=False, - auto_shortflags=True, - help=None, - pre=None, - post=None, - autoprint=False, - iterable=None, - incrementable=None, - ): + body: Callable, + name: Optional[str] = None, + aliases: Tuple[str, ...] = (), + positional: Optional[Iterable[str]] = None, + optional: Iterable[str] = (), + default: bool = False, + auto_shortflags: bool = True, + help: Optional[Dict[str, Any]] = None, + pre: Optional[Union[List[str], str]] = None, + post: Optional[Union[List[str], str]] = None, + autoprint: bool = False, + iterable: Optional[Iterable[str]] = None, + incrementable: Optional[Iterable[str]] = None, + ) -> None: # Real callable self.body = body # Copy a bunch of special properties from the body for the benefit of @@ -77,16 +94,16 @@ def __init__( self.autoprint = autoprint @property - def name(self): + def name(self) -> str: return self._name or self.__name__ - def __repr__(self): + def __repr__(self) -> str: aliases = "" if self.aliases: aliases = " ({})".format(", ".join(self.aliases)) return "".format(self.name, aliases) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Task) or self.name != other.name: return False # Functions do not define __eq__ but func_code objects apparently do. @@ -100,13 +117,13 @@ def __eq__(self, other): except AttributeError: return False - def __hash__(self): + def __hash__(self) -> int: # Presumes name and body will never be changed. Hrm. # Potentially cleaner to just not use Tasks as hash keys, but let's do # this for now. return hash(self.name) + hash(self.body) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: # Guard against calling tasks with no context. if not isinstance(args[0], Context): err = "Task expected a Context as its first arg, got {} instead!" @@ -117,10 +134,10 @@ def __call__(self, *args, **kwargs): return result @property - def called(self): + def called(self) -> bool: return self.times_called > 0 - def argspec(self, body): + def argspec(self, body: Callable[..., Any]) -> "Signature": """ Returns a modified `inspect.Signature` based on that of ``body``. @@ -136,7 +153,11 @@ def argspec(self, body): returning an `inspect.Signature`. """ # Handle callable-but-not-function objects - func = body if isinstance(body, types.FunctionType) else body.__call__ + func = ( + body + if isinstance(body, types.FunctionType) + else body.__call__ # type: ignore + ) # Rebuild signature with first arg dropped, or die usefully(ish trying sig = inspect.signature(func) params = list(sig.parameters.values()) @@ -147,7 +168,9 @@ def argspec(self, body): raise TypeError("Tasks must have an initial Context argument!") return sig.replace(parameters=params[1:]) - def fill_implicit_positionals(self, positional): + def fill_implicit_positionals( + self, positional: Optional[Iterable[str]] + ) -> Iterable[str]: # If positionals is None, everything lacking a default # value will be automatically considered positional. if positional is None: @@ -158,8 +181,10 @@ def fill_implicit_positionals(self, positional): ] return positional - def arg_opts(self, name, default, taken_names): - opts = {} + def arg_opts( + self, name: str, default: str, taken_names: Set[str] + ) -> Dict[str, Any]: + opts: Dict[str, Any] = {} # Whether it's positional or not opts["positional"] = name in self.positional # Whether it is a value-optional flag @@ -205,7 +230,9 @@ def arg_opts(self, name, default, taken_names): break return opts - def get_arguments(self, ignore_unknown_help=None): + def get_arguments( + self, ignore_unknown_help: Optional[bool] = None + ) -> List[Argument]: """ Return a list of Argument objects representing this task's signature. @@ -225,9 +252,9 @@ def get_arguments(self, ignore_unknown_help=None): # Build arg list (arg_opts will take care of setting up shortnames, # etc) args = [] - for arg in sig.parameters.values(): + for param in sig.parameters.values(): new_arg = Argument( - **self.arg_opts(arg.name, arg.default, taken_names) + **self.arg_opts(param.name, param.default, taken_names) ) args.append(new_arg) # Update taken_names list with new argument's full name list @@ -245,7 +272,7 @@ def get_arguments(self, ignore_unknown_help=None): # Now we need to ensure positionals end up in the front of the list, in # order given in self.positionals, so that when Context consumes them, # this order is preserved. - for posarg in reversed(self.positional): + for posarg in reversed(list(self.positional)): for i, arg in enumerate(args): if arg.name == posarg: args.insert(0, args.pop(i)) @@ -253,7 +280,7 @@ def get_arguments(self, ignore_unknown_help=None): return args -def task(*args, **kwargs): +def task(*args: Any, **kwargs: Any) -> Callable[..., Any]: """ Marks wrapped callable object as a valid Invoke task. @@ -335,8 +362,8 @@ def task(*args, **kwargs): post = kwargs.pop("post", []) autoprint = kwargs.pop("autoprint", False) - def inner(obj): - obj = klass( + def inner(obj: Callable) -> Task: + _obj = klass( obj, name=name, aliases=aliases, @@ -353,7 +380,7 @@ def inner(obj): # Pass in any remaining kwargs as-is. **kwargs ) - return obj + return _obj return inner @@ -369,7 +396,12 @@ class Call: .. versionadded:: 1.0 """ - def __init__(self, task, called_as=None, args=None, kwargs=None): + def __init__( + self, task: "Task", + called_as: Optional[str] = None, + args: Optional[Tuple[str, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None + ) -> None: """ Create a new `.Call` object. @@ -392,13 +424,13 @@ def __init__(self, task, called_as=None, args=None, kwargs=None): self.kwargs = kwargs or dict() # TODO: just how useful is this? feels like maybe overkill magic - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.task, name) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: object) -> "Call": return self.clone() - def __repr__(self): + def __repr__(self) -> str: aka = "" if self.called_as is not None and self.called_as != self.task.name: aka = " (called as: {!r})".format(self.called_as) @@ -410,7 +442,7 @@ def __repr__(self): self.kwargs, ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: # NOTE: Not comparing 'called_as'; a named call of a given Task with # same args/kwargs should be considered same as an unnamed call of the # same Task with the same args/kwargs (e.g. pre/post task specified w/o @@ -420,7 +452,7 @@ def __eq__(self, other): return False return True - def make_context(self, config): + def make_context(self, config: "Config") -> Context: """ Generate a `.Context` appropriate for this call, with given config. @@ -428,7 +460,7 @@ def make_context(self, config): """ return Context(config=config) - def clone_data(self): + def clone_data(self) -> Dict[str, Any]: """ Return keyword args suitable for cloning this call into another. @@ -441,7 +473,11 @@ def clone_data(self): kwargs=deepcopy(self.kwargs), ) - def clone(self, into=None, with_=None): + def clone( + self, + into: Optional[Type["Call"]] = None, + with_: Optional[Dict[str, Any]] = None + ) -> "Call": """ Return a standalone copy of this Call. @@ -471,7 +507,7 @@ def clone(self, into=None, with_=None): return klass(**data) -def call(task, *args, **kwargs): +def call(task: Task, *args: Any, **kwargs: Any) -> "Call": """ Describes execution of a `.Task`, typically with pre-supplied arguments. diff --git a/invoke/terminals.py b/invoke/terminals.py index ae1efc81e..d11a17125 100644 --- a/invoke/terminals.py +++ b/invoke/terminals.py @@ -8,7 +8,7 @@ """ from contextlib import contextmanager -from typing import Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Generator, IO, Optional, Tuple, Union import os import select import sys @@ -16,6 +16,9 @@ # TODO: move in here? They're currently platform-agnostic... from .util import has_fileno, isatty +if TYPE_CHECKING: + from io import BytesIO, StringIO, TextIOWrapper + WINDOWS = sys.platform == "win32" """ @@ -30,7 +33,13 @@ if WINDOWS: import msvcrt - from ctypes import Structure, c_ushort, windll, POINTER, byref + from ctypes import ( # type: ignore + Structure, + c_ushort, + windll, + POINTER, + byref, + ) from ctypes.wintypes import HANDLE, _COORD, _SMALL_RECT else: import fcntl @@ -38,6 +47,8 @@ import termios import tty +StreamTypes = Union["BytesIO", "StringIO", "TextIOWrapper"] + def pty_size() -> Tuple[int, int]: """ @@ -116,7 +127,7 @@ class CONSOLE_SCREEN_BUFFER_INFO(Structure): return (None, None) -def stdin_is_foregrounded_tty(stream) -> bool: +def stdin_is_foregrounded_tty(stream: StreamTypes) -> bool: """ Detect if given stdin ``stream`` seems to be in the foreground of a TTY. @@ -140,7 +151,7 @@ def stdin_is_foregrounded_tty(stream) -> bool: return os.getpgrp() == os.tcgetpgrp(stream.fileno()) -def cbreak_already_set(stream) -> bool: +def cbreak_already_set(stream: StreamTypes) -> bool: # Explicitly not docstringed to remain private, for now. Eh. # Checks whether tty.setcbreak appears to have already been run against # ``stream`` (or if it would otherwise just not do anything). @@ -163,7 +174,10 @@ def cbreak_already_set(stream) -> bool: @contextmanager -def character_buffered(stream): +def character_buffered( + stream: Union[int, IO[str]], + # Union[BytesIO, StringIO, TextIOWrapper], +) -> Generator[None, None, None]: """ Force local terminal ``stream`` be character, not line, buffered. @@ -173,9 +187,9 @@ def character_buffered(stream): """ if ( WINDOWS - or not isatty(stream) - or not stdin_is_foregrounded_tty(stream) - or cbreak_already_set(stream) + or not isatty(stream) # type: ignore + or not stdin_is_foregrounded_tty(stream) # type: ignore + or cbreak_already_set(stream) # type: ignore ): yield else: @@ -187,7 +201,7 @@ def character_buffered(stream): termios.tcsetattr(stream, termios.TCSADRAIN, old_settings) -def ready_for_reading(input_) -> bool: +def ready_for_reading(input_: StreamTypes) -> bool: """ Test ``input_`` to determine whether a read action will succeed. @@ -204,13 +218,16 @@ def ready_for_reading(input_) -> bool: if not has_fileno(input_): return True if WINDOWS: - return msvcrt.kbhit() + return msvcrt.kbhit() # type: ignore else: reads, _, _ = select.select([input_], [], [], 0.0) return bool(reads and reads[0] is input_) -def bytes_to_read(input_) -> int: +def bytes_to_read( + input_: StreamTypes, + # Union["BytesIO", "StringIO", "TextIOWrapper"] +) -> int: """ Query stream ``input_`` to see how many bytes may be readable. @@ -230,5 +247,5 @@ def bytes_to_read(input_) -> int: # going to work re: ioctl(). if not WINDOWS and isatty(input_) and has_fileno(input_): fionread = fcntl.ioctl(input_, termios.FIONREAD, " ") - return struct.unpack("h", fionread)[0] + return int(struct.unpack("h", fionread)[0]) return 1 diff --git a/invoke/util.py b/invoke/util.py index 39b726127..0c1e698d9 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -1,8 +1,8 @@ from collections import namedtuple from contextlib import contextmanager - -# from os import PathLike -from typing import Any, List, Optional, Tuple, Union +from io import BytesIO, StringIO, TextIOWrapper +from types import TracebackType +from typing import Any, Generator, List, Optional, Tuple, Type, Union import io import logging import os @@ -63,7 +63,7 @@ def task_name_sort_key(name: str) -> Tuple[List[str], str]: # TODO: Make part of public API sometime @contextmanager -def cd(where: str): +def cd(where: str) -> Generator[None, None, None]: cwd = os.getcwd() os.chdir(where) try: @@ -72,7 +72,7 @@ def cd(where: str): os.chdir(cwd) -def has_fileno(stream) -> bool: +def has_fileno(stream: Union[BytesIO, TextIOWrapper]) -> bool: """ Cleanly determine whether ``stream`` has a useful ``.fileno()``. @@ -96,7 +96,9 @@ def has_fileno(stream) -> bool: return False -def isatty(stream) -> Union[bool, Any]: +def isatty( + stream: Union[BytesIO, StringIO, TextIOWrapper] +) -> Union[bool, Any]: """ Cleanly determine whether ``stream`` is a TTY. @@ -164,6 +166,14 @@ class ExceptionHandlingThread(threading.Thread): .. versionadded:: 1.0 """ + # TODO: legacy cruft that needs to be removed + exc_info: Optional[ + Union[ + Tuple[Type[BaseException], BaseException, TracebackType], + Tuple[None, None, None], + ] + ] + def __init__(self, **kwargs: Any) -> None: """ Create a new exception-handling thread instance. @@ -215,9 +225,9 @@ def run(self) -> None: name = "_run" if "target" in self.kwargs: name = self.kwargs["target"].__name__ - debug(msg.format(self.exc_info[1], name)) # noqa + debug(msg.format(self.exc_info[1], name)) # type: ignore # noqa - def exception(self) -> Optional['ExceptionWrapper']: + def exception(self) -> Optional["ExceptionWrapper"]: """ If an exception occurred, return an `.ExceptionWrapper` around it. diff --git a/invoke/watchers.py b/invoke/watchers.py index b0e1ef571..6f6495bf9 100644 --- a/invoke/watchers.py +++ b/invoke/watchers.py @@ -1,5 +1,6 @@ import re import threading +from typing import Generator, Iterable, Literal from .exceptions import ResponseNotAccepted @@ -34,7 +35,7 @@ class StreamWatcher(threading.local): .. versionadded:: 1.0 """ - def submit(self, stream): + def submit(self, stream: str) -> Iterable[str]: """ Act on ``stream`` data, potentially returning responses. @@ -58,7 +59,7 @@ class Responder(StreamWatcher): .. versionadded:: 1.0 """ - def __init__(self, pattern, response): + def __init__(self, pattern: Literal["pattern"], response: str) -> None: r""" Imprint this `Responder` with necessary parameters. @@ -75,7 +76,9 @@ def __init__(self, pattern, response): self.response = response self.index = 0 - def pattern_matches(self, stream, pattern, index_attr): + def pattern_matches( + self, stream: str, pattern: str, index_attr: str + ) -> Iterable[str]: """ Generic "search for pattern in stream, using index" behavior. @@ -101,7 +104,7 @@ def pattern_matches(self, stream, pattern, index_attr): setattr(self, index_attr, index + len(new_)) return matches - def submit(self, stream): + def submit(self, stream: str) -> Generator[str, None, None]: # Iterate over findall() response in case >1 match occurred. for _ in self.pattern_matches(stream, self.pattern, "index"): yield self.response @@ -118,13 +121,15 @@ class FailingResponder(Responder): .. versionadded:: 1.0 """ - def __init__(self, pattern, response, sentinel): + def __init__( + self, pattern: Literal["pattern"], response: str, sentinel: str + ) -> None: super().__init__(pattern, response) self.sentinel = sentinel self.failure_index = 0 self.tried = False - def submit(self, stream): + def submit(self, stream: str) -> Generator[str, None, None]: # Behave like regular Responder initially response = super().submit(stream) # Also check stream for our failure sentinel diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index e104bb452..000000000 --- a/mypy.ini +++ /dev/null @@ -1,4 +0,0 @@ -[mypy] -warn_return_any = True -warn_unused_configs = True -exclude = (integration|invoke/vendor|site|tests|tasks) From 8e0d612afe073c70cf468e7859a3ac5ce517d347 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 17:22:54 -0500 Subject: [PATCH 08/27] test: completed mypy integration --- invoke/completion/complete.py | 5 ++- invoke/config.py | 6 ++- invoke/exceptions.py | 4 +- invoke/parser/parser.py | 4 +- invoke/program.py | 15 +++++--- invoke/runners.py | 69 +++++++++++++++++++++-------------- invoke/tasks.py | 2 +- invoke/terminals.py | 29 +++++---------- invoke/util.py | 9 ++--- 9 files changed, 78 insertions(+), 65 deletions(-) diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index d65901f57..3f9b413e7 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -7,7 +7,7 @@ import os import re import shlex -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from ..exceptions import Exit, ParseError from ..util import debug, task_name_sort_key # type: ignore @@ -38,7 +38,8 @@ def complete( # Use last seen context in case of failure (required for # otherwise-invalid partial invocations being completed). - # contexts: List[ParserContext, ParseResult] + # FIXME: this seems wonky + contexts: Union[List[ParserContext], ParseResult] try: debug("Seeking context name in tokens: {!r}".format(tokens)) contexts = parser.parse_argv(tokens) diff --git a/invoke/config.py b/invoke/config.py index 5301cdfb9..2cd3fda69 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -4,7 +4,7 @@ import types from os import PathLike from os.path import join, splitext, expanduser -from typing import Any, Dict, Iterator, Optional, Tuple, Type +from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union from .env import Environment from .exceptions import UnknownFileType, UnpicklableConfigMember @@ -819,7 +819,9 @@ def load_collection( if merge: self.merge() - def set_project_location(self, path: Optional[PathLike]) -> None: + def set_project_location( + self, path: Optional[Union[PathLike, str]] + ) -> None: """ Set the directory path where a project-level config file may be found. diff --git a/invoke/exceptions.py b/invoke/exceptions.py index e1683c5c8..19ca563bc 100644 --- a/invoke/exceptions.py +++ b/invoke/exceptions.py @@ -47,7 +47,9 @@ class Failure(Exception): .. versionadded:: 1.0 """ - def __init__(self, result: "Result", reason: Optional[str] = None) -> None: + def __init__( + self, result: "Result", reason: Optional["WatcherError"] = None + ) -> None: self.result = result self.reason = reason diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index f4fb343ba..dd732f027 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Iterable, List, Optional try: from invoke.vendor.lexicon import Lexicon @@ -63,7 +63,7 @@ class Parser: def __init__( self, - contexts: Tuple["Lexicon", ...] = (), + contexts: Iterable["ParserContext"] = (), initial: Optional["ParserContext"] = None, ignore_unknown: bool = False, ) -> None: diff --git a/invoke/program.py b/invoke/program.py index 8e9084dd6..f9b4a6d5f 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -170,6 +170,7 @@ def task_args(self) -> List["Argument"]: ), ] + argv: List[str] # Other class-level global variables a subclass might override sometime # maybe? leading_indent_width = 2 @@ -270,7 +271,7 @@ def __init__( # code to autogenerate it from binary_names.) self._binary = binary self._binary_names = binary_names - self.argv: Optional[List[str]] = None + self.argv = [] self.loader_class = loader_class or FilesystemLoader self.executor_class = executor_class or Executor self.config_class = config_class or Config @@ -471,8 +472,8 @@ def parse_collection(self) -> None: # Set these up for potential use later when listing tasks # TODO: be nice if these came from the config...! Users would love to # say they default to nested for example. Easy 2.x feature-add. - self.list_root = None - self.list_depth = None + self.list_root: Optional[str] = None + self.list_depth: Optional[int] = None self.list_format = "flat" self.scoped_collection = self.collection @@ -506,6 +507,10 @@ def parse_cleanup(self) -> None: # Print discovered tasks if necessary list_root = self.args.list.value # will be True or string + # print('list_root', type(list_root), self.args.list.value) + # print('args', self.args) + # print('args.list', self.args.list) + # print('args.list.value', self.args.list.value) self.list_format = self.args["list-format"].value self.list_depth = self.args["list-depth"].value if list_root: @@ -819,7 +824,7 @@ def _make_pairs( self, coll: "Collection", ancestors: Optional[List[str]] = None, - ) -> Tuple[str, Optional[str]]: + ) -> List[Tuple[str, Optional[str]]]: if ancestors is None: ancestors = [] pairs = [] @@ -917,7 +922,7 @@ def task_list_opener(self, extra: str = "") -> str: return text def display_with_columns( - self, pairs: Tuple[str, Optional[str]], extra: str = "" + self, pairs: List[Tuple[str, Optional[str]]], extra: str = "" ) -> None: root = self.list_root print("{}:\n".format(self.task_list_opener(extra=extra))) diff --git a/invoke/runners.py b/invoke/runners.py index dc0220130..48995928d 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -14,9 +14,11 @@ Callable, Dict, Generator, + IO, List, Optional, Tuple, + Type, Union, ) @@ -53,7 +55,7 @@ from .util import has_fileno, isatty, ExceptionHandlingThread if TYPE_CHECKING: - from io import StringIO, TextIOWrapper + # from io import BytesIO, StringIO, TextIOWrapper from .context import Context from .watchers import StreamWatcher @@ -116,7 +118,7 @@ def __init__(self, context: "Context") -> None: #: in at runtime by `run`. self.watchers: List["StreamWatcher"] = [] # Optional timeout timer placeholder - self._timer = None + self._timer: Optional[threading.Timer] = None # Async flags (initialized for 'finally' referencing in case something # goes REAL bad during options parsing) self._asynchronous = False @@ -709,7 +711,7 @@ def read_proc_output(self, reader: Callable) -> Generator[str, None, None]: break yield self.decode(data) - def write_our_output(self, stream: "TextIOWrapper", string: str) -> None: + def write_our_output(self, stream: IO, string: str) -> None: """ Write ``string`` to ``stream``. @@ -731,9 +733,9 @@ def write_our_output(self, stream: "TextIOWrapper", string: str) -> None: def _handle_output( self, - buffer_: List["StringIO"], + buffer_: List[str], hide: bool, - output: "TextIOWrapper", + output: IO, reader: Callable, ) -> None: # TODO: store un-decoded/raw bytes somewhere as well... @@ -754,7 +756,7 @@ def _handle_output( self.respond(buffer_) def handle_stdout( - self, buffer_: List["StringIO"], hide: bool, output: "TextIOWrapper" + self, buffer_: List[str], hide: bool, output: IO ) -> None: """ Read process' stdout, storing into a buffer & printing/parsing. @@ -777,7 +779,7 @@ def handle_stdout( ) def handle_stderr( - self, buffer_: List["StringIO"], hide: bool, output: "TextIOWrapper" + self, buffer_: List[str], hide: bool, output: IO ) -> None: """ Read process' stderr, storing into a buffer & printing/parsing. @@ -791,7 +793,7 @@ def handle_stderr( buffer_, hide, output, reader=self.read_proc_stderr ) - def read_our_stdin(self, input_: "TextIOWrapper") -> Optional[str]: + def read_our_stdin(self, input_: IO) -> Optional[str]: """ Read & decode bytes from a local stdin stream. @@ -833,8 +835,8 @@ def read_our_stdin(self, input_: "TextIOWrapper") -> Optional[str]: def handle_stdin( self, - input_: "TextIOWrapper", - output: "TextIOWrapper", + input_: IO, + output: IO, echo: bool = False, ) -> None: """ @@ -897,9 +899,7 @@ def handle_stdin( # Take a nap so we're not chewing CPU. time.sleep(self.input_sleep) - def should_echo_stdin( - self, input_: "StringIO", output: "TextIOWrapper" - ) -> bool: + def should_echo_stdin(self, input_: IO, output: IO) -> bool: """ Determine whether data read from ``input_`` should echo to ``output``. @@ -1045,7 +1045,7 @@ def process_is_finished(self) -> bool: """ raise NotImplementedError - def start(self, command: str, shell: str, env: Dict[str, Any]): + def start(self, command: str, shell: str, env: Dict[str, Any]) -> None: """ Initiate execution of ``command`` (via ``shell``, with ``env``). @@ -1076,7 +1076,7 @@ def start_timer(self, timeout: int) -> None: self._timer = threading.Timer(timeout, self.kill) self._timer.start() - def read_proc_stdout(self, num_bytes: int) -> Union[bytes, str]: + def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: """ Read ``num_bytes`` from the running process' stdout stream. @@ -1088,7 +1088,7 @@ def read_proc_stdout(self, num_bytes: int) -> Union[bytes, str]: """ raise NotImplementedError - def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str]: + def read_proc_stderr(self, num_bytes: int) -> Optional[Union[bytes, str]]: """ Read ``num_bytes`` from the running process' stderr stream. @@ -1201,7 +1201,7 @@ def timed_out(self) -> bool: """ # Timer expiry implies we did time out. (The timer itself will have # killed the subprocess, allowing us to even get to this point.) - return self._timer and not self._timer.is_alive() + return True if self._timer and not self._timer.is_alive() else False class Local(Runner): @@ -1224,7 +1224,7 @@ class Local(Runner): def __init__(self, context: "Context") -> None: super().__init__(context) # Bookkeeping var for pty use case - self.status = None + self.status = 0 def should_use_pty(self, pty: bool = False, fallback: bool = True) -> bool: use_pty = False @@ -1239,7 +1239,7 @@ def should_use_pty(self, pty: bool = False, fallback: bool = True) -> bool: use_pty = False return use_pty - def read_proc_stdout(self, num_bytes: int): + def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: # Obtain useful read-some-bytes function if self.using_pty: # Need to handle spurious OSErrors on some Linux platforms. @@ -1260,24 +1260,33 @@ def read_proc_stdout(self, num_bytes: int): # appeared, so we return a falsey value, which triggers the # "end of output" logic in code using reader functions. data = None - else: + elif self.process and self.process.stdout: data = os.read(self.process.stdout.fileno(), num_bytes) + else: + data = None return data - def read_proc_stderr(self, num_bytes: int): + def read_proc_stderr(self, num_bytes: int) -> Optional[Union[bytes, str]]: # NOTE: when using a pty, this will never be called. # TODO: do we ever get those OSErrors on stderr? Feels like we could? - return os.read(self.process.stderr.fileno(), num_bytes) + if self.process and self.process.stderr: + return os.read(self.process.stderr.fileno(), num_bytes) + return None - def _write_proc_stdin(self, data: bytes) -> int: + def _write_proc_stdin(self, data: bytes) -> None: # NOTE: parent_fd from os.fork() is a read/write pipe attached to our # forked process' stdout/stdin, respectively. - fd = self.parent_fd if self.using_pty else self.process.stdin.fileno() + if self.using_pty: + fd = self.parent_fd + elif self.process and self.process.stdin: + fd = self.process.stdin.fileno() + else: + raise SubprocessPipeError("No stdin process exists") # Try to write, ignoring broken pipes if encountered (implies child # process exited before the process piping stdin to us finished; # there's nothing we can do about that!) try: - return os.write(fd, data) + os.write(fd, data) except OSError as e: if "Broken pipe" not in str(e): raise @@ -1287,7 +1296,10 @@ def close_proc_stdin(self) -> None: # there is no working scenario to tell the process that stdin # closed when using pty raise SubprocessPipeError("Cannot close stdin when pty=True") - self.process.stdin.close() + elif self.process and self.process.stdin: + self.process.stdin.close() + else: + raise SubprocessPipeError("No stdin process exists") def start(self, command: str, shell: str, env: Dict[str, Any]) -> None: if self.using_pty: @@ -1605,7 +1617,10 @@ def __enter__(self) -> "Promise": return self def __exit__( - self, exc_type, exc_value, exc_tb: Optional[TracebackType] + self, + exc_type: Optional[Type[BaseException]], + exc_value: BaseException, + exc_tb: Optional[TracebackType], ) -> None: self.join() diff --git a/invoke/tasks.py b/invoke/tasks.py index 5209aa3a5..8c2430899 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -55,7 +55,7 @@ def __init__( self, body: Callable, name: Optional[str] = None, - aliases: Tuple[str, ...] = (), + aliases: Iterable[str] = (), positional: Optional[Iterable[str]] = None, optional: Iterable[str] = (), default: bool = False, diff --git a/invoke/terminals.py b/invoke/terminals.py index d11a17125..490750c08 100644 --- a/invoke/terminals.py +++ b/invoke/terminals.py @@ -8,7 +8,7 @@ """ from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, IO, Optional, Tuple, Union +from typing import Generator, IO, Optional, Tuple import os import select import sys @@ -16,9 +16,6 @@ # TODO: move in here? They're currently platform-agnostic... from .util import has_fileno, isatty -if TYPE_CHECKING: - from io import BytesIO, StringIO, TextIOWrapper - WINDOWS = sys.platform == "win32" """ @@ -47,8 +44,6 @@ import termios import tty -StreamTypes = Union["BytesIO", "StringIO", "TextIOWrapper"] - def pty_size() -> Tuple[int, int]: """ @@ -127,7 +122,7 @@ class CONSOLE_SCREEN_BUFFER_INFO(Structure): return (None, None) -def stdin_is_foregrounded_tty(stream: StreamTypes) -> bool: +def stdin_is_foregrounded_tty(stream: IO) -> bool: """ Detect if given stdin ``stream`` seems to be in the foreground of a TTY. @@ -151,7 +146,7 @@ def stdin_is_foregrounded_tty(stream: StreamTypes) -> bool: return os.getpgrp() == os.tcgetpgrp(stream.fileno()) -def cbreak_already_set(stream: StreamTypes) -> bool: +def cbreak_already_set(stream: IO) -> bool: # Explicitly not docstringed to remain private, for now. Eh. # Checks whether tty.setcbreak appears to have already been run against # ``stream`` (or if it would otherwise just not do anything). @@ -175,8 +170,7 @@ def cbreak_already_set(stream: StreamTypes) -> bool: @contextmanager def character_buffered( - stream: Union[int, IO[str]], - # Union[BytesIO, StringIO, TextIOWrapper], + stream: IO, ) -> Generator[None, None, None]: """ Force local terminal ``stream`` be character, not line, buffered. @@ -187,9 +181,9 @@ def character_buffered( """ if ( WINDOWS - or not isatty(stream) # type: ignore - or not stdin_is_foregrounded_tty(stream) # type: ignore - or cbreak_already_set(stream) # type: ignore + or not isatty(stream) + or not stdin_is_foregrounded_tty(stream) + or cbreak_already_set(stream) ): yield else: @@ -201,7 +195,7 @@ def character_buffered( termios.tcsetattr(stream, termios.TCSADRAIN, old_settings) -def ready_for_reading(input_: StreamTypes) -> bool: +def ready_for_reading(input_: IO) -> bool: """ Test ``input_`` to determine whether a read action will succeed. @@ -224,10 +218,7 @@ def ready_for_reading(input_: StreamTypes) -> bool: return bool(reads and reads[0] is input_) -def bytes_to_read( - input_: StreamTypes, - # Union["BytesIO", "StringIO", "TextIOWrapper"] -) -> int: +def bytes_to_read(input_: IO) -> int: """ Query stream ``input_`` to see how many bytes may be readable. @@ -246,6 +237,6 @@ def bytes_to_read( # it's not a tty but has a fileno, or vice versa; neither is typically # going to work re: ioctl(). if not WINDOWS and isatty(input_) and has_fileno(input_): - fionread = fcntl.ioctl(input_, termios.FIONREAD, " ") + fionread = fcntl.ioctl(input_, termios.FIONREAD, b" ") return int(struct.unpack("h", fionread)[0]) return 1 diff --git a/invoke/util.py b/invoke/util.py index 0c1e698d9..a8f1ee974 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -1,8 +1,7 @@ from collections import namedtuple from contextlib import contextmanager -from io import BytesIO, StringIO, TextIOWrapper from types import TracebackType -from typing import Any, Generator, List, Optional, Tuple, Type, Union +from typing import Any, Generator, List, IO, Optional, Tuple, Type, Union import io import logging import os @@ -72,7 +71,7 @@ def cd(where: str) -> Generator[None, None, None]: os.chdir(cwd) -def has_fileno(stream: Union[BytesIO, TextIOWrapper]) -> bool: +def has_fileno(stream: IO) -> bool: """ Cleanly determine whether ``stream`` has a useful ``.fileno()``. @@ -96,9 +95,7 @@ def has_fileno(stream: Union[BytesIO, TextIOWrapper]) -> bool: return False -def isatty( - stream: Union[BytesIO, StringIO, TextIOWrapper] -) -> Union[bool, Any]: +def isatty(stream: IO) -> Union[bool, Any]: """ Cleanly determine whether ``stream`` is a TTY. From e9bdb6e954d4b26041f8c970854dcabc64e7392e Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 17:28:08 -0500 Subject: [PATCH 09/27] test: add pyproject extensions and config --- invoke/vendor/typing_extensions.py | 2274 ++++++++++++++++++++++++++++ pyproject.toml | 37 + 2 files changed, 2311 insertions(+) create mode 100644 invoke/vendor/typing_extensions.py create mode 100644 pyproject.toml diff --git a/invoke/vendor/typing_extensions.py b/invoke/vendor/typing_extensions.py new file mode 100644 index 000000000..c32c63d16 --- /dev/null +++ b/invoke/vendor/typing_extensions.py @@ -0,0 +1,2274 @@ +import abc +import collections +import collections.abc +import functools +import inspect +import operator +import sys +import types as _types +import typing + + +__all__ = [ + # Super-special typing primitives. + 'Any', + 'ClassVar', + 'Concatenate', + 'Final', + 'LiteralString', + 'ParamSpec', + 'ParamSpecArgs', + 'ParamSpecKwargs', + 'Self', + 'Type', + 'TypeVar', + 'TypeVarTuple', + 'Unpack', + + # ABCs (from collections.abc). + 'Awaitable', + 'AsyncIterator', + 'AsyncIterable', + 'Coroutine', + 'AsyncGenerator', + 'AsyncContextManager', + 'ChainMap', + + # Concrete collection types. + 'ContextManager', + 'Counter', + 'Deque', + 'DefaultDict', + 'NamedTuple', + 'OrderedDict', + 'TypedDict', + + # Structural checks, a.k.a. protocols. + 'SupportsIndex', + + # One-off things. + 'Annotated', + 'assert_never', + 'assert_type', + 'clear_overloads', + 'dataclass_transform', + 'deprecated', + 'get_overloads', + 'final', + 'get_args', + 'get_origin', + 'get_type_hints', + 'IntVar', + 'is_typeddict', + 'Literal', + 'NewType', + 'overload', + 'override', + 'Protocol', + 'reveal_type', + 'runtime', + 'runtime_checkable', + 'Text', + 'TypeAlias', + 'TypeGuard', + 'TYPE_CHECKING', + 'Never', + 'NoReturn', + 'Required', + 'NotRequired', +] + +# for backward compatibility +PEP_560 = True +GenericMeta = type + +# The functions below are modified copies of typing internal helpers. +# They are needed by _ProtocolMeta and they provide support for PEP 646. + +_marker = object() + + +def _check_generic(cls, parameters, elen=_marker): + """Check correct count for parameters of a generic cls (internal helper). + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + if elen is _marker: + if not hasattr(cls, "__parameters__") or not cls.__parameters__: + raise TypeError(f"{cls} is not a generic class") + elen = len(cls.__parameters__) + alen = len(parameters) + if alen != elen: + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) + if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): + return + raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" + f" actual {alen}, expected {elen}") + + +if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): + return isinstance( + t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) + ) +elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): + return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) +else: + def _should_collect_from_parameters(t): + return isinstance(t, typing._GenericAlias) and not t._special + + +def _collect_type_vars(types, typevar_types=None): + """Collect all type variable contained in types in order of + first appearance (lexicographic order). For example:: + + _collect_type_vars((T, List[S, T])) == (T, S) + """ + if typevar_types is None: + typevar_types = typing.TypeVar + tvars = [] + for t in types: + if ( + isinstance(t, typevar_types) and + t not in tvars and + not _is_unpack(t) + ): + tvars.append(t) + if _should_collect_from_parameters(t): + tvars.extend([t for t in t.__parameters__ if t not in tvars]) + return tuple(tvars) + + +NoReturn = typing.NoReturn + +# Some unconstrained type variables. These are used by the container types. +# (These are not for export.) +T = typing.TypeVar('T') # Any type. +KT = typing.TypeVar('KT') # Key type. +VT = typing.TypeVar('VT') # Value type. +T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. + + +if sys.version_info >= (3, 11): + from typing import Any +else: + + class _AnyMeta(type): + def __instancecheck__(self, obj): + if self is Any: + raise TypeError("typing_extensions.Any cannot be used with isinstance()") + return super().__instancecheck__(obj) + + def __repr__(self): + if self is Any: + return "typing_extensions.Any" + return super().__repr__() + + class Any(metaclass=_AnyMeta): + """Special type indicating an unconstrained type. + - Any is compatible with every type. + - Any assumed to have all methods. + - All values assumed to be instances of Any. + Note that all the above statements are true from the point of view of + static type checkers. At runtime, Any should not be used with instance + checks. + """ + def __new__(cls, *args, **kwargs): + if cls is Any: + raise TypeError("Any cannot be instantiated") + return super().__new__(cls, *args, **kwargs) + + +ClassVar = typing.ClassVar + +# On older versions of typing there is an internal class named "Final". +# 3.8+ +if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): + Final = typing.Final +# 3.7 +else: + class _FinalForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Final = _FinalForm('Final', + doc="""A special typing construct to indicate that a name + cannot be re-assigned or overridden in a subclass. + For example: + + MAX_SIZE: Final = 9000 + MAX_SIZE += 1 # Error reported by type checker + + class Connection: + TIMEOUT: Final[int] = 10 + class FastConnector(Connection): + TIMEOUT = 1 # Error reported by type checker + + There is no runtime checking of these properties.""") + +if sys.version_info >= (3, 11): + final = typing.final +else: + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + +def IntVar(name): + return typing.TypeVar(name) + + +# 3.8+: +if hasattr(typing, 'Literal'): + Literal = typing.Literal +# 3.7: +else: + class _LiteralForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + return typing._GenericAlias(self, parameters) + + Literal = _LiteralForm('Literal', + doc="""A type that can be used to indicate to type checkers + that the corresponding value has a value literally equivalent + to the provided parameter. For example: + + var: Literal[4] = 4 + + The type checker understands that 'var' is literally equal to + the value 4 and no other value. + + Literal[...] cannot be subclassed. There is no runtime + checking verifying that the parameter is actually a value + instead of a type.""") + + +_overload_dummy = typing._overload_dummy # noqa + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + +# This is not a real generic class. Don't use outside annotations. +Type = typing.Type + +# Various ABCs mimicking those in collections.abc. +# A few are simply re-exported for completeness. + + +Awaitable = typing.Awaitable +Coroutine = typing.Coroutine +AsyncIterable = typing.AsyncIterable +AsyncIterator = typing.AsyncIterator +Deque = typing.Deque +ContextManager = typing.ContextManager +AsyncContextManager = typing.AsyncContextManager +DefaultDict = typing.DefaultDict + +# 3.7.2+ +if hasattr(typing, 'OrderedDict'): + OrderedDict = typing.OrderedDict +# 3.7.0-3.7.2 +else: + OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) + +Counter = typing.Counter +ChainMap = typing.ChainMap +AsyncGenerator = typing.AsyncGenerator +NewType = typing.NewType +Text = typing.Text +TYPE_CHECKING = typing.TYPE_CHECKING + + +_PROTO_WHITELIST = ['Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'ContextManager', 'AsyncContextManager'] + + +def _get_protocol_attrs(cls): + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker', '_gorg')): + attrs.add(attr) + return attrs + + +def _is_callable_members_only(cls): + return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) + + +def _maybe_adjust_parameters(cls): + """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. + + The contents of this function are very similar + to logic found in typing.Generic.__init_subclass__ + on the CPython main branch. + """ + tvars = [] + if '__orig_bases__' in cls.__dict__: + tvars = typing._collect_type_vars(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...] and/or Protocol[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (typing.Generic, Protocol)): + # for error messages + the_base = base.__origin__.__name__ + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...]" + " and/or Protocol[...] multiple types.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) + + +# 3.8+ +if hasattr(typing, 'Protocol'): + Protocol = typing.Protocol +# 3.7 +else: + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + class _ProtocolMeta(abc.ABCMeta): # noqa: B024 + # This metaclass is a bit unfortunate and exists only because of the lack + # of __instancehook__. + def __instancecheck__(cls, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if ((not getattr(cls, '_is_protocol', False) or + _is_callable_members_only(cls)) and + issubclass(instance.__class__, cls)): + return True + if cls._is_protocol: + if all(hasattr(instance, attr) and + (not callable(getattr(cls, attr, None)) or + getattr(instance, attr) is not None) + for attr in _get_protocol_attrs(cls)): + return True + return super().__instancecheck__(instance) + + class Protocol(metaclass=_ProtocolMeta): + # There is quite a lot of overlapping code with typing.Generic. + # Unfortunately it is hard to avoid this while these live in two different + # modules. The duplicated code will be removed when Protocol is moved to typing. + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing_extensions.runtime act as simple-minded runtime protocol that checks + only the presence of given attributes, ignoring their type signatures. + + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if cls is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can only be used as a base class") + return super().__new__(cls) + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple): + params = (params,) + if not params and cls is not typing.Tuple: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty") + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) # noqa + if cls is Protocol: + # Generic can only be subscripted with unique type variables. + if not all(isinstance(p, typing.TypeVar) for p in params): + i = 0 + while isinstance(params[i], typing.TypeVar): + i += 1 + raise TypeError( + "Parameters to Protocol[...] must all be type variables." + f" Parameter {i + 1} is {params[i]}") + if len(set(params)) != len(params): + raise TypeError( + "Parameters to Protocol[...] must all be unique") + else: + # Subscripting a regular Generic subclass. + _check_generic(cls, params, len(cls.__parameters__)) + return typing._GenericAlias(cls, params) + + def __init_subclass__(cls, *args, **kwargs): + if '__orig_bases__' in cls.__dict__: + error = typing.Generic in cls.__orig_bases__ + else: + error = typing.Generic in cls.__bases__ + if error: + raise TypeError("Cannot inherit from plain Generic") + _maybe_adjust_parameters(cls) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + if not getattr(cls, '_is_runtime_protocol', False): + if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + return NotImplemented + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") + if not _is_callable_members_only(cls): + if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + return NotImplemented + raise TypeError("Protocols with non-method members" + " don't support issubclass()") + if not isinstance(other, type): + # Same error as for issubclass(1, int) + raise TypeError('issubclass() arg 1 must be a class') + for attr in _get_protocol_attrs(cls): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + annotations = getattr(base, '__annotations__', {}) + if (isinstance(annotations, typing.Mapping) and + attr in annotations and + isinstance(other, _ProtocolMeta) and + other._is_protocol): + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # We have nothing more to do for non-protocols. + if not cls._is_protocol: + return + + # Check consistency of bases. + for base in cls.__bases__: + if not (base in (object, typing.Generic) or + base.__module__ == 'collections.abc' and + base.__name__ in _PROTO_WHITELIST or + isinstance(base, _ProtocolMeta) and base._is_protocol): + raise TypeError('Protocols can only inherit from other' + f' protocols, got {repr(base)}') + cls.__init__ = _no_init + + +# 3.8+ +if hasattr(typing, 'runtime_checkable'): + runtime_checkable = typing.runtime_checkable +# 3.7 +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + f' got {cls!r}') + cls._is_runtime_protocol = True + return cls + + +# Exists for backwards compatibility. +runtime = runtime_checkable + + +# 3.8+ +if hasattr(typing, 'SupportsIndex'): + SupportsIndex = typing.SupportsIndex +# 3.7 +else: + @runtime_checkable + class SupportsIndex(Protocol): + __slots__ = () + + @abc.abstractmethod + def __index__(self) -> int: + pass + + +if hasattr(typing, "Required"): + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834 + # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" + # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 + # The standard library TypedDict below Python 3.11 does not store runtime + # information about optional and required keys when using Required or NotRequired. + # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. + TypedDict = typing.TypedDict + _TypedDictMeta = typing._TypedDictMeta + is_typeddict = typing.is_typeddict +else: + def _check_fails(cls, other): + try: + if sys._getframe(1).f_globals['__name__'] not in ['abc', + 'functools', + 'typing']: + # Typed dicts are only for static structural subtyping. + raise TypeError('TypedDict does not support instance and class checks') + except (AttributeError, ValueError): + pass + return False + + def _dict_new(*args, **kwargs): + if not args: + raise TypeError('TypedDict.__new__(): not enough arguments') + _, args = args[0], args[1:] # allow the "cls" keyword be passed + return dict(*args, **kwargs) + + _dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)' + + def _typeddict_new(*args, total=True, **kwargs): + if not args: + raise TypeError('TypedDict.__new__(): not enough arguments') + _, args = args[0], args[1:] # allow the "cls" keyword be passed + if args: + typename, args = args[0], args[1:] # allow the "_typename" keyword be passed + elif '_typename' in kwargs: + typename = kwargs.pop('_typename') + import warnings + warnings.warn("Passing '_typename' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + raise TypeError("TypedDict.__new__() missing 1 required positional " + "argument: '_typename'") + if args: + try: + fields, = args # allow the "_fields" keyword be passed + except ValueError: + raise TypeError('TypedDict.__new__() takes from 2 to 3 ' + f'positional arguments but {len(args) + 2} ' + 'were given') + elif '_fields' in kwargs and len(kwargs) == 1: + fields = kwargs.pop('_fields') + import warnings + warnings.warn("Passing '_fields' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + fields = None + + if fields is None: + fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") + + ns = {'__annotations__': dict(fields)} + try: + # Setting correct module is necessary to make typed dict classes pickleable. + ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + + return _TypedDictMeta(typename, (), ns, total=total) + + _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' + ' /, *, total=True, **kwargs)') + + _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters + + class _TypedDictMeta(type): + def __init__(cls, name, bases, ns, total=True): + super().__init__(name, bases, ns) + + def __new__(cls, name, bases, ns, total=True): + # Create new typed dict class object. + # This method is called directly when TypedDict is subclassed, + # or via _typeddict_new when TypedDict is instantiated. This way + # TypedDict supports all three syntaxes described in its docstring. + # Subclasses and instances of TypedDict return actual dictionaries + # via _dict_new. + ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new + # Don't insert typing.Generic into __bases__ here, + # or Generic.__init_subclass__ will raise TypeError + # in the super().__new__() call. + # Instead, monkey-patch __bases__ onto the class after it's been created. + tp_dict = super().__new__(cls, name, (dict,), ns) + + if any(issubclass(base, typing.Generic) for base in bases): + tp_dict.__bases__ = (typing.Generic, dict) + _maybe_adjust_parameters(tp_dict) + + annotations = {} + own_annotations = ns.get('__annotations__', {}) + msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + kwds = {"module": tp_dict.__module__} if _TAKES_MODULE else {} + own_annotations = { + n: typing._type_check(tp, msg, **kwds) + for n, tp in own_annotations.items() + } + required_keys = set() + optional_keys = set() + + for base in bases: + annotations.update(base.__dict__.get('__annotations__', {})) + required_keys.update(base.__dict__.get('__required_keys__', ())) + optional_keys.update(base.__dict__.get('__optional_keys__', ())) + + annotations.update(own_annotations) + for annotation_key, annotation_type in own_annotations.items(): + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + annotation_origin = get_origin(annotation_type) + + if annotation_origin is Required: + required_keys.add(annotation_key) + elif annotation_origin is NotRequired: + optional_keys.add(annotation_key) + elif total: + required_keys.add(annotation_key) + else: + optional_keys.add(annotation_key) + + tp_dict.__annotations__ = annotations + tp_dict.__required_keys__ = frozenset(required_keys) + tp_dict.__optional_keys__ = frozenset(optional_keys) + if not hasattr(tp_dict, '__total__'): + tp_dict.__total__ = total + return tp_dict + + __instancecheck__ = __subclasscheck__ = _check_fails + + TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) + TypedDict.__module__ = __name__ + TypedDict.__doc__ = \ + """A simple typed name space. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type that expects all of its + instances to have a certain set of keys, with each key + associated with a value of a consistent type. This expectation + is not checked at runtime but is only enforced by type checkers. + Usage:: + + class Point2D(TypedDict): + x: int + y: int + label: str + + a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + + assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + + The type info can be accessed via the Point2D.__annotations__ dict, and + the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. + TypedDict supports two additional equivalent forms:: + + Point2D = TypedDict('Point2D', x=int, y=int, label=str) + Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) + + The class syntax is only supported in Python 3.6+, while two other + syntax forms work for Python 2.7 and 3.2+ + """ + + if hasattr(typing, "_TypedDictMeta"): + _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) + else: + _TYPEDDICT_TYPES = (_TypedDictMeta,) + + def is_typeddict(tp): + """Check if an annotation is a TypedDict class + + For example:: + class Film(TypedDict): + title: str + year: int + + is_typeddict(Film) # => True + is_typeddict(Union[list, str]) # => False + """ + return isinstance(tp, tuple(_TYPEDDICT_TYPES)) + + +if hasattr(typing, "assert_type"): + assert_type = typing.assert_type + +else: + def assert_type(__val, __typ): + """Assert (to the type checker) that the value is of the given type. + + When the type checker encounters a call to assert_type(), it + emits an error if the value is not of the specified type:: + + def greet(name: str) -> None: + assert_type(name, str) # ok + assert_type(name, int) # type checker error + + At runtime this returns the first argument unchanged and otherwise + does nothing. + """ + return __val + + +if hasattr(typing, "Required"): + get_type_hints = typing.get_type_hints +else: + import functools + import types + + # replaces _strip_annotations() + def _strip_extras(t): + """Strips Annotated, Required and NotRequired from a given type.""" + if isinstance(t, _AnnotatedAlias): + return _strip_extras(t.__origin__) + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): + return _strip_extras(t.__args__[0]) + if isinstance(t, typing._GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return t.copy_with(stripped_args) + if hasattr(types, "GenericAlias") and isinstance(t, types.GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return types.GenericAlias(t.__origin__, stripped_args) + if hasattr(types, "UnionType") and isinstance(t, types.UnionType): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + + return t + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + """Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]', 'Required[T]' or 'NotRequired[T]' with 'T' + (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + if hasattr(typing, "Annotated"): + hint = typing.get_type_hints( + obj, globalns=globalns, localns=localns, include_extras=True + ) + else: + hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if include_extras: + return hint + return {k: _strip_extras(t) for k, t in hint.items()} + + +# Python 3.9+ has PEP 593 (Annotated) +if hasattr(typing, 'Annotated'): + Annotated = typing.Annotated + # Not exported and not a public API, but needed for get_origin() and get_args() + # to work. + _AnnotatedAlias = typing._AnnotatedAlias +# 3.7-3.8 +else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): + """Runtime representation of an annotated type. + + At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' + with extra annotations. The alias behaves like a normal typing alias, + instantiating is the same as instantiating the underlying type, binding + it to types is also the same. + """ + def __init__(self, origin, metadata): + if isinstance(origin, _AnnotatedAlias): + metadata = origin.__metadata__ + metadata + origin = origin.__origin__ + super().__init__(origin, origin) + self.__metadata__ = metadata + + def copy_with(self, params): + assert len(params) == 1 + new_type = params[0] + return _AnnotatedAlias(new_type, self.__metadata__) + + def __repr__(self): + return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]") + + def __reduce__(self): + return operator.getitem, ( + Annotated, (self.__origin__,) + self.__metadata__ + ) + + def __eq__(self, other): + if not isinstance(other, _AnnotatedAlias): + return NotImplemented + if self.__origin__ != other.__origin__: + return False + return self.__metadata__ == other.__metadata__ + + def __hash__(self): + return hash((self.__origin__, self.__metadata__)) + + class Annotated: + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type (and will be in + the __origin__ field), the remaining arguments are kept as a tuple in + the __extra__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize()] + Optimized[int] == Annotated[int, runtime.Optimize()] + + OptimizedList = Annotated[List[T], runtime.Optimize()] + OptimizedList[int] == Annotated[List[int], runtime.Optimize()] + """ + + __slots__ = () + + def __new__(cls, *args, **kwargs): + raise TypeError("Type Annotated cannot be instantiated.") + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + allowed_special_forms = (ClassVar, Final) + if get_origin(params[0]) in allowed_special_forms: + origin = params[0] + else: + msg = "Annotated[t, ...]: t must be a type." + origin = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return _AnnotatedAlias(origin, metadata) + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + f"Cannot subclass {cls.__module__}.Annotated" + ) + +# Python 3.8 has get_origin() and get_args() but those implementations aren't +# Annotated-aware, so we can't use those. Python 3.9's versions don't support +# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. +if sys.version_info[:2] >= (3, 10): + get_origin = typing.get_origin + get_args = typing.get_args +# 3.7-3.9 +else: + try: + # 3.9+ + from typing import _BaseGenericAlias + except ImportError: + _BaseGenericAlias = typing._GenericAlias + try: + # 3.9+ + from typing import GenericAlias as _typing_GenericAlias + except ImportError: + _typing_GenericAlias = typing._GenericAlias + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar + and Annotated. Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + get_origin(P.args) is P + """ + if isinstance(tp, _AnnotatedAlias): + return Annotated + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, + ParamSpecArgs, ParamSpecKwargs)): + return tp.__origin__ + if tp is typing.Generic: + return typing.Generic + return None + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _AnnotatedAlias): + return (tp.__origin__,) + tp.__metadata__ + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias)): + if getattr(tp, "_special", False): + return () + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () + + +# 3.10+ +if hasattr(typing, 'TypeAlias'): + TypeAlias = typing.TypeAlias +# 3.9 +elif sys.version_info[:2] >= (3, 9): + class _TypeAliasForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_TypeAliasForm + def TypeAlias(self, parameters): + """Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example above. + """ + raise TypeError(f"{self} is not subscriptable") +# 3.7-3.8 +else: + class _TypeAliasForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + TypeAlias = _TypeAliasForm('TypeAlias', + doc="""Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example + above.""") + + +class _DefaultMixin: + """Mixin for TypeVarLike defaults.""" + + __slots__ = () + + def __init__(self, default): + if isinstance(default, (tuple, list)): + self.__default__ = tuple((typing._type_check(d, "Default must be a type") + for d in default)) + elif default != _marker: + self.__default__ = typing._type_check(default, "Default must be a type") + else: + self.__default__ = None + + +# Add default and infer_variance parameters from PEP 696 and 695 +class TypeVar(typing.TypeVar, _DefaultMixin, _root=True): + """Type variable.""" + + __module__ = 'typing' + + def __init__(self, name, *constraints, bound=None, + covariant=False, contravariant=False, + default=_marker, infer_variance=False): + super().__init__(name, *constraints, bound=bound, covariant=covariant, + contravariant=contravariant) + _DefaultMixin.__init__(self, default) + self.__infer_variance__ = infer_variance + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + +# Python 3.10+ has PEP 612 +if hasattr(typing, 'ParamSpecArgs'): + ParamSpecArgs = typing.ParamSpecArgs + ParamSpecKwargs = typing.ParamSpecKwargs +# 3.7-3.9 +else: + class _Immutable: + """Mixin to indicate that object should not be copied.""" + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + class ParamSpecArgs(_Immutable): + """The args for a ParamSpec object. + + Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. + + ParamSpecArgs objects have a reference back to their ParamSpec: + + P.args.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.args" + + def __eq__(self, other): + if not isinstance(other, ParamSpecArgs): + return NotImplemented + return self.__origin__ == other.__origin__ + + class ParamSpecKwargs(_Immutable): + """The kwargs for a ParamSpec object. + + Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. + + ParamSpecKwargs objects have a reference back to their ParamSpec: + + P.kwargs.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.kwargs" + + def __eq__(self, other): + if not isinstance(other, ParamSpecKwargs): + return NotImplemented + return self.__origin__ == other.__origin__ + +# 3.10+ +if hasattr(typing, 'ParamSpec'): + + # Add default Parameter - PEP 696 + class ParamSpec(typing.ParamSpec, _DefaultMixin, _root=True): + """Parameter specification variable.""" + + __module__ = 'typing' + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + default=_marker): + super().__init__(name, bound=bound, covariant=covariant, + contravariant=contravariant) + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + +# 3.7-3.9 +else: + + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class ParamSpec(list, _DefaultMixin): + """Parameter specification variable. + + Usage:: + + P = ParamSpec('P') + + Parameter specification variables exist primarily for the benefit of static + type checkers. They are used to forward the parameter types of one + callable to another callable, a pattern commonly found in higher order + functions and decorators. They are only valid when used in ``Concatenate``, + or s the first argument to ``Callable``. In Python 3.10 and higher, + they are also supported in user-defined Generics at runtime. + See class Generic for more information on generic types. An + example for annotating a decorator:: + + T = TypeVar('T') + P = ParamSpec('P') + + def add_logging(f: Callable[P, T]) -> Callable[P, T]: + '''A type-safe decorator to add logging to a function.''' + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + logging.info(f'{f.__name__} was called') + return f(*args, **kwargs) + return inner + + @add_logging + def add_two(x: float, y: float) -> float: + '''Add two numbers together.''' + return x + y + + Parameter specification variables defined with covariant=True or + contravariant=True can be used to declare covariant or contravariant + generic types. These keyword arguments are valid, but their actual semantics + are yet to be decided. See PEP 612 for details. + + Parameter specification variables can be introspected. e.g.: + + P.__name__ == 'T' + P.__bound__ == None + P.__covariant__ == False + P.__contravariant__ == False + + Note that only parameter specification variables defined in global scope can + be pickled. + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + @property + def args(self): + return ParamSpecArgs(self) + + @property + def kwargs(self): + return ParamSpecKwargs(self) + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + default=_marker): + super().__init__([self]) + self.__name__ = name + self.__covariant__ = bool(covariant) + self.__contravariant__ = bool(contravariant) + if bound: + self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + else: + self.__bound__ = None + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __repr__(self): + if self.__covariant__: + prefix = '+' + elif self.__contravariant__: + prefix = '-' + else: + prefix = '~' + return prefix + self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + # Hack to get typing._type_check to pass. + def __call__(self, *args, **kwargs): + pass + + +# 3.7-3.9 +if not hasattr(typing, 'Concatenate'): + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class _ConcatenateGenericAlias(list): + + # Trick Generic into looking into this for __parameters__. + __class__ = typing._GenericAlias + + # Flag in 3.8. + _special = False + + def __init__(self, origin, args): + super().__init__(args) + self.__origin__ = origin + self.__args__ = args + + def __repr__(self): + _type_repr = typing._type_repr + return (f'{_type_repr(self.__origin__)}' + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + + def __hash__(self): + return hash((self.__origin__, self.__args__)) + + # Hack to get typing._type_check to pass in Generic. + def __call__(self, *args, **kwargs): + pass + + @property + def __parameters__(self): + return tuple( + tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + ) + + +# 3.7-3.9 +@typing._tp_cache +def _concatenate_getitem(self, parameters): + if parameters == (): + raise TypeError("Cannot take a Concatenate of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + if not isinstance(parameters[-1], ParamSpec): + raise TypeError("The last parameter to Concatenate should be a " + "ParamSpec variable.") + msg = "Concatenate[arg, ...]: each arg must be a type." + parameters = tuple(typing._type_check(p, msg) for p in parameters) + return _ConcatenateGenericAlias(self, parameters) + + +# 3.10+ +if hasattr(typing, 'Concatenate'): + Concatenate = typing.Concatenate + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_TypeAliasForm + def Concatenate(self, parameters): + """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """ + return _concatenate_getitem(self, parameters) +# 3.7-8 +else: + class _ConcatenateForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + return _concatenate_getitem(self, parameters) + + Concatenate = _ConcatenateForm( + 'Concatenate', + doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """) + +# 3.10+ +if hasattr(typing, 'TypeGuard'): + TypeGuard = typing.TypeGuard +# 3.9 +elif sys.version_info[:2] >= (3, 9): + class _TypeGuardForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_TypeGuardForm + def TypeGuard(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.7-3.8 +else: + class _TypeGuardForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeGuard = _TypeGuardForm( + 'TypeGuard', + doc="""Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """) + + +# Vendored from cpython typing._SpecialFrom +class _SpecialForm(typing._Final, _root=True): + __slots__ = ('_name', '__doc__', '_getitem') + + def __init__(self, getitem): + self._getitem = getitem + self._name = getitem.__name__ + self.__doc__ = getitem.__doc__ + + def __getattr__(self, item): + if item in {'__name__', '__qualname__'}: + return self._name + + raise AttributeError(item) + + def __mro_entries__(self, bases): + raise TypeError(f"Cannot subclass {self!r}") + + def __repr__(self): + return f'typing_extensions.{self._name}' + + def __reduce__(self): + return self._name + + def __call__(self, *args, **kwds): + raise TypeError(f"Cannot instantiate {self!r}") + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance()") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass()") + + @typing._tp_cache + def __getitem__(self, parameters): + return self._getitem(self, parameters) + + +if hasattr(typing, "LiteralString"): + LiteralString = typing.LiteralString +else: + @_SpecialForm + def LiteralString(self, params): + """Represents an arbitrary literal string. + + Example:: + + from typing_extensions import LiteralString + + def query(sql: LiteralString) -> ...: + ... + + query("SELECT * FROM table") # ok + query(f"SELECT * FROM {input()}") # not ok + + See PEP 675 for details. + + """ + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Self"): + Self = typing.Self +else: + @_SpecialForm + def Self(self, params): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class ReturnsSelf: + def parse(self, data: bytes) -> Self: + ... + return self + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Never"): + Never = typing.Never +else: + @_SpecialForm + def Never(self, params): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: + + from typing_extensions import Never + + def never_call_me(arg: Never) -> None: + pass + + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # ok, arg is of type Never + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, 'Required'): + Required = typing.Required + NotRequired = typing.NotRequired +elif sys.version_info[:2] >= (3, 9): + class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_ExtensionsSpecialForm + def Required(self, parameters): + """A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + @_ExtensionsSpecialForm + def NotRequired(self, parameters): + """A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: + class _RequiredForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Required = _RequiredForm( + 'Required', + doc="""A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """) + NotRequired = _RequiredForm( + 'NotRequired', + doc="""A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """) + + +if hasattr(typing, "Unpack"): # 3.11+ + Unpack = typing.Unpack +elif sys.version_info[:2] >= (3, 9): + class _UnpackSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + @_UnpackSpecialForm + def Unpack(self, parameters): + """A special typing construct to unpack a variadic type. For example: + + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) + + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... + + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + +else: + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + class _UnpackForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + Unpack = _UnpackForm( + 'Unpack', + doc="""A special typing construct to unpack a variadic type. For example: + + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) + + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... + + """) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + + +if hasattr(typing, "TypeVarTuple"): # 3.11+ + + # Add default Parameter - PEP 696 + class TypeVarTuple(typing.TypeVarTuple, _DefaultMixin, _root=True): + """Type variable tuple.""" + + def __init__(self, name, *, default=_marker): + super().__init__(name) + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + +else: + class TypeVarTuple(_DefaultMixin): + """Type variable tuple. + + Usage:: + + Ts = TypeVarTuple('Ts') + + In the same way that a normal type variable is a stand-in for a single + type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* + type such as ``Tuple[int, str]``. + + Type variable tuples can be used in ``Generic`` declarations. + Consider the following example:: + + class Array(Generic[*Ts]): ... + + The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, + where ``T1`` and ``T2`` are type variables. To use these type variables + as type parameters of ``Array``, we must *unpack* the type variable tuple using + the star operator: ``*Ts``. The signature of ``Array`` then behaves + as if we had simply written ``class Array(Generic[T1, T2]): ...``. + In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows + us to parameterise the class with an *arbitrary* number of type parameters. + + Type variable tuples can be used anywhere a normal ``TypeVar`` can. + This includes class definitions, as shown above, as well as function + signatures and variable annotations:: + + class Array(Generic[*Ts]): + + def __init__(self, shape: Tuple[*Ts]): + self._shape: Tuple[*Ts] = shape + + def get_shape(self) -> Tuple[*Ts]: + return self._shape + + shape = (Height(480), Width(640)) + x: Array[Height, Width] = Array(shape) + y = abs(x) # Inferred type is Array[Height, Width] + z = x + x # ... is Array[Height, Width] + x.get_shape() # ... is tuple[Height, Width] + + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + def __iter__(self): + yield self.__unpacked__ + + def __init__(self, name, *, default=_marker): + self.__name__ = name + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + self.__unpacked__ = Unpack[self] + + def __repr__(self): + return self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(self, *args, **kwds): + if '_root' not in kwds: + raise TypeError("Cannot subclass special typing classes") + + +if hasattr(typing, "reveal_type"): + reveal_type = typing.reveal_type +else: + def reveal_type(__obj: T) -> T: + """Reveal the inferred type of a variable. + + When a static type checker encounters a call to ``reveal_type()``, + it will emit the inferred type of the argument:: + + x: int = 1 + reveal_type(x) + + Running a static type checker (e.g., ``mypy``) on this example + will produce output similar to 'Revealed type is "builtins.int"'. + + At runtime, the function prints the runtime type of the + argument and returns it unchanged. + + """ + print(f"Runtime type is {type(__obj).__name__!r}", file=sys.stderr) + return __obj + + +if hasattr(typing, "assert_never"): + assert_never = typing.assert_never +else: + def assert_never(__arg: Never) -> Never: + """Assert to the type checker that a line of code is unreachable. + + Example:: + + def int_or_str(arg: int | str) -> None: + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + assert_never(arg) + + If a type checker finds that a call to assert_never() is + reachable, it will emit an error. + + At runtime, this throws an exception when called. + + """ + raise AssertionError("Expected code to be unreachable") + + +if sys.version_info >= (3, 12): + # dataclass_transform exists in 3.11 but lacks the frozen_default parameter + dataclass_transform = typing.dataclass_transform +else: + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: typing.Tuple[ + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], + ... + ] = (), + **kwargs: typing.Any, + ) -> typing.Callable[[T], T]: + """Decorator that marks a function, class, or metaclass as providing + dataclass-like behavior. + + Example: + + from typing_extensions import dataclass_transform + + _T = TypeVar("_T") + + # Used on a decorator function + @dataclass_transform() + def create_model(cls: type[_T]) -> type[_T]: + ... + return cls + + @create_model + class CustomerModel: + id: int + name: str + + # Used on a base class + @dataclass_transform() + class ModelBase: ... + + class CustomerModel(ModelBase): + id: int + name: str + + # Used on a metaclass + @dataclass_transform() + class ModelMeta(type): ... + + class ModelBase(metaclass=ModelMeta): ... + + class CustomerModel(ModelBase): + id: int + name: str + + Each of the ``CustomerModel`` classes defined in this example will now + behave similarly to a dataclass created with the ``@dataclasses.dataclass`` + decorator. For example, the type checker will synthesize an ``__init__`` + method. + + The arguments to this decorator can be used to customize this behavior: + - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be + True or False if it is omitted by the caller. + - ``order_default`` indicates whether the ``order`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``kw_only_default`` indicates whether the ``kw_only`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``frozen_default`` indicates whether the ``frozen`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``field_specifiers`` specifies a static list of supported classes + or functions that describe fields, similar to ``dataclasses.field()``. + + At runtime, this decorator records its arguments in the + ``__dataclass_transform__`` attribute on the decorated object. + + See PEP 681 for details. + + """ + def decorator(cls_or_fn): + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + return decorator + + +if hasattr(typing, "override"): + override = typing.override +else: + _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + def override(__arg: _F) -> _F: + """Indicate that a method is intended to override a method in a base class. + + Usage: + + class Base: + def method(self) -> None: ... + pass + + class Child(Base): + @override + def method(self) -> None: + super().method() + + When this decorator is applied to a method, the type checker will + validate that it overrides a method with the same name on a base class. + This helps prevent bugs that may occur when a base class is changed + without an equivalent change to a child class. + + There is no runtime checking of these properties. The decorator + sets the ``__override__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + + See PEP 698 for details. + + """ + try: + __arg.__override__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return __arg + + +if hasattr(typing, "deprecated"): + deprecated = typing.deprecated +else: + _T = typing.TypeVar("_T") + + def deprecated(__msg: str) -> typing.Callable[[_T], _T]: + """Indicate that a class, function or overload is deprecated. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + No runtime warning is issued. The decorator sets the ``__deprecated__`` + attribute on the decorated object to the deprecation message + passed to the decorator. If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def decorator(__arg: _T) -> _T: + __arg.__deprecated__ = __msg + return __arg + + return decorator + + +# We have to do some monkey patching to deal with the dual nature of +# Unpack/TypeVarTuple: +# - We want Unpack to be a kind of TypeVar so it gets accepted in +# Generic[Unpack[Ts]] +# - We want it to *not* be treated as a TypeVar for the purposes of +# counting generic parameters, so that when we subscript a generic, +# the runtime doesn't try to substitute the Unpack with the subscripted type. +if not hasattr(typing, "TypeVarTuple"): + typing._collect_type_vars = _collect_type_vars + typing._check_generic = _check_generic + + +# Backport typing.NamedTuple as it exists in Python 3.11. +# In 3.11, the ability to define generic `NamedTuple`s was supported. +# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. +if sys.version_info >= (3, 11): + NamedTuple = typing.NamedTuple +else: + def _caller(): + try: + return sys._getframe(2).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + def _make_nmtuple(name, types, module, defaults=()): + fields = [n for n, t in types] + annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types} + nm_tpl = collections.namedtuple(name, fields, + defaults=defaults, module=module) + nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations + # The `_field_types` attribute was removed in 3.9; + # in earlier versions, it is the same as the `__annotations__` attribute + if sys.version_info < (3, 9): + nm_tpl._field_types = annotations + return nm_tpl + + _prohibited_namedtuple_fields = typing._prohibited + _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + + class _NamedTupleMeta(type): + def __new__(cls, typename, bases, ns): + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not typing.Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) + types = ns.get('__annotations__', {}) + default_names = [] + for field_name in types: + if field_name in ns: + default_names.append(field_name) + elif default_names: + raise TypeError(f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}") + nm_tpl = _make_nmtuple( + typename, types.items(), + defaults=[ns[n] for n in default_names], + module=ns['__module__'] + ) + nm_tpl.__bases__ = bases + if typing.Generic in bases: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) + # update from user namespace without overriding special namedtuple attributes + for key in ns: + if key in _prohibited_namedtuple_fields: + raise AttributeError("Cannot overwrite NamedTuple attribute " + key) + elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + if typing.Generic in bases: + nm_tpl.__init_subclass__() + return nm_tpl + + def NamedTuple(__typename, __fields=None, **kwargs): + if __fields is None: + __fields = kwargs.items() + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + return _make_nmtuple(__typename, __fields, module=_caller()) + + NamedTuple.__doc__ = typing.NamedTuple.__doc__ + _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + + # On 3.8+, alter the signature so that it matches typing.NamedTuple. + # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, + # so just leave the signature as it is on 3.7. + if sys.version_info >= (3, 8): + NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' + + def _namedtuple_mro_entries(bases): + assert NamedTuple in bases + return (_NamedTuple,) + + NamedTuple.__mro_entries__ = _namedtuple_mro_entries diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..9742e28e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,37 @@ +[tool.mypy] +# check_untyped_defs = true +# follow_imports_for_stubs = true +# disallow_any_decorated = true +# disallow_any_generics = true +# disallow_any_unimported = true +# disallow_incomplete_defs = true +# disallow_subclassing_any = true +# disallow_untyped_calls = true +# disallow_untyped_decorators = true +disallow_untyped_defs = true +# enable_error_code = redundant-expr, truthy-bool, ignore-without-code, unused-awaitable +# implicit_reexport = False +exclude = [ + "integration/", "tests/", "setup.py", "tasks.py", "sites/www/conf.py" +] +ignore_missing_imports = true +# no_implicit_optional = true +# pretty = true +# show_column_numbers = true +# show_error_codes = true +# strict_equality = true +warn_incomplete_stub = true +warn_redundant_casts = true +# warn_return_any = true +# warn_unreachable = true +warn_unused_ignores = true + +[[tool.mypy.overrides]] +module = "invoke.vendor.*" +ignore_errors = true + +# [[tool.mypy.overrides]] +# module = "mypy-tests.*" +# disallow_any_decorated = False +# disallow_untyped_calls = False +# disallow_untyped_defs = False From efef3ceef4f4a2c385297ae717736f93c4d26a6a Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 17:36:14 -0500 Subject: [PATCH 10/27] test: cleanup mypy config --- invoke/loader.py | 4 ++-- pyproject.toml | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/invoke/loader.py b/invoke/loader.py index db6d582a6..af8439b93 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -29,7 +29,7 @@ def __init__(self, config: Optional["Config"] = None) -> None: config = Config() self.config = config - def find(self, name: str) -> Tuple[IO[Any], str, Tuple[str, str, int]]: + def find(self, name: str) -> Tuple[IO, str, Tuple[str, str, int]]: """ Implementation-specific finder method seeking collection ``name``. @@ -113,7 +113,7 @@ def start(self) -> str: # Lazily determine default CWD if configured value is falsey return self._start or os.getcwd() - def find(self, name: str) -> Tuple[IO[Any], str, Tuple[str, str, int]]: + def find(self, name: str) -> Tuple[IO, str, Tuple[str, str, int]]: # Accumulate all parent directories start = self.start debug("FilesystemLoader find starting at {!r}".format(start)) diff --git a/pyproject.toml b/pyproject.toml index 9742e28e0..8ec47c6ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,12 @@ # disallow_untyped_calls = true # disallow_untyped_decorators = true disallow_untyped_defs = true -# enable_error_code = redundant-expr, truthy-bool, ignore-without-code, unused-awaitable +# enable_error_code = [ +# "redundant-expr", +# "truthy-bool", +# "ignore-without-code", +# "unused-awaitable", +# # implicit_reexport = False exclude = [ "integration/", "tests/", "setup.py", "tasks.py", "sites/www/conf.py" @@ -29,9 +34,3 @@ warn_unused_ignores = true [[tool.mypy.overrides]] module = "invoke.vendor.*" ignore_errors = true - -# [[tool.mypy.overrides]] -# module = "mypy-tests.*" -# disallow_any_decorated = False -# disallow_untyped_calls = False -# disallow_untyped_defs = False From 56ddd30eff0cf833a332e130dfaa6f468b853fd3 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 17:43:56 -0500 Subject: [PATCH 11/27] test: cleanup mypy config --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 0d7c15288..34a0773fc 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,6 +20,6 @@ setuptools>56 # Debuggery icecream>=2.1 # typing -mypy==0.991 +mypy>=0.942 typed-ast>=1.4.3 types-PyYAML>=5.4.3 From 6b08c5876142e64f8b28c470190cef173a96d61c Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 18:01:01 -0500 Subject: [PATCH 12/27] test: attempt local typing import --- dev-requirements.txt | 6 +++--- invoke/watchers.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 34a0773fc..1b2edcefe 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,6 +20,6 @@ setuptools>56 # Debuggery icecream>=2.1 # typing -mypy>=0.942 -typed-ast>=1.4.3 -types-PyYAML>=5.4.3 +mypy>=0.942,<1 +typed-ast>=1.4.3,<2 +types-PyYAML>=5.4.3,<6 diff --git a/invoke/watchers.py b/invoke/watchers.py index 6f6495bf9..89538b2ec 100644 --- a/invoke/watchers.py +++ b/invoke/watchers.py @@ -1,9 +1,14 @@ import re import threading -from typing import Generator, Iterable, Literal +from typing import Generator, Iterable from .exceptions import ResponseNotAccepted +try: + from .vendor.typing_extensions import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + class StreamWatcher(threading.local): """ From 629155dba8b468ec828cb1633a3f5811797368c5 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 18:10:50 -0500 Subject: [PATCH 13/27] test: attempt local typing import --- invoke/vendor/typing_extensions.py | 1870 +++++++++++++++++++--------- 1 file changed, 1252 insertions(+), 618 deletions(-) diff --git a/invoke/vendor/typing_extensions.py b/invoke/vendor/typing_extensions.py index c32c63d16..194731cd3 100644 --- a/invoke/vendor/typing_extensions.py +++ b/invoke/vendor/typing_extensions.py @@ -1,27 +1,33 @@ import abc import collections import collections.abc -import functools -import inspect import operator import sys import types as _types import typing +# After PEP 560, internal typing API was substantially reworked. +# This is especially important for Protocol class which uses internal APIs +# quite extensively. +PEP_560 = sys.version_info[:3] >= (3, 7, 0) +if PEP_560: + GenericMeta = type +else: + # 3.6 + from typing import GenericMeta, _type_vars # noqa + + +# Please keep __all__ alphabetized within each category. __all__ = [ # Super-special typing primitives. - 'Any', 'ClassVar', 'Concatenate', 'Final', 'LiteralString', 'ParamSpec', - 'ParamSpecArgs', - 'ParamSpecKwargs', 'Self', 'Type', - 'TypeVar', 'TypeVarTuple', 'Unpack', @@ -39,7 +45,6 @@ 'Counter', 'Deque', 'DefaultDict', - 'NamedTuple', 'OrderedDict', 'TypedDict', @@ -49,21 +54,13 @@ # One-off things. 'Annotated', 'assert_never', - 'assert_type', - 'clear_overloads', 'dataclass_transform', - 'deprecated', - 'get_overloads', 'final', - 'get_args', - 'get_origin', - 'get_type_hints', 'IntVar', 'is_typeddict', 'Literal', 'NewType', 'overload', - 'override', 'Protocol', 'reveal_type', 'runtime', @@ -78,13 +75,21 @@ 'NotRequired', ] -# for backward compatibility -PEP_560 = True -GenericMeta = type +if PEP_560: + __all__.extend(["get_args", "get_origin", "get_type_hints"]) # The functions below are modified copies of typing internal helpers. # They are needed by _ProtocolMeta and they provide support for PEP 646. + +def _no_slots_copy(dct): + dict_copy = dict(dct) + if '__slots__' in dict_copy: + for slot in dict_copy['__slots__']: + dict_copy.pop(slot, None) + return dict_copy + + _marker = object() @@ -143,7 +148,32 @@ def _collect_type_vars(types, typevar_types=None): return tuple(tvars) -NoReturn = typing.NoReturn +# 3.6.2+ +if hasattr(typing, 'NoReturn'): + NoReturn = typing.NoReturn +# 3.6.0-3.6.1 +else: + class _NoReturn(typing._FinalTypingBase, _root=True): + """Special type indicating functions that never return. + Example:: + + from typing import NoReturn + + def stop() -> NoReturn: + raise Exception('no way') + + This type is invalid in other positions, e.g., ``List[NoReturn]`` + will fail in static type checkers. + """ + __slots__ = () + + def __instancecheck__(self, obj): + raise TypeError("NoReturn cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError("NoReturn cannot be used with issubclass().") + + NoReturn = _NoReturn(_root=True) # Some unconstrained type variables. These are used by the container types. # (These are not for export.) @@ -153,37 +183,6 @@ def _collect_type_vars(types, typevar_types=None): T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. - -if sys.version_info >= (3, 11): - from typing import Any -else: - - class _AnyMeta(type): - def __instancecheck__(self, obj): - if self is Any: - raise TypeError("typing_extensions.Any cannot be used with isinstance()") - return super().__instancecheck__(obj) - - def __repr__(self): - if self is Any: - return "typing_extensions.Any" - return super().__repr__() - - class Any(metaclass=_AnyMeta): - """Special type indicating an unconstrained type. - - Any is compatible with every type. - - Any assumed to have all methods. - - All values assumed to be instances of Any. - Note that all the above statements are true from the point of view of - static type checkers. At runtime, Any should not be used with instance - checks. - """ - def __new__(cls, *args, **kwargs): - if cls is Any: - raise TypeError("Any cannot be instantiated") - return super().__new__(cls, *args, **kwargs) - - ClassVar = typing.ClassVar # On older versions of typing there is an internal class named "Final". @@ -191,7 +190,7 @@ def __new__(cls, *args, **kwargs): if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): Final = typing.Final # 3.7 -else: +elif sys.version_info[:2] >= (3, 7): class _FinalForm(typing._SpecialForm, _root=True): def __repr__(self): @@ -199,7 +198,7 @@ def __repr__(self): def __getitem__(self, parameters): item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + f'{self._name} accepts only single type') return typing._GenericAlias(self, (item,)) Final = _FinalForm('Final', @@ -216,6 +215,61 @@ class FastConnector(Connection): TIMEOUT = 1 # Error reported by type checker There is no runtime checking of these properties.""") +# 3.6 +else: + class _Final(typing._FinalTypingBase, _root=True): + """A special typing construct to indicate that a name + cannot be re-assigned or overridden in a subclass. + For example: + + MAX_SIZE: Final = 9000 + MAX_SIZE += 1 # Error reported by type checker + + class Connection: + TIMEOUT: Final[int] = 10 + class FastConnector(Connection): + TIMEOUT = 1 # Error reported by type checker + + There is no runtime checking of these properties. + """ + + __slots__ = ('__type__',) + + def __init__(self, tp=None, **kwds): + self.__type__ = tp + + def __getitem__(self, item): + cls = type(self) + if self.__type__ is None: + return cls(typing._type_check(item, + f'{cls.__name__[1:]} accepts only single type.'), + _root=True) + raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + + def _eval_type(self, globalns, localns): + new_tp = typing._eval_type(self.__type__, globalns, localns) + if new_tp == self.__type__: + return self + return type(self)(new_tp, _root=True) + + def __repr__(self): + r = super().__repr__() + if self.__type__ is not None: + r += f'[{typing._type_repr(self.__type__)}]' + return r + + def __hash__(self): + return hash((type(self).__name__, self.__type__)) + + def __eq__(self, other): + if not isinstance(other, _Final): + return NotImplemented + if self.__type__ is not None: + return self.__type__ == other.__type__ + return self is other + + Final = _Final(_root=True) + if sys.version_info >= (3, 11): final = typing.final @@ -263,7 +317,7 @@ def IntVar(name): if hasattr(typing, 'Literal'): Literal = typing.Literal # 3.7: -else: +elif sys.version_info[:2] >= (3, 7): class _LiteralForm(typing._SpecialForm, _root=True): def __repr__(self): @@ -285,75 +339,59 @@ def __getitem__(self, parameters): Literal[...] cannot be subclassed. There is no runtime checking verifying that the parameter is actually a value instead of a type.""") +# 3.6: +else: + class _Literal(typing._FinalTypingBase, _root=True): + """A type that can be used to indicate to type checkers that the + corresponding value has a value literally equivalent to the + provided parameter. For example: + var: Literal[4] = 4 -_overload_dummy = typing._overload_dummy # noqa - + The type checker understands that 'var' is literally equal to the + value 4 and no other value. -if hasattr(typing, "get_overloads"): # 3.11+ - overload = typing.overload - get_overloads = typing.get_overloads - clear_overloads = typing.clear_overloads -else: - # {module: {qualname: {firstlineno: func}}} - _overload_registry = collections.defaultdict( - functools.partial(collections.defaultdict, dict) - ) - - def overload(func): - """Decorator for overloaded functions/methods. - - In a stub file, place two or more stub definitions for the same - function in a row, each decorated with @overload. For example: - - @overload - def utf8(value: None) -> None: ... - @overload - def utf8(value: bytes) -> bytes: ... - @overload - def utf8(value: str) -> bytes: ... - - In a non-stub file (i.e. a regular .py file), do the same but - follow it with an implementation. The implementation should *not* - be decorated with @overload. For example: - - @overload - def utf8(value: None) -> None: ... - @overload - def utf8(value: bytes) -> bytes: ... - @overload - def utf8(value: str) -> bytes: ... - def utf8(value): - # implementation goes here - - The overloads for a function can be retrieved at runtime using the - get_overloads() function. + Literal[...] cannot be subclassed. There is no runtime checking + verifying that the parameter is actually a value instead of a type. """ - # classmethod and staticmethod - f = getattr(func, "__func__", func) - try: - _overload_registry[f.__module__][f.__qualname__][ - f.__code__.co_firstlineno - ] = func - except AttributeError: - # Not a normal function; ignore. - pass - return _overload_dummy - def get_overloads(func): - """Return all defined overloads for *func* as a sequence.""" - # classmethod and staticmethod - f = getattr(func, "__func__", func) - if f.__module__ not in _overload_registry: - return [] - mod_dict = _overload_registry[f.__module__] - if f.__qualname__ not in mod_dict: - return [] - return list(mod_dict[f.__qualname__].values()) + __slots__ = ('__values__',) + + def __init__(self, values=None, **kwds): + self.__values__ = values + + def __getitem__(self, values): + cls = type(self) + if self.__values__ is None: + if not isinstance(values, tuple): + values = (values,) + return cls(values, _root=True) + raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + + def _eval_type(self, globalns, localns): + return self + + def __repr__(self): + r = super().__repr__() + if self.__values__ is not None: + r += f'[{", ".join(map(typing._type_repr, self.__values__))}]' + return r + + def __hash__(self): + return hash((type(self).__name__, self.__values__)) + + def __eq__(self, other): + if not isinstance(other, _Literal): + return NotImplemented + if self.__values__ is not None: + return self.__values__ == other.__values__ + return self is other + + Literal = _Literal(_root=True) - def clear_overloads(): - """Clear all overloads in the registry.""" - _overload_registry.clear() + +_overload_dummy = typing._overload_dummy # noqa +overload = typing.overload # This is not a real generic class. Don't use outside annotations. @@ -363,30 +401,154 @@ def clear_overloads(): # A few are simply re-exported for completeness. +class _ExtensionsGenericMeta(GenericMeta): + def __subclasscheck__(self, subclass): + """This mimics a more modern GenericMeta.__subclasscheck__() logic + (that does not have problems with recursion) to work around interactions + between collections, typing, and typing_extensions on older + versions of Python, see https://github.com/python/typing/issues/501. + """ + if self.__origin__ is not None: + if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + raise TypeError("Parameterized generics cannot be used with class " + "or instance checks") + return False + if not self.__extra__: + return super().__subclasscheck__(subclass) + res = self.__extra__.__subclasshook__(subclass) + if res is not NotImplemented: + return res + if self.__extra__ in subclass.__mro__: + return True + for scls in self.__extra__.__subclasses__(): + if isinstance(scls, GenericMeta): + continue + if issubclass(subclass, scls): + return True + return False + + Awaitable = typing.Awaitable Coroutine = typing.Coroutine AsyncIterable = typing.AsyncIterable AsyncIterator = typing.AsyncIterator -Deque = typing.Deque + +# 3.6.1+ +if hasattr(typing, 'Deque'): + Deque = typing.Deque +# 3.6.0 +else: + class Deque(collections.deque, typing.MutableSequence[T], + metaclass=_ExtensionsGenericMeta, + extra=collections.deque): + __slots__ = () + + def __new__(cls, *args, **kwds): + if cls._gorg is Deque: + return collections.deque(*args, **kwds) + return typing._generic_new(collections.deque, cls, *args, **kwds) + ContextManager = typing.ContextManager -AsyncContextManager = typing.AsyncContextManager +# 3.6.2+ +if hasattr(typing, 'AsyncContextManager'): + AsyncContextManager = typing.AsyncContextManager +# 3.6.0-3.6.1 +else: + from _collections_abc import _check_methods as _check_methods_in_mro # noqa + + class AsyncContextManager(typing.Generic[T_co]): + __slots__ = () + + async def __aenter__(self): + return self + + @abc.abstractmethod + async def __aexit__(self, exc_type, exc_value, traceback): + return None + + @classmethod + def __subclasshook__(cls, C): + if cls is AsyncContextManager: + return _check_methods_in_mro(C, "__aenter__", "__aexit__") + return NotImplemented + DefaultDict = typing.DefaultDict # 3.7.2+ if hasattr(typing, 'OrderedDict'): OrderedDict = typing.OrderedDict # 3.7.0-3.7.2 -else: +elif (3, 7, 0) <= sys.version_info[:3] < (3, 7, 2): OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) +# 3.6 +else: + class OrderedDict(collections.OrderedDict, typing.MutableMapping[KT, VT], + metaclass=_ExtensionsGenericMeta, + extra=collections.OrderedDict): + + __slots__ = () + + def __new__(cls, *args, **kwds): + if cls._gorg is OrderedDict: + return collections.OrderedDict(*args, **kwds) + return typing._generic_new(collections.OrderedDict, cls, *args, **kwds) + +# 3.6.2+ +if hasattr(typing, 'Counter'): + Counter = typing.Counter +# 3.6.0-3.6.1 +else: + class Counter(collections.Counter, + typing.Dict[T, int], + metaclass=_ExtensionsGenericMeta, extra=collections.Counter): + + __slots__ = () + + def __new__(cls, *args, **kwds): + if cls._gorg is Counter: + return collections.Counter(*args, **kwds) + return typing._generic_new(collections.Counter, cls, *args, **kwds) + +# 3.6.1+ +if hasattr(typing, 'ChainMap'): + ChainMap = typing.ChainMap +elif hasattr(collections, 'ChainMap'): + class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT], + metaclass=_ExtensionsGenericMeta, + extra=collections.ChainMap): + + __slots__ = () + + def __new__(cls, *args, **kwds): + if cls._gorg is ChainMap: + return collections.ChainMap(*args, **kwds) + return typing._generic_new(collections.ChainMap, cls, *args, **kwds) + +# 3.6.1+ +if hasattr(typing, 'AsyncGenerator'): + AsyncGenerator = typing.AsyncGenerator +# 3.6.0 +else: + class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra], + metaclass=_ExtensionsGenericMeta, + extra=collections.abc.AsyncGenerator): + __slots__ = () -Counter = typing.Counter -ChainMap = typing.ChainMap -AsyncGenerator = typing.AsyncGenerator NewType = typing.NewType Text = typing.Text TYPE_CHECKING = typing.TYPE_CHECKING +def _gorg(cls): + """This function exists for compatibility with old typing versions.""" + assert isinstance(cls, GenericMeta) + if hasattr(cls, '_gorg'): + return cls._gorg + while cls.__origin__ is not None: + cls = cls.__origin__ + return cls + + _PROTO_WHITELIST = ['Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', @@ -416,57 +578,17 @@ def _is_callable_members_only(cls): return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) -def _maybe_adjust_parameters(cls): - """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. - - The contents of this function are very similar - to logic found in typing.Generic.__init_subclass__ - on the CPython main branch. - """ - tvars = [] - if '__orig_bases__' in cls.__dict__: - tvars = typing._collect_type_vars(cls.__orig_bases__) - # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. - # If found, tvars must be a subset of it. - # If not found, tvars is it. - # Also check for and reject plain Generic, - # and reject multiple Generic[...] and/or Protocol[...]. - gvars = None - for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): - # for error messages - the_base = base.__origin__.__name__ - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") - gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") - tvars = gvars - cls.__parameters__ = tuple(tvars) - - # 3.8+ if hasattr(typing, 'Protocol'): Protocol = typing.Protocol # 3.7 -else: +elif PEP_560: def _no_init(self, *args, **kwargs): if type(self)._is_protocol: raise TypeError('Protocols cannot be instantiated') - class _ProtocolMeta(abc.ABCMeta): # noqa: B024 + class _ProtocolMeta(abc.ABCMeta): # This metaclass is a bit unfortunate and exists only because of the lack # of __instancehook__. def __instancecheck__(cls, instance): @@ -552,13 +674,43 @@ def __class_getitem__(cls, params): return typing._GenericAlias(cls, params) def __init_subclass__(cls, *args, **kwargs): + tvars = [] if '__orig_bases__' in cls.__dict__: error = typing.Generic in cls.__orig_bases__ else: error = typing.Generic in cls.__bases__ if error: raise TypeError("Cannot inherit from plain Generic") - _maybe_adjust_parameters(cls) + if '__orig_bases__' in cls.__dict__: + tvars = typing._collect_type_vars(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...] and/or Protocol[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (typing.Generic, Protocol)): + # for error messages + the_base = base.__origin__.__name__ + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...]" + " and/or Protocol[...] multiple types.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) # Determine if this is a protocol or a concrete subclass. if not cls.__dict__.get('_is_protocol', None): @@ -612,12 +764,250 @@ def _proto_hook(other): raise TypeError('Protocols can only inherit from other' f' protocols, got {repr(base)}') cls.__init__ = _no_init +# 3.6 +else: + from typing import _next_in_mro, _type_check # noqa + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + class _ProtocolMeta(GenericMeta): + """Internal metaclass for Protocol. + + This exists so Protocol classes can be generic without deriving + from Generic. + """ + def __new__(cls, name, bases, namespace, + tvars=None, args=None, origin=None, extra=None, orig_bases=None): + # This is just a version copied from GenericMeta.__new__ that + # includes "Protocol" special treatment. (Comments removed for brevity.) + assert extra is None # Protocols should not have extra + if tvars is not None: + assert origin is not None + assert all(isinstance(t, typing.TypeVar) for t in tvars), tvars + else: + tvars = _type_vars(bases) + gvars = None + for base in bases: + if base is typing.Generic: + raise TypeError("Cannot inherit from plain Generic") + if (isinstance(base, GenericMeta) and + base.__origin__ in (typing.Generic, Protocol)): + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...] or" + " Protocol[...] multiple times.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ", ".join(str(t) for t in tvars if t not in gvarset) + s_args = ", ".join(str(g) for g in gvars) + cls_name = "Generic" if any(b.__origin__ is typing.Generic + for b in bases) else "Protocol" + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {cls_name}[{s_args}]") + tvars = gvars + + initial_bases = bases + if (extra is not None and type(extra) is abc.ABCMeta and + extra not in bases): + bases = (extra,) + bases + bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b + for b in bases) + if any(isinstance(b, GenericMeta) and b is not typing.Generic for b in bases): + bases = tuple(b for b in bases if b is not typing.Generic) + namespace.update({'__origin__': origin, '__extra__': extra}) + self = super(GenericMeta, cls).__new__(cls, name, bases, namespace, + _root=True) + super(GenericMeta, self).__setattr__('_gorg', + self if not origin else + _gorg(origin)) + self.__parameters__ = tvars + self.__args__ = tuple(... if a is typing._TypingEllipsis else + () if a is typing._TypingEmpty else + a for a in args) if args else None + self.__next_in_mro__ = _next_in_mro(self) + if orig_bases is None: + self.__orig_bases__ = initial_bases + elif origin is not None: + self._abc_registry = origin._abc_registry + self._abc_cache = origin._abc_cache + if hasattr(self, '_subs_tree'): + self.__tree_hash__ = (hash(self._subs_tree()) if origin else + super(GenericMeta, self).__hash__()) + return self + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + isinstance(b, _ProtocolMeta) and + b.__origin__ is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (object, typing.Generic) or + base.__module__ == 'collections.abc' and + base.__name__ in _PROTO_WHITELIST or + isinstance(base, typing.TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and + base.__origin__ is typing.Generic): + raise TypeError(f'Protocols can only inherit from other' + f' protocols, got {repr(base)}') + + cls.__init__ = _no_init + + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + if not isinstance(other, type): + # Same error as for issubclass(1, int) + raise TypeError('issubclass() arg 1 must be a class') + for attr in _get_protocol_attrs(cls): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + annotations = getattr(base, '__annotations__', {}) + if (isinstance(annotations, typing.Mapping) and + attr in annotations and + isinstance(other, _ProtocolMeta) and + other._is_protocol): + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + def __instancecheck__(self, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if ((not getattr(self, '_is_protocol', False) or + _is_callable_members_only(self)) and + issubclass(instance.__class__, self)): + return True + if self._is_protocol: + if all(hasattr(instance, attr) and + (not callable(getattr(self, attr, None)) or + getattr(instance, attr) is not None) + for attr in _get_protocol_attrs(self)): + return True + return super(GenericMeta, self).__instancecheck__(instance) + + def __subclasscheck__(self, cls): + if self.__origin__ is not None: + if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + raise TypeError("Parameterized generics cannot be used with class " + "or instance checks") + return False + if (self.__dict__.get('_is_protocol', None) and + not self.__dict__.get('_is_runtime_protocol', None)): + if sys._getframe(1).f_globals['__name__'] in ['abc', + 'functools', + 'typing']: + return False + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") + if (self.__dict__.get('_is_runtime_protocol', None) and + not _is_callable_members_only(self)): + if sys._getframe(1).f_globals['__name__'] in ['abc', + 'functools', + 'typing']: + return super(GenericMeta, self).__subclasscheck__(cls) + raise TypeError("Protocols with non-method members" + " don't support issubclass()") + return super(GenericMeta, self).__subclasscheck__(cls) + + @typing._tp_cache + def __getitem__(self, params): + # We also need to copy this from GenericMeta.__getitem__ to get + # special treatment of "Protocol". (Comments removed for brevity.) + if not isinstance(params, tuple): + params = (params,) + if not params and _gorg(self) is not typing.Tuple: + raise TypeError( + f"Parameter list to {self.__qualname__}[...] cannot be empty") + msg = "Parameters to generic types must be types." + params = tuple(_type_check(p, msg) for p in params) + if self in (typing.Generic, Protocol): + if not all(isinstance(p, typing.TypeVar) for p in params): + raise TypeError( + f"Parameters to {repr(self)}[...] must all be type variables") + if len(set(params)) != len(params): + raise TypeError( + f"Parameters to {repr(self)}[...] must all be unique") + tvars = params + args = params + elif self in (typing.Tuple, typing.Callable): + tvars = _type_vars(params) + args = params + elif self.__origin__ in (typing.Generic, Protocol): + raise TypeError(f"Cannot subscript already-subscripted {repr(self)}") + else: + _check_generic(self, params, len(self.__parameters__)) + tvars = _type_vars(params) + args = params + + prepend = (self,) if self.__origin__ is None else () + return self.__class__(self.__name__, + prepend + self.__bases__, + _no_slots_copy(self.__dict__), + tvars=tvars, + args=args, + origin=self, + extra=self.__extra__, + orig_bases=self.__orig_bases__) + + class Protocol(metaclass=_ProtocolMeta): + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing_extensions.runtime act as simple-minded runtime protocol that checks + only the presence of given attributes, ignoring their type signatures. + + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if _gorg(cls) is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return typing._generic_new(cls.__next_in_mro__, cls, *args, **kwds) # 3.8+ if hasattr(typing, 'runtime_checkable'): runtime_checkable = typing.runtime_checkable -# 3.7 +# 3.6-3.7 else: def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it @@ -641,7 +1031,7 @@ def runtime_checkable(cls): # 3.8+ if hasattr(typing, 'SupportsIndex'): SupportsIndex = typing.SupportsIndex -# 3.7 +# 3.6-3.7 else: @runtime_checkable class SupportsIndex(Protocol): @@ -659,7 +1049,6 @@ def __index__(self) -> int: # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 # The standard library TypedDict below Python 3.11 does not store runtime # information about optional and required keys when using Required or NotRequired. - # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. TypedDict = typing.TypedDict _TypedDictMeta = typing._TypedDictMeta is_typeddict = typing.is_typeddict @@ -730,8 +1119,6 @@ def _typeddict_new(*args, total=True, **kwargs): _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' ' /, *, total=True, **kwargs)') - _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters - class _TypedDictMeta(type): def __init__(cls, name, bases, ns, total=True): super().__init__(name, bases, ns) @@ -744,23 +1131,13 @@ def __new__(cls, name, bases, ns, total=True): # Subclasses and instances of TypedDict return actual dictionaries # via _dict_new. ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new - # Don't insert typing.Generic into __bases__ here, - # or Generic.__init_subclass__ will raise TypeError - # in the super().__new__() call. - # Instead, monkey-patch __bases__ onto the class after it's been created. tp_dict = super().__new__(cls, name, (dict,), ns) - if any(issubclass(base, typing.Generic) for base in bases): - tp_dict.__bases__ = (typing.Generic, dict) - _maybe_adjust_parameters(tp_dict) - annotations = {} own_annotations = ns.get('__annotations__', {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" - kwds = {"module": tp_dict.__module__} if _TAKES_MODULE else {} own_annotations = { - n: typing._type_check(tp, msg, **kwds) - for n, tp in own_annotations.items() + n: typing._type_check(tp, msg) for n, tp in own_annotations.items() } required_keys = set() optional_keys = set() @@ -771,22 +1148,29 @@ def __new__(cls, name, bases, ns, total=True): optional_keys.update(base.__dict__.get('__optional_keys__', ())) annotations.update(own_annotations) - for annotation_key, annotation_type in own_annotations.items(): - annotation_origin = get_origin(annotation_type) - if annotation_origin is Annotated: - annotation_args = get_args(annotation_type) - if annotation_args: - annotation_type = annotation_args[0] - annotation_origin = get_origin(annotation_type) - - if annotation_origin is Required: - required_keys.add(annotation_key) - elif annotation_origin is NotRequired: - optional_keys.add(annotation_key) - elif total: - required_keys.add(annotation_key) + if PEP_560: + for annotation_key, annotation_type in own_annotations.items(): + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + annotation_origin = get_origin(annotation_type) + + if annotation_origin is Required: + required_keys.add(annotation_key) + elif annotation_origin is NotRequired: + optional_keys.add(annotation_key) + elif total: + required_keys.add(annotation_key) + else: + optional_keys.add(annotation_key) + else: + own_annotation_keys = set(own_annotations.keys()) + if total: + required_keys.update(own_annotation_keys) else: - optional_keys.add(annotation_key) + optional_keys.update(own_annotation_keys) tp_dict.__annotations__ = annotations tp_dict.__required_keys__ = frozenset(required_keys) @@ -847,30 +1231,9 @@ class Film(TypedDict): """ return isinstance(tp, tuple(_TYPEDDICT_TYPES)) - -if hasattr(typing, "assert_type"): - assert_type = typing.assert_type - -else: - def assert_type(__val, __typ): - """Assert (to the type checker) that the value is of the given type. - - When the type checker encounters a call to assert_type(), it - emits an error if the value is not of the specified type:: - - def greet(name: str) -> None: - assert_type(name, str) # ok - assert_type(name, int) # type checker error - - At runtime this returns the first argument unchanged and otherwise - does nothing. - """ - return __val - - if hasattr(typing, "Required"): get_type_hints = typing.get_type_hints -else: +elif PEP_560: import functools import types @@ -949,7 +1312,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): # to work. _AnnotatedAlias = typing._AnnotatedAlias # 3.7-3.8 -else: +elif PEP_560: class _AnnotatedAlias(typing._GenericAlias, _root=True): """Runtime representation of an annotated type. @@ -1046,45 +1409,191 @@ def __init_subclass__(cls, *args, **kwargs): raise TypeError( f"Cannot subclass {cls.__module__}.Annotated" ) - -# Python 3.8 has get_origin() and get_args() but those implementations aren't -# Annotated-aware, so we can't use those. Python 3.9's versions don't support -# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. -if sys.version_info[:2] >= (3, 10): - get_origin = typing.get_origin - get_args = typing.get_args -# 3.7-3.9 +# 3.6 else: - try: - # 3.9+ - from typing import _BaseGenericAlias - except ImportError: - _BaseGenericAlias = typing._GenericAlias - try: - # 3.9+ - from typing import GenericAlias as _typing_GenericAlias - except ImportError: - _typing_GenericAlias = typing._GenericAlias - def get_origin(tp): - """Get the unsubscripted version of a type. + def _is_dunder(name): + """Returns True if name is a __dunder_variable_name__.""" + return len(name) > 4 and name.startswith('__') and name.endswith('__') - This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar - and Annotated. Return None for unsupported types. Examples:: + # Prior to Python 3.7 types did not have `copy_with`. A lot of the equality + # checks, argument expansion etc. are done on the _subs_tre. As a result we + # can't provide a get_type_hints function that strips out annotations. - get_origin(Literal[42]) is Literal - get_origin(int) is None - get_origin(ClassVar[int]) is ClassVar - get_origin(Generic) is Generic - get_origin(Generic[T]) is Generic - get_origin(Union[T, int]) is Union - get_origin(List[Tuple[T, T]][int]) == list - get_origin(P.args) is P - """ - if isinstance(tp, _AnnotatedAlias): - return Annotated - if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, - ParamSpecArgs, ParamSpecKwargs)): + class AnnotatedMeta(typing.GenericMeta): + """Metaclass for Annotated""" + + def __new__(cls, name, bases, namespace, **kwargs): + if any(b is not object for b in bases): + raise TypeError("Cannot subclass " + str(Annotated)) + return super().__new__(cls, name, bases, namespace, **kwargs) + + @property + def __metadata__(self): + return self._subs_tree()[2] + + def _tree_repr(self, tree): + cls, origin, metadata = tree + if not isinstance(origin, tuple): + tp_repr = typing._type_repr(origin) + else: + tp_repr = origin[0]._tree_repr(origin) + metadata_reprs = ", ".join(repr(arg) for arg in metadata) + return f'{cls}[{tp_repr}, {metadata_reprs}]' + + def _subs_tree(self, tvars=None, args=None): # noqa + if self is Annotated: + return Annotated + res = super()._subs_tree(tvars=tvars, args=args) + # Flatten nested Annotated + if isinstance(res[1], tuple) and res[1][0] is Annotated: + sub_tp = res[1][1] + sub_annot = res[1][2] + return (Annotated, sub_tp, sub_annot + res[2]) + return res + + def _get_cons(self): + """Return the class used to create instance of this type.""" + if self.__origin__ is None: + raise TypeError("Cannot get the underlying type of a " + "non-specialized Annotated type.") + tree = self._subs_tree() + while isinstance(tree, tuple) and tree[0] is Annotated: + tree = tree[1] + if isinstance(tree, tuple): + return tree[0] + else: + return tree + + @typing._tp_cache + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + if self.__origin__ is not None: # specializing an instantiated type + return super().__getitem__(params) + elif not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be instantiated " + "with at least two arguments (a type and an " + "annotation).") + else: + if ( + isinstance(params[0], typing._TypingBase) and + type(params[0]).__name__ == "_ClassVar" + ): + tp = params[0] + else: + msg = "Annotated[t, ...]: t must be a type." + tp = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return self.__class__( + self.__name__, + self.__bases__, + _no_slots_copy(self.__dict__), + tvars=_type_vars((tp,)), + # Metadata is a tuple so it won't be touched by _replace_args et al. + args=(tp, metadata), + origin=self, + ) + + def __call__(self, *args, **kwargs): + cons = self._get_cons() + result = cons(*args, **kwargs) + try: + result.__orig_class__ = self + except AttributeError: + pass + return result + + def __getattr__(self, attr): + # For simplicity we just don't relay all dunder names + if self.__origin__ is not None and not _is_dunder(attr): + return getattr(self._get_cons(), attr) + raise AttributeError(attr) + + def __setattr__(self, attr, value): + if _is_dunder(attr) or attr.startswith('_abc_'): + super().__setattr__(attr, value) + elif self.__origin__ is None: + raise AttributeError(attr) + else: + setattr(self._get_cons(), attr, value) + + def __instancecheck__(self, obj): + raise TypeError("Annotated cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError("Annotated cannot be used with issubclass().") + + class Annotated(metaclass=AnnotatedMeta): + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type, the remaining + arguments are kept as a tuple in the __metadata__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize()] + Optimized[int] == Annotated[int, runtime.Optimize()] + + OptimizedList = Annotated[List[T], runtime.Optimize()] + OptimizedList[int] == Annotated[List[int], runtime.Optimize()] + """ + +# Python 3.8 has get_origin() and get_args() but those implementations aren't +# Annotated-aware, so we can't use those. Python 3.9's versions don't support +# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. +if sys.version_info[:2] >= (3, 10): + get_origin = typing.get_origin + get_args = typing.get_args +# 3.7-3.9 +elif PEP_560: + try: + # 3.9+ + from typing import _BaseGenericAlias + except ImportError: + _BaseGenericAlias = typing._GenericAlias + try: + # 3.9+ + from typing import GenericAlias + except ImportError: + GenericAlias = typing._GenericAlias + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar + and Annotated. Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + get_origin(P.args) is P + """ + if isinstance(tp, _AnnotatedAlias): + return Annotated + if isinstance(tp, (typing._GenericAlias, GenericAlias, _BaseGenericAlias, + ParamSpecArgs, ParamSpecKwargs)): return tp.__origin__ if tp is typing.Generic: return typing.Generic @@ -1103,7 +1612,7 @@ def get_args(tp): """ if isinstance(tp, _AnnotatedAlias): return (tp.__origin__,) + tp.__metadata__ - if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias)): + if isinstance(tp, (typing._GenericAlias, GenericAlias)): if getattr(tp, "_special", False): return () res = tp.__args__ @@ -1136,7 +1645,7 @@ def TypeAlias(self, parameters): """ raise TypeError(f"{self} is not subscriptable") # 3.7-3.8 -else: +elif sys.version_info[:2] >= (3, 7): class _TypeAliasForm(typing._SpecialForm, _root=True): def __repr__(self): return 'typing_extensions.' + self._name @@ -1152,51 +1661,44 @@ def __repr__(self): It's invalid when used anywhere except as in the example above.""") +# 3.6 +else: + class _TypeAliasMeta(typing.TypingMeta): + """Metaclass for TypeAlias""" + def __repr__(self): + return 'typing_extensions.TypeAlias' -class _DefaultMixin: - """Mixin for TypeVarLike defaults.""" + class _TypeAliasBase(typing._FinalTypingBase, metaclass=_TypeAliasMeta, _root=True): + """Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. - __slots__ = () + For example:: - def __init__(self, default): - if isinstance(default, (tuple, list)): - self.__default__ = tuple((typing._type_check(d, "Default must be a type") - for d in default)) - elif default != _marker: - self.__default__ = typing._type_check(default, "Default must be a type") - else: - self.__default__ = None + Predicate: TypeAlias = Callable[..., bool] + It's invalid when used anywhere except as in the example above. + """ + __slots__ = () -# Add default and infer_variance parameters from PEP 696 and 695 -class TypeVar(typing.TypeVar, _DefaultMixin, _root=True): - """Type variable.""" + def __instancecheck__(self, obj): + raise TypeError("TypeAlias cannot be used with isinstance().") - __module__ = 'typing' + def __subclasscheck__(self, cls): + raise TypeError("TypeAlias cannot be used with issubclass().") - def __init__(self, name, *constraints, bound=None, - covariant=False, contravariant=False, - default=_marker, infer_variance=False): - super().__init__(name, *constraints, bound=bound, covariant=covariant, - contravariant=contravariant) - _DefaultMixin.__init__(self, default) - self.__infer_variance__ = infer_variance + def __repr__(self): + return 'typing_extensions.TypeAlias' - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + TypeAlias = _TypeAliasBase(_root=True) # Python 3.10+ has PEP 612 if hasattr(typing, 'ParamSpecArgs'): ParamSpecArgs = typing.ParamSpecArgs ParamSpecKwargs = typing.ParamSpecKwargs -# 3.7-3.9 +# 3.6-3.9 else: class _Immutable: """Mixin to indicate that object should not be copied.""" @@ -1256,32 +1758,12 @@ def __eq__(self, other): # 3.10+ if hasattr(typing, 'ParamSpec'): - - # Add default Parameter - PEP 696 - class ParamSpec(typing.ParamSpec, _DefaultMixin, _root=True): - """Parameter specification variable.""" - - __module__ = 'typing' - - def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - default=_marker): - super().__init__(name, bound=bound, covariant=covariant, - contravariant=contravariant) - _DefaultMixin.__init__(self, default) - - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod - -# 3.7-3.9 + ParamSpec = typing.ParamSpec +# 3.6-3.9 else: # Inherits from list as a workaround for Callable checks in Python < 3.9.2. - class ParamSpec(list, _DefaultMixin): + class ParamSpec(list): """Parameter specification variable. Usage:: @@ -1339,8 +1821,7 @@ def args(self): def kwargs(self): return ParamSpecKwargs(self) - def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - default=_marker): + def __init__(self, name, *, bound=None, covariant=False, contravariant=False): super().__init__([self]) self.__name__ = name self.__covariant__ = bool(covariant) @@ -1349,7 +1830,6 @@ def __init__(self, name, *, bound=None, covariant=False, contravariant=False, self.__bound__ = typing._type_check(bound, 'Bound must be a type.') else: self.__bound__ = None - _DefaultMixin.__init__(self, default) # for pickling: try: @@ -1381,17 +1861,28 @@ def __reduce__(self): def __call__(self, *args, **kwargs): pass + if not PEP_560: + # Only needed in 3.6. + def _get_type_vars(self, tvars): + if self not in tvars: + tvars.append(self) -# 3.7-3.9 + +# 3.6-3.9 if not hasattr(typing, 'Concatenate'): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. class _ConcatenateGenericAlias(list): # Trick Generic into looking into this for __parameters__. - __class__ = typing._GenericAlias + if PEP_560: + __class__ = typing._GenericAlias + else: + __class__ = typing._TypingBase # Flag in 3.8. _special = False + # Attribute in 3.6 and earlier. + _gorg = typing.Generic def __init__(self, origin, args): super().__init__(args) @@ -1416,8 +1907,14 @@ def __parameters__(self): tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) ) + if not PEP_560: + # Only required in 3.6. + def _get_type_vars(self, tvars): + if self.__origin__ and self.__parameters__: + typing._get_type_vars(self.__parameters__, tvars) -# 3.7-3.9 + +# 3.6-3.9 @typing._tp_cache def _concatenate_getitem(self, parameters): if parameters == (): @@ -1452,7 +1949,7 @@ def Concatenate(self, parameters): """ return _concatenate_getitem(self, parameters) # 3.7-8 -else: +elif sys.version_info[:2] >= (3, 7): class _ConcatenateForm(typing._SpecialForm, _root=True): def __repr__(self): return 'typing_extensions.' + self._name @@ -1472,6 +1969,42 @@ def __getitem__(self, parameters): See PEP 612 for detailed information. """) +# 3.6 +else: + class _ConcatenateAliasMeta(typing.TypingMeta): + """Metaclass for Concatenate.""" + + def __repr__(self): + return 'typing_extensions.Concatenate' + + class _ConcatenateAliasBase(typing._FinalTypingBase, + metaclass=_ConcatenateAliasMeta, + _root=True): + """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """ + __slots__ = () + + def __instancecheck__(self, obj): + raise TypeError("Concatenate cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError("Concatenate cannot be used with issubclass().") + + def __repr__(self): + return 'typing_extensions.Concatenate' + + def __getitem__(self, parameters): + return _concatenate_getitem(self, parameters) + + Concatenate = _ConcatenateAliasBase(_root=True) # 3.10+ if hasattr(typing, 'TypeGuard'): @@ -1526,10 +2059,10 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). """ - item = typing._type_check(parameters, f'{self} accepts only a single type.') + item = typing._type_check(parameters, f'{self} accepts only single type.') return typing._GenericAlias(self, (item,)) # 3.7-3.8 -else: +elif sys.version_info[:2] >= (3, 7): class _TypeGuardForm(typing._SpecialForm, _root=True): def __repr__(self): @@ -1584,55 +2117,138 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). """) +# 3.6 +else: + class _TypeGuard(typing._FinalTypingBase, _root=True): + """Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". -# Vendored from cpython typing._SpecialFrom -class _SpecialForm(typing._Final, _root=True): - __slots__ = ('_name', '__doc__', '_getitem') + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. - def __init__(self, getitem): - self._getitem = getitem - self._name = getitem.__name__ - self.__doc__ = getitem.__doc__ + Using ``-> TypeGuard`` tells the static type checker that for a given + function: - def __getattr__(self, item): - if item in {'__name__', '__qualname__'}: - return self._name + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. - raise AttributeError(item) + For example:: - def __mro_entries__(self, bases): - raise TypeError(f"Cannot subclass {self!r}") + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... - def __repr__(self): - return f'typing_extensions.{self._name}' + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. - def __reduce__(self): - return self._name + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """ + + __slots__ = ('__type__',) + + def __init__(self, tp=None, **kwds): + self.__type__ = tp + + def __getitem__(self, item): + cls = type(self) + if self.__type__ is None: + return cls(typing._type_check(item, + f'{cls.__name__[1:]} accepts only a single type.'), + _root=True) + raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + + def _eval_type(self, globalns, localns): + new_tp = typing._eval_type(self.__type__, globalns, localns) + if new_tp == self.__type__: + return self + return type(self)(new_tp, _root=True) + + def __repr__(self): + r = super().__repr__() + if self.__type__ is not None: + r += f'[{typing._type_repr(self.__type__)}]' + return r + + def __hash__(self): + return hash((type(self).__name__, self.__type__)) + + def __eq__(self, other): + if not isinstance(other, _TypeGuard): + return NotImplemented + if self.__type__ is not None: + return self.__type__ == other.__type__ + return self is other + + TypeGuard = _TypeGuard(_root=True) + + +if sys.version_info[:2] >= (3, 7): + # Vendored from cpython typing._SpecialFrom + class _SpecialForm(typing._Final, _root=True): + __slots__ = ('_name', '__doc__', '_getitem') + + def __init__(self, getitem): + self._getitem = getitem + self._name = getitem.__name__ + self.__doc__ = getitem.__doc__ + + def __getattr__(self, item): + if item in {'__name__', '__qualname__'}: + return self._name + + raise AttributeError(item) + + def __mro_entries__(self, bases): + raise TypeError(f"Cannot subclass {self!r}") - def __call__(self, *args, **kwds): - raise TypeError(f"Cannot instantiate {self!r}") + def __repr__(self): + return f'typing_extensions.{self._name}' - def __or__(self, other): - return typing.Union[self, other] + def __reduce__(self): + return self._name - def __ror__(self, other): - return typing.Union[other, self] + def __call__(self, *args, **kwds): + raise TypeError(f"Cannot instantiate {self!r}") - def __instancecheck__(self, obj): - raise TypeError(f"{self} cannot be used with isinstance()") + def __or__(self, other): + return typing.Union[self, other] - def __subclasscheck__(self, cls): - raise TypeError(f"{self} cannot be used with issubclass()") + def __ror__(self, other): + return typing.Union[other, self] - @typing._tp_cache - def __getitem__(self, parameters): - return self._getitem(self, parameters) + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance()") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass()") + + @typing._tp_cache + def __getitem__(self, parameters): + return self._getitem(self, parameters) if hasattr(typing, "LiteralString"): LiteralString = typing.LiteralString -else: +elif sys.version_info[:2] >= (3, 7): @_SpecialForm def LiteralString(self, params): """Represents an arbitrary literal string. @@ -1651,11 +2267,38 @@ def query(sql: LiteralString) -> ...: """ raise TypeError(f"{self} is not subscriptable") +else: + class _LiteralString(typing._FinalTypingBase, _root=True): + """Represents an arbitrary literal string. + + Example:: + + from typing_extensions import LiteralString + + def query(sql: LiteralString) -> ...: + ... + + query("SELECT * FROM table") # ok + query(f"SELECT * FROM {input()}") # not ok + + See PEP 675 for details. + + """ + + __slots__ = () + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass().") + + LiteralString = _LiteralString(_root=True) if hasattr(typing, "Self"): Self = typing.Self -else: +elif sys.version_info[:2] >= (3, 7): @_SpecialForm def Self(self, params): """Used to spell the type of "self" in classes. @@ -1672,11 +2315,35 @@ def parse(self, data: bytes) -> Self: """ raise TypeError(f"{self} is not subscriptable") +else: + class _Self(typing._FinalTypingBase, _root=True): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class ReturnsSelf: + def parse(self, data: bytes) -> Self: + ... + return self + + """ + + __slots__ = () + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass().") + + Self = _Self(_root=True) if hasattr(typing, "Never"): Never = typing.Never -else: +elif sys.version_info[:2] >= (3, 7): @_SpecialForm def Never(self, params): """The bottom type, a type that has no members. @@ -1702,6 +2369,39 @@ def int_or_str(arg: int | str) -> None: """ raise TypeError(f"{self} is not subscriptable") +else: + class _Never(typing._FinalTypingBase, _root=True): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: + + from typing_extensions import Never + + def never_call_me(arg: Never) -> None: + pass + + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # ok, arg is of type Never + + """ + + __slots__ = () + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance().") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass().") + + Never = _Never(_root=True) if hasattr(typing, 'Required'): @@ -1729,7 +2429,7 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check(parameters, f'{self._name} accepts only single type') return typing._GenericAlias(self, (item,)) @_ExtensionsSpecialForm @@ -1746,17 +2446,17 @@ class Movie(TypedDict): year=1999, ) """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check(parameters, f'{self._name} accepts only single type') return typing._GenericAlias(self, (item,)) -else: +elif sys.version_info[:2] >= (3, 7): class _RequiredForm(typing._SpecialForm, _root=True): def __repr__(self): return 'typing_extensions.' + self._name def __getitem__(self, parameters): item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + '{} accepts only single type'.format(self._name)) return typing._GenericAlias(self, (item,)) Required = _RequiredForm( @@ -1790,11 +2490,81 @@ class Movie(TypedDict): year=1999, ) """) +else: + # NOTE: Modeled after _Final's implementation when _FinalTypingBase available + class _MaybeRequired(typing._FinalTypingBase, _root=True): + __slots__ = ('__type__',) + + def __init__(self, tp=None, **kwds): + self.__type__ = tp + + def __getitem__(self, item): + cls = type(self) + if self.__type__ is None: + return cls(typing._type_check(item, + '{} accepts only single type.'.format(cls.__name__[1:])), + _root=True) + raise TypeError('{} cannot be further subscripted' + .format(cls.__name__[1:])) + + def _eval_type(self, globalns, localns): + new_tp = typing._eval_type(self.__type__, globalns, localns) + if new_tp == self.__type__: + return self + return type(self)(new_tp, _root=True) + def __repr__(self): + r = super().__repr__() + if self.__type__ is not None: + r += '[{}]'.format(typing._type_repr(self.__type__)) + return r -if hasattr(typing, "Unpack"): # 3.11+ - Unpack = typing.Unpack -elif sys.version_info[:2] >= (3, 9): + def __hash__(self): + return hash((type(self).__name__, self.__type__)) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + if self.__type__ is not None: + return self.__type__ == other.__type__ + return self is other + + class _Required(_MaybeRequired, _root=True): + """A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + + class _NotRequired(_MaybeRequired, _root=True): + """A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + + Required = _Required(_root=True) + NotRequired = _NotRequired(_root=True) + + +if sys.version_info[:2] >= (3, 9): class _UnpackSpecialForm(typing._SpecialForm, _root=True): def __repr__(self): return 'typing_extensions.' + self._name @@ -1814,13 +2584,13 @@ def add_batch_axis( ) -> Array[Batch, Unpack[Shape]]: ... """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check(parameters, f'{self._name} accepts only single type') return _UnpackAlias(self, (item,)) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) -else: +elif sys.version_info[:2] >= (3, 7): class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar @@ -1830,7 +2600,7 @@ def __repr__(self): def __getitem__(self, parameters): item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + f'{self._name} accepts only single type') return _UnpackAlias(self, (item,)) Unpack = _UnpackForm( @@ -1849,105 +2619,149 @@ def add_batch_axis( def _is_unpack(obj): return isinstance(obj, _UnpackAlias) +else: + # NOTE: Modeled after _Final's implementation when _FinalTypingBase available + class _Unpack(typing._FinalTypingBase, _root=True): + """A special typing construct to unpack a variadic type. For example: -if hasattr(typing, "TypeVarTuple"): # 3.11+ + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) - # Add default Parameter - PEP 696 - class TypeVarTuple(typing.TypeVarTuple, _DefaultMixin, _root=True): - """Type variable tuple.""" + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... - def __init__(self, name, *, default=_marker): - super().__init__(name) - _DefaultMixin.__init__(self, default) + """ + __slots__ = ('__type__',) + __class__ = typing.TypeVar - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + def __init__(self, tp=None, **kwds): + self.__type__ = tp -else: - class TypeVarTuple(_DefaultMixin): - """Type variable tuple. + def __getitem__(self, item): + cls = type(self) + if self.__type__ is None: + return cls(typing._type_check(item, + 'Unpack accepts only single type.'), + _root=True) + raise TypeError('Unpack cannot be further subscripted') - Usage:: + def _eval_type(self, globalns, localns): + new_tp = typing._eval_type(self.__type__, globalns, localns) + if new_tp == self.__type__: + return self + return type(self)(new_tp, _root=True) - Ts = TypeVarTuple('Ts') + def __repr__(self): + r = super().__repr__() + if self.__type__ is not None: + r += '[{}]'.format(typing._type_repr(self.__type__)) + return r - In the same way that a normal type variable is a stand-in for a single - type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* - type such as ``Tuple[int, str]``. + def __hash__(self): + return hash((type(self).__name__, self.__type__)) - Type variable tuples can be used in ``Generic`` declarations. - Consider the following example:: + def __eq__(self, other): + if not isinstance(other, _Unpack): + return NotImplemented + if self.__type__ is not None: + return self.__type__ == other.__type__ + return self is other - class Array(Generic[*Ts]): ... + # For 3.6 only + def _get_type_vars(self, tvars): + self.__type__._get_type_vars(tvars) - The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, - where ``T1`` and ``T2`` are type variables. To use these type variables - as type parameters of ``Array``, we must *unpack* the type variable tuple using - the star operator: ``*Ts``. The signature of ``Array`` then behaves - as if we had simply written ``class Array(Generic[T1, T2]): ...``. - In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows - us to parameterise the class with an *arbitrary* number of type parameters. + Unpack = _Unpack(_root=True) - Type variable tuples can be used anywhere a normal ``TypeVar`` can. - This includes class definitions, as shown above, as well as function - signatures and variable annotations:: + def _is_unpack(obj): + return isinstance(obj, _Unpack) - class Array(Generic[*Ts]): - def __init__(self, shape: Tuple[*Ts]): - self._shape: Tuple[*Ts] = shape +class TypeVarTuple: + """Type variable tuple. - def get_shape(self) -> Tuple[*Ts]: - return self._shape + Usage:: - shape = (Height(480), Width(640)) - x: Array[Height, Width] = Array(shape) - y = abs(x) # Inferred type is Array[Height, Width] - z = x + x # ... is Array[Height, Width] - x.get_shape() # ... is tuple[Height, Width] + Ts = TypeVarTuple('Ts') - """ + In the same way that a normal type variable is a stand-in for a single + type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* type such as + ``Tuple[int, str]``. - # Trick Generic __parameters__. - __class__ = typing.TypeVar + Type variable tuples can be used in ``Generic`` declarations. + Consider the following example:: - def __iter__(self): - yield self.__unpacked__ + class Array(Generic[*Ts]): ... - def __init__(self, name, *, default=_marker): - self.__name__ = name - _DefaultMixin.__init__(self, default) + The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, + where ``T1`` and ``T2`` are type variables. To use these type variables + as type parameters of ``Array``, we must *unpack* the type variable tuple using + the star operator: ``*Ts``. The signature of ``Array`` then behaves + as if we had simply written ``class Array(Generic[T1, T2]): ...``. + In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows + us to parameterise the class with an *arbitrary* number of type parameters. - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + Type variable tuples can be used anywhere a normal ``TypeVar`` can. + This includes class definitions, as shown above, as well as function + signatures and variable annotations:: - self.__unpacked__ = Unpack[self] + class Array(Generic[*Ts]): - def __repr__(self): - return self.__name__ + def __init__(self, shape: Tuple[*Ts]): + self._shape: Tuple[*Ts] = shape - def __hash__(self): - return object.__hash__(self) + def get_shape(self) -> Tuple[*Ts]: + return self._shape - def __eq__(self, other): - return self is other + shape = (Height(480), Width(640)) + x: Array[Height, Width] = Array(shape) + y = abs(x) # Inferred type is Array[Height, Width] + z = x + x # ... is Array[Height, Width] + x.get_shape() # ... is tuple[Height, Width] - def __reduce__(self): - return self.__name__ + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + def __iter__(self): + yield self.__unpacked__ + + def __init__(self, name): + self.__name__ = name + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + self.__unpacked__ = Unpack[self] + + def __repr__(self): + return self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(self, *args, **kwds): + if '_root' not in kwds: + raise TypeError("Cannot subclass special typing classes") - def __init_subclass__(self, *args, **kwds): - if '_root' not in kwds: - raise TypeError("Cannot subclass special typing classes") + if not PEP_560: + # Only needed in 3.6. + def _get_type_vars(self, tvars): + if self not in tvars: + tvars.append(self) if hasattr(typing, "reveal_type"): @@ -1999,8 +2813,7 @@ def int_or_str(arg: int | str) -> None: raise AssertionError("Expected code to be unreachable") -if sys.version_info >= (3, 12): - # dataclass_transform exists in 3.11 but lacks the frozen_default parameter +if hasattr(typing, 'dataclass_transform'): dataclass_transform = typing.dataclass_transform else: def dataclass_transform( @@ -2008,12 +2821,10 @@ def dataclass_transform( eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, - frozen_default: bool = False, - field_specifiers: typing.Tuple[ + field_descriptors: typing.Tuple[ typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], ... ] = (), - **kwargs: typing.Any, ) -> typing.Callable[[T], T]: """Decorator that marks a function, class, or metaclass as providing dataclass-like behavior. @@ -2065,10 +2876,8 @@ class CustomerModel(ModelBase): assumed to be True or False if it is omitted by the caller. - ``kw_only_default`` indicates whether the ``kw_only`` parameter is assumed to be True or False if it is omitted by the caller. - - ``frozen_default`` indicates whether the ``frozen`` parameter is - assumed to be True or False if it is omitted by the caller. - - ``field_specifiers`` specifies a static list of supported classes - or functions that describe fields, similar to ``dataclasses.field()``. + - ``field_descriptors`` specifies a static list of supported classes + or functions, that describe fields, similar to ``dataclasses.field()``. At runtime, this decorator records its arguments in the ``__dataclass_transform__`` attribute on the decorated object. @@ -2081,98 +2890,12 @@ def decorator(cls_or_fn): "eq_default": eq_default, "order_default": order_default, "kw_only_default": kw_only_default, - "frozen_default": frozen_default, - "field_specifiers": field_specifiers, - "kwargs": kwargs, + "field_descriptors": field_descriptors, } return cls_or_fn return decorator -if hasattr(typing, "override"): - override = typing.override -else: - _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) - - def override(__arg: _F) -> _F: - """Indicate that a method is intended to override a method in a base class. - - Usage: - - class Base: - def method(self) -> None: ... - pass - - class Child(Base): - @override - def method(self) -> None: - super().method() - - When this decorator is applied to a method, the type checker will - validate that it overrides a method with the same name on a base class. - This helps prevent bugs that may occur when a base class is changed - without an equivalent change to a child class. - - There is no runtime checking of these properties. The decorator - sets the ``__override__`` attribute to ``True`` on the decorated object - to allow runtime introspection. - - See PEP 698 for details. - - """ - try: - __arg.__override__ = True - except (AttributeError, TypeError): - # Skip the attribute silently if it is not writable. - # AttributeError happens if the object has __slots__ or a - # read-only property, TypeError if it's a builtin class. - pass - return __arg - - -if hasattr(typing, "deprecated"): - deprecated = typing.deprecated -else: - _T = typing.TypeVar("_T") - - def deprecated(__msg: str) -> typing.Callable[[_T], _T]: - """Indicate that a class, function or overload is deprecated. - - Usage: - - @deprecated("Use B instead") - class A: - pass - - @deprecated("Use g instead") - def f(): - pass - - @overload - @deprecated("int support is deprecated") - def g(x: int) -> int: ... - @overload - def g(x: str) -> int: ... - - When this decorator is applied to an object, the type checker - will generate a diagnostic on usage of the deprecated object. - - No runtime warning is issued. The decorator sets the ``__deprecated__`` - attribute on the decorated object to the deprecation message - passed to the decorator. If applied to an overload, the decorator - must be after the ``@overload`` decorator for the attribute to - exist on the overload as returned by ``get_overloads()``. - - See PEP 702 for details. - - """ - def decorator(__arg: _T) -> _T: - __arg.__deprecated__ = __msg - return __arg - - return decorator - - # We have to do some monkey patching to deal with the dual nature of # Unpack/TypeVarTuple: # - We want Unpack to be a kind of TypeVar so it gets accepted in @@ -2183,92 +2906,3 @@ def decorator(__arg: _T) -> _T: if not hasattr(typing, "TypeVarTuple"): typing._collect_type_vars = _collect_type_vars typing._check_generic = _check_generic - - -# Backport typing.NamedTuple as it exists in Python 3.11. -# In 3.11, the ability to define generic `NamedTuple`s was supported. -# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. -if sys.version_info >= (3, 11): - NamedTuple = typing.NamedTuple -else: - def _caller(): - try: - return sys._getframe(2).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): # For platforms without _getframe() - return None - - def _make_nmtuple(name, types, module, defaults=()): - fields = [n for n, t in types] - annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") - for n, t in types} - nm_tpl = collections.namedtuple(name, fields, - defaults=defaults, module=module) - nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations - # The `_field_types` attribute was removed in 3.9; - # in earlier versions, it is the same as the `__annotations__` attribute - if sys.version_info < (3, 9): - nm_tpl._field_types = annotations - return nm_tpl - - _prohibited_namedtuple_fields = typing._prohibited - _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) - - class _NamedTupleMeta(type): - def __new__(cls, typename, bases, ns): - assert _NamedTuple in bases - for base in bases: - if base is not _NamedTuple and base is not typing.Generic: - raise TypeError( - 'can only inherit from a NamedTuple type and Generic') - bases = tuple(tuple if base is _NamedTuple else base for base in bases) - types = ns.get('__annotations__', {}) - default_names = [] - for field_name in types: - if field_name in ns: - default_names.append(field_name) - elif default_names: - raise TypeError(f"Non-default namedtuple field {field_name} " - f"cannot follow default field" - f"{'s' if len(default_names) > 1 else ''} " - f"{', '.join(default_names)}") - nm_tpl = _make_nmtuple( - typename, types.items(), - defaults=[ns[n] for n in default_names], - module=ns['__module__'] - ) - nm_tpl.__bases__ = bases - if typing.Generic in bases: - class_getitem = typing.Generic.__class_getitem__.__func__ - nm_tpl.__class_getitem__ = classmethod(class_getitem) - # update from user namespace without overriding special namedtuple attributes - for key in ns: - if key in _prohibited_namedtuple_fields: - raise AttributeError("Cannot overwrite NamedTuple attribute " + key) - elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: - setattr(nm_tpl, key, ns[key]) - if typing.Generic in bases: - nm_tpl.__init_subclass__() - return nm_tpl - - def NamedTuple(__typename, __fields=None, **kwargs): - if __fields is None: - __fields = kwargs.items() - elif kwargs: - raise TypeError("Either list of fields or keywords" - " can be provided to NamedTuple, not both") - return _make_nmtuple(__typename, __fields, module=_caller()) - - NamedTuple.__doc__ = typing.NamedTuple.__doc__ - _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) - - # On 3.8+, alter the signature so that it matches typing.NamedTuple. - # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, - # so just leave the signature as it is on 3.7. - if sys.version_info >= (3, 8): - NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' - - def _namedtuple_mro_entries(bases): - assert NamedTuple in bases - return (_NamedTuple,) - - NamedTuple.__mro_entries__ = _namedtuple_mro_entries From a9e5617c169edee0c53e5cd775283c5485d6c1f1 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 18:15:37 -0500 Subject: [PATCH 14/27] style: blacken --- invoke/collection.py | 4 ++-- invoke/program.py | 2 +- invoke/tasks.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/invoke/collection.py b/invoke/collection.py index 835bc7e22..7e1005125 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -248,7 +248,7 @@ def add_task( task: "Task", name: Optional[str] = None, aliases: Optional[Tuple[str, ...]] = None, - default: Optional[bool] = None + default: Optional[bool] = None, ) -> None: """ Add `.Task` ``task`` to this collection. @@ -293,7 +293,7 @@ def add_collection( self, coll: "Collection", name: Optional[str] = None, - default: Optional[bool] = None + default: Optional[bool] = None, ) -> None: """ Add `.Collection` ``coll`` as a sub-collection of this one. diff --git a/invoke/program.py b/invoke/program.py index f9b4a6d5f..d2102a451 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -622,7 +622,7 @@ def called_as(self) -> str: .. versionadded:: 1.2 """ - return os.path.basename(self.argv[0]) if self.argv else 'invoke' + return os.path.basename(self.argv[0]) if self.argv else "invoke" @property def binary(self) -> str: diff --git a/invoke/tasks.py b/invoke/tasks.py index 8c2430899..1a53106e2 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -397,10 +397,11 @@ class Call: """ def __init__( - self, task: "Task", + self, + task: "Task", called_as: Optional[str] = None, args: Optional[Tuple[str, ...]] = None, - kwargs: Optional[Dict[str, Any]] = None + kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Create a new `.Call` object. @@ -476,7 +477,7 @@ def clone_data(self) -> Dict[str, Any]: def clone( self, into: Optional[Type["Call"]] = None, - with_: Optional[Dict[str, Any]] = None + with_: Optional[Dict[str, Any]] = None, ) -> "Call": """ Return a standalone copy of this Call. From 202a1eff9043a54a344b08c58411b30b342c6eab Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 19:03:23 -0500 Subject: [PATCH 15/27] refactor: remove py2 cruft --- invoke/collection.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/invoke/collection.py b/invoke/collection.py index 7e1005125..4d4ff6080 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -144,9 +144,6 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self == other - def __nonzero__(self) -> bool: - return self.__bool__() - def __bool__(self) -> bool: return bool(self.task_names) From 85e9cf625723d02bf1aaf15a5859ff16378650c9 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 19:24:13 -0500 Subject: [PATCH 16/27] Update argument.py --- invoke/parser/argument.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invoke/parser/argument.py b/invoke/parser/argument.py index 43603dce9..b50cf9a39 100644 --- a/invoke/parser/argument.py +++ b/invoke/parser/argument.py @@ -1,6 +1,6 @@ from typing import Any, Iterable, Optional, Tuple -# TODO: dynamic map kind +# TODO: dynamic type for kind # T = TypeVar('T') From cfdd42d496276bc5b4b6f64ce22ae70896bad726 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 19:31:10 -0500 Subject: [PATCH 17/27] refactor: revert import changes --- invoke/parser/parser.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index dd732f027..f7945debf 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional try: - from invoke.vendor.lexicon import Lexicon - from invoke.vendor.fluidity import StateMachine, state, transition + from ..vendor.lexicon import Lexicon + from ..vendor.fluidity import StateMachine, state, transition except ImportError: from lexicon import Lexicon # type: ignore from fluidity import StateMachine, state, transition # type: ignore -from invoke.exceptions import ParseError -from invoke.util import debug # type: ignore +from ..exceptions import ParseError +from ..util import debug # type: ignore if TYPE_CHECKING: from .context import ParserContext From d6fd2b94dac89c5b0b5170e415d4508a59dafe3f Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 29 Jan 2023 19:36:10 -0500 Subject: [PATCH 18/27] refactor: remove cruft --- invoke/program.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/invoke/program.py b/invoke/program.py index d2102a451..985fe975d 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -507,10 +507,6 @@ def parse_cleanup(self) -> None: # Print discovered tasks if necessary list_root = self.args.list.value # will be True or string - # print('list_root', type(list_root), self.args.list.value) - # print('args', self.args) - # print('args.list', self.args.list) - # print('args.list.value', self.args.list.value) self.list_format = self.args["list-format"].value self.list_depth = self.args["list-depth"].value if list_root: From 7c20e5a7743d7e940bc9697221f1ce7b53dc441f Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sat, 4 Feb 2023 11:38:38 -0500 Subject: [PATCH 19/27] test: resolve 906 threads --- dev-requirements.txt | 7 ++++--- invoke/collection.py | 3 --- invoke/config.py | 4 +--- invoke/parser/argument.py | 5 +++-- invoke/parser/context.py | 2 +- invoke/parser/parser.py | 4 ++-- invoke/program.py | 1 + invoke/runners.py | 8 ++++---- invoke/util.py | 16 +++++++--------- invoke/watchers.py | 1 + pyproject.toml | 2 +- sites/www/changelog.rst | 3 +++ 12 files changed, 28 insertions(+), 28 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 1b2edcefe..6bf5d9157 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,6 +20,7 @@ setuptools>56 # Debuggery icecream>=2.1 # typing -mypy>=0.942,<1 -typed-ast>=1.4.3,<2 -types-PyYAML>=5.4.3,<6 +mypy==0.971 +typed-ast==1.5.4 +types-PyYAML==6 +pytest-mypy==0.10.3 diff --git a/invoke/collection.py b/invoke/collection.py index 4d4ff6080..26523769d 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -141,9 +141,6 @@ def __eq__(self, other: object) -> bool: ) return False - def __ne__(self, other: object) -> bool: - return not self == other - def __bool__(self) -> bool: return bool(self.task_names) diff --git a/invoke/config.py b/invoke/config.py index 2cd3fda69..00400164c 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -819,9 +819,7 @@ def load_collection( if merge: self.merge() - def set_project_location( - self, path: Optional[Union[PathLike, str]] - ) -> None: + def set_project_location(self, path: Union[PathLike, str, None]) -> None: """ Set the directory path where a project-level config file may be found. diff --git a/invoke/parser/argument.py b/invoke/parser/argument.py index b50cf9a39..761eb6021 100644 --- a/invoke/parser/argument.py +++ b/invoke/parser/argument.py @@ -52,8 +52,9 @@ def __init__( attr_name: Optional[str] = None, ) -> None: if name and names: - msg = "Cannot give both 'name' and 'names' arguments! Pick one." - raise TypeError(msg) + raise TypeError( + "Cannot give both 'name' and 'names' arguments! Pick one." + ) if not (name or names): raise TypeError("An Argument must have at least one name.") if names: diff --git a/invoke/parser/context.py b/invoke/parser/context.py index 583b05b3f..e8a465faf 100644 --- a/invoke/parser/context.py +++ b/invoke/parser/context.py @@ -91,7 +91,7 @@ def __init__( self.args = Lexicon() self.positional_args: List[Argument] = [] self.flags = Lexicon() - self.inverse_flags: Dict[str, str] = {} # No need for Lexicone + self.inverse_flags: Dict[str, str] = {} # No need for Lexicon here self.name = name self.aliases = aliases for arg in args: diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index f7945debf..b005e35e4 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -16,11 +16,11 @@ def is_flag(value: str) -> bool: - return bool(value.startswith("-")) + return value.startswith("-") def is_long_flag(value: str) -> bool: - return bool(value.startswith("--")) + return value.startswith("--") class ParseResult(list): diff --git a/invoke/program.py b/invoke/program.py index 985fe975d..e02301886 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -618,6 +618,7 @@ def called_as(self) -> str: .. versionadded:: 1.2 """ + # FIXME: need to return a string here but argv is optional return os.path.basename(self.argv[0]) if self.argv else "invoke" @property diff --git a/invoke/runners.py b/invoke/runners.py index 48995928d..be1baf779 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -1076,7 +1076,7 @@ def start_timer(self, timeout: int) -> None: self._timer = threading.Timer(timeout, self.kill) self._timer.start() - def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: + def read_proc_stdout(self, num_bytes: int) -> Union[bytes, str, None]: """ Read ``num_bytes`` from the running process' stdout stream. @@ -1088,7 +1088,7 @@ def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: """ raise NotImplementedError - def read_proc_stderr(self, num_bytes: int) -> Optional[Union[bytes, str]]: + def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str, None]: """ Read ``num_bytes`` from the running process' stderr stream. @@ -1239,7 +1239,7 @@ def should_use_pty(self, pty: bool = False, fallback: bool = True) -> bool: use_pty = False return use_pty - def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: + def read_proc_stdout(self, num_bytes: int) -> Union[bytes, str, None]: # Obtain useful read-some-bytes function if self.using_pty: # Need to handle spurious OSErrors on some Linux platforms. @@ -1266,7 +1266,7 @@ def read_proc_stdout(self, num_bytes: int) -> Optional[Union[bytes, str]]: data = None return data - def read_proc_stderr(self, num_bytes: int) -> Optional[Union[bytes, str]]: + def read_proc_stderr(self, num_bytes: int) -> Union[bytes, str, None]: # NOTE: when using a pty, this will never be called. # TODO: do we ever get those OSErrors on stderr? Feels like we could? if self.process and self.process.stderr: diff --git a/invoke/util.py b/invoke/util.py index a8f1ee974..5ec92de64 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -163,14 +163,6 @@ class ExceptionHandlingThread(threading.Thread): .. versionadded:: 1.0 """ - # TODO: legacy cruft that needs to be removed - exc_info: Optional[ - Union[ - Tuple[Type[BaseException], BaseException, TracebackType], - Tuple[None, None, None], - ] - ] - def __init__(self, **kwargs: Any) -> None: """ Create a new exception-handling thread instance. @@ -185,7 +177,13 @@ def __init__(self, **kwargs: Any) -> None: self.daemon = True # Track exceptions raised in run() self.kwargs = kwargs - self.exc_info = None + # TODO: legacy cruft that needs to be removed + self.exc_info: Optional[ + Union[ + Tuple[Type[BaseException], BaseException, TracebackType], + Tuple[None, None, None], + ] + ] = None def run(self) -> None: try: diff --git a/invoke/watchers.py b/invoke/watchers.py index 89538b2ec..dd1718986 100644 --- a/invoke/watchers.py +++ b/invoke/watchers.py @@ -4,6 +4,7 @@ from .exceptions import ResponseNotAccepted +# TODO: update imports so that Litaral is used as type try: from .vendor.typing_extensions import Literal except ImportError: diff --git a/pyproject.toml b/pyproject.toml index 8ec47c6ff..e2f64f8a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,11 +15,11 @@ disallow_untyped_defs = true # "ignore-without-code", # "unused-awaitable", # -# implicit_reexport = False exclude = [ "integration/", "tests/", "setup.py", "tasks.py", "sites/www/conf.py" ] ignore_missing_imports = true +# implicit_reexport = False # no_implicit_optional = true # pretty = true # show_column_numbers = true diff --git a/sites/www/changelog.rst b/sites/www/changelog.rst index 31dc42052..7f4813e3e 100644 --- a/sites/www/changelog.rst +++ b/sites/www/changelog.rst @@ -2,6 +2,9 @@ Changelog ========= +- :support:`906` Implement type hints and type checking tests with mypy to + reduce errors and impove code documentation. Patches by Jesse P. Johnson and + review by Sam Bull. - :support:`901 backported` (via :issue:`903`) Tweak test suite ``setup`` methods to be named ``setup_method`` so pytest stops whining about it. Patch via Jesse P. Johnson. From dc76dabf1e1b7285f136300627f40c16bca37937 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sat, 4 Feb 2023 14:19:27 -0500 Subject: [PATCH 20/27] test: resolve 906 threads --- dev-requirements.txt | 1 - invoke/program.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6bf5d9157..8901bc3bf 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -23,4 +23,3 @@ icecream>=2.1 mypy==0.971 typed-ast==1.5.4 types-PyYAML==6 -pytest-mypy==0.10.3 diff --git a/invoke/program.py b/invoke/program.py index e02301886..9b2980c2a 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -618,8 +618,8 @@ def called_as(self) -> str: .. versionadded:: 1.2 """ - # FIXME: need to return a string here but argv is optional - return os.path.basename(self.argv[0]) if self.argv else "invoke" + # XXX: defaults to empty string if 'argv' is '[]' or 'None' + return os.path.basename(self.argv[0]) if self.argv else "" @property def binary(self) -> str: From f24917cd887a840b13384d04d84b678a4ce7af5e Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sat, 4 Feb 2023 15:52:54 -0500 Subject: [PATCH 21/27] test: resolve 906 threads --- invoke/context.py | 2 +- invoke/watchers.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/invoke/context.py b/invoke/context.py index 69bdcbae6..a1bafeda8 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -216,7 +216,7 @@ def _sudo( ) # FIXME pattern should be raw string prompt.encode('unicode_escape') watcher = FailingResponder( - pattern=re.escape(prompt), # type: ignore + pattern=re.escape(prompt), response="{}\n".format(password), sentinel="Sorry, try again.\n", ) diff --git a/invoke/watchers.py b/invoke/watchers.py index dd1718986..2ce98fe03 100644 --- a/invoke/watchers.py +++ b/invoke/watchers.py @@ -4,12 +4,6 @@ from .exceptions import ResponseNotAccepted -# TODO: update imports so that Litaral is used as type -try: - from .vendor.typing_extensions import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - class StreamWatcher(threading.local): """ @@ -65,7 +59,7 @@ class Responder(StreamWatcher): .. versionadded:: 1.0 """ - def __init__(self, pattern: Literal["pattern"], response: str) -> None: + def __init__(self, pattern: str, response: str) -> None: r""" Imprint this `Responder` with necessary parameters. @@ -127,9 +121,7 @@ class FailingResponder(Responder): .. versionadded:: 1.0 """ - def __init__( - self, pattern: Literal["pattern"], response: str, sentinel: str - ) -> None: + def __init__(self, pattern: str, response: str, sentinel: str) -> None: super().__init__(pattern, response) self.sentinel = sentinel self.failure_index = 0 From 38d0a91aa37e38ea6693b8cdc7b074fb0adba5a4 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 5 Feb 2023 10:38:52 -0500 Subject: [PATCH 22/27] test: resolve 906 threads --- invoke/collection.py | 6 +- invoke/runners.py | 2 +- invoke/tasks.py | 4 +- invoke/vendor/typing_extensions.py | 2908 ---------------------------- 4 files changed, 5 insertions(+), 2915 deletions(-) delete mode 100644 invoke/vendor/typing_extensions.py diff --git a/invoke/collection.py b/invoke/collection.py index 26523769d..9510b4f1d 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -113,10 +113,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for name, obj in kwargs.items(): self._add_object(obj, name) - def _add_object( - self, obj: Any, name: Optional[str] = None - ) -> Callable[..., Any]: - method: Callable[..., Any] + def _add_object(self, obj: Any, name: Optional[str] = None) -> Callable: + method: Callable if isinstance(obj, Task): method = self.add_task elif isinstance(obj, (Collection, ModuleType)): diff --git a/invoke/runners.py b/invoke/runners.py index be1baf779..21f8ca691 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -636,7 +636,7 @@ def create_io_threads( stdout: List[str] = [] stderr: List[str] = [] # Set up IO thread parameters (format - body_func: {kwargs}) - thread_args: Dict[Callable[..., Any], Any] = { + thread_args: Dict[Callable, Any] = { self.handle_stdout: { "buffer_": stdout, "hide": "stdout" in self.opts["hide"], diff --git a/invoke/tasks.py b/invoke/tasks.py index 1a53106e2..806ac8582 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -137,7 +137,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: def called(self) -> bool: return self.times_called > 0 - def argspec(self, body: Callable[..., Any]) -> "Signature": + def argspec(self, body: Callable) -> "Signature": """ Returns a modified `inspect.Signature` based on that of ``body``. @@ -280,7 +280,7 @@ def get_arguments( return args -def task(*args: Any, **kwargs: Any) -> Callable[..., Any]: +def task(*args: Any, **kwargs: Any) -> Callable: """ Marks wrapped callable object as a valid Invoke task. diff --git a/invoke/vendor/typing_extensions.py b/invoke/vendor/typing_extensions.py deleted file mode 100644 index 194731cd3..000000000 --- a/invoke/vendor/typing_extensions.py +++ /dev/null @@ -1,2908 +0,0 @@ -import abc -import collections -import collections.abc -import operator -import sys -import types as _types -import typing - -# After PEP 560, internal typing API was substantially reworked. -# This is especially important for Protocol class which uses internal APIs -# quite extensively. -PEP_560 = sys.version_info[:3] >= (3, 7, 0) - -if PEP_560: - GenericMeta = type -else: - # 3.6 - from typing import GenericMeta, _type_vars # noqa - - -# Please keep __all__ alphabetized within each category. -__all__ = [ - # Super-special typing primitives. - 'ClassVar', - 'Concatenate', - 'Final', - 'LiteralString', - 'ParamSpec', - 'Self', - 'Type', - 'TypeVarTuple', - 'Unpack', - - # ABCs (from collections.abc). - 'Awaitable', - 'AsyncIterator', - 'AsyncIterable', - 'Coroutine', - 'AsyncGenerator', - 'AsyncContextManager', - 'ChainMap', - - # Concrete collection types. - 'ContextManager', - 'Counter', - 'Deque', - 'DefaultDict', - 'OrderedDict', - 'TypedDict', - - # Structural checks, a.k.a. protocols. - 'SupportsIndex', - - # One-off things. - 'Annotated', - 'assert_never', - 'dataclass_transform', - 'final', - 'IntVar', - 'is_typeddict', - 'Literal', - 'NewType', - 'overload', - 'Protocol', - 'reveal_type', - 'runtime', - 'runtime_checkable', - 'Text', - 'TypeAlias', - 'TypeGuard', - 'TYPE_CHECKING', - 'Never', - 'NoReturn', - 'Required', - 'NotRequired', -] - -if PEP_560: - __all__.extend(["get_args", "get_origin", "get_type_hints"]) - -# The functions below are modified copies of typing internal helpers. -# They are needed by _ProtocolMeta and they provide support for PEP 646. - - -def _no_slots_copy(dct): - dict_copy = dict(dct) - if '__slots__' in dict_copy: - for slot in dict_copy['__slots__']: - dict_copy.pop(slot, None) - return dict_copy - - -_marker = object() - - -def _check_generic(cls, parameters, elen=_marker): - """Check correct count for parameters of a generic cls (internal helper). - This gives a nice error message in case of count mismatch. - """ - if not elen: - raise TypeError(f"{cls} is not a generic class") - if elen is _marker: - if not hasattr(cls, "__parameters__") or not cls.__parameters__: - raise TypeError(f"{cls} is not a generic class") - elen = len(cls.__parameters__) - alen = len(parameters) - if alen != elen: - if hasattr(cls, "__parameters__"): - parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] - num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) - if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): - return - raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" - f" actual {alen}, expected {elen}") - - -if sys.version_info >= (3, 10): - def _should_collect_from_parameters(t): - return isinstance( - t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) - ) -elif sys.version_info >= (3, 9): - def _should_collect_from_parameters(t): - return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) -else: - def _should_collect_from_parameters(t): - return isinstance(t, typing._GenericAlias) and not t._special - - -def _collect_type_vars(types, typevar_types=None): - """Collect all type variable contained in types in order of - first appearance (lexicographic order). For example:: - - _collect_type_vars((T, List[S, T])) == (T, S) - """ - if typevar_types is None: - typevar_types = typing.TypeVar - tvars = [] - for t in types: - if ( - isinstance(t, typevar_types) and - t not in tvars and - not _is_unpack(t) - ): - tvars.append(t) - if _should_collect_from_parameters(t): - tvars.extend([t for t in t.__parameters__ if t not in tvars]) - return tuple(tvars) - - -# 3.6.2+ -if hasattr(typing, 'NoReturn'): - NoReturn = typing.NoReturn -# 3.6.0-3.6.1 -else: - class _NoReturn(typing._FinalTypingBase, _root=True): - """Special type indicating functions that never return. - Example:: - - from typing import NoReturn - - def stop() -> NoReturn: - raise Exception('no way') - - This type is invalid in other positions, e.g., ``List[NoReturn]`` - will fail in static type checkers. - """ - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError("NoReturn cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError("NoReturn cannot be used with issubclass().") - - NoReturn = _NoReturn(_root=True) - -# Some unconstrained type variables. These are used by the container types. -# (These are not for export.) -T = typing.TypeVar('T') # Any type. -KT = typing.TypeVar('KT') # Key type. -VT = typing.TypeVar('VT') # Value type. -T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. -T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. - -ClassVar = typing.ClassVar - -# On older versions of typing there is an internal class named "Final". -# 3.8+ -if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): - Final = typing.Final -# 3.7 -elif sys.version_info[:2] >= (3, 7): - class _FinalForm(typing._SpecialForm, _root=True): - - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only single type') - return typing._GenericAlias(self, (item,)) - - Final = _FinalForm('Final', - doc="""A special typing construct to indicate that a name - cannot be re-assigned or overridden in a subclass. - For example: - - MAX_SIZE: Final = 9000 - MAX_SIZE += 1 # Error reported by type checker - - class Connection: - TIMEOUT: Final[int] = 10 - class FastConnector(Connection): - TIMEOUT = 1 # Error reported by type checker - - There is no runtime checking of these properties.""") -# 3.6 -else: - class _Final(typing._FinalTypingBase, _root=True): - """A special typing construct to indicate that a name - cannot be re-assigned or overridden in a subclass. - For example: - - MAX_SIZE: Final = 9000 - MAX_SIZE += 1 # Error reported by type checker - - class Connection: - TIMEOUT: Final[int] = 10 - class FastConnector(Connection): - TIMEOUT = 1 # Error reported by type checker - - There is no runtime checking of these properties. - """ - - __slots__ = ('__type__',) - - def __init__(self, tp=None, **kwds): - self.__type__ = tp - - def __getitem__(self, item): - cls = type(self) - if self.__type__ is None: - return cls(typing._type_check(item, - f'{cls.__name__[1:]} accepts only single type.'), - _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') - - def _eval_type(self, globalns, localns): - new_tp = typing._eval_type(self.__type__, globalns, localns) - if new_tp == self.__type__: - return self - return type(self)(new_tp, _root=True) - - def __repr__(self): - r = super().__repr__() - if self.__type__ is not None: - r += f'[{typing._type_repr(self.__type__)}]' - return r - - def __hash__(self): - return hash((type(self).__name__, self.__type__)) - - def __eq__(self, other): - if not isinstance(other, _Final): - return NotImplemented - if self.__type__ is not None: - return self.__type__ == other.__type__ - return self is other - - Final = _Final(_root=True) - - -if sys.version_info >= (3, 11): - final = typing.final -else: - # @final exists in 3.8+, but we backport it for all versions - # before 3.11 to keep support for the __final__ attribute. - # See https://bugs.python.org/issue46342 - def final(f): - """This decorator can be used to indicate to type checkers that - the decorated method cannot be overridden, and decorated class - cannot be subclassed. For example: - - class Base: - @final - def done(self) -> None: - ... - class Sub(Base): - def done(self) -> None: # Error reported by type checker - ... - @final - class Leaf: - ... - class Other(Leaf): # Error reported by type checker - ... - - There is no runtime checking of these properties. The decorator - sets the ``__final__`` attribute to ``True`` on the decorated object - to allow runtime introspection. - """ - try: - f.__final__ = True - except (AttributeError, TypeError): - # Skip the attribute silently if it is not writable. - # AttributeError happens if the object has __slots__ or a - # read-only property, TypeError if it's a builtin class. - pass - return f - - -def IntVar(name): - return typing.TypeVar(name) - - -# 3.8+: -if hasattr(typing, 'Literal'): - Literal = typing.Literal -# 3.7: -elif sys.version_info[:2] >= (3, 7): - class _LiteralForm(typing._SpecialForm, _root=True): - - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - return typing._GenericAlias(self, parameters) - - Literal = _LiteralForm('Literal', - doc="""A type that can be used to indicate to type checkers - that the corresponding value has a value literally equivalent - to the provided parameter. For example: - - var: Literal[4] = 4 - - The type checker understands that 'var' is literally equal to - the value 4 and no other value. - - Literal[...] cannot be subclassed. There is no runtime - checking verifying that the parameter is actually a value - instead of a type.""") -# 3.6: -else: - class _Literal(typing._FinalTypingBase, _root=True): - """A type that can be used to indicate to type checkers that the - corresponding value has a value literally equivalent to the - provided parameter. For example: - - var: Literal[4] = 4 - - The type checker understands that 'var' is literally equal to the - value 4 and no other value. - - Literal[...] cannot be subclassed. There is no runtime checking - verifying that the parameter is actually a value instead of a type. - """ - - __slots__ = ('__values__',) - - def __init__(self, values=None, **kwds): - self.__values__ = values - - def __getitem__(self, values): - cls = type(self) - if self.__values__ is None: - if not isinstance(values, tuple): - values = (values,) - return cls(values, _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') - - def _eval_type(self, globalns, localns): - return self - - def __repr__(self): - r = super().__repr__() - if self.__values__ is not None: - r += f'[{", ".join(map(typing._type_repr, self.__values__))}]' - return r - - def __hash__(self): - return hash((type(self).__name__, self.__values__)) - - def __eq__(self, other): - if not isinstance(other, _Literal): - return NotImplemented - if self.__values__ is not None: - return self.__values__ == other.__values__ - return self is other - - Literal = _Literal(_root=True) - - -_overload_dummy = typing._overload_dummy # noqa -overload = typing.overload - - -# This is not a real generic class. Don't use outside annotations. -Type = typing.Type - -# Various ABCs mimicking those in collections.abc. -# A few are simply re-exported for completeness. - - -class _ExtensionsGenericMeta(GenericMeta): - def __subclasscheck__(self, subclass): - """This mimics a more modern GenericMeta.__subclasscheck__() logic - (that does not have problems with recursion) to work around interactions - between collections, typing, and typing_extensions on older - versions of Python, see https://github.com/python/typing/issues/501. - """ - if self.__origin__ is not None: - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError("Parameterized generics cannot be used with class " - "or instance checks") - return False - if not self.__extra__: - return super().__subclasscheck__(subclass) - res = self.__extra__.__subclasshook__(subclass) - if res is not NotImplemented: - return res - if self.__extra__ in subclass.__mro__: - return True - for scls in self.__extra__.__subclasses__(): - if isinstance(scls, GenericMeta): - continue - if issubclass(subclass, scls): - return True - return False - - -Awaitable = typing.Awaitable -Coroutine = typing.Coroutine -AsyncIterable = typing.AsyncIterable -AsyncIterator = typing.AsyncIterator - -# 3.6.1+ -if hasattr(typing, 'Deque'): - Deque = typing.Deque -# 3.6.0 -else: - class Deque(collections.deque, typing.MutableSequence[T], - metaclass=_ExtensionsGenericMeta, - extra=collections.deque): - __slots__ = () - - def __new__(cls, *args, **kwds): - if cls._gorg is Deque: - return collections.deque(*args, **kwds) - return typing._generic_new(collections.deque, cls, *args, **kwds) - -ContextManager = typing.ContextManager -# 3.6.2+ -if hasattr(typing, 'AsyncContextManager'): - AsyncContextManager = typing.AsyncContextManager -# 3.6.0-3.6.1 -else: - from _collections_abc import _check_methods as _check_methods_in_mro # noqa - - class AsyncContextManager(typing.Generic[T_co]): - __slots__ = () - - async def __aenter__(self): - return self - - @abc.abstractmethod - async def __aexit__(self, exc_type, exc_value, traceback): - return None - - @classmethod - def __subclasshook__(cls, C): - if cls is AsyncContextManager: - return _check_methods_in_mro(C, "__aenter__", "__aexit__") - return NotImplemented - -DefaultDict = typing.DefaultDict - -# 3.7.2+ -if hasattr(typing, 'OrderedDict'): - OrderedDict = typing.OrderedDict -# 3.7.0-3.7.2 -elif (3, 7, 0) <= sys.version_info[:3] < (3, 7, 2): - OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) -# 3.6 -else: - class OrderedDict(collections.OrderedDict, typing.MutableMapping[KT, VT], - metaclass=_ExtensionsGenericMeta, - extra=collections.OrderedDict): - - __slots__ = () - - def __new__(cls, *args, **kwds): - if cls._gorg is OrderedDict: - return collections.OrderedDict(*args, **kwds) - return typing._generic_new(collections.OrderedDict, cls, *args, **kwds) - -# 3.6.2+ -if hasattr(typing, 'Counter'): - Counter = typing.Counter -# 3.6.0-3.6.1 -else: - class Counter(collections.Counter, - typing.Dict[T, int], - metaclass=_ExtensionsGenericMeta, extra=collections.Counter): - - __slots__ = () - - def __new__(cls, *args, **kwds): - if cls._gorg is Counter: - return collections.Counter(*args, **kwds) - return typing._generic_new(collections.Counter, cls, *args, **kwds) - -# 3.6.1+ -if hasattr(typing, 'ChainMap'): - ChainMap = typing.ChainMap -elif hasattr(collections, 'ChainMap'): - class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT], - metaclass=_ExtensionsGenericMeta, - extra=collections.ChainMap): - - __slots__ = () - - def __new__(cls, *args, **kwds): - if cls._gorg is ChainMap: - return collections.ChainMap(*args, **kwds) - return typing._generic_new(collections.ChainMap, cls, *args, **kwds) - -# 3.6.1+ -if hasattr(typing, 'AsyncGenerator'): - AsyncGenerator = typing.AsyncGenerator -# 3.6.0 -else: - class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra], - metaclass=_ExtensionsGenericMeta, - extra=collections.abc.AsyncGenerator): - __slots__ = () - -NewType = typing.NewType -Text = typing.Text -TYPE_CHECKING = typing.TYPE_CHECKING - - -def _gorg(cls): - """This function exists for compatibility with old typing versions.""" - assert isinstance(cls, GenericMeta) - if hasattr(cls, '_gorg'): - return cls._gorg - while cls.__origin__ is not None: - cls = cls.__origin__ - return cls - - -_PROTO_WHITELIST = ['Callable', 'Awaitable', - 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', - 'ContextManager', 'AsyncContextManager'] - - -def _get_protocol_attrs(cls): - attrs = set() - for base in cls.__mro__[:-1]: # without object - if base.__name__ in ('Protocol', 'Generic'): - continue - annotations = getattr(base, '__annotations__', {}) - for attr in list(base.__dict__.keys()) + list(annotations.keys()): - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker', '_gorg')): - attrs.add(attr) - return attrs - - -def _is_callable_members_only(cls): - return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) - - -# 3.8+ -if hasattr(typing, 'Protocol'): - Protocol = typing.Protocol -# 3.7 -elif PEP_560: - - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') - - class _ProtocolMeta(abc.ABCMeta): - # This metaclass is a bit unfortunate and exists only because of the lack - # of __instancehook__. - def __instancecheck__(cls, instance): - # We need this method for situations where attributes are - # assigned in __init__. - if ((not getattr(cls, '_is_protocol', False) or - _is_callable_members_only(cls)) and - issubclass(instance.__class__, cls)): - return True - if cls._is_protocol: - if all(hasattr(instance, attr) and - (not callable(getattr(cls, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(cls)): - return True - return super().__instancecheck__(instance) - - class Protocol(metaclass=_ProtocolMeta): - # There is quite a lot of overlapping code with typing.Generic. - # Unfortunately it is hard to avoid this while these live in two different - # modules. The duplicated code will be removed when Protocol is moved to typing. - """Base class for protocol classes. Protocol classes are defined as:: - - class Proto(Protocol): - def meth(self) -> int: - ... - - Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: - - class C: - def meth(self) -> int: - return 0 - - def func(x: Proto) -> int: - return x.meth() - - func(C()) # Passes static type check - - See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks - only the presence of given attributes, ignoring their type signatures. - - Protocol classes can be generic, they are defined as:: - - class GenProto(Protocol[T]): - def meth(self) -> T: - ... - """ - __slots__ = () - _is_protocol = True - - def __new__(cls, *args, **kwds): - if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") - return super().__new__(cls) - - @typing._tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple): - params = (params,) - if not params and cls is not typing.Tuple: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") - msg = "Parameters to generic types must be types." - params = tuple(typing._type_check(p, msg) for p in params) # noqa - if cls is Protocol: - # Generic can only be subscripted with unique type variables. - if not all(isinstance(p, typing.TypeVar) for p in params): - i = 0 - while isinstance(params[i], typing.TypeVar): - i += 1 - raise TypeError( - "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") - if len(set(params)) != len(params): - raise TypeError( - "Parameters to Protocol[...] must all be unique") - else: - # Subscripting a regular Generic subclass. - _check_generic(cls, params, len(cls.__parameters__)) - return typing._GenericAlias(cls, params) - - def __init_subclass__(cls, *args, **kwargs): - tvars = [] - if '__orig_bases__' in cls.__dict__: - error = typing.Generic in cls.__orig_bases__ - else: - error = typing.Generic in cls.__bases__ - if error: - raise TypeError("Cannot inherit from plain Generic") - if '__orig_bases__' in cls.__dict__: - tvars = typing._collect_type_vars(cls.__orig_bases__) - # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. - # If found, tvars must be a subset of it. - # If not found, tvars is it. - # Also check for and reject plain Generic, - # and reject multiple Generic[...] and/or Protocol[...]. - gvars = None - for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): - # for error messages - the_base = base.__origin__.__name__ - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") - gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") - tvars = gvars - cls.__parameters__ = tuple(tvars) - - # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): - cls._is_protocol = any(b is Protocol for b in cls.__bases__) - - # Set (or override) the protocol subclass hook. - def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): - return NotImplemented - if not getattr(cls, '_is_runtime_protocol', False): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: - return NotImplemented - raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") - if not _is_callable_members_only(cls): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: - return NotImplemented - raise TypeError("Protocols with non-method members" - " don't support issubclass()") - if not isinstance(other, type): - # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') - for attr in _get_protocol_attrs(cls): - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): - break - else: - return NotImplemented - return True - if '__subclasshook__' not in cls.__dict__: - cls.__subclasshook__ = _proto_hook - - # We have nothing more to do for non-protocols. - if not cls._is_protocol: - return - - # Check consistency of bases. - for base in cls.__bases__: - if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, _ProtocolMeta) and base._is_protocol): - raise TypeError('Protocols can only inherit from other' - f' protocols, got {repr(base)}') - cls.__init__ = _no_init -# 3.6 -else: - from typing import _next_in_mro, _type_check # noqa - - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') - - class _ProtocolMeta(GenericMeta): - """Internal metaclass for Protocol. - - This exists so Protocol classes can be generic without deriving - from Generic. - """ - def __new__(cls, name, bases, namespace, - tvars=None, args=None, origin=None, extra=None, orig_bases=None): - # This is just a version copied from GenericMeta.__new__ that - # includes "Protocol" special treatment. (Comments removed for brevity.) - assert extra is None # Protocols should not have extra - if tvars is not None: - assert origin is not None - assert all(isinstance(t, typing.TypeVar) for t in tvars), tvars - else: - tvars = _type_vars(bases) - gvars = None - for base in bases: - if base is typing.Generic: - raise TypeError("Cannot inherit from plain Generic") - if (isinstance(base, GenericMeta) and - base.__origin__ in (typing.Generic, Protocol)): - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...] or" - " Protocol[...] multiple times.") - gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ", ".join(str(t) for t in tvars if t not in gvarset) - s_args = ", ".join(str(g) for g in gvars) - cls_name = "Generic" if any(b.__origin__ is typing.Generic - for b in bases) else "Protocol" - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {cls_name}[{s_args}]") - tvars = gvars - - initial_bases = bases - if (extra is not None and type(extra) is abc.ABCMeta and - extra not in bases): - bases = (extra,) + bases - bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b - for b in bases) - if any(isinstance(b, GenericMeta) and b is not typing.Generic for b in bases): - bases = tuple(b for b in bases if b is not typing.Generic) - namespace.update({'__origin__': origin, '__extra__': extra}) - self = super(GenericMeta, cls).__new__(cls, name, bases, namespace, - _root=True) - super(GenericMeta, self).__setattr__('_gorg', - self if not origin else - _gorg(origin)) - self.__parameters__ = tvars - self.__args__ = tuple(... if a is typing._TypingEllipsis else - () if a is typing._TypingEmpty else - a for a in args) if args else None - self.__next_in_mro__ = _next_in_mro(self) - if orig_bases is None: - self.__orig_bases__ = initial_bases - elif origin is not None: - self._abc_registry = origin._abc_registry - self._abc_cache = origin._abc_cache - if hasattr(self, '_subs_tree'): - self.__tree_hash__ = (hash(self._subs_tree()) if origin else - super(GenericMeta, self).__hash__()) - return self - - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - if not cls.__dict__.get('_is_protocol', None): - cls._is_protocol = any(b is Protocol or - isinstance(b, _ProtocolMeta) and - b.__origin__ is Protocol - for b in cls.__bases__) - if cls._is_protocol: - for base in cls.__mro__[1:]: - if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, typing.TypingMeta) and base._is_protocol or - isinstance(base, GenericMeta) and - base.__origin__ is typing.Generic): - raise TypeError(f'Protocols can only inherit from other' - f' protocols, got {repr(base)}') - - cls.__init__ = _no_init - - def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): - return NotImplemented - if not isinstance(other, type): - # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') - for attr in _get_protocol_attrs(cls): - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): - break - else: - return NotImplemented - return True - if '__subclasshook__' not in cls.__dict__: - cls.__subclasshook__ = _proto_hook - - def __instancecheck__(self, instance): - # We need this method for situations where attributes are - # assigned in __init__. - if ((not getattr(self, '_is_protocol', False) or - _is_callable_members_only(self)) and - issubclass(instance.__class__, self)): - return True - if self._is_protocol: - if all(hasattr(instance, attr) and - (not callable(getattr(self, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(self)): - return True - return super(GenericMeta, self).__instancecheck__(instance) - - def __subclasscheck__(self, cls): - if self.__origin__ is not None: - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError("Parameterized generics cannot be used with class " - "or instance checks") - return False - if (self.__dict__.get('_is_protocol', None) and - not self.__dict__.get('_is_runtime_protocol', None)): - if sys._getframe(1).f_globals['__name__'] in ['abc', - 'functools', - 'typing']: - return False - raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") - if (self.__dict__.get('_is_runtime_protocol', None) and - not _is_callable_members_only(self)): - if sys._getframe(1).f_globals['__name__'] in ['abc', - 'functools', - 'typing']: - return super(GenericMeta, self).__subclasscheck__(cls) - raise TypeError("Protocols with non-method members" - " don't support issubclass()") - return super(GenericMeta, self).__subclasscheck__(cls) - - @typing._tp_cache - def __getitem__(self, params): - # We also need to copy this from GenericMeta.__getitem__ to get - # special treatment of "Protocol". (Comments removed for brevity.) - if not isinstance(params, tuple): - params = (params,) - if not params and _gorg(self) is not typing.Tuple: - raise TypeError( - f"Parameter list to {self.__qualname__}[...] cannot be empty") - msg = "Parameters to generic types must be types." - params = tuple(_type_check(p, msg) for p in params) - if self in (typing.Generic, Protocol): - if not all(isinstance(p, typing.TypeVar) for p in params): - raise TypeError( - f"Parameters to {repr(self)}[...] must all be type variables") - if len(set(params)) != len(params): - raise TypeError( - f"Parameters to {repr(self)}[...] must all be unique") - tvars = params - args = params - elif self in (typing.Tuple, typing.Callable): - tvars = _type_vars(params) - args = params - elif self.__origin__ in (typing.Generic, Protocol): - raise TypeError(f"Cannot subscript already-subscripted {repr(self)}") - else: - _check_generic(self, params, len(self.__parameters__)) - tvars = _type_vars(params) - args = params - - prepend = (self,) if self.__origin__ is None else () - return self.__class__(self.__name__, - prepend + self.__bases__, - _no_slots_copy(self.__dict__), - tvars=tvars, - args=args, - origin=self, - extra=self.__extra__, - orig_bases=self.__orig_bases__) - - class Protocol(metaclass=_ProtocolMeta): - """Base class for protocol classes. Protocol classes are defined as:: - - class Proto(Protocol): - def meth(self) -> int: - ... - - Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: - - class C: - def meth(self) -> int: - return 0 - - def func(x: Proto) -> int: - return x.meth() - - func(C()) # Passes static type check - - See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks - only the presence of given attributes, ignoring their type signatures. - - Protocol classes can be generic, they are defined as:: - - class GenProto(Protocol[T]): - def meth(self) -> T: - ... - """ - __slots__ = () - _is_protocol = True - - def __new__(cls, *args, **kwds): - if _gorg(cls) is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can be used only as a base class") - return typing._generic_new(cls.__next_in_mro__, cls, *args, **kwds) - - -# 3.8+ -if hasattr(typing, 'runtime_checkable'): - runtime_checkable = typing.runtime_checkable -# 3.6-3.7 -else: - def runtime_checkable(cls): - """Mark a protocol class as a runtime protocol, so that it - can be used with isinstance() and issubclass(). Raise TypeError - if applied to a non-protocol class. - - This allows a simple-minded structural check very similar to the - one-offs in collections.abc such as Hashable. - """ - if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: - raise TypeError('@runtime_checkable can be only applied to protocol classes,' - f' got {cls!r}') - cls._is_runtime_protocol = True - return cls - - -# Exists for backwards compatibility. -runtime = runtime_checkable - - -# 3.8+ -if hasattr(typing, 'SupportsIndex'): - SupportsIndex = typing.SupportsIndex -# 3.6-3.7 -else: - @runtime_checkable - class SupportsIndex(Protocol): - __slots__ = () - - @abc.abstractmethod - def __index__(self) -> int: - pass - - -if hasattr(typing, "Required"): - # The standard library TypedDict in Python 3.8 does not store runtime information - # about which (if any) keys are optional. See https://bugs.python.org/issue38834 - # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" - # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 - # The standard library TypedDict below Python 3.11 does not store runtime - # information about optional and required keys when using Required or NotRequired. - TypedDict = typing.TypedDict - _TypedDictMeta = typing._TypedDictMeta - is_typeddict = typing.is_typeddict -else: - def _check_fails(cls, other): - try: - if sys._getframe(1).f_globals['__name__'] not in ['abc', - 'functools', - 'typing']: - # Typed dicts are only for static structural subtyping. - raise TypeError('TypedDict does not support instance and class checks') - except (AttributeError, ValueError): - pass - return False - - def _dict_new(*args, **kwargs): - if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') - _, args = args[0], args[1:] # allow the "cls" keyword be passed - return dict(*args, **kwargs) - - _dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)' - - def _typeddict_new(*args, total=True, **kwargs): - if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') - _, args = args[0], args[1:] # allow the "cls" keyword be passed - if args: - typename, args = args[0], args[1:] # allow the "_typename" keyword be passed - elif '_typename' in kwargs: - typename = kwargs.pop('_typename') - import warnings - warnings.warn("Passing '_typename' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - raise TypeError("TypedDict.__new__() missing 1 required positional " - "argument: '_typename'") - if args: - try: - fields, = args # allow the "_fields" keyword be passed - except ValueError: - raise TypeError('TypedDict.__new__() takes from 2 to 3 ' - f'positional arguments but {len(args) + 2} ' - 'were given') - elif '_fields' in kwargs and len(kwargs) == 1: - fields = kwargs.pop('_fields') - import warnings - warnings.warn("Passing '_fields' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - fields = None - - if fields is None: - fields = kwargs - elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") - - ns = {'__annotations__': dict(fields)} - try: - # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - pass - - return _TypedDictMeta(typename, (), ns, total=total) - - _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' - ' /, *, total=True, **kwargs)') - - class _TypedDictMeta(type): - def __init__(cls, name, bases, ns, total=True): - super().__init__(name, bases, ns) - - def __new__(cls, name, bases, ns, total=True): - # Create new typed dict class object. - # This method is called directly when TypedDict is subclassed, - # or via _typeddict_new when TypedDict is instantiated. This way - # TypedDict supports all three syntaxes described in its docstring. - # Subclasses and instances of TypedDict return actual dictionaries - # via _dict_new. - ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new - tp_dict = super().__new__(cls, name, (dict,), ns) - - annotations = {} - own_annotations = ns.get('__annotations__', {}) - msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" - own_annotations = { - n: typing._type_check(tp, msg) for n, tp in own_annotations.items() - } - required_keys = set() - optional_keys = set() - - for base in bases: - annotations.update(base.__dict__.get('__annotations__', {})) - required_keys.update(base.__dict__.get('__required_keys__', ())) - optional_keys.update(base.__dict__.get('__optional_keys__', ())) - - annotations.update(own_annotations) - if PEP_560: - for annotation_key, annotation_type in own_annotations.items(): - annotation_origin = get_origin(annotation_type) - if annotation_origin is Annotated: - annotation_args = get_args(annotation_type) - if annotation_args: - annotation_type = annotation_args[0] - annotation_origin = get_origin(annotation_type) - - if annotation_origin is Required: - required_keys.add(annotation_key) - elif annotation_origin is NotRequired: - optional_keys.add(annotation_key) - elif total: - required_keys.add(annotation_key) - else: - optional_keys.add(annotation_key) - else: - own_annotation_keys = set(own_annotations.keys()) - if total: - required_keys.update(own_annotation_keys) - else: - optional_keys.update(own_annotation_keys) - - tp_dict.__annotations__ = annotations - tp_dict.__required_keys__ = frozenset(required_keys) - tp_dict.__optional_keys__ = frozenset(optional_keys) - if not hasattr(tp_dict, '__total__'): - tp_dict.__total__ = total - return tp_dict - - __instancecheck__ = __subclasscheck__ = _check_fails - - TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) - TypedDict.__module__ = __name__ - TypedDict.__doc__ = \ - """A simple typed name space. At runtime it is equivalent to a plain dict. - - TypedDict creates a dictionary type that expects all of its - instances to have a certain set of keys, with each key - associated with a value of a consistent type. This expectation - is not checked at runtime but is only enforced by type checkers. - Usage:: - - class Point2D(TypedDict): - x: int - y: int - label: str - - a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK - b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check - - assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') - - The type info can be accessed via the Point2D.__annotations__ dict, and - the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. - TypedDict supports two additional equivalent forms:: - - Point2D = TypedDict('Point2D', x=int, y=int, label=str) - Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) - - The class syntax is only supported in Python 3.6+, while two other - syntax forms work for Python 2.7 and 3.2+ - """ - - if hasattr(typing, "_TypedDictMeta"): - _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) - else: - _TYPEDDICT_TYPES = (_TypedDictMeta,) - - def is_typeddict(tp): - """Check if an annotation is a TypedDict class - - For example:: - class Film(TypedDict): - title: str - year: int - - is_typeddict(Film) # => True - is_typeddict(Union[list, str]) # => False - """ - return isinstance(tp, tuple(_TYPEDDICT_TYPES)) - -if hasattr(typing, "Required"): - get_type_hints = typing.get_type_hints -elif PEP_560: - import functools - import types - - # replaces _strip_annotations() - def _strip_extras(t): - """Strips Annotated, Required and NotRequired from a given type.""" - if isinstance(t, _AnnotatedAlias): - return _strip_extras(t.__origin__) - if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): - return _strip_extras(t.__args__[0]) - if isinstance(t, typing._GenericAlias): - stripped_args = tuple(_strip_extras(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return t.copy_with(stripped_args) - if hasattr(types, "GenericAlias") and isinstance(t, types.GenericAlias): - stripped_args = tuple(_strip_extras(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return types.GenericAlias(t.__origin__, stripped_args) - if hasattr(types, "UnionType") and isinstance(t, types.UnionType): - stripped_args = tuple(_strip_extras(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return functools.reduce(operator.or_, stripped_args) - - return t - - def get_type_hints(obj, globalns=None, localns=None, include_extras=False): - """Return type hints for an object. - - This is often the same as obj.__annotations__, but it handles - forward references encoded as string literals, adds Optional[t] if a - default value equal to None is set and recursively replaces all - 'Annotated[T, ...]', 'Required[T]' or 'NotRequired[T]' with 'T' - (unless 'include_extras=True'). - - The argument may be a module, class, method, or function. The annotations - are returned as a dictionary. For classes, annotations include also - inherited members. - - TypeError is raised if the argument is not of a type that can contain - annotations, and an empty dictionary is returned if no annotations are - present. - - BEWARE -- the behavior of globalns and localns is counterintuitive - (unless you are familiar with how eval() and exec() work). The - search order is locals first, then globals. - - - If no dict arguments are passed, an attempt is made to use the - globals from obj (or the respective module's globals for classes), - and these are also used as the locals. If the object does not appear - to have globals, an empty dictionary is used. - - - If one dict argument is passed, it is used for both globals and - locals. - - - If two dict arguments are passed, they specify globals and - locals, respectively. - """ - if hasattr(typing, "Annotated"): - hint = typing.get_type_hints( - obj, globalns=globalns, localns=localns, include_extras=True - ) - else: - hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) - if include_extras: - return hint - return {k: _strip_extras(t) for k, t in hint.items()} - - -# Python 3.9+ has PEP 593 (Annotated) -if hasattr(typing, 'Annotated'): - Annotated = typing.Annotated - # Not exported and not a public API, but needed for get_origin() and get_args() - # to work. - _AnnotatedAlias = typing._AnnotatedAlias -# 3.7-3.8 -elif PEP_560: - class _AnnotatedAlias(typing._GenericAlias, _root=True): - """Runtime representation of an annotated type. - - At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' - with extra annotations. The alias behaves like a normal typing alias, - instantiating is the same as instantiating the underlying type, binding - it to types is also the same. - """ - def __init__(self, origin, metadata): - if isinstance(origin, _AnnotatedAlias): - metadata = origin.__metadata__ + metadata - origin = origin.__origin__ - super().__init__(origin, origin) - self.__metadata__ = metadata - - def copy_with(self, params): - assert len(params) == 1 - new_type = params[0] - return _AnnotatedAlias(new_type, self.__metadata__) - - def __repr__(self): - return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " - f"{', '.join(repr(a) for a in self.__metadata__)}]") - - def __reduce__(self): - return operator.getitem, ( - Annotated, (self.__origin__,) + self.__metadata__ - ) - - def __eq__(self, other): - if not isinstance(other, _AnnotatedAlias): - return NotImplemented - if self.__origin__ != other.__origin__: - return False - return self.__metadata__ == other.__metadata__ - - def __hash__(self): - return hash((self.__origin__, self.__metadata__)) - - class Annotated: - """Add context specific metadata to a type. - - Example: Annotated[int, runtime_check.Unsigned] indicates to the - hypothetical runtime_check module that this type is an unsigned int. - Every other consumer of this type can ignore this metadata and treat - this type as int. - - The first argument to Annotated must be a valid type (and will be in - the __origin__ field), the remaining arguments are kept as a tuple in - the __extra__ field. - - Details: - - - It's an error to call `Annotated` with less than two arguments. - - Nested Annotated are flattened:: - - Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] - - - Instantiating an annotated type is equivalent to instantiating the - underlying type:: - - Annotated[C, Ann1](5) == C(5) - - - Annotated can be used as a generic type alias:: - - Optimized = Annotated[T, runtime.Optimize()] - Optimized[int] == Annotated[int, runtime.Optimize()] - - OptimizedList = Annotated[List[T], runtime.Optimize()] - OptimizedList[int] == Annotated[List[int], runtime.Optimize()] - """ - - __slots__ = () - - def __new__(cls, *args, **kwargs): - raise TypeError("Type Annotated cannot be instantiated.") - - @typing._tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be used " - "with at least two arguments (a type and an " - "annotation).") - allowed_special_forms = (ClassVar, Final) - if get_origin(params[0]) in allowed_special_forms: - origin = params[0] - else: - msg = "Annotated[t, ...]: t must be a type." - origin = typing._type_check(params[0], msg) - metadata = tuple(params[1:]) - return _AnnotatedAlias(origin, metadata) - - def __init_subclass__(cls, *args, **kwargs): - raise TypeError( - f"Cannot subclass {cls.__module__}.Annotated" - ) -# 3.6 -else: - - def _is_dunder(name): - """Returns True if name is a __dunder_variable_name__.""" - return len(name) > 4 and name.startswith('__') and name.endswith('__') - - # Prior to Python 3.7 types did not have `copy_with`. A lot of the equality - # checks, argument expansion etc. are done on the _subs_tre. As a result we - # can't provide a get_type_hints function that strips out annotations. - - class AnnotatedMeta(typing.GenericMeta): - """Metaclass for Annotated""" - - def __new__(cls, name, bases, namespace, **kwargs): - if any(b is not object for b in bases): - raise TypeError("Cannot subclass " + str(Annotated)) - return super().__new__(cls, name, bases, namespace, **kwargs) - - @property - def __metadata__(self): - return self._subs_tree()[2] - - def _tree_repr(self, tree): - cls, origin, metadata = tree - if not isinstance(origin, tuple): - tp_repr = typing._type_repr(origin) - else: - tp_repr = origin[0]._tree_repr(origin) - metadata_reprs = ", ".join(repr(arg) for arg in metadata) - return f'{cls}[{tp_repr}, {metadata_reprs}]' - - def _subs_tree(self, tvars=None, args=None): # noqa - if self is Annotated: - return Annotated - res = super()._subs_tree(tvars=tvars, args=args) - # Flatten nested Annotated - if isinstance(res[1], tuple) and res[1][0] is Annotated: - sub_tp = res[1][1] - sub_annot = res[1][2] - return (Annotated, sub_tp, sub_annot + res[2]) - return res - - def _get_cons(self): - """Return the class used to create instance of this type.""" - if self.__origin__ is None: - raise TypeError("Cannot get the underlying type of a " - "non-specialized Annotated type.") - tree = self._subs_tree() - while isinstance(tree, tuple) and tree[0] is Annotated: - tree = tree[1] - if isinstance(tree, tuple): - return tree[0] - else: - return tree - - @typing._tp_cache - def __getitem__(self, params): - if not isinstance(params, tuple): - params = (params,) - if self.__origin__ is not None: # specializing an instantiated type - return super().__getitem__(params) - elif not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be instantiated " - "with at least two arguments (a type and an " - "annotation).") - else: - if ( - isinstance(params[0], typing._TypingBase) and - type(params[0]).__name__ == "_ClassVar" - ): - tp = params[0] - else: - msg = "Annotated[t, ...]: t must be a type." - tp = typing._type_check(params[0], msg) - metadata = tuple(params[1:]) - return self.__class__( - self.__name__, - self.__bases__, - _no_slots_copy(self.__dict__), - tvars=_type_vars((tp,)), - # Metadata is a tuple so it won't be touched by _replace_args et al. - args=(tp, metadata), - origin=self, - ) - - def __call__(self, *args, **kwargs): - cons = self._get_cons() - result = cons(*args, **kwargs) - try: - result.__orig_class__ = self - except AttributeError: - pass - return result - - def __getattr__(self, attr): - # For simplicity we just don't relay all dunder names - if self.__origin__ is not None and not _is_dunder(attr): - return getattr(self._get_cons(), attr) - raise AttributeError(attr) - - def __setattr__(self, attr, value): - if _is_dunder(attr) or attr.startswith('_abc_'): - super().__setattr__(attr, value) - elif self.__origin__ is None: - raise AttributeError(attr) - else: - setattr(self._get_cons(), attr, value) - - def __instancecheck__(self, obj): - raise TypeError("Annotated cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError("Annotated cannot be used with issubclass().") - - class Annotated(metaclass=AnnotatedMeta): - """Add context specific metadata to a type. - - Example: Annotated[int, runtime_check.Unsigned] indicates to the - hypothetical runtime_check module that this type is an unsigned int. - Every other consumer of this type can ignore this metadata and treat - this type as int. - - The first argument to Annotated must be a valid type, the remaining - arguments are kept as a tuple in the __metadata__ field. - - Details: - - - It's an error to call `Annotated` with less than two arguments. - - Nested Annotated are flattened:: - - Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] - - - Instantiating an annotated type is equivalent to instantiating the - underlying type:: - - Annotated[C, Ann1](5) == C(5) - - - Annotated can be used as a generic type alias:: - - Optimized = Annotated[T, runtime.Optimize()] - Optimized[int] == Annotated[int, runtime.Optimize()] - - OptimizedList = Annotated[List[T], runtime.Optimize()] - OptimizedList[int] == Annotated[List[int], runtime.Optimize()] - """ - -# Python 3.8 has get_origin() and get_args() but those implementations aren't -# Annotated-aware, so we can't use those. Python 3.9's versions don't support -# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. -if sys.version_info[:2] >= (3, 10): - get_origin = typing.get_origin - get_args = typing.get_args -# 3.7-3.9 -elif PEP_560: - try: - # 3.9+ - from typing import _BaseGenericAlias - except ImportError: - _BaseGenericAlias = typing._GenericAlias - try: - # 3.9+ - from typing import GenericAlias - except ImportError: - GenericAlias = typing._GenericAlias - - def get_origin(tp): - """Get the unsubscripted version of a type. - - This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar - and Annotated. Return None for unsupported types. Examples:: - - get_origin(Literal[42]) is Literal - get_origin(int) is None - get_origin(ClassVar[int]) is ClassVar - get_origin(Generic) is Generic - get_origin(Generic[T]) is Generic - get_origin(Union[T, int]) is Union - get_origin(List[Tuple[T, T]][int]) == list - get_origin(P.args) is P - """ - if isinstance(tp, _AnnotatedAlias): - return Annotated - if isinstance(tp, (typing._GenericAlias, GenericAlias, _BaseGenericAlias, - ParamSpecArgs, ParamSpecKwargs)): - return tp.__origin__ - if tp is typing.Generic: - return typing.Generic - return None - - def get_args(tp): - """Get type arguments with all substitutions performed. - - For unions, basic simplifications used by Union constructor are performed. - Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) - """ - if isinstance(tp, _AnnotatedAlias): - return (tp.__origin__,) + tp.__metadata__ - if isinstance(tp, (typing._GenericAlias, GenericAlias)): - if getattr(tp, "_special", False): - return () - res = tp.__args__ - if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: - res = (list(res[:-1]), res[-1]) - return res - return () - - -# 3.10+ -if hasattr(typing, 'TypeAlias'): - TypeAlias = typing.TypeAlias -# 3.9 -elif sys.version_info[:2] >= (3, 9): - class _TypeAliasForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - @_TypeAliasForm - def TypeAlias(self, parameters): - """Special marker indicating that an assignment should - be recognized as a proper type alias definition by type - checkers. - - For example:: - - Predicate: TypeAlias = Callable[..., bool] - - It's invalid when used anywhere except as in the example above. - """ - raise TypeError(f"{self} is not subscriptable") -# 3.7-3.8 -elif sys.version_info[:2] >= (3, 7): - class _TypeAliasForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - TypeAlias = _TypeAliasForm('TypeAlias', - doc="""Special marker indicating that an assignment should - be recognized as a proper type alias definition by type - checkers. - - For example:: - - Predicate: TypeAlias = Callable[..., bool] - - It's invalid when used anywhere except as in the example - above.""") -# 3.6 -else: - class _TypeAliasMeta(typing.TypingMeta): - """Metaclass for TypeAlias""" - - def __repr__(self): - return 'typing_extensions.TypeAlias' - - class _TypeAliasBase(typing._FinalTypingBase, metaclass=_TypeAliasMeta, _root=True): - """Special marker indicating that an assignment should - be recognized as a proper type alias definition by type - checkers. - - For example:: - - Predicate: TypeAlias = Callable[..., bool] - - It's invalid when used anywhere except as in the example above. - """ - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError("TypeAlias cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError("TypeAlias cannot be used with issubclass().") - - def __repr__(self): - return 'typing_extensions.TypeAlias' - - TypeAlias = _TypeAliasBase(_root=True) - - -# Python 3.10+ has PEP 612 -if hasattr(typing, 'ParamSpecArgs'): - ParamSpecArgs = typing.ParamSpecArgs - ParamSpecKwargs = typing.ParamSpecKwargs -# 3.6-3.9 -else: - class _Immutable: - """Mixin to indicate that object should not be copied.""" - __slots__ = () - - def __copy__(self): - return self - - def __deepcopy__(self, memo): - return self - - class ParamSpecArgs(_Immutable): - """The args for a ParamSpec object. - - Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. - - ParamSpecArgs objects have a reference back to their ParamSpec: - - P.args.__origin__ is P - - This type is meant for runtime introspection and has no special meaning to - static type checkers. - """ - def __init__(self, origin): - self.__origin__ = origin - - def __repr__(self): - return f"{self.__origin__.__name__}.args" - - def __eq__(self, other): - if not isinstance(other, ParamSpecArgs): - return NotImplemented - return self.__origin__ == other.__origin__ - - class ParamSpecKwargs(_Immutable): - """The kwargs for a ParamSpec object. - - Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. - - ParamSpecKwargs objects have a reference back to their ParamSpec: - - P.kwargs.__origin__ is P - - This type is meant for runtime introspection and has no special meaning to - static type checkers. - """ - def __init__(self, origin): - self.__origin__ = origin - - def __repr__(self): - return f"{self.__origin__.__name__}.kwargs" - - def __eq__(self, other): - if not isinstance(other, ParamSpecKwargs): - return NotImplemented - return self.__origin__ == other.__origin__ - -# 3.10+ -if hasattr(typing, 'ParamSpec'): - ParamSpec = typing.ParamSpec -# 3.6-3.9 -else: - - # Inherits from list as a workaround for Callable checks in Python < 3.9.2. - class ParamSpec(list): - """Parameter specification variable. - - Usage:: - - P = ParamSpec('P') - - Parameter specification variables exist primarily for the benefit of static - type checkers. They are used to forward the parameter types of one - callable to another callable, a pattern commonly found in higher order - functions and decorators. They are only valid when used in ``Concatenate``, - or s the first argument to ``Callable``. In Python 3.10 and higher, - they are also supported in user-defined Generics at runtime. - See class Generic for more information on generic types. An - example for annotating a decorator:: - - T = TypeVar('T') - P = ParamSpec('P') - - def add_logging(f: Callable[P, T]) -> Callable[P, T]: - '''A type-safe decorator to add logging to a function.''' - def inner(*args: P.args, **kwargs: P.kwargs) -> T: - logging.info(f'{f.__name__} was called') - return f(*args, **kwargs) - return inner - - @add_logging - def add_two(x: float, y: float) -> float: - '''Add two numbers together.''' - return x + y - - Parameter specification variables defined with covariant=True or - contravariant=True can be used to declare covariant or contravariant - generic types. These keyword arguments are valid, but their actual semantics - are yet to be decided. See PEP 612 for details. - - Parameter specification variables can be introspected. e.g.: - - P.__name__ == 'T' - P.__bound__ == None - P.__covariant__ == False - P.__contravariant__ == False - - Note that only parameter specification variables defined in global scope can - be pickled. - """ - - # Trick Generic __parameters__. - __class__ = typing.TypeVar - - @property - def args(self): - return ParamSpecArgs(self) - - @property - def kwargs(self): - return ParamSpecKwargs(self) - - def __init__(self, name, *, bound=None, covariant=False, contravariant=False): - super().__init__([self]) - self.__name__ = name - self.__covariant__ = bool(covariant) - self.__contravariant__ = bool(contravariant) - if bound: - self.__bound__ = typing._type_check(bound, 'Bound must be a type.') - else: - self.__bound__ = None - - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod - - def __repr__(self): - if self.__covariant__: - prefix = '+' - elif self.__contravariant__: - prefix = '-' - else: - prefix = '~' - return prefix + self.__name__ - - def __hash__(self): - return object.__hash__(self) - - def __eq__(self, other): - return self is other - - def __reduce__(self): - return self.__name__ - - # Hack to get typing._type_check to pass. - def __call__(self, *args, **kwargs): - pass - - if not PEP_560: - # Only needed in 3.6. - def _get_type_vars(self, tvars): - if self not in tvars: - tvars.append(self) - - -# 3.6-3.9 -if not hasattr(typing, 'Concatenate'): - # Inherits from list as a workaround for Callable checks in Python < 3.9.2. - class _ConcatenateGenericAlias(list): - - # Trick Generic into looking into this for __parameters__. - if PEP_560: - __class__ = typing._GenericAlias - else: - __class__ = typing._TypingBase - - # Flag in 3.8. - _special = False - # Attribute in 3.6 and earlier. - _gorg = typing.Generic - - def __init__(self, origin, args): - super().__init__(args) - self.__origin__ = origin - self.__args__ = args - - def __repr__(self): - _type_repr = typing._type_repr - return (f'{_type_repr(self.__origin__)}' - f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') - - def __hash__(self): - return hash((self.__origin__, self.__args__)) - - # Hack to get typing._type_check to pass in Generic. - def __call__(self, *args, **kwargs): - pass - - @property - def __parameters__(self): - return tuple( - tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) - ) - - if not PEP_560: - # Only required in 3.6. - def _get_type_vars(self, tvars): - if self.__origin__ and self.__parameters__: - typing._get_type_vars(self.__parameters__, tvars) - - -# 3.6-3.9 -@typing._tp_cache -def _concatenate_getitem(self, parameters): - if parameters == (): - raise TypeError("Cannot take a Concatenate of no types.") - if not isinstance(parameters, tuple): - parameters = (parameters,) - if not isinstance(parameters[-1], ParamSpec): - raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") - msg = "Concatenate[arg, ...]: each arg must be a type." - parameters = tuple(typing._type_check(p, msg) for p in parameters) - return _ConcatenateGenericAlias(self, parameters) - - -# 3.10+ -if hasattr(typing, 'Concatenate'): - Concatenate = typing.Concatenate - _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa -# 3.9 -elif sys.version_info[:2] >= (3, 9): - @_TypeAliasForm - def Concatenate(self, parameters): - """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a - higher order function which adds, removes or transforms parameters of a - callable. - - For example:: - - Callable[Concatenate[int, P], int] - - See PEP 612 for detailed information. - """ - return _concatenate_getitem(self, parameters) -# 3.7-8 -elif sys.version_info[:2] >= (3, 7): - class _ConcatenateForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - return _concatenate_getitem(self, parameters) - - Concatenate = _ConcatenateForm( - 'Concatenate', - doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a - higher order function which adds, removes or transforms parameters of a - callable. - - For example:: - - Callable[Concatenate[int, P], int] - - See PEP 612 for detailed information. - """) -# 3.6 -else: - class _ConcatenateAliasMeta(typing.TypingMeta): - """Metaclass for Concatenate.""" - - def __repr__(self): - return 'typing_extensions.Concatenate' - - class _ConcatenateAliasBase(typing._FinalTypingBase, - metaclass=_ConcatenateAliasMeta, - _root=True): - """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a - higher order function which adds, removes or transforms parameters of a - callable. - - For example:: - - Callable[Concatenate[int, P], int] - - See PEP 612 for detailed information. - """ - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError("Concatenate cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError("Concatenate cannot be used with issubclass().") - - def __repr__(self): - return 'typing_extensions.Concatenate' - - def __getitem__(self, parameters): - return _concatenate_getitem(self, parameters) - - Concatenate = _ConcatenateAliasBase(_root=True) - -# 3.10+ -if hasattr(typing, 'TypeGuard'): - TypeGuard = typing.TypeGuard -# 3.9 -elif sys.version_info[:2] >= (3, 9): - class _TypeGuardForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - @_TypeGuardForm - def TypeGuard(self, parameters): - """Special typing form used to annotate the return type of a user-defined - type guard function. ``TypeGuard`` only accepts a single type argument. - At runtime, functions marked this way should return a boolean. - - ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static - type checkers to determine a more precise type of an expression within a - program's code flow. Usually type narrowing is done by analyzing - conditional code flow and applying the narrowing to a block of code. The - conditional expression here is sometimes referred to as a "type guard". - - Sometimes it would be convenient to use a user-defined boolean function - as a type guard. Such a function should use ``TypeGuard[...]`` as its - return type to alert static type checkers to this intention. - - Using ``-> TypeGuard`` tells the static type checker that for a given - function: - - 1. The return value is a boolean. - 2. If the return value is ``True``, the type of its argument - is the type inside ``TypeGuard``. - - For example:: - - def is_str(val: Union[str, float]): - # "isinstance" type guard - if isinstance(val, str): - # Type of ``val`` is narrowed to ``str`` - ... - else: - # Else, type of ``val`` is narrowed to ``float``. - ... - - Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower - form of ``TypeA`` (it can even be a wider form) and this may lead to - type-unsafe results. The main reason is to allow for things like - narrowing ``List[object]`` to ``List[str]`` even though the latter is not - a subtype of the former, since ``List`` is invariant. The responsibility of - writing type-safe type guards is left to the user. - - ``TypeGuard`` also works with type variables. For more information, see - PEP 647 (User-Defined Type Guards). - """ - item = typing._type_check(parameters, f'{self} accepts only single type.') - return typing._GenericAlias(self, (item,)) -# 3.7-3.8 -elif sys.version_info[:2] >= (3, 7): - class _TypeGuardForm(typing._SpecialForm, _root=True): - - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type') - return typing._GenericAlias(self, (item,)) - - TypeGuard = _TypeGuardForm( - 'TypeGuard', - doc="""Special typing form used to annotate the return type of a user-defined - type guard function. ``TypeGuard`` only accepts a single type argument. - At runtime, functions marked this way should return a boolean. - - ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static - type checkers to determine a more precise type of an expression within a - program's code flow. Usually type narrowing is done by analyzing - conditional code flow and applying the narrowing to a block of code. The - conditional expression here is sometimes referred to as a "type guard". - - Sometimes it would be convenient to use a user-defined boolean function - as a type guard. Such a function should use ``TypeGuard[...]`` as its - return type to alert static type checkers to this intention. - - Using ``-> TypeGuard`` tells the static type checker that for a given - function: - - 1. The return value is a boolean. - 2. If the return value is ``True``, the type of its argument - is the type inside ``TypeGuard``. - - For example:: - - def is_str(val: Union[str, float]): - # "isinstance" type guard - if isinstance(val, str): - # Type of ``val`` is narrowed to ``str`` - ... - else: - # Else, type of ``val`` is narrowed to ``float``. - ... - - Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower - form of ``TypeA`` (it can even be a wider form) and this may lead to - type-unsafe results. The main reason is to allow for things like - narrowing ``List[object]`` to ``List[str]`` even though the latter is not - a subtype of the former, since ``List`` is invariant. The responsibility of - writing type-safe type guards is left to the user. - - ``TypeGuard`` also works with type variables. For more information, see - PEP 647 (User-Defined Type Guards). - """) -# 3.6 -else: - class _TypeGuard(typing._FinalTypingBase, _root=True): - """Special typing form used to annotate the return type of a user-defined - type guard function. ``TypeGuard`` only accepts a single type argument. - At runtime, functions marked this way should return a boolean. - - ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static - type checkers to determine a more precise type of an expression within a - program's code flow. Usually type narrowing is done by analyzing - conditional code flow and applying the narrowing to a block of code. The - conditional expression here is sometimes referred to as a "type guard". - - Sometimes it would be convenient to use a user-defined boolean function - as a type guard. Such a function should use ``TypeGuard[...]`` as its - return type to alert static type checkers to this intention. - - Using ``-> TypeGuard`` tells the static type checker that for a given - function: - - 1. The return value is a boolean. - 2. If the return value is ``True``, the type of its argument - is the type inside ``TypeGuard``. - - For example:: - - def is_str(val: Union[str, float]): - # "isinstance" type guard - if isinstance(val, str): - # Type of ``val`` is narrowed to ``str`` - ... - else: - # Else, type of ``val`` is narrowed to ``float``. - ... - - Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower - form of ``TypeA`` (it can even be a wider form) and this may lead to - type-unsafe results. The main reason is to allow for things like - narrowing ``List[object]`` to ``List[str]`` even though the latter is not - a subtype of the former, since ``List`` is invariant. The responsibility of - writing type-safe type guards is left to the user. - - ``TypeGuard`` also works with type variables. For more information, see - PEP 647 (User-Defined Type Guards). - """ - - __slots__ = ('__type__',) - - def __init__(self, tp=None, **kwds): - self.__type__ = tp - - def __getitem__(self, item): - cls = type(self) - if self.__type__ is None: - return cls(typing._type_check(item, - f'{cls.__name__[1:]} accepts only a single type.'), - _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') - - def _eval_type(self, globalns, localns): - new_tp = typing._eval_type(self.__type__, globalns, localns) - if new_tp == self.__type__: - return self - return type(self)(new_tp, _root=True) - - def __repr__(self): - r = super().__repr__() - if self.__type__ is not None: - r += f'[{typing._type_repr(self.__type__)}]' - return r - - def __hash__(self): - return hash((type(self).__name__, self.__type__)) - - def __eq__(self, other): - if not isinstance(other, _TypeGuard): - return NotImplemented - if self.__type__ is not None: - return self.__type__ == other.__type__ - return self is other - - TypeGuard = _TypeGuard(_root=True) - - -if sys.version_info[:2] >= (3, 7): - # Vendored from cpython typing._SpecialFrom - class _SpecialForm(typing._Final, _root=True): - __slots__ = ('_name', '__doc__', '_getitem') - - def __init__(self, getitem): - self._getitem = getitem - self._name = getitem.__name__ - self.__doc__ = getitem.__doc__ - - def __getattr__(self, item): - if item in {'__name__', '__qualname__'}: - return self._name - - raise AttributeError(item) - - def __mro_entries__(self, bases): - raise TypeError(f"Cannot subclass {self!r}") - - def __repr__(self): - return f'typing_extensions.{self._name}' - - def __reduce__(self): - return self._name - - def __call__(self, *args, **kwds): - raise TypeError(f"Cannot instantiate {self!r}") - - def __or__(self, other): - return typing.Union[self, other] - - def __ror__(self, other): - return typing.Union[other, self] - - def __instancecheck__(self, obj): - raise TypeError(f"{self} cannot be used with isinstance()") - - def __subclasscheck__(self, cls): - raise TypeError(f"{self} cannot be used with issubclass()") - - @typing._tp_cache - def __getitem__(self, parameters): - return self._getitem(self, parameters) - - -if hasattr(typing, "LiteralString"): - LiteralString = typing.LiteralString -elif sys.version_info[:2] >= (3, 7): - @_SpecialForm - def LiteralString(self, params): - """Represents an arbitrary literal string. - - Example:: - - from typing_extensions import LiteralString - - def query(sql: LiteralString) -> ...: - ... - - query("SELECT * FROM table") # ok - query(f"SELECT * FROM {input()}") # not ok - - See PEP 675 for details. - - """ - raise TypeError(f"{self} is not subscriptable") -else: - class _LiteralString(typing._FinalTypingBase, _root=True): - """Represents an arbitrary literal string. - - Example:: - - from typing_extensions import LiteralString - - def query(sql: LiteralString) -> ...: - ... - - query("SELECT * FROM table") # ok - query(f"SELECT * FROM {input()}") # not ok - - See PEP 675 for details. - - """ - - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError(f"{self} cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError(f"{self} cannot be used with issubclass().") - - LiteralString = _LiteralString(_root=True) - - -if hasattr(typing, "Self"): - Self = typing.Self -elif sys.version_info[:2] >= (3, 7): - @_SpecialForm - def Self(self, params): - """Used to spell the type of "self" in classes. - - Example:: - - from typing import Self - - class ReturnsSelf: - def parse(self, data: bytes) -> Self: - ... - return self - - """ - - raise TypeError(f"{self} is not subscriptable") -else: - class _Self(typing._FinalTypingBase, _root=True): - """Used to spell the type of "self" in classes. - - Example:: - - from typing import Self - - class ReturnsSelf: - def parse(self, data: bytes) -> Self: - ... - return self - - """ - - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError(f"{self} cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError(f"{self} cannot be used with issubclass().") - - Self = _Self(_root=True) - - -if hasattr(typing, "Never"): - Never = typing.Never -elif sys.version_info[:2] >= (3, 7): - @_SpecialForm - def Never(self, params): - """The bottom type, a type that has no members. - - This can be used to define a function that should never be - called, or a function that never returns:: - - from typing_extensions import Never - - def never_call_me(arg: Never) -> None: - pass - - def int_or_str(arg: int | str) -> None: - never_call_me(arg) # type checker error - match arg: - case int(): - print("It's an int") - case str(): - print("It's a str") - case _: - never_call_me(arg) # ok, arg is of type Never - - """ - - raise TypeError(f"{self} is not subscriptable") -else: - class _Never(typing._FinalTypingBase, _root=True): - """The bottom type, a type that has no members. - - This can be used to define a function that should never be - called, or a function that never returns:: - - from typing_extensions import Never - - def never_call_me(arg: Never) -> None: - pass - - def int_or_str(arg: int | str) -> None: - never_call_me(arg) # type checker error - match arg: - case int(): - print("It's an int") - case str(): - print("It's a str") - case _: - never_call_me(arg) # ok, arg is of type Never - - """ - - __slots__ = () - - def __instancecheck__(self, obj): - raise TypeError(f"{self} cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - raise TypeError(f"{self} cannot be used with issubclass().") - - Never = _Never(_root=True) - - -if hasattr(typing, 'Required'): - Required = typing.Required - NotRequired = typing.NotRequired -elif sys.version_info[:2] >= (3, 9): - class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - @_ExtensionsSpecialForm - def Required(self, parameters): - """A special typing construct to mark a key of a total=False TypedDict - as required. For example: - - class Movie(TypedDict, total=False): - title: Required[str] - year: int - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - - There is no runtime checking that a required key is actually provided - when instantiating a related TypedDict. - """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') - return typing._GenericAlias(self, (item,)) - - @_ExtensionsSpecialForm - def NotRequired(self, parameters): - """A special typing construct to mark a key of a TypedDict as - potentially missing. For example: - - class Movie(TypedDict): - title: str - year: NotRequired[int] - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') - return typing._GenericAlias(self, (item,)) - -elif sys.version_info[:2] >= (3, 7): - class _RequiredForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - item = typing._type_check(parameters, - '{} accepts only single type'.format(self._name)) - return typing._GenericAlias(self, (item,)) - - Required = _RequiredForm( - 'Required', - doc="""A special typing construct to mark a key of a total=False TypedDict - as required. For example: - - class Movie(TypedDict, total=False): - title: Required[str] - year: int - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - - There is no runtime checking that a required key is actually provided - when instantiating a related TypedDict. - """) - NotRequired = _RequiredForm( - 'NotRequired', - doc="""A special typing construct to mark a key of a TypedDict as - potentially missing. For example: - - class Movie(TypedDict): - title: str - year: NotRequired[int] - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - """) -else: - # NOTE: Modeled after _Final's implementation when _FinalTypingBase available - class _MaybeRequired(typing._FinalTypingBase, _root=True): - __slots__ = ('__type__',) - - def __init__(self, tp=None, **kwds): - self.__type__ = tp - - def __getitem__(self, item): - cls = type(self) - if self.__type__ is None: - return cls(typing._type_check(item, - '{} accepts only single type.'.format(cls.__name__[1:])), - _root=True) - raise TypeError('{} cannot be further subscripted' - .format(cls.__name__[1:])) - - def _eval_type(self, globalns, localns): - new_tp = typing._eval_type(self.__type__, globalns, localns) - if new_tp == self.__type__: - return self - return type(self)(new_tp, _root=True) - - def __repr__(self): - r = super().__repr__() - if self.__type__ is not None: - r += '[{}]'.format(typing._type_repr(self.__type__)) - return r - - def __hash__(self): - return hash((type(self).__name__, self.__type__)) - - def __eq__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - if self.__type__ is not None: - return self.__type__ == other.__type__ - return self is other - - class _Required(_MaybeRequired, _root=True): - """A special typing construct to mark a key of a total=False TypedDict - as required. For example: - - class Movie(TypedDict, total=False): - title: Required[str] - year: int - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - - There is no runtime checking that a required key is actually provided - when instantiating a related TypedDict. - """ - - class _NotRequired(_MaybeRequired, _root=True): - """A special typing construct to mark a key of a TypedDict as - potentially missing. For example: - - class Movie(TypedDict): - title: str - year: NotRequired[int] - - m = Movie( - title='The Matrix', # typechecker error if key is omitted - year=1999, - ) - """ - - Required = _Required(_root=True) - NotRequired = _NotRequired(_root=True) - - -if sys.version_info[:2] >= (3, 9): - class _UnpackSpecialForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - class _UnpackAlias(typing._GenericAlias, _root=True): - __class__ = typing.TypeVar - - @_UnpackSpecialForm - def Unpack(self, parameters): - """A special typing construct to unpack a variadic type. For example: - - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) - - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... - - """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') - return _UnpackAlias(self, (item,)) - - def _is_unpack(obj): - return isinstance(obj, _UnpackAlias) - -elif sys.version_info[:2] >= (3, 7): - class _UnpackAlias(typing._GenericAlias, _root=True): - __class__ = typing.TypeVar - - class _UnpackForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only single type') - return _UnpackAlias(self, (item,)) - - Unpack = _UnpackForm( - 'Unpack', - doc="""A special typing construct to unpack a variadic type. For example: - - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) - - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... - - """) - - def _is_unpack(obj): - return isinstance(obj, _UnpackAlias) - -else: - # NOTE: Modeled after _Final's implementation when _FinalTypingBase available - class _Unpack(typing._FinalTypingBase, _root=True): - """A special typing construct to unpack a variadic type. For example: - - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) - - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... - - """ - __slots__ = ('__type__',) - __class__ = typing.TypeVar - - def __init__(self, tp=None, **kwds): - self.__type__ = tp - - def __getitem__(self, item): - cls = type(self) - if self.__type__ is None: - return cls(typing._type_check(item, - 'Unpack accepts only single type.'), - _root=True) - raise TypeError('Unpack cannot be further subscripted') - - def _eval_type(self, globalns, localns): - new_tp = typing._eval_type(self.__type__, globalns, localns) - if new_tp == self.__type__: - return self - return type(self)(new_tp, _root=True) - - def __repr__(self): - r = super().__repr__() - if self.__type__ is not None: - r += '[{}]'.format(typing._type_repr(self.__type__)) - return r - - def __hash__(self): - return hash((type(self).__name__, self.__type__)) - - def __eq__(self, other): - if not isinstance(other, _Unpack): - return NotImplemented - if self.__type__ is not None: - return self.__type__ == other.__type__ - return self is other - - # For 3.6 only - def _get_type_vars(self, tvars): - self.__type__._get_type_vars(tvars) - - Unpack = _Unpack(_root=True) - - def _is_unpack(obj): - return isinstance(obj, _Unpack) - - -class TypeVarTuple: - """Type variable tuple. - - Usage:: - - Ts = TypeVarTuple('Ts') - - In the same way that a normal type variable is a stand-in for a single - type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* type such as - ``Tuple[int, str]``. - - Type variable tuples can be used in ``Generic`` declarations. - Consider the following example:: - - class Array(Generic[*Ts]): ... - - The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, - where ``T1`` and ``T2`` are type variables. To use these type variables - as type parameters of ``Array``, we must *unpack* the type variable tuple using - the star operator: ``*Ts``. The signature of ``Array`` then behaves - as if we had simply written ``class Array(Generic[T1, T2]): ...``. - In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows - us to parameterise the class with an *arbitrary* number of type parameters. - - Type variable tuples can be used anywhere a normal ``TypeVar`` can. - This includes class definitions, as shown above, as well as function - signatures and variable annotations:: - - class Array(Generic[*Ts]): - - def __init__(self, shape: Tuple[*Ts]): - self._shape: Tuple[*Ts] = shape - - def get_shape(self) -> Tuple[*Ts]: - return self._shape - - shape = (Height(480), Width(640)) - x: Array[Height, Width] = Array(shape) - y = abs(x) # Inferred type is Array[Height, Width] - z = x + x # ... is Array[Height, Width] - x.get_shape() # ... is tuple[Height, Width] - - """ - - # Trick Generic __parameters__. - __class__ = typing.TypeVar - - def __iter__(self): - yield self.__unpacked__ - - def __init__(self, name): - self.__name__ = name - - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod - - self.__unpacked__ = Unpack[self] - - def __repr__(self): - return self.__name__ - - def __hash__(self): - return object.__hash__(self) - - def __eq__(self, other): - return self is other - - def __reduce__(self): - return self.__name__ - - def __init_subclass__(self, *args, **kwds): - if '_root' not in kwds: - raise TypeError("Cannot subclass special typing classes") - - if not PEP_560: - # Only needed in 3.6. - def _get_type_vars(self, tvars): - if self not in tvars: - tvars.append(self) - - -if hasattr(typing, "reveal_type"): - reveal_type = typing.reveal_type -else: - def reveal_type(__obj: T) -> T: - """Reveal the inferred type of a variable. - - When a static type checker encounters a call to ``reveal_type()``, - it will emit the inferred type of the argument:: - - x: int = 1 - reveal_type(x) - - Running a static type checker (e.g., ``mypy``) on this example - will produce output similar to 'Revealed type is "builtins.int"'. - - At runtime, the function prints the runtime type of the - argument and returns it unchanged. - - """ - print(f"Runtime type is {type(__obj).__name__!r}", file=sys.stderr) - return __obj - - -if hasattr(typing, "assert_never"): - assert_never = typing.assert_never -else: - def assert_never(__arg: Never) -> Never: - """Assert to the type checker that a line of code is unreachable. - - Example:: - - def int_or_str(arg: int | str) -> None: - match arg: - case int(): - print("It's an int") - case str(): - print("It's a str") - case _: - assert_never(arg) - - If a type checker finds that a call to assert_never() is - reachable, it will emit an error. - - At runtime, this throws an exception when called. - - """ - raise AssertionError("Expected code to be unreachable") - - -if hasattr(typing, 'dataclass_transform'): - dataclass_transform = typing.dataclass_transform -else: - def dataclass_transform( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - field_descriptors: typing.Tuple[ - typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], - ... - ] = (), - ) -> typing.Callable[[T], T]: - """Decorator that marks a function, class, or metaclass as providing - dataclass-like behavior. - - Example: - - from typing_extensions import dataclass_transform - - _T = TypeVar("_T") - - # Used on a decorator function - @dataclass_transform() - def create_model(cls: type[_T]) -> type[_T]: - ... - return cls - - @create_model - class CustomerModel: - id: int - name: str - - # Used on a base class - @dataclass_transform() - class ModelBase: ... - - class CustomerModel(ModelBase): - id: int - name: str - - # Used on a metaclass - @dataclass_transform() - class ModelMeta(type): ... - - class ModelBase(metaclass=ModelMeta): ... - - class CustomerModel(ModelBase): - id: int - name: str - - Each of the ``CustomerModel`` classes defined in this example will now - behave similarly to a dataclass created with the ``@dataclasses.dataclass`` - decorator. For example, the type checker will synthesize an ``__init__`` - method. - - The arguments to this decorator can be used to customize this behavior: - - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be - True or False if it is omitted by the caller. - - ``order_default`` indicates whether the ``order`` parameter is - assumed to be True or False if it is omitted by the caller. - - ``kw_only_default`` indicates whether the ``kw_only`` parameter is - assumed to be True or False if it is omitted by the caller. - - ``field_descriptors`` specifies a static list of supported classes - or functions, that describe fields, similar to ``dataclasses.field()``. - - At runtime, this decorator records its arguments in the - ``__dataclass_transform__`` attribute on the decorated object. - - See PEP 681 for details. - - """ - def decorator(cls_or_fn): - cls_or_fn.__dataclass_transform__ = { - "eq_default": eq_default, - "order_default": order_default, - "kw_only_default": kw_only_default, - "field_descriptors": field_descriptors, - } - return cls_or_fn - return decorator - - -# We have to do some monkey patching to deal with the dual nature of -# Unpack/TypeVarTuple: -# - We want Unpack to be a kind of TypeVar so it gets accepted in -# Generic[Unpack[Ts]] -# - We want it to *not* be treated as a TypeVar for the purposes of -# counting generic parameters, so that when we subscript a generic, -# the runtime doesn't try to substitute the Unpack with the subscripted type. -if not hasattr(typing, "TypeVarTuple"): - typing._collect_type_vars = _collect_type_vars - typing._check_generic = _check_generic From 2ee76a7926b4903e9147967b577d38481f0c7784 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 5 Feb 2023 10:44:41 -0500 Subject: [PATCH 23/27] test: update changelog for 376 bug --- sites/www/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sites/www/changelog.rst b/sites/www/changelog.rst index 7f4813e3e..0088923c4 100644 --- a/sites/www/changelog.rst +++ b/sites/www/changelog.rst @@ -2,6 +2,8 @@ Changelog ========= +- :bug:`376` Resolve equality comparison bug for non-collections. Patch via + Jesse P. Johnson - :support:`906` Implement type hints and type checking tests with mypy to reduce errors and impove code documentation. Patches by Jesse P. Johnson and review by Sam Bull. From 1f3e5a456a2f830943e16553d95960f04b79ff28 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 5 Feb 2023 15:38:00 -0500 Subject: [PATCH 24/27] test: update changelog for 376 bug --- invoke/collection.py | 3 ++- invoke/util.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/invoke/collection.py b/invoke/collection.py index 9510b4f1d..6441e38a3 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -264,8 +264,9 @@ def add_task( if name is None: if task.name: name = task.name + # XXX https://github.com/python/mypy/issues/1424 elif hasattr(task.body, "func_name"): - name = task.body.func_name + name = task.body.func_name # type: ignore elif hasattr(task.body, "__name__"): name = task.__name__ else: diff --git a/invoke/util.py b/invoke/util.py index 5ec92de64..c622f8192 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -191,7 +191,8 @@ def run(self) -> None: # approach to work, by using _run() instead of run(). If that # doesn't appear to be the case, then assume we're being used # directly and just use super() ourselves. - if hasattr(self, "_run") and callable(self._run): + # XXX https://github.com/python/mypy/issues/1424 + if hasattr(self, "_run") and callable(self._run): # type: ignore # TODO: this could be: # - io worker with no 'result' (always local) # - tunnel worker, also with no 'result' (also always local) @@ -206,7 +207,7 @@ def run(self) -> None: # and let it continue acting like a normal thread (meh) # - assume the run/sudo/etc case will use a queue inside its # worker body, orthogonal to how exception handling works - self._run() + self._run() # type: ignore else: super().run() except BaseException: From 70e9d93a23c1f42e7b7b2950fc09e59877f547c9 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 5 Feb 2023 15:52:45 -0500 Subject: [PATCH 25/27] refactor: set debug without globals injection --- invoke/completion/complete.py | 2 +- invoke/config.py | 2 +- invoke/env.py | 2 +- invoke/executor.py | 2 +- invoke/loader.py | 2 +- invoke/parser/parser.py | 2 +- invoke/program.py | 2 +- invoke/util.py | 5 ++--- 8 files changed, 9 insertions(+), 10 deletions(-) diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index 3f9b413e7..57f61f64b 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Union from ..exceptions import Exit, ParseError -from ..util import debug, task_name_sort_key # type: ignore +from ..util import debug, task_name_sort_key if TYPE_CHECKING: from ..collection import Collection diff --git a/invoke/config.py b/invoke/config.py index 00400164c..08c65974b 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -10,7 +10,7 @@ from .exceptions import UnknownFileType, UnpicklableConfigMember from .runners import Local from .terminals import WINDOWS -from .util import debug, yaml # type: ignore +from .util import debug, yaml try: diff --git a/invoke/env.py b/invoke/env.py index f772fc6ce..2e90f4e95 100644 --- a/invoke/env.py +++ b/invoke/env.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Dict, List from .exceptions import UncastableEnvVar, AmbiguousEnvVar -from .util import debug # type: ignore +from .util import debug if TYPE_CHECKING: from .config import Config diff --git a/invoke/executor.py b/invoke/executor.py index db7cc6f24..8e80d6e82 100644 --- a/invoke/executor.py +++ b/invoke/executor.py @@ -2,7 +2,7 @@ from .config import Config from .parser import ParserContext -from .util import debug # type: ignore +from .util import debug from .tasks import Call, Task if TYPE_CHECKING: diff --git a/invoke/loader.py b/invoke/loader.py index af8439b93..23bffdf0f 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -6,7 +6,7 @@ from . import Config from .exceptions import CollectionNotFound -from .util import debug # type: ignore +from .util import debug class Loader: diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index b005e35e4..f926b2168 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -9,7 +9,7 @@ from fluidity import StateMachine, state, transition # type: ignore from ..exceptions import ParseError -from ..util import debug # type: ignore +from ..util import debug if TYPE_CHECKING: from .context import ParserContext diff --git a/invoke/program.py b/invoke/program.py index 9b2980c2a..76cb22806 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -12,7 +12,7 @@ from .parser import Parser, ParserContext, Argument from .exceptions import UnexpectedExit, CollectionNotFound, ParseError, Exit from .terminals import pty_size -from .util import debug, enable_logging, helpline # type: ignore +from .util import debug, enable_logging, helpline if TYPE_CHECKING: from .context import Context diff --git a/invoke/util.py b/invoke/util.py index c622f8192..31f9a11dc 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -39,8 +39,7 @@ def enable_logging() -> None: # Add top level logger functions to global namespace. Meh. log = logging.getLogger("invoke") -for x in ("debug",): - globals()[x] = getattr(log, x) +debug = log.debug def task_name_sort_key(name: str) -> Tuple[List[str], str]: @@ -221,7 +220,7 @@ def run(self) -> None: name = "_run" if "target" in self.kwargs: name = self.kwargs["target"].__name__ - debug(msg.format(self.exc_info[1], name)) # type: ignore # noqa + debug(msg.format(self.exc_info[1], name)) # noqa def exception(self) -> Optional["ExceptionWrapper"]: """ From 0a23f95c591d98f7f5706ef2d30e35c68de6890c Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Sun, 5 Feb 2023 16:48:05 -0500 Subject: [PATCH 26/27] refactor: use userlist --- invoke/context.py | 1 - invoke/parser/parser.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invoke/context.py b/invoke/context.py index a1bafeda8..84936d916 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -214,7 +214,6 @@ def _sudo( cmd_str = "sudo -S -p '{}' {}{}{}".format( prompt, env_flags, user_flags, command ) - # FIXME pattern should be raw string prompt.encode('unicode_escape') watcher = FailingResponder( pattern=re.escape(prompt), response="{}\n".format(password), diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index f926b2168..7a7e0746d 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -1,4 +1,5 @@ import copy +from collections import UserList from typing import TYPE_CHECKING, Any, Iterable, List, Optional try: @@ -23,7 +24,7 @@ def is_long_flag(value: str) -> bool: return value.startswith("--") -class ParseResult(list): +class ParseResult(UserList): """ List-like object with some extra parse-related attributes. From 0b6a448e357f04a9927e63ff2e108d0bae25e9d3 Mon Sep 17 00:00:00 2001 From: "Jesse P. Johnson" Date: Fri, 10 Feb 2023 19:12:22 -0500 Subject: [PATCH 27/27] test: add sams suggestions --- dev-requirements.txt | 2 +- invoke/__init__.py | 6 +- invoke/collection.py | 6 +- invoke/completion/complete.py | 5 +- invoke/config.py | 12 +- invoke/context.py | 4 +- invoke/env.py | 26 +-- invoke/executor.py | 12 +- invoke/parser/context.py | 2 +- invoke/parser/parser.py | 13 +- invoke/program.py | 32 ++- invoke/runners.py | 17 +- invoke/tasks.py | 62 ++--- invoke/terminals.py | 141 ++++++------ invoke/util.py | 4 +- invoke/vendor/decorator.py | 414 ---------------------------------- invoke/watchers.py | 6 +- pyproject.toml | 18 +- tasks.py | 41 ++-- 19 files changed, 214 insertions(+), 609 deletions(-) delete mode 100644 invoke/vendor/decorator.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 8901bc3bf..4b56b30ff 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -22,4 +22,4 @@ icecream>=2.1 # typing mypy==0.971 typed-ast==1.5.4 -types-PyYAML==6 +types-PyYAML==6.0.12.4 diff --git a/invoke/__init__.py b/invoke/__init__.py index e7fa1208d..b70726759 100644 --- a/invoke/__init__.py +++ b/invoke/__init__.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from ._version import __version_info__, __version__ # noqa from .collection import Collection # noqa @@ -31,7 +31,7 @@ from .watchers import FailingResponder, Responder, StreamWatcher # noqa -def run(command: str, **kwargs: Any) -> Any: +def run(command: str, **kwargs: Any) -> Optional[Result]: """ Run ``command`` in a subprocess and return a `.Result` object. @@ -50,7 +50,7 @@ def run(command: str, **kwargs: Any) -> Any: return Context().run(command, **kwargs) -def sudo(command: str, **kwargs: Any) -> Any: +def sudo(command: str, **kwargs: Any) -> Optional[Result]: """ Run ``command`` in a ``sudo`` subprocess and return a `.Result` object. diff --git a/invoke/collection.py b/invoke/collection.py index 6441e38a3..23dcff928 100644 --- a/invoke/collection.py +++ b/invoke/collection.py @@ -113,7 +113,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for name, obj in kwargs.items(): self._add_object(obj, name) - def _add_object(self, obj: Any, name: Optional[str] = None) -> Callable: + def _add_object(self, obj: Any, name: Optional[str] = None) -> None: method: Callable if isinstance(obj, Task): method = self.add_task @@ -121,7 +121,7 @@ def _add_object(self, obj: Any, name: Optional[str] = None) -> Callable: method = self.add_collection else: raise TypeError("No idea how to insert {!r}!".format(type(obj))) - return method(obj, name=name) + method(obj, name=name) def __repr__(self) -> str: task_names = list(self.tasks.keys()) @@ -510,7 +510,7 @@ def _transform_lexicon(self, old: Lexicon) -> Lexicon: return new @property - def task_names(self) -> Dict[str, Any]: + def task_names(self) -> Dict[str, List[str]]: """ Return all task identifiers for this collection as a one-level dict. diff --git a/invoke/completion/complete.py b/invoke/completion/complete.py index 57f61f64b..97e9a959e 100644 --- a/invoke/completion/complete.py +++ b/invoke/completion/complete.py @@ -7,7 +7,7 @@ import os import re import shlex -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from ..exceptions import Exit, ParseError from ..util import debug, task_name_sort_key @@ -38,8 +38,7 @@ def complete( # Use last seen context in case of failure (required for # otherwise-invalid partial invocations being completed). - # FIXME: this seems wonky - contexts: Union[List[ParserContext], ParseResult] + contexts: List[ParserContext] try: debug("Seeking context name in tokens: {!r}".format(tokens)) contexts = parser.parse_argv(tokens) diff --git a/invoke/config.py b/invoke/config.py index 08c65974b..c38afc67c 100644 --- a/invoke/config.py +++ b/invoke/config.py @@ -16,7 +16,7 @@ try: from importlib.machinery import SourceFileLoader except ImportError: # PyPy3 - from importlib._bootstrap import ( # type: ignore + from importlib._bootstrap import ( # type: ignore[no-redef] _SourceFileLoader as SourceFileLoader, ) @@ -1030,7 +1030,7 @@ def clone(self, into: Optional[Type["Config"]] = None) -> "Config": # instantiation" and "I want cloning to not trigger certain things like # external data source loading". # NOTE: this will include lazy=True, see end of method - new = klass(**self._clone_init_kwargs(into=into)) # type: ignore + new = klass(**self._clone_init_kwargs(into=into)) # Copy/merge/etc all 'private' data sources and attributes for name in """ collection @@ -1074,7 +1074,7 @@ def clone(self, into: Optional[Type["Config"]] = None) -> "Config": return new def _clone_init_kwargs( - self, into: Optional["Config"] = None + self, into: Optional[Type["Config"]] = None ) -> Dict[str, Any]: """ Supply kwargs suitable for initializing a new clone of this object. @@ -1227,15 +1227,15 @@ def merge_dicts( return base -def _merge_error(orig: str, new_: Any) -> AmbiguousMergeError: +def _merge_error(orig: object, new: object) -> AmbiguousMergeError: return AmbiguousMergeError( "Can't cleanly merge {} with {}".format( - _format_mismatch(orig), _format_mismatch(new_) + _format_mismatch(orig), _format_mismatch(new) ) ) -def _format_mismatch(x: Any) -> str: +def _format_mismatch(x: object) -> str: return "{} ({!r})".format(type(x), x) diff --git a/invoke/context.py b/invoke/context.py index a1bafeda8..c3c4e4aad 100644 --- a/invoke/context.py +++ b/invoke/context.py @@ -73,13 +73,13 @@ def __init__(self, config: Optional[Config] = None) -> None: self._set(command_cwds=command_cwds) @property - def config(self) -> Any: + def config(self) -> Config: # Allows Context to expose a .config attribute even though DataProxy # otherwise considers it a config key. return self._config @config.setter - def config(self, value: Any) -> None: + def config(self, value: Config) -> None: # NOTE: mostly used by client libraries needing to tweak a Context's # config at execution time; i.e. a Context subclass that bears its own # unique data may want to be stood up when parameterizing/expanding a diff --git a/invoke/env.py b/invoke/env.py index 2e90f4e95..2c7aaa692 100644 --- a/invoke/env.py +++ b/invoke/env.py @@ -9,7 +9,7 @@ """ import os -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Sequence from .exceptions import UncastableEnvVar, AmbiguousEnvVar from .util import debug @@ -46,7 +46,7 @@ def load(self) -> Dict[str, Any]: return self.data def _crawl( - self, key_path: List[str], env_vars: Dict[str, Any] + self, key_path: List[str], env_vars: Mapping[str, Sequence[str]] ) -> Dict[str, Any]: """ Examine config at location ``key_path`` & return potential env vars. @@ -61,7 +61,7 @@ def _crawl( Returns another dictionary of new keypairs as per above. """ - new_vars: Dict[str, Any] = {} + new_vars: Dict[str, List[str]] = {} obj = self._path_get(key_path) # Sub-dict -> recurse if ( @@ -85,10 +85,10 @@ def _crawl( new_vars[self._to_env_var(key_path)] = key_path return new_vars - def _to_env_var(self, key_path: List[str]) -> str: + def _to_env_var(self, key_path: Iterable[str]) -> str: return "_".join(key_path).upper() - def _path_get(self, key_path: List[str]) -> "Config": + def _path_get(self, key_path: Iterable[str]) -> "Config": # Gets are from self._config because that's what determines valid env # vars and/or values for typecasting. obj = self._config @@ -96,7 +96,7 @@ def _path_get(self, key_path: List[str]) -> "Config": obj = obj[key] return obj - def _path_set(self, key_path: List[str], value: str) -> None: + def _path_set(self, key_path: Sequence[str], value: str) -> None: # Sets are to self.data since that's what we are presenting to the # outer config object and debugging. obj = self.data @@ -105,19 +105,19 @@ def _path_set(self, key_path: List[str], value: str) -> None: obj[key] = {} obj = obj[key] old = self._path_get(key_path) - new_ = self._cast(old, value) - obj[key_path[-1]] = new_ + new = self._cast(old, value) + obj[key_path[-1]] = new - def _cast(self, old: Any, new_: Any) -> Any: + def _cast(self, old: Any, new: Any) -> Any: if isinstance(old, bool): - return new_ not in ("0", "") + return new not in ("0", "") elif isinstance(old, str): - return new_ + return new elif old is None: - return new_ + return new elif isinstance(old, (list, tuple)): err = "Can't adapt an environment string into a {}!" err = err.format(type(old)) raise UncastableEnvVar(err) else: - return old.__class__(new_) + return old.__class__(new) diff --git a/invoke/executor.py b/invoke/executor.py index 8e80d6e82..08aa74e31 100644 --- a/invoke/executor.py +++ b/invoke/executor.py @@ -136,7 +136,7 @@ def execute( # an appropriate one; e.g. subclasses might use extra data from # being parameterized), handing in this config for use there. context = call.make_context(config) - args = (context,) + call.args + args = (context, *call.args) result = call.task(*args, **call.kwargs) if autoprint: print(result) @@ -160,19 +160,19 @@ def normalize( """ calls = [] for task in tasks: + name: Optional[str] if isinstance(task, str): name = task kwargs = {} elif isinstance(task, ParserContext): - # FIXME: task.name can be none here - name = task.name # type: ignore + name = task.name kwargs = task.as_kwargs else: name, kwargs = task - c = Call(task=self.collection[name], kwargs=kwargs, called_as=name) + c = Call(self.collection[name], kwargs=kwargs, called_as=name) calls.append(c) if not tasks and self.collection.default is not None: - calls = [Call(task=self.collection[self.collection.default])] + calls = [Call(self.collection[self.collection.default])] return calls def dedupe(self, calls: List["Call"]) -> List["Call"]: @@ -213,7 +213,7 @@ def expand_calls(self, calls: List["Call"]) -> List["Call"]: # Normalize to Call (this method is sometimes called with pre/post # task lists, which may contain 'raw' Task objects) if isinstance(call, Task): - call = Call(task=call) + call = Call(call) debug("Expanding task-call {!r}".format(call)) # TODO: this is where we _used_ to call Executor.config_for(call, # config)... diff --git a/invoke/parser/context.py b/invoke/parser/context.py index e8a465faf..359e9f9e2 100644 --- a/invoke/parser/context.py +++ b/invoke/parser/context.py @@ -4,7 +4,7 @@ try: from ..vendor.lexicon import Lexicon except ImportError: - from lexicon import Lexicon # type: ignore + from lexicon import Lexicon # type: ignore[no-redef] from .argument import Argument diff --git a/invoke/parser/parser.py b/invoke/parser/parser.py index f926b2168..43e95df04 100644 --- a/invoke/parser/parser.py +++ b/invoke/parser/parser.py @@ -5,8 +5,12 @@ from ..vendor.lexicon import Lexicon from ..vendor.fluidity import StateMachine, state, transition except ImportError: - from lexicon import Lexicon # type: ignore - from fluidity import StateMachine, state, transition # type: ignore + from lexicon import Lexicon # type: ignore[no-redef] + from fluidity import ( # type: ignore[no-redef] + StateMachine, + state, + transition, + ) from ..exceptions import ParseError from ..util import debug @@ -23,7 +27,7 @@ def is_long_flag(value: str) -> bool: return value.startswith("--") -class ParseResult(list): +class ParseResult(List["ParserContext"]): """ List-like object with some extra parse-related attributes. @@ -108,7 +112,8 @@ def parse_argv(self, argv: List[str]) -> ParseResult: .. versionadded:: 1.0 """ machine = ParseMachine( - initial=self.initial, # type: ignore # FIXME: should not be none + # FIXME: initial should not be none + initial=self.initial, # type: ignore[arg-type] contexts=self.contexts, ignore_unknown=self.ignore_unknown, ) diff --git a/invoke/program.py b/invoke/program.py index 76cb22806..c7e5cd004 100644 --- a/invoke/program.py +++ b/invoke/program.py @@ -5,7 +5,16 @@ import sys import textwrap from importlib import import_module # buffalo buffalo -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, +) from . import Collection, Config, Executor, FilesystemLoader from .completion.complete import complete, print_completion_script @@ -15,7 +24,6 @@ from .util import debug, enable_logging, helpline if TYPE_CHECKING: - from .context import Context from .loader import Loader from .parser import ParseResult from .util import Lexicon @@ -185,9 +193,9 @@ def __init__( namespace: Optional["Collection"] = None, name: Optional[str] = None, binary: Optional[str] = None, - loader_class: Optional["Loader"] = None, - executor_class: Optional["Executor"] = None, - config_class: Optional["Config"] = None, + loader_class: Optional[Type["Loader"]] = None, + executor_class: Optional[Type["Executor"]] = None, + config_class: Optional[Type["Config"]] = None, binary_names: Optional[List[str]] = None, ) -> None: """ @@ -289,7 +297,7 @@ def create_config(self) -> None: .. versionadded:: 1.0 """ - self.config = self.config_class() # type: ignore + self.config = self.config_class() def update_config(self, merge: bool = True) -> None: """ @@ -571,9 +579,7 @@ def execute(self) -> None: # "normal" but also its own possible source of bugs/confusion... module = import_module(module_path) klass = getattr(module, class_name) - executor = klass( # type: ignore - self.collection, self.config, self.core - ) + executor = klass(self.collection, self.config, self.core) executor.execute(*self.tasks) def normalize_argv(self, argv: Optional[List[str]]) -> None: @@ -723,7 +729,7 @@ def load_collection(self) -> None: raise Exit("Can't find any collection named {!r}!".format(e.name)) def _update_core_context( - self, context: "Context", new_args: Dict[str, Any] + self, context: ParserContext, new_args: Dict[str, Any] ) -> None: # Update core context w/ core_via_task args, if and only if the # via-task version of the arg was truly given a value. @@ -919,7 +925,7 @@ def task_list_opener(self, extra: str = "") -> str: return text def display_with_columns( - self, pairs: List[Tuple[str, Optional[str]]], extra: str = "" + self, pairs: Sequence[Tuple[str, Optional[str]]], extra: str = "" ) -> None: root = self.list_root print("{}:\n".format(self.task_list_opener(extra=extra))) @@ -935,7 +941,9 @@ def display_with_columns( # TODO: trim/prefix dots print("Default{} task: {}\n".format(specific, default)) - def print_columns(self, tuples: List[Tuple[str, Optional[str]]]) -> None: + def print_columns( + self, tuples: Sequence[Tuple[str, Optional[str]]] + ) -> None: """ Print tabbed columns from (name, help) ``tuples``. diff --git a/invoke/runners.py b/invoke/runners.py index 21f8ca691..f24bb5f34 100644 --- a/invoke/runners.py +++ b/invoke/runners.py @@ -27,15 +27,15 @@ try: import pty except ImportError: - pty = None # type: ignore + pty = None # type: ignore[assignment] try: import fcntl except ImportError: - fcntl = None # type: ignore + fcntl = None # type: ignore[assignment] try: import termios except ImportError: - termios = None # type: ignore + termios = None # type: ignore[assignment] from .exceptions import ( UnexpectedExit, @@ -55,7 +55,6 @@ from .util import has_fileno, isatty, ExceptionHandlingThread if TYPE_CHECKING: - # from io import BytesIO, StringIO, TextIOWrapper from .context import Context from .watchers import StreamWatcher @@ -460,7 +459,7 @@ def make_promise(self) -> "Promise": """ return Promise(self) - def _finish(self) -> Any: + def _finish(self) -> "Result": # Wait for subprocess to run, forwarding signals as we get them. try: while True: @@ -626,7 +625,7 @@ def _thread_join_timeout(self, target: Callable) -> Optional[int]: def create_io_threads( self, - ) -> Tuple[Dict[Any, ExceptionHandlingThread], List[Any], List[Any]]: + ) -> Tuple[Dict[Callable, ExceptionHandlingThread], List[str], List[str]]: """ Create and return a dictionary of IO thread worker objects. @@ -1201,7 +1200,7 @@ def timed_out(self) -> bool: """ # Timer expiry implies we did time out. (The timer itself will have # killed the subprocess, allowing us to even get to this point.) - return True if self._timer and not self._timer.is_alive() else False + return bool(self._timer and not self._timer.is_alive()) class Local(Runner): @@ -1485,7 +1484,7 @@ def __init__( self.hide = hide @property - def return_code(self) -> Any: + def return_code(self) -> int: """ An alias for ``.exited``. @@ -1591,7 +1590,7 @@ def __init__(self, runner: "Runner") -> None: for key, value in self.runner.result_kwargs.items(): setattr(self, key, value) - def join(self) -> Any: + def join(self) -> Result: """ Block until associated subprocess exits, returning/raising the result. diff --git a/invoke/tasks.py b/invoke/tasks.py index 806ac8582..22ce59521 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -3,20 +3,23 @@ generate new tasks. """ -from copy import deepcopy import inspect import types +from copy import deepcopy +from functools import update_wrapper from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, + Generic, Iterable, Optional, Set, Tuple, Type, + TypeVar, Union, ) @@ -27,8 +30,10 @@ from inspect import Signature from .config import Config +T = TypeVar("T", bound=Callable) + -class Task: +class Task(Generic[T]): """ Core object representing an executable task & its argument specification. @@ -54,6 +59,7 @@ class Task: def __init__( self, body: Callable, + /, name: Optional[str] = None, aliases: Iterable[str] = (), positional: Optional[Iterable[str]] = None, @@ -69,6 +75,7 @@ def __init__( ) -> None: # Real callable self.body = body + update_wrapper(self, self.body) # Copy a bunch of special properties from the body for the benefit of # Sphinx autodoc or other introspectors. self.__doc__ = getattr(body, "__doc__", "") @@ -81,7 +88,7 @@ def __init__( self.is_default = default # Arg/flag/parser hints self.positional = self.fill_implicit_positionals(positional) - self.optional = optional + self.optional = tuple(optional) self.iterable = iterable or [] self.incrementable = incrementable or [] self.auto_shortflags = auto_shortflags @@ -123,7 +130,7 @@ def __hash__(self) -> int: # this for now. return hash(self.name) + hash(self.body) - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> T: # Guard against calling tasks with no context. if not isinstance(args[0], Context): err = "Task expected a Context as its first arg, got {} instead!" @@ -334,7 +341,7 @@ def task(*args: Any, **kwargs: Any) -> Callable: .. versionchanged:: 1.1 Added the ``klass`` keyword argument. """ - klass = kwargs.pop("klass", Task) + klass: Type[Task] = kwargs.pop("klass", Task) # @task -- no options were (probably) given. if len(args) == 1 and callable(args[0]) and not isinstance(args[0], Task): return klass(args[0], **kwargs) @@ -345,43 +352,12 @@ def task(*args: Any, **kwargs: Any) -> Callable: "May not give *args and 'pre' kwarg simultaneously!" ) kwargs["pre"] = args - # @task(options) - # TODO: why the heck did we originally do this in this manner instead of - # simply delegating to Task?! Let's just remove all this sometime & see - # what, if anything, breaks. - name = kwargs.pop("name", None) - aliases = kwargs.pop("aliases", ()) - positional = kwargs.pop("positional", None) - optional = tuple(kwargs.pop("optional", ())) - iterable = kwargs.pop("iterable", None) - incrementable = kwargs.pop("incrementable", None) - default = kwargs.pop("default", False) - auto_shortflags = kwargs.pop("auto_shortflags", True) - help = kwargs.pop("help", {}) - pre = kwargs.pop("pre", []) - post = kwargs.pop("post", []) - autoprint = kwargs.pop("autoprint", False) - - def inner(obj: Callable) -> Task: - _obj = klass( - obj, - name=name, - aliases=aliases, - positional=positional, - optional=optional, - iterable=iterable, - incrementable=incrementable, - default=default, - auto_shortflags=auto_shortflags, - help=help, - pre=pre, - post=post, - autoprint=autoprint, - # Pass in any remaining kwargs as-is. - **kwargs - ) - return _obj + def inner(body: Callable) -> Task[T]: + _task = klass(body, **kwargs) + return _task + + # update_wrapper(inner, klass) return inner @@ -508,7 +484,7 @@ def clone( return klass(**data) -def call(task: Task, *args: Any, **kwargs: Any) -> "Call": +def call(task: "Task", /, *args: Any, **kwargs: Any) -> "Call": """ Describes execution of a `.Task`, typically with pre-supplied arguments. @@ -541,4 +517,4 @@ def clean_build(c): .. versionadded:: 1.0 """ - return Call(task=task, args=args, kwargs=kwargs) + return Call(task, args=args, kwargs=kwargs) diff --git a/invoke/terminals.py b/invoke/terminals.py index 490750c08..2694712fa 100644 --- a/invoke/terminals.py +++ b/invoke/terminals.py @@ -28,9 +28,9 @@ .. versionadded:: 1.0 """ -if WINDOWS: +if sys.platform == "win32": import msvcrt - from ctypes import ( # type: ignore + from ctypes import ( Structure, c_ushort, windll, @@ -45,6 +45,73 @@ import tty +if sys.platform == "win32": + + def _pty_size() -> Tuple[Optional[int], Optional[int]]: + class CONSOLE_SCREEN_BUFFER_INFO(Structure): + _fields_ = [ + ("dwSize", _COORD), + ("dwCursorPosition", _COORD), + ("wAttributes", c_ushort), + ("srWindow", _SMALL_RECT), + ("dwMaximumWindowSize", _COORD), + ] + + GetStdHandle = windll.kernel32.GetStdHandle + GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo + GetStdHandle.restype = HANDLE + GetConsoleScreenBufferInfo.argtypes = [ + HANDLE, + POINTER(CONSOLE_SCREEN_BUFFER_INFO), + ] + + hstd = GetStdHandle(-11) # STD_OUTPUT_HANDLE = -11 + csbi = CONSOLE_SCREEN_BUFFER_INFO() + ret = GetConsoleScreenBufferInfo(hstd, byref(csbi)) + + if ret: + sizex = csbi.srWindow.Right - csbi.srWindow.Left + 1 + sizey = csbi.srWindow.Bottom - csbi.srWindow.Top + 1 + return sizex, sizey + else: + return (None, None) + +else: + + def _pty_size() -> Tuple[Optional[int], Optional[int]]: + """ + Suitable for most POSIX platforms. + + .. versionadded:: 1.0 + """ + # Sentinel values to be replaced w/ defaults by caller + size = (None, None) + # We want two short unsigned integers (rows, cols) + fmt = "HH" + # Create an empty (zeroed) buffer for ioctl to map onto. Yay for C! + buf = struct.pack(fmt, 0, 0) + # Call TIOCGWINSZ to get window size of stdout, returns our filled + # buffer + try: + result = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf) + # Unpack buffer back into Python data types + # NOTE: this unpack gives us rows x cols, but we return the + # inverse. + rows, cols = struct.unpack(fmt, result) + return (cols, rows) + # Fallback to emptyish return value in various failure cases: + # * sys.stdout being monkeypatched, such as in testing, and lacking + # * .fileno + # * sys.stdout having a .fileno but not actually being attached to a + # * TTY + # * termios not having a TIOCGWINSZ attribute (happens sometimes...) + # * other situations where ioctl doesn't explode but the result isn't + # something unpack can deal with + except (struct.error, TypeError, IOError, AttributeError): + pass + return size + + def pty_size() -> Tuple[int, int]: """ Determine current local pseudoterminal dimensions. @@ -55,71 +122,9 @@ def pty_size() -> Tuple[int, int]: .. versionadded:: 1.0 """ - cols, rows = _pty_size() if not WINDOWS else _win_pty_size() + cols, rows = _pty_size() # TODO: make defaults configurable? - return (int(cols or 80), int(rows or 24)) - - -def _pty_size() -> Tuple[Optional[int], Optional[int]]: - """ - Suitable for most POSIX platforms. - - .. versionadded:: 1.0 - """ - # Sentinel values to be replaced w/ defaults by caller - size = (None, None) - # We want two short unsigned integers (rows, cols) - fmt = "HH" - # Create an empty (zeroed) buffer for ioctl to map onto. Yay for C! - buf = struct.pack(fmt, 0, 0) - # Call TIOCGWINSZ to get window size of stdout, returns our filled - # buffer - try: - result = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf) - # Unpack buffer back into Python data types - # NOTE: this unpack gives us rows x cols, but we return the - # inverse. - rows, cols = struct.unpack(fmt, result) - return (cols, rows) - # Fallback to emptyish return value in various failure cases: - # * sys.stdout being monkeypatched, such as in testing, and lacking .fileno - # * sys.stdout having a .fileno but not actually being attached to a TTY - # * termios not having a TIOCGWINSZ attribute (happens sometimes...) - # * other situations where ioctl doesn't explode but the result isn't - # something unpack can deal with - except (struct.error, TypeError, IOError, AttributeError): - pass - return size - - -def _win_pty_size() -> Tuple[Optional[str], Optional[str]]: - class CONSOLE_SCREEN_BUFFER_INFO(Structure): - _fields_ = [ - ("dwSize", _COORD), - ("dwCursorPosition", _COORD), - ("wAttributes", c_ushort), - ("srWindow", _SMALL_RECT), - ("dwMaximumWindowSize", _COORD), - ] - - GetStdHandle = windll.kernel32.GetStdHandle - GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo - GetStdHandle.restype = HANDLE - GetConsoleScreenBufferInfo.argtypes = [ - HANDLE, - POINTER(CONSOLE_SCREEN_BUFFER_INFO), - ] - - hstd = GetStdHandle(-11) # STD_OUTPUT_HANDLE = -11 - csbi = CONSOLE_SCREEN_BUFFER_INFO() - ret = GetConsoleScreenBufferInfo(hstd, byref(csbi)) - - if ret: - sizex = csbi.srWindow.Right - csbi.srWindow.Left + 1 - sizey = csbi.srWindow.Bottom - csbi.srWindow.Top + 1 - return sizex, sizey - else: - return (None, None) + return (cols or 80, rows or 24) def stdin_is_foregrounded_tty(stream: IO) -> bool: @@ -211,8 +216,8 @@ def ready_for_reading(input_: IO) -> bool: # nonblocking fashion (e.g. a StringIO or regular file). if not has_fileno(input_): return True - if WINDOWS: - return msvcrt.kbhit() # type: ignore + if sys.platform == "win32": + return msvcrt.kbhit() else: reads, _, _ = select.select([input_], [], [], 0.0) return bool(reads and reads[0] is input_) diff --git a/invoke/util.py b/invoke/util.py index 31f9a11dc..df29c841a 100644 --- a/invoke/util.py +++ b/invoke/util.py @@ -21,8 +21,8 @@ from .vendor.lexicon import Lexicon # noqa from .vendor import yaml # noqa except ImportError: - from lexicon import Lexicon # type: ignore # noqa - import yaml # type: ignore # noqa + from lexicon import Lexicon # type: ignore[no-redef] # noqa + import yaml # type: ignore[no-redef] # noqa LOG_FORMAT = "%(name)s.%(module)s.%(funcName)s: %(message)s" diff --git a/invoke/vendor/decorator.py b/invoke/vendor/decorator.py deleted file mode 100644 index 43be27959..000000000 --- a/invoke/vendor/decorator.py +++ /dev/null @@ -1,414 +0,0 @@ -# ######################### LICENSE ############################ # - -# Copyright (c) 2005-2017, Michele Simionato -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in bytecode form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in -# the documentation and/or other materials provided with the -# distribution. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS -# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. - -""" -Decorator module, see http://pypi.python.org/pypi/decorator -for the documentation. -""" -from __future__ import print_function - -import re -import sys -import inspect -import operator -import itertools -import collections - -__version__ = '4.0.11' - -if sys.version >= '3': - from inspect import getfullargspec - - def get_init(cls): - return cls.__init__ -else: - FullArgSpec = collections.namedtuple( - 'FullArgSpec', 'args varargs varkw defaults ' - 'kwonlyargs kwonlydefaults') - - def getfullargspec(f): - "A quick and dirty replacement for getfullargspec for Python 2.X" - return FullArgSpec._make(inspect.getargspec(f) + ([], None)) - - def get_init(cls): - return cls.__init__.__func__ - -# getargspec has been deprecated in Python 3.5 -ArgSpec = collections.namedtuple( - 'ArgSpec', 'args varargs varkw defaults') - - -def getargspec(f): - """A replacement for inspect.getargspec""" - spec = getfullargspec(f) - return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults) - - -DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(') - - -# basic functionality -class FunctionMaker(object): - """ - An object with the ability to create functions with a given signature. - It has attributes name, doc, module, signature, defaults, dict and - methods update and make. - """ - - # Atomic get-and-increment provided by the GIL - _compile_count = itertools.count() - - # make pylint happy - args = varargs = varkw = defaults = kwonlyargs = kwonlydefaults = () - - def __init__(self, func=None, name=None, signature=None, - defaults=None, doc=None, module=None, funcdict=None): - self.shortsignature = signature - if func: - # func can be a class or a callable, but not an instance method - self.name = func.__name__ - if self.name == '': # small hack for lambda functions - self.name = '_lambda_' - self.doc = func.__doc__ - self.module = func.__module__ - if inspect.isfunction(func): - argspec = getfullargspec(func) - self.annotations = getattr(func, '__annotations__', {}) - for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', - 'kwonlydefaults'): - setattr(self, a, getattr(argspec, a)) - for i, arg in enumerate(self.args): - setattr(self, 'arg%d' % i, arg) - if sys.version < '3': # easy way - self.shortsignature = self.signature = ( - inspect.formatargspec( - formatvalue=lambda val: "", *argspec[:-2])[1:-1]) - else: # Python 3 way - allargs = list(self.args) - allshortargs = list(self.args) - if self.varargs: - allargs.append('*' + self.varargs) - allshortargs.append('*' + self.varargs) - elif self.kwonlyargs: - allargs.append('*') # single star syntax - for a in self.kwonlyargs: - allargs.append('%s=None' % a) - allshortargs.append('%s=%s' % (a, a)) - if self.varkw: - allargs.append('**' + self.varkw) - allshortargs.append('**' + self.varkw) - self.signature = ', '.join(allargs) - self.shortsignature = ', '.join(allshortargs) - self.dict = func.__dict__.copy() - # func=None happens when decorating a caller - if name: - self.name = name - if signature is not None: - self.signature = signature - if defaults: - self.defaults = defaults - if doc: - self.doc = doc - if module: - self.module = module - if funcdict: - self.dict = funcdict - # check existence required attributes - assert hasattr(self, 'name') - if not hasattr(self, 'signature'): - raise TypeError('You are decorating a non function: %s' % func) - - def update(self, func, **kw): - "Update the signature of func with the data in self" - func.__name__ = self.name - func.__doc__ = getattr(self, 'doc', None) - func.__dict__ = getattr(self, 'dict', {}) - func.__defaults__ = self.defaults - func.__kwdefaults__ = self.kwonlydefaults or None - func.__annotations__ = getattr(self, 'annotations', None) - try: - frame = sys._getframe(3) - except AttributeError: # for IronPython and similar implementations - callermodule = '?' - else: - callermodule = frame.f_globals.get('__name__', '?') - func.__module__ = getattr(self, 'module', callermodule) - func.__dict__.update(kw) - - def make(self, src_templ, evaldict=None, addsource=False, **attrs): - "Make a new function from a given template and update the signature" - src = src_templ % vars(self) # expand name and signature - evaldict = evaldict or {} - mo = DEF.match(src) - if mo is None: - raise SyntaxError('not a valid function template\n%s' % src) - name = mo.group(1) # extract the function name - names = set([name] + [arg.strip(' *') for arg in - self.shortsignature.split(',')]) - for n in names: - if n in ('_func_', '_call_'): - raise NameError('%s is overridden in\n%s' % (n, src)) - - if not src.endswith('\n'): # add a newline for old Pythons - src += '\n' - - # Ensure each generated function has a unique filename for profilers - # (such as cProfile) that depend on the tuple of (, - # , ) being unique. - filename = '' % (next(self._compile_count),) - try: - code = compile(src, filename, 'single') - exec(code, evaldict) - except: - print('Error in generated code:', file=sys.stderr) - print(src, file=sys.stderr) - raise - func = evaldict[name] - if addsource: - attrs['__source__'] = src - self.update(func, **attrs) - return func - - @classmethod - def create(cls, obj, body, evaldict, defaults=None, - doc=None, module=None, addsource=True, **attrs): - """ - Create a function from the strings name, signature and body. - evaldict is the evaluation dictionary. If addsource is true an - attribute __source__ is added to the result. The attributes attrs - are added, if any. - """ - if isinstance(obj, str): # "name(signature)" - name, rest = obj.strip().split('(', 1) - signature = rest[:-1] # strip a right parens - func = None - else: # a function - name = None - signature = None - func = obj - self = cls(func, name, signature, defaults, doc, module) - ibody = '\n'.join(' ' + line for line in body.splitlines()) - return self.make('def %(name)s(%(signature)s):\n' + ibody, - evaldict, addsource, **attrs) - - -def decorate(func, caller): - """ - decorate(func, caller) decorates a function using a caller. - """ - evaldict = dict(_call_=caller, _func_=func) - fun = FunctionMaker.create( - func, "return _call_(_func_, %(shortsignature)s)", - evaldict, __wrapped__=func) - if hasattr(func, '__qualname__'): - fun.__qualname__ = func.__qualname__ - return fun - - -def decorator(caller, _func=None): - """decorator(caller) converts a caller function into a decorator""" - if _func is not None: # return a decorated function - # this is obsolete behavior; you should use decorate instead - return decorate(_func, caller) - # else return a decorator function - if inspect.isclass(caller): - name = caller.__name__.lower() - doc = 'decorator(%s) converts functions/generators into ' \ - 'factories of %s objects' % (caller.__name__, caller.__name__) - elif inspect.isfunction(caller): - if caller.__name__ == '': - name = '_lambda_' - else: - name = caller.__name__ - doc = caller.__doc__ - else: # assume caller is an object with a __call__ method - name = caller.__class__.__name__.lower() - doc = caller.__call__.__doc__ - evaldict = dict(_call_=caller, _decorate_=decorate) - return FunctionMaker.create( - '%s(func)' % name, 'return _decorate_(func, _call_)', - evaldict, doc=doc, module=caller.__module__, - __wrapped__=caller) - - -# ####################### contextmanager ####################### # - -try: # Python >= 3.2 - from contextlib import _GeneratorContextManager -except ImportError: # Python >= 2.5 - from contextlib import GeneratorContextManager as _GeneratorContextManager - - -class ContextManager(_GeneratorContextManager): - def __call__(self, func): - """Context manager decorator""" - return FunctionMaker.create( - func, "with _self_: return _func_(%(shortsignature)s)", - dict(_self_=self, _func_=func), __wrapped__=func) - - -init = getfullargspec(_GeneratorContextManager.__init__) -n_args = len(init.args) -if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7 - def __init__(self, g, *a, **k): - return _GeneratorContextManager.__init__(self, g(*a, **k)) - ContextManager.__init__ = __init__ -elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4 - pass -elif n_args == 4: # (self, gen, args, kwds) Python 3.5 - def __init__(self, g, *a, **k): - return _GeneratorContextManager.__init__(self, g, a, k) - ContextManager.__init__ = __init__ - -contextmanager = decorator(ContextManager) - - -# ############################ dispatch_on ############################ # - -def append(a, vancestors): - """ - Append ``a`` to the list of the virtual ancestors, unless it is already - included. - """ - add = True - for j, va in enumerate(vancestors): - if issubclass(va, a): - add = False - break - if issubclass(a, va): - vancestors[j] = a - add = False - if add: - vancestors.append(a) - - -# inspired from simplegeneric by P.J. Eby and functools.singledispatch -def dispatch_on(*dispatch_args): - """ - Factory of decorators turning a function into a generic function - dispatching on the given arguments. - """ - assert dispatch_args, 'No dispatch args passed' - dispatch_str = '(%s,)' % ', '.join(dispatch_args) - - def check(arguments, wrong=operator.ne, msg=''): - """Make sure one passes the expected number of arguments""" - if wrong(len(arguments), len(dispatch_args)): - raise TypeError('Expected %d arguments, got %d%s' % - (len(dispatch_args), len(arguments), msg)) - - def gen_func_dec(func): - """Decorator turning a function into a generic function""" - - # first check the dispatch arguments - argset = set(getfullargspec(func).args) - if not set(dispatch_args) <= argset: - raise NameError('Unknown dispatch arguments %s' % dispatch_str) - - typemap = {} - - def vancestors(*types): - """ - Get a list of sets of virtual ancestors for the given types - """ - check(types) - ras = [[] for _ in range(len(dispatch_args))] - for types_ in typemap: - for t, type_, ra in zip(types, types_, ras): - if issubclass(t, type_) and type_ not in t.mro(): - append(type_, ra) - return [set(ra) for ra in ras] - - def ancestors(*types): - """ - Get a list of virtual MROs, one for each type - """ - check(types) - lists = [] - for t, vas in zip(types, vancestors(*types)): - n_vas = len(vas) - if n_vas > 1: - raise RuntimeError( - 'Ambiguous dispatch for %s: %s' % (t, vas)) - elif n_vas == 1: - va, = vas - mro = type('t', (t, va), {}).mro()[1:] - else: - mro = t.mro() - lists.append(mro[:-1]) # discard t and object - return lists - - def register(*types): - """ - Decorator to register an implementation for the given types - """ - check(types) - - def dec(f): - check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__) - typemap[types] = f - return f - return dec - - def dispatch_info(*types): - """ - An utility to introspect the dispatch algorithm - """ - check(types) - lst = [] - for anc in itertools.product(*ancestors(*types)): - lst.append(tuple(a.__name__ for a in anc)) - return lst - - def _dispatch(dispatch_args, *args, **kw): - types = tuple(type(arg) for arg in dispatch_args) - try: # fast path - f = typemap[types] - except KeyError: - pass - else: - return f(*args, **kw) - combinations = itertools.product(*ancestors(*types)) - next(combinations) # the first one has been already tried - for types_ in combinations: - f = typemap.get(types_) - if f is not None: - return f(*args, **kw) - - # else call the default implementation - return func(*args, **kw) - - return FunctionMaker.create( - func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str, - dict(_f_=_dispatch), register=register, default=func, - typemap=typemap, vancestors=vancestors, ancestors=ancestors, - dispatch_info=dispatch_info, __wrapped__=func) - - gen_func_dec.__name__ = 'dispatch_on' + dispatch_str - return gen_func_dec diff --git a/invoke/watchers.py b/invoke/watchers.py index 2ce98fe03..eb813df28 100644 --- a/invoke/watchers.py +++ b/invoke/watchers.py @@ -96,12 +96,12 @@ def pattern_matches( # once, e.g. in FailingResponder. # Only look at stream contents we haven't seen yet, to avoid dupes. index = getattr(self, index_attr) - new_ = stream[index:] + new = stream[index:] # Search, across lines if necessary - matches = re.findall(pattern, new_, re.S) + matches = re.findall(pattern, new, re.S) # Update seek index if we've matched if matches: - setattr(self, index_attr, index + len(new_)) + setattr(self, index_attr, index + len(new)) return matches def submit(self, stream: str) -> Generator[str, None, None]: diff --git a/pyproject.toml b/pyproject.toml index e2f64f8a7..5965db502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ disallow_untyped_defs = true # "unused-awaitable", # exclude = [ - "integration/", "tests/", "setup.py", "tasks.py", "sites/www/conf.py" + "integration/", "tests/", "setup.py", "sites/www/conf.py" ] ignore_missing_imports = true # implicit_reexport = False @@ -34,3 +34,19 @@ warn_unused_ignores = true [[tool.mypy.overrides]] module = "invoke.vendor.*" ignore_errors = true + +[[tool.mypy.overrides]] +module = "alabaster" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "icecream" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "invocations" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "pytest_relaxed" +ignore_missing_imports = true diff --git a/tasks.py b/tasks.py index 63930a1cc..228b29314 100644 --- a/tasks.py +++ b/tasks.py @@ -1,4 +1,5 @@ import os +from typing import TYPE_CHECKING, Optional from invoke import Collection, task, Exit @@ -7,19 +8,22 @@ from invocations.pytest import coverage as coverage_, test as test_ from invocations.packaging import vendorize, release +if TYPE_CHECKING: + from invoke import Context + @task def test( - c, - verbose=False, - color=True, - capture="no", - module=None, - k=None, - x=False, - opts="", - pty=True, -): + c: "Context", + verbose: bool = False, + color: bool = True, + capture: str = "no", + module: Optional[str] = None, + k: Optional[str] = None, + x: bool = False, + opts: str = "", + pty: bool = True, +) -> None: """ Run pytest. See `invocations.pytest.test` for details. @@ -34,7 +38,7 @@ def test( """ # TODO: update test suite to use c.config.run.in_stream = False globally. # somehow. - return test_( + test_( c, verbose=verbose, color=color, @@ -47,10 +51,15 @@ def test( ) +print('test', vars(test), type(test)) + + # TODO: replace with invocations' once the "call truly local tester" problem is # solved (see other TODOs). For now this is just a copy/paste/modify. -@task(help=test.help) -def integration(c, opts=None, pty=True): +@task(help=test.help) # type: ignore +def integration( + c: "Context", opts: Optional[str] = None, pty: bool = True +) -> None: """ Run the integration test suite. May be slow! """ @@ -67,7 +76,9 @@ def integration(c, opts=None, pty=True): @task -def coverage(c, report="term", opts="", codecov=False): +def coverage( + c: "Context", report: str = "term", opts: str = "", codecov: bool = False +) -> None: """ Run pytest in coverage mode. See `invocations.pytest.coverage` for details. """ @@ -86,7 +97,7 @@ def coverage(c, report="term", opts="", codecov=False): @task -def regression(c, jobs=8): +def regression(c: "Context", jobs: int = 8) -> None: """ Run an expensive, hard-to-test-in-pytest run() regression checker.