diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index d25cf16680..ca088645cb 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -3,6 +3,7 @@ import math import warnings +from typing import List, Union import torch @@ -15,45 +16,12 @@ from .util import plate_log_prob_sum -class Importance(TracePosterior): +class LogWeightsMixin: """ - :param model: probabilistic model defined as a function - :param guide: guide used for sampling defined as a function - :param num_samples: number of samples to draw from the guide (default 10) - - This method performs posterior inference by importance sampling - using the guide as the proposal distribution. - If no guide is provided, it defaults to proposing from the model's prior. + Mixin class to compute analytics from a ``.log_weights`` attribute. """ - def __init__(self, model, guide=None, num_samples=None): - """ - Constructor. default to num_samples = 10, guide = model - """ - super().__init__() - if num_samples is None: - num_samples = 10 - warnings.warn( - "num_samples not provided, defaulting to {}".format(num_samples) - ) - if guide is None: - # propose from the prior by making a guide from the model by hiding observes - guide = poutine.block(model, hide_types=["observe"]) - self.num_samples = num_samples - self.model = model - self.guide = guide - - def _traces(self, *args, **kwargs): - """ - Generator of weighted samples from the proposal distribution. - """ - for i in range(self.num_samples): - guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) - model_trace = poutine.trace( - poutine.replay(self.model, trace=guide_trace) - ).get_trace(*args, **kwargs) - log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum() - yield (model_trace, log_weight) + log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor] def get_log_normalizer(self): """ @@ -61,9 +29,13 @@ def get_log_normalizer(self): (mean of the unnormalized weights) """ # ensure list is not empty - if self.log_weights: - log_w = torch.tensor(self.log_weights) - log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0)) + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) + log_num_samples = torch.log(torch.tensor(log_w.numel() * 1.0)) return torch.logsumexp(log_w - log_num_samples, 0) else: warnings.warn( @@ -74,8 +46,12 @@ def get_normalized_weights(self, log_scale=False): """ Compute the normalized importance weights. """ - if self.log_weights: - log_w = torch.tensor(self.log_weights) + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) log_w_norm = log_w - torch.logsumexp(log_w, 0) return log_w_norm if log_scale else torch.exp(log_w_norm) else: @@ -87,7 +63,7 @@ def get_ESS(self): """ Compute (Importance Sampling) Effective Sample Size (ESS). """ - if self.log_weights: + if len(self.log_weights) > 0: log_w_norm = self.get_normalized_weights(log_scale=True) ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) else: @@ -98,6 +74,47 @@ def get_ESS(self): return ess +class Importance(TracePosterior, LogWeightsMixin): + """ + :param model: probabilistic model defined as a function + :param guide: guide used for sampling defined as a function + :param num_samples: number of samples to draw from the guide (default 10) + + This method performs posterior inference by importance sampling + using the guide as the proposal distribution. + If no guide is provided, it defaults to proposing from the model's prior. + """ + + def __init__(self, model, guide=None, num_samples=None): + """ + Constructor. default to num_samples = 10, guide = model + """ + super().__init__() + if num_samples is None: + num_samples = 10 + warnings.warn( + "num_samples not provided, defaulting to {}".format(num_samples) + ) + if guide is None: + # propose from the prior by making a guide from the model by hiding observes + guide = poutine.block(model, hide_types=["observe"]) + self.num_samples = num_samples + self.model = model + self.guide = guide + + def _traces(self, *args, **kwargs): + """ + Generator of weighted samples from the proposal distribution. + """ + for i in range(self.num_samples): + guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(self.model, trace=guide_trace) + ).get_trace(*args, **kwargs) + log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum() + yield (model_trace, log_weight) + + def vectorized_importance_weights(model, guide, *args, **kwargs): """ :param model: probabilistic model defined as a function diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 6be8b5cb5f..ea89aff5e5 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from dataclasses import dataclass from functools import reduce -from typing import List, NamedTuple, Union +from typing import List, Union import torch import pyro import pyro.poutine as poutine +from pyro.infer.importance import LogWeightsMixin from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -34,7 +36,8 @@ def _guess_max_plate_nesting(model, args, kwargs): return max_plate_nesting -class _predictiveResults(NamedTuple): +@dataclass(frozen=True, eq=False) +class _predictiveResults: """ Return value of call to ``_predictive`` and ``_predictive_sequential``. """ @@ -316,7 +319,8 @@ def get_vectorized_trace(self, *args, **kwargs): ).trace -class WeighedPredictiveResults(NamedTuple): +@dataclass(frozen=True, eq=False) +class WeighedPredictiveResults(LogWeightsMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 1f28e1f05c..319a1196dd 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -46,6 +46,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): num_trials = ( torch.ones(5) * 400 ) # Reduced to 400 from 1000 in order for guide optimization to converge + num_samples = 10000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -57,7 +58,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): posterior_predictive = predictive( model, guide=beta_guide, - num_samples=10000, + num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"], ) @@ -71,6 +72,8 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape # Weights should be uniform as the guide has the same distribution as the model assert weighed_samples.log_weights.std() < 0.6 + # Effective sample size should be close to actual number of samples taken from the guide + assert weighed_samples.get_ESS() > 0.8 * num_samples assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1)