Skip to content

Commit

Permalink
Merge pull request #906 from kuwv/typing
Browse files Browse the repository at this point in the history
test: implement type hints
  • Loading branch information
bitprophet authored Feb 12, 2023
2 parents 0bcee75 + f24e290 commit 4a48966
Show file tree
Hide file tree
Showing 27 changed files with 985 additions and 1,050 deletions.
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
exclude = invoke/vendor,sites,.git,build,dist,alt_env,appveyor
ignore = E124,E125,E128,E261,E301,E302,E303,E306,W503,E731
max-line-length = 79

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ src/
htmlcov
coverage.xml
.cache
.mypy_cache/
4 changes: 4 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ black>=22.8,<22.9
setuptools>56
# Debuggery
icecream>=2.1
# typing
mypy==0.971
typed-ast==1.5.4
types-PyYAML==6.0.12.4
6 changes: 4 additions & 2 deletions invoke/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Optional

from ._version import __version_info__, __version__ # noqa
from .collection import Collection # noqa
from .config import Config # noqa
Expand Down Expand Up @@ -29,7 +31,7 @@
from .watchers import FailingResponder, Responder, StreamWatcher # noqa


def run(command, **kwargs):
def run(command: str, **kwargs: Any) -> Optional[Result]:
"""
Run ``command`` in a subprocess and return a `.Result` object.
Expand All @@ -48,7 +50,7 @@ def run(command, **kwargs):
return Context().run(command, **kwargs)


def sudo(command, **kwargs):
def sudo(command: str, **kwargs: Any) -> Optional[Result]:
"""
Run ``command`` in a ``sudo`` subprocess and return a `.Result` object.
Expand Down
130 changes: 74 additions & 56 deletions invoke/collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import types
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple

from .util import Lexicon, helpline

Expand All @@ -15,7 +16,7 @@ class Collection:
.. versionadded:: 1.0
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Create a new task collection/namespace.
Expand Down Expand Up @@ -92,67 +93,64 @@ def __init__(self, *args, **kwargs):
# Initialize
self.tasks = Lexicon()
self.collections = Lexicon()
self.default = None
self.default: Optional[str] = None
self.name = None
self._configuration = {}
self._configuration: Dict[str, Any] = {}
# Specific kwargs if applicable
self.loaded_from = kwargs.pop("loaded_from", None)
self.auto_dash_names = kwargs.pop("auto_dash_names", None)
# splat-kwargs version of default value (auto_dash_names=True)
if self.auto_dash_names is None:
self.auto_dash_names = True
# Name if applicable
args = list(args)
if args and isinstance(args[0], str):
self.name = self.transform(args.pop(0))
_args = list(args)
if _args and isinstance(args[0], str):
self.name = self.transform(_args.pop(0))
# Dispatch args/kwargs
for arg in args:
for arg in _args:
self._add_object(arg)
# Dispatch kwargs
for name, obj in kwargs.items():
self._add_object(obj, name)

def _add_object(self, obj, name=None):
def _add_object(self, obj: Any, name: Optional[str] = None) -> None:
method: Callable
if isinstance(obj, Task):
method = self.add_task
elif isinstance(obj, (Collection, types.ModuleType)):
elif isinstance(obj, (Collection, ModuleType)):
method = self.add_collection
else:
raise TypeError("No idea how to insert {!r}!".format(type(obj)))
return method(obj, name=name)
method(obj, name=name)

def __repr__(self):
def __repr__(self) -> str:
task_names = list(self.tasks.keys())
collections = ["{}...".format(x) for x in self.collections.keys()]
return "<Collection {!r}: {}>".format(
self.name, ", ".join(sorted(task_names) + sorted(collections))
)

def __eq__(self, other):
return (
self.name == other.name
and self.tasks == other.tasks
and self.collections == other.collections
)

def __ne__(self, other):
return not self == other

def __nonzero__(self):
return self.__bool__()
def __eq__(self, other: object) -> bool:
if isinstance(other, Collection):
return (
self.name == other.name
and self.tasks == other.tasks
and self.collections == other.collections
)
return False

def __bool__(self):
def __bool__(self) -> bool:
return bool(self.task_names)

@classmethod
def from_module(
cls,
module,
name=None,
config=None,
loaded_from=None,
auto_dash_names=None,
):
module: ModuleType,
name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
loaded_from: Optional[str] = None,
auto_dash_names: Optional[bool] = None,
) -> "Collection":
"""
Return a new `.Collection` created from ``module``.
Expand Down Expand Up @@ -198,7 +196,7 @@ def from_module(
"""
module_name = module.__name__.split(".")[-1]

def instantiate(obj_name=None):
def instantiate(obj_name: Optional[str] = None) -> "Collection":
# Explicitly given name wins over root ns name (if applicable),
# which wins over actual module name.
args = [name or obj_name or module_name]
Expand All @@ -218,7 +216,9 @@ def instantiate(obj_name=None):
ret = instantiate(obj_name=obj.name)
ret.tasks = ret._transform_lexicon(obj.tasks)
ret.collections = ret._transform_lexicon(obj.collections)
ret.default = ret.transform(obj.default)
ret.default = (
ret.transform(obj.default) if obj.default else None
)
# Explicitly given config wins over root ns config
obj_config = copy_dict(obj._configuration)
if config:
Expand All @@ -235,7 +235,13 @@ def instantiate(obj_name=None):
collection.configure(config)
return collection

