From 3a20bc399b7864ce436376f7a6178348fa347931 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 6 Oct 2023 16:34:37 +0200 Subject: [PATCH] Improve numeric stability of `LPIPS` (#2144) * improve stability * changelog (cherry picked from commit 1d102776be30175ecd466bfa116f0466f5cdf2ad) --- CHANGELOG.md | 2 +- src/torchmetrics/functional/image/lpips.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2600ee1526e..ea6cb7c67ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed numerical stability bug in `LearnedPerceptualImagePatchSimilarity` metric ([#2144](https://github.com/Lightning-AI/torchmetrics/pull/2144)) ## [1.2.0] - 2023-09-22 diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index d6480b4eea3..e6fa726f2ec 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -191,10 +191,10 @@ def _upsample(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor: return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) -def _normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor: - """Normalize tensors.""" - norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) - return in_feat / (norm_factor + eps) +def _normalize_tensor(in_feat: Tensor, eps: float = 1e-8) -> Tensor: + """Normalize input tensor.""" + norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / norm_factor def _resize_tensor(x: Tensor, size: int = 64) -> Tensor: