diff --git a/CHANGELOG.md b/CHANGELOG.md index 7eee849cf93..2600ee1526e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - ### Changed -- +- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089)) ### Removed diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index f699673742a..47d4839c5e2 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, List, Optional, Sequence, Union -from torch import Tensor +from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.image.sam import _sam_compute, _sam_update @@ -75,33 +75,50 @@ class SpectralAngleMapper(Metric): preds: List[Tensor] target: List[Tensor] + sum_sam: Tensor + numel: Tensor def __init__( self, - reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", + reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", **kwargs: Any, ) -> None: super().__init__(**kwargs) - rank_zero_warn( - "Metric `SpectralAngleMapper` will save all targets and predictions in the buffer." - " For large datasets, this may lead to a large memory footprint." - ) - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") + if reduction not in ("elementwise_mean", "sum", "none", None): + raise ValueError( + f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None." + ) + if reduction == "none" or reduction is None: + rank_zero_warn( + "Metric `SpectralAngleMapper` will save all targets and predictions in the buffer when using" + "`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint." + ) + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.add_state("sum_sam", tensor(0.0), dist_reduce_fx="sum") + self.add_state("numel", tensor(0), dist_reduce_fx="sum") self.reduction = reduction def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" preds, target = _sam_update(preds, target) - self.preds.append(preds) - self.target.append(target) + if self.reduction == "none" or self.reduction is None: + self.preds.append(preds) + self.target.append(target) + else: + sam_score = _sam_compute(preds, target, reduction="sum") + self.sum_sam += sam_score + p_shape = preds.shape + self.numel += p_shape[0] * p_shape[2] * p_shape[3] def compute(self) -> Tensor: """Compute spectra over state.""" - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - return _sam_compute(preds, target, self.reduction) + if self.reduction == "none" or self.reduction is None: + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + return _sam_compute(preds, target, self.reduction) + return self.sum_sam / self.numel if self.reduction == "elementwise_mean" else self.sum_sam def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index 92c15f23e92..3fea5e8986f 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, List, Optional, Sequence, Union -from torch import Tensor +from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update @@ -73,6 +73,8 @@ class UniversalImageQualityIndex(Metric): preds: List[Tensor] target: List[Tensor] + sum_uqi: Tensor + numel: Tensor def __init__( self, @@ -82,14 +84,20 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - rank_zero_warn( - "Metric `UniversalImageQualityIndex` will save all targets and" - " predictions in buffer. For large datasets this may lead" - " to large memory footprint." - ) - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") + if reduction not in ("elementwise_mean", "sum", "none", None): + raise ValueError( + f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None." + ) + if reduction is None or reduction == "none": + rank_zero_warn( + "Metric `UniversalImageQualityIndex` will save all targets and predictions in the buffer when using" + "`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint." + ) + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.add_state("sum_uqi", tensor(0.0), dist_reduce_fx="sum") + self.add_state("numel", tensor(0), dist_reduce_fx="sum") self.kernel_size = kernel_size self.sigma = sigma self.reduction = reduction @@ -97,14 +105,22 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" preds, target = _uqi_update(preds, target) - self.preds.append(preds) - self.target.append(target) + if self.reduction is None or self.reduction == "none": + self.preds.append(preds) + self.target.append(target) + else: + uqi_score = _uqi_compute(preds, target, self.kernel_size, self.sigma, reduction="sum") + self.sum_uqi += uqi_score + ps = preds.shape + self.numel += ps[0] * ps[1] * (ps[2] - self.kernel_size[0] + 1) * (ps[3] - self.kernel_size[1] + 1) def compute(self) -> Tensor: """Compute explained variance over state.""" - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction) + if self.reduction == "none" or self.reduction is None: + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction) + return self.sum_uqi / self.numel if self.reduction == "elementwise_mean" else self.sum_uqi def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None