Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ruff integration #34

Merged
merged 7 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ jobs:
run: |
make pre-commit

- name: ruff
run: |
make ruff

- name: flake8
run: |
make flake8
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ repos:
hooks:
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.247
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
stages: [commit, push, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#33](https://github.com/metaopt/optree/pull/33) and [#34](https://github.com/metaopt/optree/pull/34).

### Changed

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ make uninstall

We use several tools to secure code quality, including:

- PEP8 code style: `black`, `isort`, `pylint`, `flake8`
- Python code style: `black`, `isort`, `pylint`, `flake8`, `ruff`
- Type hint check: `mypy`
- C++ code style: `cpplint`, `clang-format`, `clang-tidy`
- License: `addlicense`
Expand Down
14 changes: 12 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ py-format-install:
$(call check_pip_install,isort)
$(call check_pip_install,black)

ruff-install:
$(call check_pip_install,ruff)

mypy-install:
$(call check_pip_install,mypy)

Expand Down Expand Up @@ -132,6 +135,12 @@ py-format: py-format-install
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
$(PYTHON) -m black --check $(PYTHON_FILES)

ruff: ruff-install
$(PYTHON) -m ruff check .

ruff-fix: ruff-install
$(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix

mypy: mypy-install
$(PYTHON) -m mypy $(PROJECT_PATH)

Expand Down Expand Up @@ -187,11 +196,12 @@ clean-docs:

# Utility functions

lint: flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling
lint: ruff flake8 py-format mypy pylint doctest clang-format clang-tidy cpplint addlicense docstyle spelling

format: py-format-install clang-format-install addlicense-install
format: py-format-install ruff-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES)
$(PYTHON) -m ruff check . --fix --exit-zero
$(CLANG_FORMAT) -style=file -i $(CXX_FILES)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)

Expand Down
2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def fn3(x, y, z):
)


