Skip to content

Commit

Permalink
lint: switch from Black to ruff-format (#2388)
Browse files Browse the repository at this point in the history
* lint: switch from Black to `ruff-format`
* manual fixes
* mypy

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Feb 15, 2024
1 parent cdce36b commit 4c999b8
Show file tree
Hide file tree
Showing 111 changed files with 509 additions and 470 deletions.
10 changes: 3 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ repos:
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1
hooks:
- id: black
name: Format code

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
hooks:
Expand Down Expand Up @@ -130,7 +124,9 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.2.0
hooks:
- id: ruff-format
args: ["--preview"]
- id: ruff
args: ["--fix"]
4 changes: 1 addition & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import sys

import lai_sphinx_theme
import torchmetrics
from lightning_utilities.docs import fetch_external_assets
from lightning_utilities.docs.formatting import _transform_changelog

import torchmetrics


_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/binary_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/binary_accuracy_multistep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/collection_binary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/collection_binary_together.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/multiclass_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 0 additions & 1 deletion docs/source/pyplots/tracker_binary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
16 changes: 7 additions & 9 deletions examples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,13 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
for sentence in sentences
]
output_dict["input_ids"] = torch.cat(
[torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences]
)
output_dict["attention_mask"] = torch.cat(
[
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]
).long()
output_dict["input_ids"] = torch.cat([
torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences
])
output_dict["attention_mask"] = torch.cat([
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]).long()

return output_dict

Expand Down
21 changes: 8 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ concurrency = "thread"
relative_files = true


[tool.black]
# https://github.com/psf/black
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"

[tool.docformatter]
recursive = true
# some docstring start with r"""
Expand Down Expand Up @@ -85,7 +80,7 @@ wil = "wil"
target-version = "py38"
line-length = 120
# Enable Pyflakes `E` and `F` codes by default.
select = [
lint.select = [
"E",
"W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
Expand All @@ -94,7 +89,7 @@ select = [
"N", # see: https://pypi.org/project/pep8-naming
"S", # see: https://pypi.org/project/flake8-bandit
]
extend-select = [
lint.extend-select = [
"A", # see: https://pypi.org/project/flake8-builtins
"B", # see: https://pypi.org/project/flake8-bugbear
"C4", # see: https://pypi.org/project/flake8-comprehensions
Expand All @@ -114,7 +109,7 @@ extend-select = [
"PERF", # see: https://pypi.org/project/perflint/
"PYI", # see: https://pypi.org/project/flake8-pyi/
]
ignore = [
lint.ignore = [
"E731", # Do not assign a lambda expression, use a def
"D100", # todo: Missing docstring in public module
"D104", # todo: Missing docstring in public package
Expand All @@ -136,22 +131,22 @@ exclude = [
"dist",
"docs",
]
ignore-init-module-imports = true
unfixable = ["F401"]
lint.ignore-init-module-imports = true
lint.unfixable = ["F401"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"setup.py" = ["ANN202", "ANN401"]
"src/**" = ["ANN401"]
"tests/**" = ["S101", "ANN001", "ANN201", "ANN202", "ANN401"]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
# Use Google-style docstrings.
convention = "google"

#[tool.ruff.pycodestyle]
#ignore-overlong-task-comments = true

[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10

Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
r"""Root package info."""

import logging as __logging
import os

Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class SignalNoiseRatio(Metric):
tensor(16.1805)
"""

full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ShortTimeObjectiveIntelligibility(Metric):
tensor(-0.0100)
"""

sum_stoi: Tensor
total: Tensor
full_state_update: bool = False
Expand Down
11 changes: 8 additions & 3 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class BinaryAccuracy(BinaryStatScores):
tensor([0.3333, 0.1667])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -244,6 +245,7 @@ class MulticlassAccuracy(MulticlassStatScores):
[0.0000, 0.3333, 0.5000]])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -396,6 +398,7 @@ class MultilabelAccuracy(MultilabelStatScores):
[0.0000, 0.0000, 0.5000]])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -499,9 +502,11 @@ def __new__( # type: ignore[misc]
"""Initialize task metric."""
task = ClassificationTask.from_str(task)

kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})

if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve):
tensor(0.5000)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -273,7 +274,10 @@ def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
state,
self.num_classes,
self.average, # type: ignore[arg-type]
self.thresholds,
)

def plot( # type: ignore[override]
Expand Down Expand Up @@ -396,6 +400,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):
tensor([0.6250, 0.5000, 0.8333])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
tensor(0.6667)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -271,7 +272,10 @@ def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
state,
self.num_classes,
self.average, # type: ignore[arg-type]
self.thresholds,
)

def plot( # type: ignore[override]
Expand Down Expand Up @@ -399,6 +403,7 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
tensor([0.7500, 0.6667, 0.9167])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BinaryCalibrationError(Metric):
tensor(0.3167)
"""

is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
Expand Down Expand Up @@ -249,6 +250,7 @@ class MulticlassCalibrationError(Metric):
tensor(0.2333)
"""

is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class labels.
tensor(0.5000)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -216,6 +217,7 @@ class labels.
tensor(0.6364)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class BinaryConfusionMatrix(Metric):
[1, 1]])
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
Expand Down Expand Up @@ -248,6 +249,7 @@ class MulticlassConfusionMatrix(Metric):
[0, 0, 1]])
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
Expand Down Expand Up @@ -387,6 +389,7 @@ class MultilabelConfusionMatrix(Metric):
[[0, 1], [0, 1]]])
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class Dice(Metric):
tensor(0.2500)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down
9 changes: 6 additions & 3 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class MulticlassExactMatch(Metric):
tensor([1., 0.])
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
Expand Down Expand Up @@ -405,9 +406,11 @@ def __new__(
) -> Metric:
"""Initialize task metric."""
task = ClassificationTaskNoBinary.from_str(task)
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
kwargs.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
})
if task == ClassificationTaskNoBinary.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
Expand Down
Loading

0 comments on commit 4c999b8

Please sign in to comment.