From 04c371fbbe6ed264e343fbdd2733f5e0d00070a8 Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:57:45 +0300 Subject: [PATCH] Add batched calculation option to `energy_score_empirical` in order to reduce memory consumption (#3402) * Add batched calculation option to energy_score_empirical in order to reduce memory consumption. * Replace native Python sum with torch stack(...).sum(). --------- Co-authored-by: Ben Zickel --- pyro/ops/stats.py | 50 ++++++++++++++++++++++++++++++++++++----- tests/ops/test_stats.py | 23 +++++++++++++++++++ 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index a0a546059a..73f6054e5a 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -3,7 +3,7 @@ import math import numbers -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch.fft import irfft, rfft @@ -510,7 +510,9 @@ def crps_empirical(pred, truth): return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2 -def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor: +def energy_score_empirical( + pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None +) -> torch.Tensor: r""" Computes negative Energy Score ES* (see equation 22 in [1]) between a set of multivariate samples ``pred`` and a true data vector ``truth``. Running time @@ -538,6 +540,8 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten The leftmost dim is that of the multivariate sample. :param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except for the second leftmost dim which can have any value or be omitted. + :param int pred_batch_size: If specified the predictions will be batched before calculation + according to the specified batch size in order to reduce memory consumption. :return: A tensor of shape ``truth.shape``. :rtype: torch.Tensor """ @@ -552,10 +556,44 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten "Actual shapes: {} versus {}".format(pred.shape, truth.shape) ) - retval = ( - torch.cdist(pred, truth).mean(dim=-2) - - 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None] - ) + if pred_batch_size is None: + retval = ( + torch.cdist(pred, truth).mean(dim=-2) + - 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None] + ) + else: + # Divide predictions into batches + pred_len = pred.shape[-2] + pred_batches = [] + while pred.numel() > 0: + pred_batches.append(pred[..., :pred_batch_size, :]) + pred = pred[..., pred_batch_size:, :] + # Calculate predictions distance to truth + retval = ( + torch.stack( + [ + torch.cdist(pred_batch, truth).sum(dim=-2) + for pred_batch in pred_batches + ], + dim=0, + ).sum(dim=0) + / pred_len + ) + # Calculate predictions self distance + for aux_pred_batch in pred_batches: + retval = ( + retval + - 0.5 + * torch.stack( + [ + torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2]) + for pred_batch in pred_batches + ], + dim=0, + ).sum(dim=0)[..., None] + / pred_len + / pred_len + ) if remove_leftmost_dim: retval = retval[..., 0] diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index 41f7ba3c8c..cae8ef5aba 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -355,3 +355,26 @@ def test_multivariate_energy_score(sample_dim, num_samples=10000): rtol=0.02, ) assert energy_score * 1.02 < energy_score_uncorrelated + + +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)]) +@pytest.mark.parametrize("sample_dim", [30, 100]) +@pytest.mark.parametrize( + "num_samples, pred_batch_size", [(100, 10), (100, 30), (100, 100), (100, 200)] +) +def test_energy_score_empirical_batched_calculation( + batch_shape, sample_dim, num_samples, pred_batch_size +): + # Generate data + truth = torch.randn(batch_shape + (sample_dim,)) + pred = torch.randn(batch_shape + (num_samples, sample_dim)) + # Do batched and regular calculation + expected = energy_score_empirical(pred, truth) + actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size) + # Check accuracy + assert_close(actual, expected) + + +def test_jit_compilation(): + # Test that functions can be JIT compiled + torch.jit.script(energy_score_empirical)