def add_task(self, task, name=None, aliases=None, default=None):
def add_task(
self,
task: "Task",
name: Optional[str] = None,
aliases: Optional[Tuple[str, ...]] = None,
default: Optional[bool] = None,
) -> None:
"""
Add `.Task` ``task`` to this collection.
Expand All @@ -258,8 +264,9 @@ def add_task(self, task, name=None, aliases=None, default=None):
if name is None:
if task.name:
name = task.name
# XXX https://github.com/python/mypy/issues/1424
elif hasattr(task.body, "func_name"):
name = task.body.func_name
name = task.body.func_name # type: ignore
elif hasattr(task.body, "__name__"):
name = task.__name__
else:
Expand All @@ -275,7 +282,12 @@ def add_task(self, task, name=None, aliases=None, default=None):
self._check_default_collision(name)
self.default = name

def add_collection(self, coll, name=None, default=None):
def add_collection(
self,
coll: "Collection",
name: Optional[str] = None,
default: Optional[bool] = None,
) -> None:
"""
Add `.Collection` ``coll`` as a sub-collection of this one.
Expand All @@ -294,7 +306,7 @@ def add_collection(self, coll, name=None, default=None):
Added the ``default`` parameter.
"""
# Handle module-as-collection
if isinstance(coll, types.ModuleType):
if isinstance(coll, ModuleType):
coll = Collection.from_module(coll)
# Ensure we have a name, or die trying
name = name or coll.name
Expand All @@ -311,12 +323,12 @@ def add_collection(self, coll, name=None, default=None):
self._check_default_collision(name)
self.default = name

def _check_default_collision(self, name):
def _check_default_collision(self, name: str) -> None:
if self.default:
msg = "'{}' cannot be the default because '{}' already is!"
raise ValueError(msg.format(name, self.default))

def _split_path(self, path):
def _split_path(self, path: str) -> Tuple[str, str]:
"""
Obtain first collection + remainder, of a task path.
Expand All @@ -331,7 +343,7 @@ def _split_path(self, path):
rest = ".".join(parts)
return coll, rest

def subcollection_from_path(self, path):
def subcollection_from_path(self, path: str) -> "Collection":
"""
Given a ``path`` to a subcollection, return that subcollection.
Expand All @@ -343,7 +355,7 @@ def subcollection_from_path(self, path):
collection = collection.collections[parts.pop(0)]
return collection

def __getitem__(self, name=None):
def __getitem__(self, name: Optional[str] = None) -> Any:
"""
Returns task named ``name``. Honors aliases and subcollections.
Expand All @@ -359,11 +371,15 @@ def __getitem__(self, name=None):
"""
return self.task_with_config(name)[0]

def _task_with_merged_config(self, coll, rest, ours):
def _task_with_merged_config(
self, coll: str, rest: str, ours: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
task, config = self.collections[coll].task_with_config(rest)
return task, dict(config, **ours)

def task_with_config(self, name):
def task_with_config(
self, name: Optional[str]
) -> Tuple[str, Dict[str, Any]]:
"""
Return task named ``name`` plus its configuration dict.
Expand Down Expand Up @@ -397,14 +413,16 @@ def task_with_config(self, name):
# Regular task lookup
return self.tasks[name], ours

def __contains__(self, name):
def __contains__(self, name: str) -> bool:
try:
self[name]
return True
except KeyError:
return False

def to_contexts(self, ignore_unknown_help=None):
def to_contexts(
self, ignore_unknown_help: Optional[bool] = None
) -> List[ParserContext]:
"""
Returns all contained tasks and subtasks as a list of parser contexts.
Expand All @@ -430,12 +448,12 @@ def to_contexts(self, ignore_unknown_help=None):
)
return result

def subtask_name(self, collection_name, task_name):
def subtask_name(self, collection_name: str, task_name: str) -> str:
return ".".join(
[self.transform(collection_name), self.transform(task_name)]
)

def transform(self, name):
def transform(self, name: str) -> str:
"""
Transform ``name`` with the configured auto-dashes behavior.
Expand Down Expand Up @@ -474,25 +492,25 @@ def transform(self, name):
replaced.append(char)
return "".join(replaced)

def _transform_lexicon(self, old):
def _transform_lexicon(self, old: Lexicon) -> Lexicon:
"""
Take a Lexicon and apply `.transform` to its keys and aliases.
:returns: A new Lexicon.
"""
new_ = Lexicon()
new = Lexicon()
# Lexicons exhibit only their real keys in most places, so this will
# only grab those, not aliases.
for key, value in old.items():
# Deepcopy the value so we're not just copying a reference
new_[self.transform(key)] = copy.deepcopy(value)
new[self.transform(key)] = copy.deepcopy(value)
# Also copy all aliases, which are string-to-string key mappings
for key, value in old.aliases.items():
new_.alias(from_=self.transform(key), to=self.transform(value))
return new_
new.alias(from_=self.transform(key), to=self.transform(value))
return new

@property
def task_names(self):
def task_names(self) -> Dict[str, List[str]]:
"""
Return all task identifiers for this collection as a one-level dict.
Expand Down Expand Up @@ -523,7 +541,7 @@ def task_names(self):
ret[self.subtask_name(coll_name, task_name)] = aliases
return ret

def configuration(self, taskpath=None):
def configuration(self, taskpath: Optional[str] = None) -> Dict[str, Any]:
"""
Obtain merged configuration values from collection & children.
Expand All @@ -541,7 +559,7 @@ def configuration(self, taskpath=None):
return copy_dict(self._configuration)
return self.task_with_config(taskpath)[1]

def configure(self, options):
def configure(self, options: Dict[str, Any]) -> None:
"""
(Recursively) merge ``options`` into the current `.configuration`.
Expand All @@ -560,7 +578,7 @@ def configure(self, options):
"""
merge_dicts(self._configuration, options)

def serialized(self):
def serialized(self) -> Dict[str, Any]:
"""
Return an appropriate-for-serialization version of this object.
Expand Down
Loading

0 comments on commit 4a48966

Please sign in to comment.