diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4e8d114f..25f98c53 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -47,6 +47,10 @@ jobs: run: | make pre-commit + - name: ruff + run: | + make ruff + - name: flake8 run: | make flake8 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b98670c..1140828d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,12 @@ repos: hooks: - id: clang-format stages: [commit, push, manual] + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.247 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + stages: [commit, push, manual] - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index f92fbc6d..41951cc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#33](https://github.com/metaopt/optree/pull/33) and [#34](https://github.com/metaopt/optree/pull/34). ### Changed diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 464bc3e9..8f2e2cf6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -45,7 +45,7 @@ make uninstall We use several tools to secure code quality, including: -- PEP8 code style: `black`, `isort`, `pylint`, `flake8` +- Python code style: `black`, `isort`, `pylint`, `flake8`, `ruff` - Type hint check: `mypy` - C++ code style: `cpplint`, `clang-format`, `clang-tidy` - License: `addlicense` diff --git a/Makefile b/Makefile index affaa575..8161fb7e 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,9 @@ py-format-install: $(call check_pip_install,isort) $(call check_pip_install,black) +ruff-install: + $(call check_pip_install,ruff) + mypy-install: $(call check_pip_install,mypy) @@ -132,6 +135,12 @@ py-format: py-format-install $(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \ $(PYTHON) -m black --check $(PYTHON_FILES) +ruff: ruff-install + $(PYTHON) -m ruff check . + +ruff-fix: ruff-install + $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix + mypy: mypy-install $(PYTHON) -m mypy $(PROJECT_PATH) @@ -187,11 +196,12 @@ clean-docs: # Utility functions -lint: flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling +lint: ruff flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling -format: py-format-install clang-format-install addlicense-install +format: py-format-install ruff-install clang-format-install addlicense-install $(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES) $(PYTHON) -m black $(PYTHON_FILES) + $(PYTHON) -m ruff check . --fix --exit-zero $(CLANG_FORMAT) -style=file -i $(CXX_FILES) addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS) diff --git a/benchmark.py b/benchmark.py index 58c6ee74..28f117be 100755 --- a/benchmark.py +++ b/benchmark.py @@ -254,7 +254,7 @@ def fn3(x, y, z): ) -def cprint(text=''): +def cprint(text: str = '') -> None: text = ( text.replace( ', none_is_leaf=False', diff --git a/conda-recipe.yaml b/conda-recipe.yaml index be90b3ee..45bfc99d 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -75,6 +75,7 @@ dependencies: - flake8-docstrings - flake8-pyi - flake8-simplify + - ruff - doc8 - pydocstyle - xdoctest diff --git a/docs/source/conf.py b/docs/source/conf.py index cbb3dfd5..f43c13e5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,7 +36,7 @@ def get_version() -> str: sys.path.insert(0, str(PROJECT_ROOT / 'optree')) - import version # noqa + import version return version.__version__ diff --git a/optree/_C.pyi b/optree/_C.pyi index e363eb1f..47bf0655 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -27,24 +27,31 @@ MAX_RECURSION_DEPTH: int def flatten( tree: PyTree[T], - leaf_predicate: Callable[[T], bool] | None = None, - node_is_leaf: bool = False, - namespace: str = '', + leaf_predicate: Callable[[T], bool] | None = ..., # None + node_is_leaf: bool = ..., # False + namespace: str = ..., # '' ) -> builtins.tuple[list[T], PyTreeSpec]: ... def flatten_with_path( tree: PyTree[T], - leaf_predicate: Callable[[T], bool] | None = None, - node_is_leaf: bool = False, - namespace: str = '', + leaf_predicate: Callable[[T], bool] | None = ..., # None + node_is_leaf: bool = ..., # False + namespace: str = ..., # '' ) -> builtins.tuple[list[builtins.tuple[Any, ...]], list[T], PyTreeSpec]: ... def all_leaves( iterable: Iterable[T], - node_is_leaf: bool = False, - namespace: str = '', + node_is_leaf: bool = ..., # False + namespace: str = ..., # '' ) -> bool: ... -def leaf(node_is_leaf: bool = False) -> PyTreeSpec: ... -def none(node_is_leaf: bool = False) -> PyTreeSpec: ... -def tuple(treespecs: Sequence[PyTreeSpec], node_is_leaf: bool = False) -> PyTreeSpec: ... +def leaf( + node_is_leaf: bool = ..., # False +) -> PyTreeSpec: ... +def none( + node_is_leaf: bool = ..., # False +) -> PyTreeSpec: ... +def tuple( + treespecs: Sequence[PyTreeSpec], + node_is_leaf: bool = ..., # False +) -> PyTreeSpec: ... def is_namedtuple_class(cls: type) -> bool: ... def is_structseq_class(cls: type) -> bool: ... def structseq_fields(obj: object | type) -> builtins.tuple[str]: ... @@ -66,7 +73,10 @@ class PyTreeSpec: leaves: Iterable[T], ) -> U: ... def children(self) -> list[PyTreeSpec]: ... - def is_leaf(self, strict: bool = True) -> bool: ... + def is_leaf( + self, + strict: bool = ..., # True + ) -> bool: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... def __hash__(self) -> int: ... diff --git a/optree/ops.py b/optree/ops.py index e79a8d99..383ff679 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -24,7 +24,7 @@ from collections import deque from typing import Any, Callable, cast, overload -import optree._C as _C +from optree import _C from optree.registry import ( AttributeKeyPathEntry, FlattenedKeyPathEntry, @@ -1157,10 +1157,7 @@ def flatten_one_level( ) children, metadata, entries = flattened children = list(children) - if entries is None: - entries = tuple(range(len(children))) - else: - entries = tuple(entries) + entries = tuple(range(len(children)) if entries is None else entries) if len(children) != len(entries): raise RuntimeError( f'PyTree custom flatten function for type {node_type} returned inconsistent ' @@ -1300,7 +1297,7 @@ def _child_keys( namespace: str = '', ) -> list[KeyPathEntry]: treespec = tree_structure(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) - assert not treespec_is_strict_leaf(treespec) + assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node' handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined] if handler: diff --git a/optree/registry.py b/optree/registry.py index 4b552385..bf918d25 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -23,7 +23,7 @@ from threading import Lock from typing import Any, Callable, Iterable, NamedTuple, Sequence, overload -import optree._C as _C +from optree import _C from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, PyTree, T, UnflattenFunc from optree.utils import safe_zip, unzip2 @@ -387,7 +387,7 @@ class _HashablePartialShim: def __init__(self, partial_func: functools.partial) -> None: self.partial_func: functools.partial = partial_func - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.partial_func(*args, **kwargs) def __hash__(self) -> int: @@ -449,7 +449,7 @@ class Partial(functools.partial, CustomTreeNode[Any]): # pylint: disable=too-fe args: tuple[Any, ...] keywords: dict[str, Any] - def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial: + def __new__(cls, func: Callable[..., Any], *args: Any, **keywords: Any) -> Partial: """Create a new :class:`Partial` instance.""" # In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__ # would merge the arguments of this Partial instance with the arguments of the func. We box @@ -458,7 +458,7 @@ def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial: if isinstance(func, functools.partial): original_func = func func = _HashablePartialShim(original_func) - assert not hasattr(func, 'func') + assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute' out = super().__new__(cls, func, *args, **keywords) func.func = original_func.func func.args = original_func.args @@ -488,7 +488,7 @@ def __add__(self, other: object) -> KeyPath: if isinstance(other, KeyPathEntry): return KeyPath((self, other)) if isinstance(other, KeyPath): - return KeyPath((self,) + other.keys) + return KeyPath((self, *other.keys)) return NotImplemented def __eq__(self, other: object) -> bool: @@ -504,7 +504,7 @@ class KeyPath(NamedTuple): def __add__(self, other: object) -> KeyPath: if isinstance(other, KeyPathEntry): - return KeyPath(self.keys + (other,)) + return KeyPath((*self.keys, other)) if isinstance(other, KeyPath): return KeyPath(self.keys + other.keys) return NotImplemented diff --git a/optree/typing.py b/optree/typing.py index ea4ce697..d0431864 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -29,6 +29,7 @@ Hashable, Iterable, List, + NoReturn, Optional, Sequence, Tuple, @@ -39,7 +40,7 @@ from typing_extensions import Protocol # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ -import optree._C as _C +from optree import _C try: @@ -118,18 +119,18 @@ def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTree _GenericAlias = type(Union[int, str]) -def _tp_cache(func): +def _tp_cache(func: Callable) -> Callable: import functools # pylint: disable=import-outside-toplevel cached = functools.lru_cache()(func) @functools.wraps(func) - def inner(*args, **kwds): - try: # noqa: SIM105 + def inner(*args: Any, **kwds: Any) -> Any: + try: return cached(*args, **kwds) except TypeError: - pass # All real errors (not unhashable args) are raised below. - return func(*args, **kwds) + # All real errors (not unhashable args) are raised below. + return func(*args, **kwds) return inner @@ -150,7 +151,9 @@ class PyTree(Generic[T]): # pylint: disable=too-few-public-methods """ @_tp_cache - def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAlias: + def __class_getitem__( # noqa: C901 + cls, item: T | tuple[T] | tuple[T, str | None] + ) -> TypeAlias: """Instantiate a PyTree type with the given type.""" if not isinstance(item, tuple): item = (item, None) @@ -200,15 +203,19 @@ def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAli pytree_alias.__pytree_args__ = item # type: ignore[attr-defined] return pytree_alias - def __init_subclass__(cls, *args, **kwargs): + def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ + """Prohibit instantiation.""" + raise TypeError('Cannot instantiate special typing classes.') + + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn: """Prohibit subclassing.""" raise TypeError('Cannot subclass special typing classes.') - def __copy__(self): + def __copy__(self) -> PyTree: """Immutable copy.""" return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any]) -> PyTree: """Immutable copy.""" return self @@ -235,15 +242,15 @@ def __new__(cls, name: str, param: type) -> TypeAlias: raise TypeError(f'{cls.__name__} only supports a string of type name. Got {name!r}.') return PyTree[param, name] # type: ignore[misc,valid-type] - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn: """Prohibit subclassing.""" raise TypeError('Cannot subclass special typing classes.') - def __copy__(self): + def __copy__(self) -> TypeAlias: """Immutable copy.""" return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias: """Immutable copy.""" return self @@ -296,5 +303,5 @@ class SubClass(cls): # type: ignore[misc,valid-type] # Ensure that the behavior is consistent with C++ implementation -# pylint: disable-next=wrong-import-position +# pylint: disable-next=wrong-import-position,ungrouped-imports from optree._C import is_namedtuple_class, is_structseq_class, structseq_fields diff --git a/pyproject.toml b/pyproject.toml index e1cad7a0..50bcd5b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ lint = [ "flake8-docstrings", "flake8-pyi", "flake8-simplify", + "ruff", "doc8 < 1.0.0a0", # unpin this when we drop support for Python 3.7 "pydocstyle", "pyenchant", @@ -135,5 +136,81 @@ convention = "google" [tool.doc8] max-line-length = 500 +[tool.ruff] +target-version = "py37" +line-length = 100 +show-source = true +src = ["optree", "tests"] +extend-exclude = ["tests"] +select = [ + "E", "W", # pycodestyle + "F", # pyflakes + "C90", # mccabe + "N", # pep8-naming + "UP", # pyupgrade + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "EXE", # flake8-executable + "ISC", # flake8-implicit-str-concat + "PIE", # flake8-pie + "PYI", # flake8-pyi + "RSE", # flake8-raise + "RET", # flake8-return + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "PL", # pylint + "RUF", # ruff +] +ignore = [ + # E501: line too long + # W505: doc line too long + # too long docstring due to long example blocks + "E501", + "W505", + # ANN101: missing type annotation for `self` in method + # ANN102: missing type annotation for `cls` in classmethod + "ANN101", + "ANN102", + # ANN401: dynamically typed expressions (typing.Any) are disallowed + "ANN401", + # S101: use of `assert` detected + # internal use and may never raise at runtime + "S101", + # PLR0402: use from {module} import {name} in lieu of alias + # use alias for import convention (e.g., `import torch.nn as nn`) + "PLR0402", +] +typing-modules = ["optree.typing"] + +[tool.ruff.per-file-ignores] +"__init__.py" = [ + "F401", # unused-import +] +"optree/typing.py" = [ + "E402", # module-import-not-at-top-of-file + "F722", # forward-annotation-syntax-error + "F811", # redefined-while-unused +] +"setup.py" = [ + "ANN", # flake8-annotations +] + +[tool.ruff.flake8-annotations] +allow-star-arg-any = true + +[tool.ruff.flake8-quotes] +docstring-quotes = "double" +multiline-quotes = "double" +inline-quotes = "single" + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.pylint] +allow-magic-value-types = ["int", "str", "float"] + [tool.pytest.ini_options] filterwarnings = ["error"] diff --git a/setup.py b/setup.py index 075cc212..3872b331 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def __init__(self, name, source_dir='.', target=None, **kwargs): self.target = target if target is not None else name.rpartition('.')[-1] -class cmake_build_ext(build_ext): +class cmake_build_ext(build_ext): # noqa: N801 def build_extension(self, ext): if not isinstance(ext, CMakeExtension): super().build_extension(ext) @@ -80,9 +80,9 @@ def build_extension(self, ext): try: os.chdir(build_temp) - self.spawn([cmake, ext.source_dir] + cmake_args) + self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: - self.spawn([cmake, '--build', '.'] + build_args) + self.spawn([cmake, '--build', '.', *build_args]) finally: os.chdir(HERE) diff --git a/tests/requirements.txt b/tests/requirements.txt index e0cc8035..233d4190 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -13,6 +13,7 @@ flake8-comprehensions flake8-docstrings flake8-pyi flake8-simplify +ruff doc8 < 1.0.0a0 # unpin this when we drop support for Python 3.7 pydocstyle pyenchant