Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace pre-commit linters (flake8, isort, black, ...) with ruff #539

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,43 @@ repos:
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.1.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-comprehensions
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/humitos/mirrors-autoflake.git
rev: v1.1
hooks:
- id: autoflake
exclude: |
(?x)^(
.*/?__init__\.py|
pytensor/graph/toolbox\.py|
pytensor/link/jax/jax_dispatch\.py|
pytensor/link/jax/jax_linker\.py|
pytensor/scalar/basic_scipy\.py|
pytensor/tensor/linalg\.py
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
# Run the linter
- id: ruff
args: [ --fix ]
# Run the formatter
- id: ruff-format
# - repo: https://github.com/psf/black
# rev: 23.1.0
# hooks:
# - id: black
# language_version: python3
# - repo: https://github.com/pycqa/flake8
# rev: 6.0.0
# hooks:
# - id: flake8
# additional_dependencies:
# - flake8-comprehensions
# - repo: https://github.com/pycqa/isort
# rev: 5.12.0
# hooks:
# - id: isort
# - repo: https://github.com/humitos/mirrors-autoflake.git
# rev: v1.1
# hooks:
# - id: autoflake
# exclude: |
# (?x)^(
# .*/?__init__\.py|
# pytensor/graph/toolbox\.py|
# pytensor/link/jax/jax_dispatch\.py|
# pytensor/link/jax/jax_linker\.py|
# pytensor/scalar/basic_scipy\.py|
# pytensor/tensor/linalg\.py
# )$
# args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.0
hooks:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,17 @@ skip = "pytensor/version.py"
skip_glob = "**/*.pyx"

[tool.ruff]
select=["C","E","F","W"]
select=["C","E","F","W","I"]
ignore=["E501","E741","C408","C901"]
exclude = [
"doc/",
"pytensor/_version.py",
"bin/pytensor_cache.py",
]

[tool.ruff.isort]
lines-after-imports = 2

[tool.ruff.per-file-ignores]
# TODO: Get rid of these:
"**/__init__.py"=["F401","E402","F403"]
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def accept(self, fgraph, no_recycling: Optional[list] = None, profile=None):
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
assert type(self) is _Linker
assert isinstance(self, _Linker)
return type(self)(maker=self.maker).accept(fgraph, no_recycling, profile)
self.fgraph = fgraph
self.no_recycling: list = no_recycling
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def thunk():
# Nothing should be in storage map after evaluating
# each the thunk (specifically the last one)
for r, s in storage_map.items():
assert type(s) is list
assert isinstance(s, list)
assert s[0] is None

# store our output variables to their respective storage lists
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,8 @@ def make_vm(
compute_map[vars_idx_inv[i]] for i in range(len(vars_idx_inv))
]
if nodes:
assert type(storage_map_list[0]) is list
assert type(compute_map_list[0]) is list
assert isinstance(storage_map_list[0], list)
assert isinstance(compute_map_list[0], list)

# Needed for allow_gc=True, profiling and storage_map reuse
dependency_map = self.compute_gc_dependencies(storage_map)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def p(node, inputs, outputs):

if hasattr(self.fn.maker, "profile"):
profile = self.fn.maker.profile
if type(profile) is not bool and profile:
if not isinstance(profile, bool) and profile:
profile.vm_call_time += t_fn
profile.callcount += 1
profile.nbsteps += n_steps
Expand Down
2 changes: 1 addition & 1 deletion tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.link.jax.dispatch import jax_funcify # noqa: E402


try:
Expand Down
20 changes: 10 additions & 10 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


jax = pytest.importorskip("jax")
import jax.errors

import pytensor
import pytensor.tensor.basic as at
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import TestAlloc
import jax.errors # noqa: E402

import pytensor # noqa: E402
import pytensor.tensor.basic as at # noqa: E402
from pytensor.configdefaults import config # noqa: E402
from pytensor.graph.fg import FunctionGraph # noqa: E402
from pytensor.graph.op import get_test_value # noqa: E402
from pytensor.tensor.type import iscalar, matrix, scalar, vector # noqa: E402
from tests.link.jax.test_basic import compare_jax_and_py # noqa: E402
from tests.tensor.test_basic import TestAlloc # noqa: E402


def test_jax_Alloc():
Expand Down
50 changes: 25 additions & 25 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@

numba = pytest.importorskip("numba")

import pytensor.scalar as aes
import pytensor.scalar.math as aesm
import pytensor.tensor as at
import pytensor.tensor.math as aem
from pytensor import config, shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.ops import ViewOp
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
import pytensor.scalar as aes # noqa: E402
import pytensor.scalar.math as aesm # noqa: E402
import pytensor.tensor as at # noqa: E402
import pytensor.tensor.math as aem # noqa: E402
from pytensor import config, shared # noqa: E402
from pytensor.compile.builders import OpFromGraph # noqa: E402
from pytensor.compile.function import function # noqa: E402
from pytensor.compile.mode import Mode # noqa: E402
from pytensor.compile.ops import ViewOp # noqa: E402
from pytensor.compile.sharedvalue import SharedVariable # noqa: E402
from pytensor.graph.basic import Apply, Constant # noqa: E402
from pytensor.graph.fg import FunctionGraph # noqa: E402
from pytensor.graph.op import Op, get_test_value # noqa: E402
from pytensor.graph.rewriting.db import RewriteDatabaseQuery # noqa: E402
from pytensor.graph.type import Type # noqa: E402
from pytensor.ifelse import ifelse # noqa: E402
from pytensor.link.numba.dispatch import basic as numba_basic # noqa: E402
from pytensor.link.numba.dispatch import numba_typify # noqa: E402
from pytensor.link.numba.linker import NumbaLinker # noqa: E402
from pytensor.raise_op import assert_op # noqa: E402
from pytensor.scalar.basic import ScalarOp, as_scalar # noqa: E402
from pytensor.tensor import blas # noqa: E402
from pytensor.tensor import subtensor as at_subtensor # noqa: E402
from pytensor.tensor.elemwise import Elemwise # noqa: E402
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape # noqa: E402


if TYPE_CHECKING:
Expand Down
7 changes: 5 additions & 2 deletions tests/link/numba/test_cython_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
numba = pytest.importorskip("numba")


from numba.types import float32, float64, int32, int64
from numba.types import float32, float64, int32, int64 # noqa: E402

from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function
from pytensor.link.numba.dispatch.cython_support import ( # noqa: E402
Signature,
wrap_cython_function,
)


@pytest.mark.parametrize(
Expand Down
14 changes: 7 additions & 7 deletions tests/link/numba/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

pytest.importorskip("numba")

import pytensor.tensor as aet
from pytensor import config
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.numba.linker import NumbaLinker
from pytensor.tensor.math import Max
import pytensor.tensor as aet # noqa: E402
from pytensor import config # noqa: E402
from pytensor.compile.function import function # noqa: E402
from pytensor.compile.mode import Mode # noqa: E402
from pytensor.graph.rewriting.db import RewriteDatabaseQuery # noqa: E402
from pytensor.link.numba.linker import NumbaLinker # noqa: E402
from pytensor.tensor.math import Max # noqa: E402


opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
Expand Down
8 changes: 4 additions & 4 deletions tests/link/numba/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


# Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
from pytensor.sparse import Dot, SparseTensorType
from tests.link.numba.test_basic import compare_numba_and_py
import pytensor.link.numba.dispatch.sparse # noqa: E402,F401
from pytensor import config # noqa: E402
from pytensor.sparse import Dot, SparseTensorType # noqa: E402
from tests.link.numba.test_basic import compare_numba_and_py # noqa: E402


pytestmark = pytest.mark.filterwarnings("error")
Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


pytest.importorskip("numba")
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch import numba_funcify # noqa: E402


rng = np.random.default_rng(42849)
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def test_local_mul_switch_sink(self):
f = self.function_remove_nan(
[condition[0], x[0], c], [y], mode=self.mode
)
if type(condition[1]) is list:
if isinstance(condition[1], list):
for i in range(len(condition[1])):
res = f(condition[1][i], x[1], -1)
assert (
Expand Down Expand Up @@ -886,7 +886,7 @@ def test_local_div_switch_sink(self):
f = self.function_remove_nan(
[condition[0], x[0], c], [y], mode=self.mode
)
if type(condition[1]) is list:
if isinstance(condition[1], list):
for i in range(len(condition[1])):
res = f(condition[1][i], x[1], -1)
assert (
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2092,7 +2092,7 @@ def test_local_mul_switch_sink(self):
f = self.function_remove_nan(
[condition[0], x[0], c], [y], mode=self.mode
)
if type(condition[1]) is list:
if isinstance(condition[1], list):
for i in range(len(condition[1])):
res = f(condition[1][i], x[1], -1)
assert (
Expand Down Expand Up @@ -2132,7 +2132,7 @@ def test_local_div_switch_sink(self):
f = self.function_remove_nan(
[condition[0], x[0], c], [y], mode=self.mode
)
if type(condition[1]) is list:
if isinstance(condition[1], list):
for i in range(len(condition[1])):
res = f(condition[1][i], x[1], -1)
assert (
Expand Down
Loading