diff --git a/pyproject.toml b/pyproject.toml index 9114383c58a..067d1a0007c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,6 @@ ignore = [ "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo "B905", # todo: `zip()` without an explicit `strict=` parameter "PYI024", # todo: Use `typing.NamedTuple` instead of `collections.namedtuple` - "PYI041", # todo: Use `float` instead of `int | float`` ] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 00162753586..0b46b25d32a 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -42,7 +42,7 @@ def _spectral_distortion_index( def _error_relative_global_dimensionless_synthesis( preds: Tensor, target: Tensor, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Wrapper for deprecated import. diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index 5df44018506..d237115fa2e 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Tuple import torch from torch import Tensor @@ -46,7 +46,7 @@ def _ergas_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: def _ergas_compute( preds: Tensor, target: Tensor, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Erreur Relative Globale Adimensionnelle de Synthèse. @@ -86,7 +86,7 @@ def _ergas_compute( def error_relative_global_dimensionless_synthesis( preds: Tensor, target: Tensor, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Erreur Relative Globale Adimensionnelle de Synthèse. diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 4f5a6ccbb5f..078760dedde 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -178,7 +178,7 @@ def _clip_iqa_update( images: Tensor, model: _CLIPModel, processor: _CLIPProcessor, - data_range: Union[int, float], + data_range: float, device: Union[str, torch.device], ) -> Tensor: images = images / float(data_range) @@ -221,7 +221,7 @@ def clip_image_quality_assessment( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ] = "clip_iqa", - data_range: Union[int, float] = 1.0, + data_range: float = 1.0, prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), ) -> Union[Tensor, Dict[str, Tensor]]: """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. diff --git a/src/torchmetrics/functional/nominal/cramers.py b/src/torchmetrics/functional/nominal/cramers.py index 46d4058010c..75a606f6c19 100644 --- a/src/torchmetrics/functional/nominal/cramers.py +++ b/src/torchmetrics/functional/nominal/cramers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -34,7 +34,7 @@ def _cramers_v_update( target: Tensor, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: """Compute the bins to update the confusion matrix with for Cramer's V calculation. @@ -90,7 +90,7 @@ def cramers_v( target: Tensor, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series. @@ -142,7 +142,7 @@ def cramers_v_matrix( matrix: Tensor, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Cramer's V`_ statistic between a set of multiple variables. diff --git a/src/torchmetrics/functional/nominal/pearson.py b/src/torchmetrics/functional/nominal/pearson.py index b519b105a7a..bd25c701fe4 100644 --- a/src/torchmetrics/functional/nominal/pearson.py +++ b/src/torchmetrics/functional/nominal/pearson.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -32,7 +32,7 @@ def _pearsons_contingency_coefficient_update( target: Tensor, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: """Compute the bins to update the confusion matrix with for Pearson's Contingency Coefficient calculation. @@ -76,7 +76,7 @@ def pearsons_contingency_coefficient( preds: Tensor, target: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Pearson's Contingency Coefficient`_ for measuring the association between two categorical data series. @@ -131,7 +131,7 @@ def pearsons_contingency_coefficient( def pearsons_contingency_coefficient_matrix( matrix: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Pearson's Contingency Coefficient`_ statistic between a set of multiple variables. diff --git a/src/torchmetrics/functional/nominal/theils_u.py b/src/torchmetrics/functional/nominal/theils_u.py index a010ee48c1d..8bdaf38aa8a 100644 --- a/src/torchmetrics/functional/nominal/theils_u.py +++ b/src/torchmetrics/functional/nominal/theils_u.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -57,7 +57,7 @@ def _theils_u_update( target: Tensor, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: """Compute the bins to update the confusion matrix with for Theil's U calculation. @@ -109,7 +109,7 @@ def theils_u( preds: Tensor, target: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Theils Uncertainty coefficient`_ statistic measuring the association between two nominal data series. @@ -154,7 +154,7 @@ def theils_u( def theils_u_matrix( matrix: Tensor, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Theil's U`_ statistic between a set of multiple variables. diff --git a/src/torchmetrics/functional/nominal/tschuprows.py b/src/torchmetrics/functional/nominal/tschuprows.py index 1ee8756d243..2ea20d57f19 100644 --- a/src/torchmetrics/functional/nominal/tschuprows.py +++ b/src/torchmetrics/functional/nominal/tschuprows.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -34,7 +34,7 @@ def _tschuprows_t_update( target: Tensor, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: """Compute the bins to update the confusion matrix with for Tschuprow's T calculation. @@ -92,7 +92,7 @@ def tschuprows_t( target: Tensor, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Tschuprow's T`_ statistic measuring the association between two categorical (nominal) data series. @@ -148,7 +148,7 @@ def tschuprows_t_matrix( matrix: Tensor, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, ) -> Tensor: r"""Compute `Tschuprow's T`_ statistic between a set of multiple variables. diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py index c1baefb977f..258209326da 100644 --- a/src/torchmetrics/functional/nominal/utils.py +++ b/src/torchmetrics/functional/nominal/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import Tensor @@ -20,12 +20,12 @@ from torchmetrics.utilities.prints import rank_zero_warn -def _nominal_input_validation(nan_strategy: str, nan_replace_value: Optional[Union[int, float]]) -> None: +def _nominal_input_validation(nan_strategy: str, nan_replace_value: Optional[float]) -> None: if nan_strategy not in ["replace", "drop"]: raise ValueError( f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}" ) - if nan_strategy == "replace" and not isinstance(nan_replace_value, (int, float)): + if nan_strategy == "replace" and not isinstance(nan_replace_value, (float, int)): raise ValueError( "Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, " f"but got {nan_replace_value}" diff --git a/src/torchmetrics/functional/pairwise/minkowski.py b/src/torchmetrics/functional/pairwise/minkowski.py index 857e0d23943..298cedd1486 100644 --- a/src/torchmetrics/functional/pairwise/minkowski.py +++ b/src/torchmetrics/functional/pairwise/minkowski.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional import torch from torch import Tensor @@ -22,7 +22,7 @@ def _pairwise_minkowski_distance_update( - x: Tensor, y: Optional[Tensor] = None, exponent: Union[int, float] = 2, zero_diagonal: Optional[bool] = None + x: Tensor, y: Optional[Tensor] = None, exponent: float = 2, zero_diagonal: Optional[bool] = None ) -> Tensor: """Calculate the pairwise minkowski distance matrix. @@ -49,7 +49,7 @@ def _pairwise_minkowski_distance_update( def pairwise_minkowski_distance( x: Tensor, y: Optional[Tensor] = None, - exponent: Union[int, float] = 2, + exponent: float = 2, reduction: Literal["mean", "sum", "none", None] = None, zero_diagonal: Optional[bool] = None, ) -> Tensor: diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 45afe6dc178..7aa7a63743d 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -28,7 +28,7 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionles def __init__( self, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 77695b976de..2e22c34846a 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -75,7 +75,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): def __init__( self, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2c4493e9877..f574320dc44 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -922,81 +922,67 @@ def __hash__(self) -> int: return hash(tuple(hash_vals)) - def __add__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __add__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the addition operator.""" return CompositionalMetric(torch.add, self, other) - def __and__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __and__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical and operator.""" return CompositionalMetric(torch.bitwise_and, self, other) - def __eq__( # type: ignore[override] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __eq__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[override] """Construct compositional metric using the equal operator.""" return CompositionalMetric(torch.eq, self, other) - def __floordiv__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __floordiv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the floor division operator.""" return CompositionalMetric(torch.floor_divide, self, other) - def __ge__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __ge__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the greater than or equal operator.""" return CompositionalMetric(torch.ge, self, other) - def __gt__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __gt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the greater than operator.""" return CompositionalMetric(torch.gt, self, other) - def __le__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __le__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the less than or equal operator.""" return CompositionalMetric(torch.le, self, other) - def __lt__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __lt__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the less than operator.""" return CompositionalMetric(torch.lt, self, other) - def __matmul__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __matmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the matrix multiplication operator.""" return CompositionalMetric(torch.matmul, self, other) - def __mod__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __mod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the remainder operator.""" return CompositionalMetric(torch.fmod, self, other) - def __mul__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __mul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the multiplication operator.""" return CompositionalMetric(torch.mul, self, other) - def __ne__( # type: ignore[override] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __ne__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[override] """Construct compositional metric using the not equal operator.""" return CompositionalMetric(torch.ne, self, other) - def __or__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __or__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical or operator.""" return CompositionalMetric(torch.bitwise_or, self, other) - def __pow__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __pow__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the exponential/power operator.""" return CompositionalMetric(torch.pow, self, other) - def __radd__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __radd__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the addition operator.""" return CompositionalMetric(torch.add, other, self) - def __rand__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __rand__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical and operator.""" # swap them since bitwise_and only supports that way and it's commutative return CompositionalMetric(torch.bitwise_and, self, other) @@ -1005,55 +991,47 @@ def __rfloordiv__(self, other: "CompositionalMetric") -> "Metric": """Construct compositional metric using the floor division operator.""" return CompositionalMetric(torch.floor_divide, other, self) - def __rmatmul__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __rmatmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the matrix multiplication operator.""" return CompositionalMetric(torch.matmul, other, self) - def __rmod__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __rmod__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the remainder operator.""" return CompositionalMetric(torch.fmod, other, self) - def __rmul__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __rmul__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the multiplication operator.""" return CompositionalMetric(torch.mul, other, self) - def __ror__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __ror__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical or operator.""" return CompositionalMetric(torch.bitwise_or, other, self) - def __rpow__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __rpow__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the exponential/power operator.""" return CompositionalMetric(torch.pow, other, self) - def __rsub__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __rsub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the subtraction operator.""" return CompositionalMetric(torch.sub, other, self) - def __rtruediv__( # type: ignore[misc] - self, other: Union["Metric", int, builtins.float, Tensor] - ) -> "CompositionalMetric": + def __rtruediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": # type: ignore[misc] """Construct compositional metric using the true divide operator.""" return CompositionalMetric(torch.true_divide, other, self) - def __rxor__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __rxor__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical xor operator.""" return CompositionalMetric(torch.bitwise_xor, other, self) - def __sub__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __sub__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the subtraction operator.""" return CompositionalMetric(torch.sub, self, other) - def __truediv__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __truediv__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the true divide operator.""" return CompositionalMetric(torch.true_divide, self, other) - def __xor__(self, other: Union["Metric", int, builtins.float, Tensor]) -> "CompositionalMetric": + def __xor__(self, other: Union["Metric", builtins.float, Tensor]) -> "CompositionalMetric": """Construct compositional metric using the logical xor operator.""" return CompositionalMetric(torch.bitwise_xor, self, other) @@ -1100,8 +1078,8 @@ class CompositionalMetric(Metric): def __init__( self, operator: Callable, - metric_a: Union[Metric, int, float, Tensor], - metric_b: Union[Metric, int, float, Tensor, None], + metric_a: Union[Metric, float, Tensor], + metric_b: Union[Metric, float, Tensor, None], ) -> None: """Class for creating compositions of metrics. diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index dff9e05f2e2..de48df1bbae 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -178,7 +178,7 @@ def __init__( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ] = "clip_iqa", - data_range: Union[int, float] = 1.0, + data_range: float = 1.0, prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), **kwargs: Any ) -> None: diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py index 1481a28a73d..2d780e7e379 100644 --- a/src/torchmetrics/nominal/cramers.py +++ b/src/torchmetrics/nominal/cramers.py @@ -91,7 +91,7 @@ def __init__( num_classes: int, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/nominal/pearson.py b/src/torchmetrics/nominal/pearson.py index 143e9d20411..2fc88c8e851 100644 --- a/src/torchmetrics/nominal/pearson.py +++ b/src/torchmetrics/nominal/pearson.py @@ -94,7 +94,7 @@ def __init__( self, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/nominal/theils_u.py b/src/torchmetrics/nominal/theils_u.py index f21a8efa385..f82c7658b1f 100644 --- a/src/torchmetrics/nominal/theils_u.py +++ b/src/torchmetrics/nominal/theils_u.py @@ -80,7 +80,7 @@ def __init__( self, num_classes: int, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/nominal/tschuprows.py b/src/torchmetrics/nominal/tschuprows.py index 540fd016fd6..a14832b4121 100644 --- a/src/torchmetrics/nominal/tschuprows.py +++ b/src/torchmetrics/nominal/tschuprows.py @@ -91,7 +91,7 @@ def __init__( num_classes: int, bias_correction: bool = True, nan_strategy: Literal["replace", "drop"] = "replace", - nan_replace_value: Optional[Union[int, float]] = 0.0, + nan_replace_value: Optional[float] = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index c2bf40a0fbf..09684c55919 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -91,9 +91,7 @@ def compute(self) -> Dict[str, Tensor]: """ val = self._base_metric.compute() if not self._is_suitable_val(val): - raise RuntimeError( - f"Returned value from base metric should be a scalar (int, float or tensor of size 1, but got {val}" - ) + raise RuntimeError(f"Returned value from base metric should be a float or scalar tensor, but got {val}.") self.max_val = val if self.max_val.to(val.device) < val else self.max_val.to(val.device) self.min_val = val if self.min_val.to(val.device) > val else self.min_val.to(val.device) return {"raw": val, "max": self.max_val, "min": self.min_val} @@ -108,7 +106,7 @@ def reset(self) -> None: self._base_metric.reset() @staticmethod - def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: + def _is_suitable_val(val: Union[float, Tensor]) -> bool: """Check whether min/max is a scalar value.""" if isinstance(val, (int, float)): return True diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index 313967cbf85..e87b4b63a0f 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import namedtuple from functools import partial -from typing import Union import pytest import torch @@ -44,7 +43,7 @@ def _baseline_ergas( preds: Tensor, target: Tensor, - ratio: Union[int, float] = 4, + ratio: float = 4, reduction: str = "elementwise_mean", ) -> Tensor: """Baseline implementation of Erreur Relative Globale Adimensionnelle de Synthèse.""" diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index ea3af7da2f7..90df27a3992 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -108,5 +108,5 @@ def test_no_scalar_compute() -> None: """Tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" min_max_nsm = MinMaxMetric(BinaryConfusionMatrix()) - with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"): + with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a float.*"): min_max_nsm.compute()