From 9908fa4a2e0dd88c0d17eeac51b51f085c3ed8d9 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 22 Apr 2024 00:43:52 +0300 Subject: [PATCH 1/5] Added the multivariate energy score metric along with tests for its equivalence to the CRPS metric in the univariate case. --- pyro/ops/stats.py | 51 +++++++++++++++++++++++++++++++++++++++++ tests/ops/test_stats.py | 15 ++++++++++++ 2 files changed, 66 insertions(+) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 8e0bd2631f..3921103028 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -508,3 +508,54 @@ def crps_empirical(pred, truth): weight = weight.reshape(weight.shape + (1,) * (diff.dim() - 1)) return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2 + + +def energy_score_empirical(pred, truth): + """ + Computes negative Energy Score ES* [1] between a + set of multivariate samples ``pred`` and a true data vector ``truth``. Running time + is quadratic in the number of samples ``n``. In case of univariate samples + the output coincides with the CRPS:: + + ES* = E|pred - truth| - 1/2 E|pred - pred'| + + Note that for a single sample this reduces to the Euclidean norm of the difference between + the sample ``pred`` and the ``truth``. + + This is a strictly proper metric so that for ``pred`` distirbuted according to a + distribution :math:`P` and ``truth`` distributed according to a distribution :math:`Q` + we have :math:`ES(P,Q) \ge ES(Q,Q)` with equality holding if and only if :math:`P=Q`. + + **References** + + [1] Tilmann Gneiting, Adrian E. Raftery (2007) + `Strictly Proper Scoring Rules, Prediction, and Estimation` + https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf + + :param torch.Tensor pred: A set of sample predictions batched on the second leftmost dim. + The leftmost dim is that of the multivariate sample. + :param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except + for the second leftmost dim which can have any value or be omitted. + :return: A tensor of shape ``truth.shape``. + :rtype: torch.Tensor + """ + if pred.dim() == (truth.dim() + 1): + remove_leftmost_dim = True + truth = truth[..., None, :] + elif pred.dim() == truth.dim(): + remove_leftmost_dim = False + else: + raise ValueError( + "Expected pred to have at most one extra dim versus truth." + "Actual shapes: {} versus {}".format(pred.shape, truth.shape) + ) + + retval = ( + torch.cdist(pred, truth).mean(dim=-2) + - 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None] + ) + + if remove_leftmost_dim: + retval = retval[..., 0] + + return retval diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index 4346be5feb..c6b9d6aa2a 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -12,6 +12,7 @@ autocovariance, crps_empirical, effective_sample_size, + energy_score_empirical, fit_generalized_pareto, gelman_rubin, hpdi, @@ -324,3 +325,17 @@ def test_crps_empirical(num_samples, event_shape): pred - pred.unsqueeze(1) ).abs().mean([0, 1]) assert_close(actual, expected) + + +@pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)]) +@pytest.mark.parametrize("num_samples", [1, 2, 3, 4, 10]) +def test_energy_score_empirical(num_samples, event_shape): + truth = torch.randn(event_shape) + pred = truth + 0.1 * torch.randn((num_samples,) + event_shape) + + actual = crps_empirical(pred, truth) + expected = energy_score_empirical( + pred[..., None].swapaxes(0, -1)[0, ..., None], truth[..., None] + ) + + assert_close(actual, expected) From a0c85183bf756a88fffcabf5509ec78510c89a1f Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 22 Apr 2024 01:42:07 +0300 Subject: [PATCH 2/5] Combined tests for CRPS and univariate energy score. --- tests/ops/test_stats.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index c6b9d6aa2a..1c88614bb2 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -314,7 +314,7 @@ def test_fit_generalized_pareto(k, sigma, n_samples=5000): @pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)]) @pytest.mark.parametrize("num_samples", [1, 2, 3, 4, 10]) -def test_crps_empirical(num_samples, event_shape): +def test_crps_univariate_energy_score_empirical(num_samples, event_shape): truth = torch.randn(event_shape) pred = truth + 0.1 * torch.randn((num_samples,) + event_shape) @@ -326,16 +326,7 @@ def test_crps_empirical(num_samples, event_shape): ).abs().mean([0, 1]) assert_close(actual, expected) - -@pytest.mark.parametrize("event_shape", [(), (4,), (3, 2)]) -@pytest.mark.parametrize("num_samples", [1, 2, 3, 4, 10]) -def test_energy_score_empirical(num_samples, event_shape): - truth = torch.randn(event_shape) - pred = truth + 0.1 * torch.randn((num_samples,) + event_shape) - - actual = crps_empirical(pred, truth) expected = energy_score_empirical( pred[..., None].swapaxes(0, -1)[0, ..., None], truth[..., None] ) - assert_close(actual, expected) From 97394b2ae66367016890a231a1495bb05ccf1e55 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 22 Apr 2024 11:08:14 +0300 Subject: [PATCH 3/5] Added test for the multivariate energy score. --- tests/ops/test_stats.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index 1c88614bb2..e5842ed953 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -330,3 +330,24 @@ def test_crps_univariate_energy_score_empirical(num_samples, event_shape): pred[..., None].swapaxes(0, -1)[0, ..., None], truth[..., None] ) assert_close(actual, expected) + + +@pytest.mark.parametrize("sample_dim", [3, 10, 30, 100]) +def test_multivariate_energy_score(sample_dim, num_samples = 10000): + pred_uncorrelated = torch.randn(num_samples, sample_dim) + + pred = torch.randn(num_samples, 1) + pred = pred.expand(pred_uncorrelated.shape) + + truth = torch.randn(num_samples, 1) + truth = truth.expand(pred_uncorrelated.shape) + + energy_score = energy_score_empirical(pred, truth).mean() + energy_score_uncorrelated = energy_score_empirical(pred_uncorrelated, truth).mean() + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + from scipy.stats import chi + + assert_close(energy_score, torch.tensor(0.5*chi(1).mean()*(2*sample_dim)**0.5), rtol=0.02) + assert energy_score * 1.02 < energy_score_uncorrelated From 46483316d4783f078810bbc4a61eed3e1245860b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 22 Apr 2024 15:41:34 +0300 Subject: [PATCH 4/5] Fixed docs and linting and added typing for pyro.ops.stats.energy_score_empirical. --- pyro/ops/stats.py | 10 ++++++---- tests/ops/test_stats.py | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 3921103028..43ab876f0c 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -510,9 +510,9 @@ def crps_empirical(pred, truth): return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2 -def energy_score_empirical(pred, truth): +def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor: """ - Computes negative Energy Score ES* [1] between a + Computes negative Energy Score ES* (see equation 22 in [1]) between a set of multivariate samples ``pred`` and a true data vector ``truth``. Running time is quadratic in the number of samples ``n``. In case of univariate samples the output coincides with the CRPS:: @@ -522,9 +522,11 @@ def energy_score_empirical(pred, truth): Note that for a single sample this reduces to the Euclidean norm of the difference between the sample ``pred`` and the ``truth``. - This is a strictly proper metric so that for ``pred`` distirbuted according to a + This is a strictly proper score so that for ``pred`` distirbuted according to a distribution :math:`P` and ``truth`` distributed according to a distribution :math:`Q` - we have :math:`ES(P,Q) \ge ES(Q,Q)` with equality holding if and only if :math:`P=Q`. + we have :math:`ES*(P,Q) \ge ES*(Q,Q)` with equality holding if and only if :math:`P=Q', i.e. + if :math:`P` and :math:`Q` have the same multivariate distribution (it is not sufficient for + :math:`P` and :math:`Q` to have the same marginals in order for equality to hold). **References** diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index e5842ed953..41f7ba3c8c 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -333,7 +333,7 @@ def test_crps_univariate_energy_score_empirical(num_samples, event_shape): @pytest.mark.parametrize("sample_dim", [3, 10, 30, 100]) -def test_multivariate_energy_score(sample_dim, num_samples = 10000): +def test_multivariate_energy_score(sample_dim, num_samples=10000): pred_uncorrelated = torch.randn(num_samples, sample_dim) pred = torch.randn(num_samples, 1) @@ -349,5 +349,9 @@ def test_multivariate_energy_score(sample_dim, num_samples = 10000): warnings.filterwarnings("ignore", category=RuntimeWarning) from scipy.stats import chi - assert_close(energy_score, torch.tensor(0.5*chi(1).mean()*(2*sample_dim)**0.5), rtol=0.02) + assert_close( + energy_score, + torch.tensor(0.5 * chi(1).mean() * (2 * sample_dim) ** 0.5), + rtol=0.02, + ) assert energy_score * 1.02 < energy_score_uncorrelated From 8ff9be068cdbbb61e5e3f88e7a61e6a0f604406d Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 22 Apr 2024 16:07:38 +0300 Subject: [PATCH 5/5] Fix math in docs. --- pyro/ops/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 43ab876f0c..efa60134e5 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -524,7 +524,7 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten This is a strictly proper score so that for ``pred`` distirbuted according to a distribution :math:`P` and ``truth`` distributed according to a distribution :math:`Q` - we have :math:`ES*(P,Q) \ge ES*(Q,Q)` with equality holding if and only if :math:`P=Q', i.e. + we have :math:`ES^{*}(P,Q) \ge ES^{*}(Q,Q)` with equality holding if and only if :math:`P=Q`, i.e. if :math:`P` and :math:`Q` have the same multivariate distribution (it is not sufficient for :math:`P` and :math:`Q` to have the same marginals in order for equality to hold).