def cprint(text=''):
def cprint(text: str = '') -> None:
text = (
text.replace(
', none_is_leaf=False',
Expand Down
1 change: 1 addition & 0 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies:
- flake8-docstrings
- flake8-pyi
- flake8-simplify
- ruff
- doc8
- pydocstyle
- xdoctest
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

def get_version() -> str:
sys.path.insert(0, str(PROJECT_ROOT / 'optree'))
import version # noqa
import version

return version.__version__

Expand Down
34 changes: 22 additions & 12 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,31 @@ MAX_RECURSION_DEPTH: int

def flatten(
tree: PyTree[T],
leaf_predicate: Callable[[T], bool] | None = None,
node_is_leaf: bool = False,
namespace: str = '',
leaf_predicate: Callable[[T], bool] | None = ..., # None
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> builtins.tuple[list[T], PyTreeSpec]: ...
def flatten_with_path(
tree: PyTree[T],
leaf_predicate: Callable[[T], bool] | None = None,
node_is_leaf: bool = False,
namespace: str = '',
leaf_predicate: Callable[[T], bool] | None = ..., # None
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> builtins.tuple[list[builtins.tuple[Any, ...]], list[T], PyTreeSpec]: ...
def all_leaves(
iterable: Iterable[T],
node_is_leaf: bool = False,
namespace: str = '',
node_is_leaf: bool = ..., # False
namespace: str = ..., # ''
) -> bool: ...
def leaf(node_is_leaf: bool = False) -> PyTreeSpec: ...
def none(node_is_leaf: bool = False) -> PyTreeSpec: ...
def tuple(treespecs: Sequence[PyTreeSpec], node_is_leaf: bool = False) -> PyTreeSpec: ...
def leaf(
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def none(
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def tuple(
treespecs: Sequence[PyTreeSpec],
node_is_leaf: bool = ..., # False
) -> PyTreeSpec: ...
def is_namedtuple_class(cls: type) -> bool: ...
def is_structseq_class(cls: type) -> bool: ...
def structseq_fields(obj: object | type) -> builtins.tuple[str]: ...
Expand All @@ -66,7 +73,10 @@ class PyTreeSpec:
leaves: Iterable[T],
) -> U: ...
def children(self) -> list[PyTreeSpec]: ...
def is_leaf(self, strict: bool = True) -> bool: ...
def is_leaf(
self,
strict: bool = ..., # True
) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
Expand Down
9 changes: 3 additions & 6 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections import deque
from typing import Any, Callable, cast, overload

import optree._C as _C
from optree import _C
from optree.registry import (
AttributeKeyPathEntry,
FlattenedKeyPathEntry,
Expand Down Expand Up @@ -1157,10 +1157,7 @@ def flatten_one_level(
)
children, metadata, entries = flattened
children = list(children)
if entries is None:
entries = tuple(range(len(children)))
else:
entries = tuple(entries)
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
Expand Down Expand Up @@ -1300,7 +1297,7 @@ def _child_keys(
namespace: str = '',
) -> list[KeyPathEntry]:
treespec = tree_structure(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
assert not treespec_is_strict_leaf(treespec)
assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node'

handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined]
if handler:
Expand Down
12 changes: 6 additions & 6 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from threading import Lock
from typing import Any, Callable, Iterable, NamedTuple, Sequence, overload

import optree._C as _C
from optree import _C
from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, PyTree, T, UnflattenFunc
from optree.utils import safe_zip, unzip2

Expand Down Expand Up @@ -387,7 +387,7 @@ class _HashablePartialShim:
def __init__(self, partial_func: functools.partial) -> None:
self.partial_func: functools.partial = partial_func

def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.partial_func(*args, **kwargs)

def __hash__(self) -> int:
Expand Down Expand Up @@ -449,7 +449,7 @@ class Partial(functools.partial, CustomTreeNode[Any]): # pylint: disable=too-fe
args: tuple[Any, ...]
keywords: dict[str, Any]

def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial:
def __new__(cls, func: Callable[..., Any], *args: Any, **keywords: Any) -> Partial:
"""Create a new :class:`Partial` instance."""
# In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__
# would merge the arguments of this Partial instance with the arguments of the func. We box
Expand All @@ -458,7 +458,7 @@ def __new__(cls, func: Callable[..., Any], *args, **keywords) -> Partial:
if isinstance(func, functools.partial):
original_func = func
func = _HashablePartialShim(original_func)
assert not hasattr(func, 'func')
assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute'
out = super().__new__(cls, func, *args, **keywords)
func.func = original_func.func
func.args = original_func.args
Expand Down Expand Up @@ -488,7 +488,7 @@ def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath((self, other))
if isinstance(other, KeyPath):
return KeyPath((self,) + other.keys)
return KeyPath((self, *other.keys))
return NotImplemented

def __eq__(self, other: object) -> bool:
Expand All @@ -504,7 +504,7 @@ class KeyPath(NamedTuple):

def __add__(self, other: object) -> KeyPath:
if isinstance(other, KeyPathEntry):
return KeyPath(self.keys + (other,))
return KeyPath((*self.keys, other))
if isinstance(other, KeyPath):
return KeyPath(self.keys + other.keys)
return NotImplemented
Expand Down
35 changes: 21 additions & 14 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Hashable,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Expand All @@ -39,7 +40,7 @@
from typing_extensions import Protocol # Python 3.8+
from typing_extensions import TypeAlias # Python 3.10+

import optree._C as _C
from optree import _C


try:
Expand Down Expand Up @@ -118,18 +119,18 @@ def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTree
_GenericAlias = type(Union[int, str])


def _tp_cache(func):
def _tp_cache(func: Callable) -> Callable:
import functools # pylint: disable=import-outside-toplevel

cached = functools.lru_cache()(func)

@functools.wraps(func)
def inner(*args, **kwds):
try: # noqa: SIM105
def inner(*args: Any, **kwds: Any) -> Any:
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
# All real errors (not unhashable args) are raised below.
return func(*args, **kwds)

return inner

Expand All @@ -150,7 +151,9 @@ class PyTree(Generic[T]): # pylint: disable=too-few-public-methods
"""

@_tp_cache
def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAlias:
def __class_getitem__( # noqa: C901
cls, item: T | tuple[T] | tuple[T, str | None]
) -> TypeAlias:
"""Instantiate a PyTree type with the given type."""
if not isinstance(item, tuple):
item = (item, None)
Expand Down Expand Up @@ -200,15 +203,19 @@ def __class_getitem__(cls, item: T | tuple[T] | tuple[T, str | None]) -> TypeAli
pytree_alias.__pytree_args__ = item # type: ignore[attr-defined]
return pytree_alias

def __init_subclass__(cls, *args, **kwargs):
def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ
"""Prohibit instantiation."""
raise TypeError('Cannot instantiate special typing classes.')

def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')

def __copy__(self):
def __copy__(self) -> PyTree:
"""Immutable copy."""
return self

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any]) -> PyTree:
"""Immutable copy."""
return self

Expand All @@ -235,15 +242,15 @@ def __new__(cls, name: str, param: type) -> TypeAlias:
raise TypeError(f'{cls.__name__} only supports a string of type name. Got {name!r}.')
return PyTree[param, name] # type: ignore[misc,valid-type]

def __init_subclass__(cls, *args, **kwargs):
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')

def __copy__(self):
def __copy__(self) -> TypeAlias:
"""Immutable copy."""
return self

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
"""Immutable copy."""
return self

Expand Down Expand Up @@ -296,5 +303,5 @@ class SubClass(cls): # type: ignore[misc,valid-type]


# Ensure that the behavior is consistent with C++ implementation
# pylint: disable-next=wrong-import-position
# pylint: disable-next=wrong-import-position,ungrouped-imports
from optree._C import is_namedtuple_class, is_structseq_class, structseq_fields
Loading