Skip to content

Commit

Permalink
Fix device and dtype for LPIPS functional metric (#2234)
Browse files Browse the repository at this point in the history
(cherry picked from commit a57dfae)
  • Loading branch information
SkafteNicki authored and Borda committed Nov 30, 2023
1 parent ab8af8b commit a1ab373
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222))


- Fix device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234))


## [1.2.0] - 2023-09-22

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,6 @@ def learned_perceptual_image_patch_similarity(
tensor(0.1008, grad_fn=<DivBackward0>)
"""
net = _NoTrainLpips(net=net_type)
net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype)
loss, total = _lpips_update(img1, img2, net, normalize)
return _lpips_compute(loss.sum(), total, reduction)
11 changes: 11 additions & 0 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from lpips import LPIPS as LPIPS_reference # noqa: N811
from torch import Tensor
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_GREATER_EQUAL_1_9

Expand Down Expand Up @@ -68,6 +69,16 @@ def test_lpips(self, net_type, ddp):
metric_args={"net_type": net_type},
)

def test_lpips_functional(self):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=_inputs.img1,
target=_inputs.img2,
metric_functional=learned_perceptual_image_patch_similarity,
reference_metric=partial(_compare_fn, net_type="alex"),
metric_args={"net_type": "alex"},
)

def test_lpips_differentiability(self):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
Expand Down

0 comments on commit a1ab373

Please sign in to comment.