Skip to content

Commit

Permalink
Merge branch 'master' into psnrb
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Apr 17, 2023
2 parents 36dfff0 + 7d34ce2 commit 8181e61
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 88 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allowed FID with `torch.float64` ([#1628](https://github.com/Lightning-AI/metrics/pull/1628))


- Changed FID matrix square root calculation from `scipy` to `torch` ([#1708](https://github.com/Lightning-AI/torchmetrics/pull/1708))

### Deprecated

- Deprecated domain metrics import from package root (
Expand Down
79 changes: 19 additions & 60 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE, _TORCH_GREATER_EQUAL_1_9
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

__doctest_skip__ = ["FrechetInceptionDistance.__init__"] if not _TORCH_GREATER_EQUAL_1_9 else []

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["FrechetInceptionDistance.plot"]
__doctest_skip__ += ["FrechetInceptionDistance.plot"]

if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3
Expand All @@ -43,9 +45,8 @@ class _FeatureExtractorInceptionV3(Module):

__doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"]


if _SCIPY_AVAILABLE:
import scipy
if not _TORCH_GREATER_EQUAL_1_9:
__doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"]


class NoTrainInceptionV3(_FeatureExtractorInceptionV3):
Expand Down Expand Up @@ -156,45 +157,7 @@ def forward(self, x: Tensor) -> Tensor:
return out[0].reshape(x.shape[0], -1)


class MatrixSquareRoot(Function):
"""Square root of a positive definite matrix.
All credit to `Square Root of a Positive Definite Matrix`_
"""

@staticmethod
def forward(ctx: Any, input_data: Tensor) -> Tensor:
"""Forward pass for the matrix square root."""
# TODO: update whenever pytorch gets an matrix square root function
# Issue: https://github.com/pytorch/pytorch/issues/9983
m = input_data.detach().cpu().numpy().astype(np.float_)
scipy_res, _ = scipy.linalg.sqrtm(m, disp=False)
sqrtm = torch.from_numpy(scipy_res.real).to(input_data)
ctx.save_for_backward(sqrtm)
return sqrtm

@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tensor:
"""Backward pass for matrix square root."""
if not ctx.needs_input_grad[0]:
return None
(sqrtm,) = ctx.saved_tensors
sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)
gm = grad_output.data.cpu().numpy().astype(np.float_)

# Given a positive semi-definite matrix X,
# since X = X^{1/2}X^{1/2}, we can compute the gradient of the
# matrix square root dX^{1/2} by solving the Sylvester equation:
# dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).
grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm)

return torch.from_numpy(grad_sqrtm).to(grad_output)


sqrtm = MatrixSquareRoot.apply


def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor, eps: float = 1e-6) -> Tensor:
def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor:
r"""Compute adjusted version of `Fid Score`_.
The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1)
Expand All @@ -205,29 +168,22 @@ def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor, eps:
sigma1: covariance matrix over activations calculated on predicted (x) samples
mu2: mean of activations calculated on target (y) samples
sigma2: covariance matrix over activations calculated on target (y) samples
eps: offset constant - used if sigma_1 @ sigma_2 matrix is singular
Returns:
Scalar value of the distance between sets.
"""
diff = mu1 - mu2

covmean = sqrtm(sigma1.mm(sigma2))
# Product might be almost singular
if not torch.isfinite(covmean).all():
rank_zero_info(f"FID calculation produces singular product; adding {eps} to diagonal of covariance estimates")
offset = torch.eye(sigma1.size(0), device=mu1.device, dtype=mu1.dtype) * eps
covmean = sqrtm((sigma1 + offset).mm(sigma2 + offset))
a = (mu1 - mu2).square().sum(dim=-1)
b = sigma1.trace() + sigma2.trace()
c = torch.linalg.eigvals(sigma1 @ sigma2).sqrt().real.sum(dim=-1)

tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean
return a + b - 2 * c


class FrechetInceptionDistance(Metric):
r"""Calculate Fréchet inception distance (FID_) which is used to access the quality of generated images.
.. math::
FID = |\mu - \mu_w| + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})
FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})
where :math:`\mathcal{N}(\mu, \Sigma)` is the multivariate normal distribution estimated from Inception v3
(`fid ref1`_) features calculated on real life images and :math:`\mathcal{N}(\mu_w, \Sigma_w)` is the
Expand All @@ -246,12 +202,10 @@ class FrechetInceptionDistance(Metric):
that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype`
method of the metric.
.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``
.. note:: using this metrics requires you to have torch 1.9 or higher installed
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity``
As input to ``forward`` and ``update`` the metric accepts the following input
Expand All @@ -278,6 +232,8 @@ class FrechetInceptionDistance(Metric):
Raises:
ValueError:
If torch version is lower than 1.9
ModuleNotFoundError:
If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed
ValueError:
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
Expand Down Expand Up @@ -322,6 +278,9 @@ def __init__(
) -> None:
super().__init__(**kwargs)

if not _TORCH_GREATER_EQUAL_1_9:
raise ValueError("FrechetInceptionDistance metric requires that PyTorch is version 1.9.0 or higher.")

if isinstance(feature, int):
num_features = feature
if not _TORCH_FIDELITY_AVAILABLE:
Expand Down
54 changes: 26 additions & 28 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,16 @@

import pytest
import torch
from scipy.linalg import sqrtm as scipy_sqrtm
from torch.nn import Module
from torch.utils.data import Dataset

from torchmetrics.image.fid import FrechetInceptionDistance, sqrtm
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCH_GREATER_EQUAL_1_9

torch.manual_seed(42)


@pytest.mark.parametrize("matrix_size", [2, 10, 100, 500])
def test_matrix_sqrt(matrix_size):
"""Test that metrix sqrt function works as expected."""

def generate_cov(n):
data = torch.randn(2 * n, n)
return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0))

cov1 = generate_cov(matrix_size)
cov2 = generate_cov(matrix_size)

scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real
tm_res = sqrtm(cov1 @ cov2)
assert torch.allclose(torch.tensor(scipy_res).float().trace(), tm_res.trace())


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_no_train():
"""Assert that metric never leaves evaluation mode."""
Expand All @@ -60,6 +44,7 @@ def forward(self, x):
assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen"


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_fid_pickle():
"""Assert that we can initialize the metric and pickle it."""
Expand All @@ -73,21 +58,28 @@ def test_fid_pickle():

def test_fid_raises_errors_and_warnings():
"""Test that expected warnings and errors are raised."""
if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match="Integer input to argument `feature` must be one of .*"):
_ = FrechetInceptionDistance(feature=2)
if _TORCH_GREATER_EQUAL_1_9:
if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match="Integer input to argument `feature` must be one of .*"):
_ = FrechetInceptionDistance(feature=2)
else:
with pytest.raises(
ModuleNotFoundError,
match="FID metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`.",
):
_ = FrechetInceptionDistance()

with pytest.raises(TypeError, match="Got unknown input to argument `feature`"):
_ = FrechetInceptionDistance(feature=[1, 2])
else:
with pytest.raises(
ModuleNotFoundError,
match="FID metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`.",
ValueError, match="FrechetInceptionDistance metric requires that PyTorch is version 1.9.0 or higher."
):
_ = FrechetInceptionDistance()

with pytest.raises(TypeError, match="Got unknown input to argument `feature`"):
_ = FrechetInceptionDistance(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_fid_same_input(feature):
Expand Down Expand Up @@ -119,6 +111,7 @@ def __len__(self):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("equal_size", [False, True])
def test_compare_fid(tmpdir, equal_size, feature=768):
Expand Down Expand Up @@ -156,6 +149,7 @@ def test_compare_fid(tmpdir, equal_size, feature=768):
assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid["frechet_inception_distance"]]), atol=1e-3)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
@pytest.mark.parametrize("reset_real_features", [True, False])
def test_reset_real_features_arg(reset_real_features):
"""Test that `reset_real_features` argument works as expected."""
Expand Down Expand Up @@ -185,6 +179,7 @@ def test_reset_real_features_arg(reset_real_features):
assert metric.real_features_cov_sum.shape == torch.Size([64, 64])


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_normalize_arg_true():
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
Expand All @@ -193,6 +188,7 @@ def test_normalize_arg_true():
metric.update(img, real=True)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_normalize_arg_false():
"""Test that normalize argument works as expected."""
img = torch.rand(2, 3, 299, 299)
Expand All @@ -201,6 +197,7 @@ def test_normalize_arg_false():
metric.update(img, real=True)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_not_enough_samples():
"""Test that an error is raised if not enough samples were provided."""
img = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
Expand All @@ -213,6 +210,7 @@ def test_not_enough_samples():
metric.compute()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9")
def test_dtype_transfer_to_submodule():
"""Test that change in dtype also changes the default inception net."""
imgs = torch.randn(1, 3, 256, 256)
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
WordInfoLost,
WordInfoPreserved,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9
from torchmetrics.wrappers import BootStrapper, ClasswiseWrapper, MetricTracker, MinMaxMetric, MultioutputWrapper

_rand_input = lambda: torch.rand(10)
Expand Down Expand Up @@ -614,6 +615,7 @@ def test_plot_methods(metric_class: object, preds: Callable, target: Callable, n
lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8),
False,
id="frechet inception distance",
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9"),
),
pytest.param(
partial(InceptionScore, feature=64),
Expand Down

0 comments on commit 8181e61

Please sign in to comment.