Skip to content

Commit

Permalink
typing 1/n adjustment for base class (#1542)
Browse files Browse the repository at this point in the history
* adjust typing for base class
* Apply suggestions from code review
* fix

---------

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 22, 2023
1 parent 6c3e1a5 commit 2322414
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 80 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ module = [
"torchmetrics.image.ssim",
"torchmetrics.image.tv",
"torchmetrics.image.uqi",
"torchmetrics.metric",
"torchmetrics.regression.kl_divergence",
"torchmetrics.regression.log_mse",
"torchmetrics.regression.mae",
Expand Down
163 changes: 84 additions & 79 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def __init__(

# initialize
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update)
self.compute: Callable = self._wrap_compute(self.compute)
self.update: Callable = self._wrap_update(self.update) # type: ignore[assignment]
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[assignment]
self._computed = None
self._forward_cache = None
self._update_count = 0
Expand Down Expand Up @@ -359,9 +359,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
reduced = torch.stack([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, list):
reduced = _flatten([global_state, local_state])
else:
elif reduce_fn and callable(reduce_fn):
reduced = reduce_fn(torch.stack([global_state, local_state]))

else:
raise TypeError(f"Unsupported reduce_fn: {reduce_fn}")
setattr(self, attr, reduced)

def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
Expand Down Expand Up @@ -597,8 +598,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
# manually restore update and compute functions for pickling
self.__dict__.update(state)
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update)
self.compute: Callable = self._wrap_compute(self.compute)
self.update: Callable = self._wrap_update(self.update) # type: ignore[assignment]
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore[assignment]

def __setattr__(self, name: str, value: Any) -> None:
"""Overwrite default method to prevent specific attributes from being set by user."""
Expand Down Expand Up @@ -685,12 +686,12 @@ def persistent(self, mode: bool = False) -> None:
for key in self._persistent:
self._persistent[key] = mode

def state_dict(
def state_dict( # type: ignore[override] # todo
self,
destination: Dict[str, Any] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Dict[str, Any]:
"""Get the current state of metric as an dictionary.
Args:
Expand All @@ -700,7 +701,9 @@ def state_dict(
keep_vars: by default the :class:`~torch.Tensor`s returned in the state dict are detached from autograd.
If set to ``True``, detaching will not be performed.
"""
destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
destination: Dict[str, Union[torch.Tensor, List, Any]] = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars # type: ignore[arg-type]
)
# Register metric states to be part of the state_dict
for key in self._defaults:
if not self._persistent[key]:
Expand Down Expand Up @@ -778,147 +781,149 @@ def __hash__(self) -> int:

return hash(tuple(hash_vals))

def __add__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the addition operator."""
def __add__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the addition operator."""
return CompositionalMetric(torch.add, self, other)

def __and__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical and operator."""
def __and__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical and operator."""
return CompositionalMetric(torch.bitwise_and, self, other)

def __eq__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the equal operator."""
def __eq__(self, other: "Metric") -> "CompositionalMetric": # type: ignore[override]
"""Construct compositional metric using the equal operator."""
return CompositionalMetric(torch.eq, self, other)

def __floordiv__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the floor division operator."""
def __floordiv__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the floor division operator."""
return CompositionalMetric(torch.floor_divide, self, other)

def __ge__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the greater than or equal operator."""
def __ge__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the greater than or equal operator."""
return CompositionalMetric(torch.ge, self, other)

def __gt__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the greater than operator."""
def __gt__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the greater than operator."""
return CompositionalMetric(torch.gt, self, other)

def __le__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the less than or equal operator."""
def __le__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the less than or equal operator."""
return CompositionalMetric(torch.le, self, other)

def __lt__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the less than operator."""
def __lt__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the less than operator."""
return CompositionalMetric(torch.lt, self, other)

def __matmul__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the matrix multiplication operator."""
def __matmul__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the matrix multiplication operator."""
return CompositionalMetric(torch.matmul, self, other)

def __mod__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the remainder operator."""
def __mod__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the remainder operator."""
return CompositionalMetric(torch.fmod, self, other)

def __mul__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the multiplication operator."""
def __mul__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the multiplication operator."""
return CompositionalMetric(torch.mul, self, other)

# Fixme: this shall return bool instead of Metric
def __ne__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the not equal operator."""
def __ne__(self, other: "Metric") -> "CompositionalMetric": # type: ignore[override]
"""Construct compositional metric using the not equal operator."""
return CompositionalMetric(torch.ne, self, other)

def __or__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical or operator."""
def __or__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical or operator."""
return CompositionalMetric(torch.bitwise_or, self, other)

def __pow__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the exponential/power operator."""
def __pow__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the exponential/power operator."""
return CompositionalMetric(torch.pow, self, other)

def __radd__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the addition operator."""
def __radd__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the addition operator."""
return CompositionalMetric(torch.add, other, self)

def __rand__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical and operator."""
def __rand__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical and operator."""
# swap them since bitwise_and only supports that way and it's commutative
return CompositionalMetric(torch.bitwise_and, self, other)

def __rfloordiv__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the floor division operator."""
def __rfloordiv__(self, other: "CompositionalMetric") -> "Metric":
"""Construct compositional metric using the floor division operator."""
return CompositionalMetric(torch.floor_divide, other, self)

def __rmatmul__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the matrix multiplication operator."""
def __rmatmul__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the matrix multiplication operator."""
return CompositionalMetric(torch.matmul, other, self)

def __rmod__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the remainder operator."""
def __rmod__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the remainder operator."""
return CompositionalMetric(torch.fmod, other, self)

def __rmul__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the multiplication operator."""
def __rmul__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the multiplication operator."""
return CompositionalMetric(torch.mul, other, self)

def __ror__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical or operator."""
def __ror__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical or operator."""
return CompositionalMetric(torch.bitwise_or, other, self)

def __rpow__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the exponential/power operator."""
def __rpow__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the exponential/power operator."""
return CompositionalMetric(torch.pow, other, self)

def __rsub__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the subtraction operator."""
def __rsub__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the subtraction operator."""
return CompositionalMetric(torch.sub, other, self)

def __rtruediv__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the true divide operator."""
def __rtruediv__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the true divide operator."""
return CompositionalMetric(torch.true_divide, other, self)

def __rxor__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical xor operator."""
def __rxor__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical xor operator."""
return CompositionalMetric(torch.bitwise_xor, other, self)

def __sub__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the subtraction operator."""
def __sub__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the subtraction operator."""
return CompositionalMetric(torch.sub, self, other)

def __truediv__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the true divide operator."""
def __truediv__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the true divide operator."""
return CompositionalMetric(torch.true_divide, self, other)

def __xor__(self, other: "Metric") -> "Metric":
"""Construct conpositional metric using the logical xor operator."""
def __xor__(self, other: "Metric") -> "CompositionalMetric":
"""Construct compositional metric using the logical xor operator."""
return CompositionalMetric(torch.bitwise_xor, self, other)

def __abs__(self) -> "Metric":
"""Construct conpositional metric using the absolute operator."""
def __abs__(self) -> "CompositionalMetric":
"""Construct compositional metric using the absolute operator."""
return CompositionalMetric(torch.abs, self, None)

def __inv__(self) -> "Metric":
"""Construct conpositional metric using the not operator."""
def __inv__(self) -> "CompositionalMetric":
"""Construct compositional metric using the not operator."""
return CompositionalMetric(torch.bitwise_not, self, None)

def __invert__(self) -> "Metric":
"""Construct conpositional metric using the not operator."""
def __invert__(self) -> "CompositionalMetric":
"""Construct compositional metric using the not operator."""
return self.__inv__()

def __neg__(self) -> "Metric":
"""Construct conpositional metric using absolute negative operator."""
def __neg__(self) -> "CompositionalMetric":
"""Construct compositional metric using absolute negative operator."""
return CompositionalMetric(_neg, self, None)

def __pos__(self) -> "Metric":
"""Construct conpositional metric using absolute operator."""
def __pos__(self) -> "CompositionalMetric":
"""Construct compositional metric using absolute operator."""
return CompositionalMetric(torch.abs, self, None)

def __getitem__(self, idx: int) -> "Metric":
"""Construct conpositional metric using the get item operator."""
def __getitem__(self, idx: int) -> "CompositionalMetric":
"""Construct compositional metric using the get item operator."""
return CompositionalMetric(lambda x: x[idx], self, None)

def __getnewargs__(self) -> Tuple:
"""Needede method for construction of new metrics __new__ method."""
return (Metric.__str__(self),)
"""Needed method for construction of new metrics __new__ method."""
return tuple(
Metric.__str__(self),
)

__iter__ = None

Expand Down

0 comments on commit 2322414

Please sign in to comment.