From 44964c3a74f1dd4e1dc40317638b57009606e125 Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Mon, 18 May 2020 07:34:08 -0700 Subject: [PATCH 1/9] Added with_cache() and refactored basic element-wise transforms (#2443) --- docs/source/distributions.rst | 4 - pyro/distributions/transforms/__init__.py | 13 +-- pyro/distributions/transforms/basic.py | 84 +++++++++++++++++++ .../transforms/discrete_cosine.py | 5 ++ .../transforms/lower_cholesky_affine.py | 9 +- .../transforms/neural_autoregressive.py | 80 +----------------- pyro/distributions/transforms/permute.py | 9 +- tests/distributions/test_transforms.py | 4 - 8 files changed, 114 insertions(+), 94 deletions(-) create mode 100644 pyro/distributions/transforms/basic.py diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 546d2fb55a..3bcc822d55 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -622,7 +622,3 @@ spline sylvester --------- .. autofunction:: pyro.distributions.transforms.sylvester - -tanh ----- -.. autofunction:: pyro.distributions.transforms.tanh diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 2e82d51dc3..80d23078d3 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -13,6 +13,7 @@ from pyro.distributions.transforms.affine_coupling import (AffineCoupling, ConditionalAffineCoupling, affine_coupling, conditional_affine_coupling) from pyro.distributions.transforms.batchnorm import BatchNorm, batchnorm +from pyro.distributions.transforms.basic import ELUTransform, LeakyReLUTransform, elu, leaky_relu from pyro.distributions.transforms.block_autoregressive import BlockAutoregressive, block_autoregressive from pyro.distributions.transforms.cholesky import CorrLCholeskyTransform from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform @@ -23,10 +24,11 @@ from pyro.distributions.transforms.householder import (ConditionalHouseholder, Householder, conditional_householder, householder) from pyro.distributions.transforms.lower_cholesky_affine import LowerCholeskyAffine -from pyro.distributions.transforms.neural_autoregressive import (ConditionalNeuralAutoregressive, ELUTransform, - LeakyReLUTransform, NeuralAutoregressive, - conditional_neural_autoregressive, elu, leaky_relu, - neural_autoregressive, tanh) +from pyro.distributions.transforms.neural_autoregressive import ( + ConditionalNeuralAutoregressive, + NeuralAutoregressive, + conditional_neural_autoregressive, + neural_autoregressive) from pyro.distributions.transforms.permute import Permute, permute from pyro.distributions.transforms.planar import ConditionalPlanar, Planar, conditional_planar, planar from pyro.distributions.transforms.polynomial import Polynomial, polynomial @@ -111,8 +113,7 @@ def iterated(repeats, base_fn, *args, **kwargs): 'polynomial', 'radial', 'spline', - 'sylvester', - 'tanh', + 'sylvester' ] __all__.extend(torch_transforms) diff --git a/pyro/distributions/transforms/basic.py b/pyro/distributions/transforms/basic.py new file mode 100644 index 0000000000..4cbe219bc6 --- /dev/null +++ b/pyro/distributions/transforms/basic.py @@ -0,0 +1,84 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +from torch.distributions.transforms import Transform, TanhTransform +from torch.distributions import constraints +import torch.nn.functional as F + +# TODO: Move upstream + + +class ELUTransform(Transform): + r""" + Bijective transform via the mapping :math:`y = \text{ELU}(x)`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, ELUTransform) + + def _call(self, x): + return F.elu(x) + + def _inverse(self, y, eps=1e-8): + return torch.max(y, torch.zeros_like(y)) + torch.min(torch.log1p(y + eps), torch.zeros_like(y)) + + def log_abs_det_jacobian(self, x, y): + return -F.relu(-x) + + +def elu(): + """ + A helper function to create an + :class:`~pyro.distributions.transform.ELUTransform` object for consistency with + other helpers. + """ + return ELUTransform() + +# TODO: Move upstream + + +class LeakyReLUTransform(Transform): + r""" + Bijective transform via the mapping :math:`y = \text{LeakyReLU}(x)`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, LeakyReLUTransform) + + def _call(self, x): + return F.leaky_relu(x) + + def _inverse(self, y): + return F.leaky_relu(y, negative_slope=100.0) + + def log_abs_det_jacobian(self, x, y): + return torch.where(x >= 0., torch.zeros_like(x), torch.ones_like(x) * math.log(0.01)) + + +def leaky_relu(): + """ + A helper function to create a + :class:`~pyro.distributions.transforms.LeakyReLUTransform` object for + consistency with other helpers. + """ + return LeakyReLUTransform() + + +def tanh(): + """ + A helper function to create a + :class:`~pyro.distributions.transforms.TanhTransform` object for consistency + with other helpers. + """ + return TanhTransform() diff --git a/pyro/distributions/transforms/discrete_cosine.py b/pyro/distributions/transforms/discrete_cosine.py index bc47dcc650..ddbf3c1e56 100644 --- a/pyro/distributions/transforms/discrete_cosine.py +++ b/pyro/distributions/transforms/discrete_cosine.py @@ -73,3 +73,8 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): return x.new_zeros((1,) * self.event_dim) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return DiscreteCosineTransform(-self.event_dim, self.smooth, cache_size=cache_size) diff --git a/pyro/distributions/transforms/lower_cholesky_affine.py b/pyro/distributions/transforms/lower_cholesky_affine.py index 98658aacf0..cc8caeab10 100644 --- a/pyro/distributions/transforms/lower_cholesky_affine.py +++ b/pyro/distributions/transforms/lower_cholesky_affine.py @@ -28,8 +28,8 @@ class LowerCholeskyAffine(Transform): event_dim = 1 volume_preserving = False - def __init__(self, loc, scale_tril): - super().__init__(cache_size=1) + def __init__(self, loc, scale_tril, cache_size=0): + super().__init__(cache_size=cache_size) self.loc = loc self.scale_tril = scale_tril assert loc.size(-1) == scale_tril.size(-1) == scale_tril.size(-2), \ @@ -64,3 +64,8 @@ def log_abs_det_jacobian(self, x, y): """ return torch.ones(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device) * \ self.scale_tril.diag().log().sum() + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return LowerCholeskyAffine(self.loc, self.scale_tril, cache_size=cache_size) diff --git a/pyro/distributions/transforms/neural_autoregressive.py b/pyro/distributions/transforms/neural_autoregressive.py index 11f583516f..2d9438505f 100644 --- a/pyro/distributions/transforms/neural_autoregressive.py +++ b/pyro/distributions/transforms/neural_autoregressive.py @@ -1,96 +1,23 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from __future__ import absolute_import, division, print_function - -import math from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import constraints -from torch.distributions.transforms import SigmoidTransform, TanhTransform, Transform +from torch.distributions.transforms import SigmoidTransform, TanhTransform from pyro.distributions.conditional import ConditionalTransformModule from pyro.distributions.torch_transform import TransformModule +from pyro.distributions.transforms.basic import ELUTransform, LeakyReLUTransform from pyro.distributions.util import copy_docs_from from pyro.nn import AutoRegressiveNN, ConditionalAutoRegressiveNN eps = 1e-8 -class ELUTransform(Transform): - r""" - Bijective transform via the mapping :math:`y = \text{ELU}(x)`. - """ - domain = constraints.real - codomain = constraints.positive - bijective = True - sign = +1 - - def __eq__(self, other): - return isinstance(other, ELUTransform) - - def _call(self, x): - return F.elu(x) - - def _inverse(self, y): - return torch.max(y, torch.zeros_like(y)) + torch.min(torch.log1p(y + eps), torch.zeros_like(y)) - - def log_abs_det_jacobian(self, x, y): - return -F.relu(-x) - - -def elu(): - """ - A helper function to create an - :class:`~pyro.distributions.transform.ELUTransform` object for consistency with - other helpers. - """ - return ELUTransform() - - -class LeakyReLUTransform(Transform): - r""" - Bijective transform via the mapping :math:`y = \text{LeakyReLU}(x)`. - """ - domain = constraints.real - codomain = constraints.positive - bijective = True - sign = +1 - - def __eq__(self, other): - return isinstance(other, LeakyReLUTransform) - - def _call(self, x): - return F.leaky_relu(x) - - def _inverse(self, y): - return F.leaky_relu(y, negative_slope=100.0) - - def log_abs_det_jacobian(self, x, y): - return torch.where(x >= 0., torch.zeros_like(x), torch.ones_like(x) * math.log(0.01)) - - -def leaky_relu(): - """ - A helper function to create a - :class:`~pyro.distributions.transforms.LeakyReLUTransform` object for - consistency with other helpers. - """ - return LeakyReLUTransform() - - -def tanh(): - """ - A helper function to create a - :class:`~pyro.distributions.transforms.TanhTransform` object for consistency - with other helpers. - """ - return TanhTransform() - - @copy_docs_from(TransformModule) class NeuralAutoregressive(TransformModule): r""" @@ -135,6 +62,7 @@ class NeuralAutoregressive(TransformModule): codomain = constraints.real bijective = True event_dim = 1 + eps = 1e-8 autoregressive = True def __init__(self, autoregressive_nn, hidden_units=16, activation='sigmoid'): @@ -201,7 +129,7 @@ def log_abs_det_jacobian(self, x, y): T = self.T log_dydD = self._cached_log_df_inv_dx - log_dDdx = torch.logsumexp(torch.log(A + eps) + self.logsoftmax(W_pre) + + log_dDdx = torch.logsumexp(torch.log(A + self.eps) + self.logsoftmax(W_pre) + T.log_abs_det_jacobian(C, T_C), dim=-2) log_det = log_dydD + log_dDdx return log_det.sum(-1) diff --git a/pyro/distributions/transforms/permute.py b/pyro/distributions/transforms/permute.py index 78dc4ec56b..30626779f6 100644 --- a/pyro/distributions/transforms/permute.py +++ b/pyro/distributions/transforms/permute.py @@ -45,8 +45,8 @@ class Permute(Transform): event_dim = 1 volume_preserving = True - def __init__(self, permutation): - super().__init__(cache_size=1) + def __init__(self, permutation, cache_size=1): + super().__init__(cache_size=cache_size) self.permutation = permutation @@ -91,6 +91,11 @@ def log_abs_det_jacobian(self, x, y): return torch.zeros(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device) + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return Permute(self.permutation, cache_size=cache_size) + def permute(input_dim, permutation=None): """ diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index f6adbc48d1..f2b33dcbf7 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -261,7 +261,3 @@ def test_spline(self): def test_sylvester(self): self._test(T.sylvester, inverse=False) - - def test_tanh(self): - # NOTE: Need following since helper function mistakenly doesn't take input dim - self._test(lambda input_dim: T.tanh()) From 28dde7463e883474df34563054b33f587753ad05 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 18 May 2020 14:46:48 -0700 Subject: [PATCH 2/9] Use systematic resampling in SMCFilter (#2488) * Use systematic resampling in SMCFilter * Remove debug statement * Add retry logic to SMC for epidemiology * Fix formatting issue --- pyro/contrib/epidemiology/compartmental.py | 28 +++++++++++----- pyro/infer/smcfilter.py | 39 +++++++++++++++++----- tests/infer/test_smcfilter.py | 18 ++++++++++ 3 files changed, 68 insertions(+), 17 deletions(-) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index a95baf46e7..0633b69110 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -20,6 +20,7 @@ from pyro.infer.autoguide import init_to_generated, init_to_value from pyro.infer.mcmc import ArrowheadMassMatrix from pyro.infer.reparam import DiscreteCosineReparam +from pyro.infer.smcfilter import SMCFailed from pyro.util import warn_if_nan from .distributions import set_approx_sample_thresh @@ -144,7 +145,7 @@ def _clear_plates(self): @torch.no_grad() @set_approx_sample_thresh(1000) - def heuristic(self, num_particles=1024, ess_threshold=0.5): + def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): """ Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are @@ -163,12 +164,20 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5): # Run SMC. model = _SMCModel(self) guide = _SMCGuide(self) - smc = SMCFilter(model, guide, num_particles=num_particles, - ess_threshold=ess_threshold, - max_plate_nesting=self.max_plate_nesting) - smc.init() - for t in range(1, self.duration): - smc.step() + for attempt in range(1, 1 + retries): + smc = SMCFilter(model, guide, num_particles=num_particles, + ess_threshold=ess_threshold, + max_plate_nesting=self.max_plate_nesting) + try: + smc.init() + for t in range(1, self.duration): + smc.step() + break + except SMCFailed as e: + if attempt == retries: + raise + logger.info("{}. Retrying...".format(e)) + continue # Select the most probable hypothesis. i = int(smc.state._log_weights.max(0).indices) @@ -309,7 +318,6 @@ def fit(self, **options): if k.startswith("heuristic_")} def heuristic(): - logger.info("Heuristically initializing...") with poutine.block(): init_values = self.heuristic(**heuristic_options) assert isinstance(init_values, dict) @@ -321,6 +329,10 @@ def heuristic(): x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) x = DiscreteCosineTransform(smooth=self._dct)(x) init_values["auxiliary_dct"] = x + logger.info("Heuristic init: {}".format(", ".join( + "{}={:0.3g}".format(k, v.item()) + for k, v in init_values.items() + if v.numel() == 1))) return init_to_value(values=init_values) # Configure a kernel. diff --git a/pyro/infer/smcfilter.py b/pyro/infer/smcfilter.py index 2199a89f80..22d2360748 100644 --- a/pyro/infer/smcfilter.py +++ b/pyro/infer/smcfilter.py @@ -13,6 +13,14 @@ from pyro.poutine.util import prune_subsample_sites +class SMCFailed(ValueError): + """ + Exception raised when :class:`SMCFilter` fails to find any hypothesis with + nonzero probability. + """ + pass + + class SMCFilter: """ :class:`SMCFilter` is the top-level interface for filtering via sequential @@ -112,16 +120,16 @@ def _update_weights(self, model_trace, guide_trace): log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p - log_q if not (self.state._log_weights.max() > -math.inf): - raise ValueError("Failed to find feasible hypothesis after site {}" - .format(name)) + raise SMCFailed("Failed to find feasible hypothesis after site {}" + .format(name)) for site in model_trace.nodes.values(): if site["type"] == "sample" and site["is_observed"]: log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p if not (self.state._log_weights.max() > -math.inf): - raise ValueError("Failed to find feasible hypothesis after site {}" - .format(site["name"])) + raise SMCFailed("Failed to find feasible hypothesis after site {}" + .format(site["name"])) self.state._log_weights -= self.state._log_weights.max() @@ -130,16 +138,29 @@ def _maybe_importance_resample(self): return # Decide whether to resample based on ESS. logp = self.state._log_weights - logp -= logp.logsumexp(dim=-1) - ess = logp.mul(2).exp().sum().reciprocal() + logp -= logp.logsumexp(-1) + probs = logp.exp() + ess = probs.dot(probs).reciprocal() if ess < self.ess_threshold * self.num_particles: - self._importance_resample() + self._importance_resample(probs) - def _importance_resample(self): - index = dist.Categorical(logits=self.state._log_weights).sample(sample_shape=(self.num_particles,)) + def _importance_resample(self, probs): + index = _systematic_sample(probs) self.state._resample(index) +def _systematic_sample(probs): + # Systematic sampling preserves diversity better than multinomial sampling + # via Categorical(probs).sample(). + batch_shape, size = probs.shape[:-1], probs.size(-1) + n = probs.cumsum(-1).mul_(size).add_(torch.rand(batch_shape + (1,))) + n = n.floor_().clamp_(min=0, max=size).long() + diff = probs.new_zeros(batch_shape + (size + 1,)) + diff.scatter_add_(-1, n, torch.ones_like(probs)) + index = diff[..., :-1].cumsum(-1).long() + return index + + class SMCState(dict): """ Dictionary-like object to hold a vectorized collection of tensors to diff --git a/tests/infer/test_smcfilter.py b/tests/infer/test_smcfilter.py index 1c180a656e..4073496f10 100644 --- a/tests/infer/test_smcfilter.py +++ b/tests/infer/test_smcfilter.py @@ -8,9 +8,27 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import SMCFilter +from pyro.infer.smcfilter import _systematic_sample from tests.common import assert_close +@pytest.mark.parametrize("size", range(1, 32)) +def test_systematic_sample(size): + pyro.set_rng_seed(size) + probs = torch.randn(size).exp() + probs /= probs.sum() + + num_samples = 20000 + index = _systematic_sample(probs.expand(num_samples, size)) + histogram = torch.zeros_like(probs) + histogram.scatter_add_(-1, index.reshape(-1), + probs.new_ones(1).expand(num_samples * size)) + + expected = probs * size + actual = histogram / num_samples + assert_close(actual, expected, atol=0.01) + + class SmokeModel: def __init__(self, state_size, plate_size): From f81c70c85ca6a2cbf9adb31d26ccc49007d5a335 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 19 May 2020 10:40:54 -0700 Subject: [PATCH 3/9] Revise binomial approximation interface (#2490) * Refactor binomial approximation logic * Remove debug statement --- pyro/contrib/epidemiology/compartmental.py | 4 ++- pyro/contrib/epidemiology/distributions.py | 40 ++++++++++----------- pyro/distributions/conjugate.py | 15 ++------ pyro/distributions/torch.py | 42 +++++++++------------- tests/distributions/test_binomial.py | 18 +++++----- 5 files changed, 50 insertions(+), 69 deletions(-) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 0633b69110..202896a35a 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -144,7 +144,7 @@ def _clear_plates(self): full_mass = False @torch.no_grad() - @set_approx_sample_thresh(1000) + @set_approx_sample_thresh(100) # This is robust to gross approximation. def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): """ Finds an initial feasible guess of all latent variables, consistent @@ -258,6 +258,7 @@ def transition_bwd(self, params, prev, curr, t): # Inference interface ######################################## @torch.no_grad() + @set_approx_sample_thresh(1000) def generate(self, fixed={}): """ Generate data from the prior. @@ -361,6 +362,7 @@ def heuristic(): return mcmc # E.g. so user can run mcmc.summary(). @torch.no_grad() + @set_approx_sample_thresh(10000) def predict(self, forecast=0): """ Predict latent variables and optionally forecast forward. diff --git a/pyro/contrib/epidemiology/distributions.py b/pyro/contrib/epidemiology/distributions.py index 6bca2f280e..0da2f6d0ee 100644 --- a/pyro/contrib/epidemiology/distributions.py +++ b/pyro/contrib/epidemiology/distributions.py @@ -8,25 +8,34 @@ import pyro.distributions as dist -_APPROX_SAMPLE_THRESH = 10000 - @contextmanager def set_approx_sample_thresh(thresh): """ - Temporarily set global approx_sample_thresh in ``infection_dist``. - The default global value is 10000. + EXPERIMENTAL Temporarily set the global default value of + ``Binomial.approx_sample_thresh``, thereby decreasing the computational + complexity of sampling from :class:`~pyro.distributions.Binomial`, + :class:`~pyro.distributions.BetaBinomial`, + :class:`~pyro.distributions.ExtendedBinomial`, + :class:`~pyro.distributions.ExtendedBetaBinomial`, and distributions + returned by :func:`infection_dist`. + + This is useful for sampling from very large ``total_count``. + + This is used internally by + :class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`. :param thresh: New temporary threshold. :type thresh: int or float. """ - global _APPROX_SAMPLE_THRESH - old = _APPROX_SAMPLE_THRESH + assert isinstance(thresh, (float, int)) + assert thresh > 0 + old = dist.Binomial.approx_sample_thresh try: - _APPROX_SAMPLE_THRESH = thresh + dist.Binomial.approx_sample_thresh = thresh yield finally: - _APPROX_SAMPLE_THRESH = old + dist.Binomial.approx_sample_thresh = old def infection_dist(*, @@ -34,8 +43,7 @@ def infection_dist(*, num_infectious, num_susceptible=math.inf, population=math.inf, - concentration=math.inf, - approx_sample_thresh=None): + concentration=math.inf): """ Create a :class:`~pyro.distributions.Distribution` over the number of new infections at a discrete time step. @@ -82,10 +90,6 @@ def infection_dist(*, :param concentration: The concentration or dispersion parameter ``k`` in overdispersed models of superspreaders [1,2]. This defaults to minimum variance ``concentration = ∞``. - :param approx_sample_thresh: Population threshold above which Binomial - samples will be approximated as clamped Poisson samples, including - internally in BetaBinomial sampling. Defaults to the global value which - defaults to 10000. """ # Convert to colloquial variable names. R = individual_rate @@ -110,19 +114,15 @@ def infection_dist(*, # Combine infections from all individuals. combined_p = p.neg().log1p().mul(I).expm1().neg() # = 1 - (1 - p)**I combined_p = combined_p.clamp(min=1e-6) - if approx_sample_thresh is None: - approx_sample_thresh = _APPROX_SAMPLE_THRESH if isinstance(k, float) and k == math.inf: # Return a pure Binomial model, combining the independent Binomial # models of each infectious individual. - return dist.ExtendedBinomial( - S, combined_p, approx_sample_thresh=approx_sample_thresh) + return dist.ExtendedBinomial(S, combined_p) else: # Return an overdispersed Beta-Binomial model, combining # independent BetaBinomial(c1,c0,S) models for each infectious # individual. c1 = (k * I).clamp(min=1e-6) c0 = c1 * (combined_p.reciprocal() - 1).clamp(min=1e-6) - return dist.ExtendedBetaBinomial( - c1, c0, S, approx_sample_thresh=approx_sample_thresh) + return dist.ExtendedBetaBinomial(c1, c0, S) diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index d80f9ec086..0ec7279537 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -1,7 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import math import numbers import torch @@ -41,25 +40,17 @@ class BetaBinomial(TorchDistribution): :param float or torch.Tensor concentration0: 2nd concentration parameter (beta) for the Beta distribution. :param int or torch.Tensor total_count: number of Bernoulli trials. - :param approx_sample_thresh: EXPERIMENTAL total_count above which sampling - will use a clamped Poisson approximation for Binomial samples. This is useful - for sampling very large populations. - :type approx_sample_thresh: int or float """ arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive, 'total_count': constraints.nonnegative_integer} has_enumerate_support = True support = Binomial.support - def __init__(self, concentration1, concentration0, total_count=1, validate_args=None, - *, approx_sample_thresh=math.inf): - assert isinstance(approx_sample_thresh, numbers.Number) - assert approx_sample_thresh >= 0 + def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): concentration1, concentration0, total_count = broadcast_all( concentration1, concentration0, total_count) self._beta = Beta(concentration1, concentration0) self.total_count = total_count - self.approx_sample_thresh = approx_sample_thresh super().__init__(self._beta._batch_shape, validate_args=validate_args) @property @@ -75,15 +66,13 @@ def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new._beta = self._beta.expand(batch_shape) new.total_count = self.total_count.expand_as(new._beta.concentration0) - new.approx_sample_thresh = self.approx_sample_thresh super(BetaBinomial, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new def sample(self, sample_shape=()): probs = self._beta.sample(sample_shape) - return Binomial(self.total_count, probs, - approx_sample_thresh=self.approx_sample_thresh).sample() + return Binomial(self.total_count, probs, validate_args=False).sample() def log_prob(self, value): if self._validate_args: diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 8323e854ac..e1b5c9b3e8 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import math -import numbers import torch from torch.distributions import constraints @@ -32,33 +31,16 @@ def _log_normalizer(d): class Binomial(torch.distributions.Binomial, TorchDistributionMixin): - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None, *, - approx_sample_thresh=math.inf): - assert isinstance(approx_sample_thresh, numbers.Number) - assert approx_sample_thresh >= 0 - super().__init__(total_count=total_count, probs=probs, logits=logits, - validate_args=validate_args) - self.approx_sample_thresh = approx_sample_thresh - - def expand(self, batch_shape, _instance=None): - new = self._get_checked_instance(Binomial, _instance) - batch_shape = torch.Size(batch_shape) - new.total_count = self.total_count.expand(batch_shape) - if 'probs' in self.__dict__: - new.probs = self.probs.expand(batch_shape) - new._param = new.probs - if 'logits' in self.__dict__: - new.logits = self.logits.expand(batch_shape) - new._param = new.logits - new.approx_sample_thresh = self.approx_sample_thresh - super(torch.distributions.Binomial, new).__init__(batch_shape, validate_args=False) - new._validate_args = self._validate_args - return new + # EXPERIMENTAL threshold on total_count above which sampling will use a + # clamped Poisson approximation for Binomial samples. This is useful for + # sampling very large populations. + approx_sample_thresh = math.inf def sample(self, sample_shape=torch.Size()): if self.approx_sample_thresh < math.inf: - if self.approx_sample_thresh < self.total_count.min(): - # Approximate with a moment-matched clamped Poisson. + exact = self.total_count <= self.approx_sample_thresh + if not exact.all(): + # Approximate large counts with a moment-matched clamped Poisson. with torch.no_grad(): shape = self._extended_shape(sample_shape) p = self.probs @@ -68,7 +50,15 @@ def sample(self, sample_shape=torch.Size()): shift = (mean - variance).round() result = torch.poisson(variance.expand(shape)) result = torch.min(result + shift, self.total_count) - return torch.where(p < q, result, self.total_count - result) + sample = torch.where(p < q, result, self.total_count - result) + # Draw exact samples for remaining items. + if exact.any(): + total_count = torch.where(exact, self.total_count, + torch.zeros_like(self.total_count)) + exact_sample = torch.distributions.Binomial( + total_count, self.probs, validate_args=False).sample(sample_shape) + sample = torch.where(exact, exact_sample, sample) + return sample return super().sample(sample_shape) diff --git a/tests/distributions/test_binomial.py b/tests/distributions/test_binomial.py index a83723e8b1..1d6a04a950 100644 --- a/tests/distributions/test_binomial.py +++ b/tests/distributions/test_binomial.py @@ -4,6 +4,7 @@ import pytest import pyro.distributions as dist +from pyro.contrib.epidemiology.distributions import set_approx_sample_thresh from tests.common import assert_close @@ -11,10 +12,10 @@ @pytest.mark.parametrize("prob", [0.01, 0.1, 0.5, 0.9, 0.99]) def test_binomial_approx_sample(total_count, prob): sample_shape = (10000,) - d1 = dist.Binomial(total_count, prob) - d2 = dist.Binomial(total_count, prob, approx_sample_thresh=200) - expected = d1.sample(sample_shape) - actual = d2.sample(sample_shape) + d = dist.Binomial(total_count, prob) + expected = d.sample(sample_shape) + with set_approx_sample_thresh(200): + actual = d.sample(sample_shape) assert_close(expected.mean(), actual.mean(), rtol=0.05) assert_close(expected.std(), actual.std(), rtol=0.05) @@ -25,11 +26,10 @@ def test_binomial_approx_sample(total_count, prob): @pytest.mark.parametrize("concentration0", [0.1, 1.0, 10.]) def test_beta_binomial_approx_sample(concentration1, concentration0, total_count): sample_shape = (10000,) - d1 = dist.BetaBinomial(concentration1, concentration0, total_count) - d2 = dist.BetaBinomial(concentration1, concentration0, total_count, - approx_sample_thresh=200) - expected = d1.sample(sample_shape) - actual = d2.sample(sample_shape) + d = dist.BetaBinomial(concentration1, concentration0, total_count) + expected = d.sample(sample_shape) + with set_approx_sample_thresh(200): + actual = d.sample(sample_shape) assert_close(expected.mean(), actual.mean(), rtol=0.1) assert_close(expected.std(), actual.std(), rtol=0.1) From 144682b3cc18cc7287af65b0bee975674843bd06 Mon Sep 17 00:00:00 2001 From: martinjankowiak Date: Tue, 19 May 2020 17:17:55 -0700 Subject: [PATCH 4/9] add HaarTransform (#2492) --- pyro/distributions/transforms/__init__.py | 2 + .../transforms/discrete_cosine.py | 6 +- pyro/distributions/transforms/haar.py | 63 +++++++++++++++++++ pyro/ops/tensor_utils.py | 42 +++++++++++++ tests/distributions/test_haar.py | 13 ++++ tests/distributions/test_transforms.py | 5 ++ 6 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 pyro/distributions/transforms/haar.py create mode 100644 tests/distributions/test_haar.py diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 80d23078d3..6ff2b37272 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -17,6 +17,7 @@ from pyro.distributions.transforms.block_autoregressive import BlockAutoregressive, block_autoregressive from pyro.distributions.transforms.cholesky import CorrLCholeskyTransform from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform +from pyro.distributions.transforms.haar import HaarTransform from pyro.distributions.transforms.generalized_channel_permute import (ConditionalGeneralizedChannelPermute, GeneralizedChannelPermute, conditional_generalized_channel_permute, @@ -82,6 +83,7 @@ def iterated(repeats, base_fn, *args, **kwargs): 'DiscreteCosineTransform', 'ELUTransform', 'GeneralizedChannelPermute', + 'HaarTransform', 'Householder', 'LeakyReLUTransform', 'LowerCholeskyAffine', diff --git a/pyro/distributions/transforms/discrete_cosine.py b/pyro/distributions/transforms/discrete_cosine.py index ddbf3c1e56..61dd995d42 100644 --- a/pyro/distributions/transforms/discrete_cosine.py +++ b/pyro/distributions/transforms/discrete_cosine.py @@ -23,8 +23,8 @@ class DiscreteCosineTransform(Transform): noise; when -1 this transforms violet noise to white noise; etc. Any real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise. """ - domain = constraints.real - codomain = constraints.real + domain = constraints.real_vector + codomain = constraints.real_vector bijective = True def __init__(self, dim=-1, smooth=0., cache_size=0): @@ -72,7 +72,7 @@ def _inverse(self, y): return x def log_abs_det_jacobian(self, x, y): - return x.new_zeros((1,) * self.event_dim) + return x.new_zeros(x.shape[:-self.event_dim]) def with_cache(self, cache_size=1): if self._cache_size == cache_size: diff --git a/pyro/distributions/transforms/haar.py b/pyro/distributions/transforms/haar.py new file mode 100644 index 0000000000..ab8838c498 --- /dev/null +++ b/pyro/distributions/transforms/haar.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from torch.distributions import constraints +from torch.distributions.transforms import Transform + +from pyro.ops.tensor_utils import haar_transform, inverse_haar_transform + + +class HaarTransform(Transform): + """ + Discrete Haar transform. + + This uses :func:`~pyro.ops.tensor_utils.haar_transform` and + :func:`~pyro.ops.tensor_utils.inverse_haar_transform` to compute + (orthonormal) Haar and inverse Haar transforms. The jacobian is 1. + For sequences with length `T` not a power of two, this implementation + is equivalent to a block-structured Haar transform in which block + sizes decrease by factors of one half from left to right. + + :param int dim: Dimension along which to transform. Must be negative. + This is an absolute dim counting from the right. + :param bool flip: Whether to flip the time axis before applying the + Haar transform. Defaults to false. + """ + domain = constraints.real_vector + codomain = constraints.real_vector + bijective = True + + def __init__(self, dim=-1, flip=False, cache_size=0): + assert isinstance(dim, int) and dim < 0 + self.event_dim = -dim + self.flip = flip + super().__init__(cache_size=cache_size) + + def __eq__(self, other): + return (type(self) == type(other) and self.event_dim == other.event_dim and + self.flip == other.flip) + + def _call(self, x): + dim = -self.event_dim + if dim != -1: + x = x.transpose(dim, -1) + if self.flip: + x = x.flip(-1) + y = haar_transform(x) + if dim != -1: + y = y.transpose(dim, -1) + return y + + def _inverse(self, y): + dim = -self.event_dim + if dim != -1: + y = y.transpose(dim, -1) + x = inverse_haar_transform(y) + if self.flip: + x = x.flip(-1) + if dim != -1: + x = x.transpose(dim, -1) + return x + + def log_abs_det_jacobian(self, x, y): + return x.new_zeros(x.shape[:-self.event_dim]) diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 8d3421fcc4..777dfb8570 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -6,6 +6,9 @@ import torch +_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0) + + class _SafeLog(torch.autograd.Function): @staticmethod def forward(ctx, x): @@ -350,6 +353,45 @@ def idct(x, dim=-1): return torch.stack([y, y.flip(-1)], axis=-1).reshape(x.shape[:-1] + (-1,))[..., :N] +def haar_transform(x): + """ + Discrete Haar transform. + + Performs a Haar transform along the final dimension. + This is the inverse of :func:`inverse_haar_transform`. + + :param Tensor x: The input signal. + :rtype: Tensor + """ + n = x.size(-1) // 2 + even, odd, end = x[..., 0:n+n:2], x[..., 1:n+n:2], x[..., n+n:] + hi = _ROOT_TWO_INVERSE * (even - odd) + lo = _ROOT_TWO_INVERSE * (even + odd) + if n >= 2: + lo = haar_transform(lo) + x = torch.cat([lo, hi, end], dim=-1) + return x + + +def inverse_haar_transform(x): + """ + Performs an inverse Haar transform along the final dimension. + This is the inverse of :func:`haar_transform`. + + :param Tensor x: The input signal. + :rtype: Tensor + """ + n = x.size(-1) // 2 + lo, hi, end = x[..., :n], x[..., n:n+n], x[..., n+n:] + if n >= 2: + lo = inverse_haar_transform(lo) + even = _ROOT_TWO_INVERSE * (lo + hi) + odd = _ROOT_TWO_INVERSE * (lo - hi) + even_odd = torch.stack([even, odd], dim=-1).reshape(even.shape[:-1] + (-1,)) + x = torch.cat([even_odd, end], dim=-1) + return x + + def cholesky(x): if x.size(-1) == 1: return x.sqrt() diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py new file mode 100644 index 0000000000..441b6852b7 --- /dev/null +++ b/tests/distributions/test_haar.py @@ -0,0 +1,13 @@ +import torch + +import pytest +from pyro.distributions.transforms import HaarTransform +from tests.common import assert_equal + + +@pytest.mark.parametrize('size', [1, 3, 4, 7, 8, 9]) +def test_haar_ortho(size): + haar = HaarTransform() + eye = torch.eye(size) + mat = haar(eye) + assert_equal(eye, mat @ mat.t()) diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index f2b33dcbf7..d1b0f0ab0b 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -192,6 +192,11 @@ def test_discrete_cosine(self): self._test(lambda input_dim: T.DiscreteCosineTransform(smooth=1.0)) self._test(lambda input_dim: T.DiscreteCosineTransform(smooth=2.0)) + def test_haar_transform(self): + # NOTE: Need following since helper function unimplemented + self._test(lambda input_dim: T.HaarTransform(flip=True)) + self._test(lambda input_dim: T.HaarTransform(flip=False)) + def test_elu(self): # NOTE: Need following since helper function mistakenly doesn't take input dim self._test(lambda input_dim: T.elu()) From 35110fa3d224a0d72638f3011a08640e237c2416 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 19 May 2020 19:30:45 -0700 Subject: [PATCH 5/9] Add a HaarReparam and use it in contrib.epidemiology (#2493) * initial commit * imports * address comments * rev comments * Factor UnitJacobianReparam out of DiscreteCosineReparam * Add a HaarReparam and use it in contrib.epidemiology * Fix haar tests on contrib.epidemiology * Fix HaarTransform.with_cache() * Fix reparam logic in .fit() * Fix typos Co-authored-by: Martin Jankowiak --- docs/source/distributions.rst | 7 ++ docs/source/infer.reparam.rst | 18 +++++ examples/contrib/epidemiology/regional.py | 2 + examples/contrib/epidemiology/sir.py | 5 +- pyro/contrib/epidemiology/compartmental.py | 34 ++++----- pyro/distributions/transforms/haar.py | 5 ++ pyro/infer/reparam/__init__.py | 6 +- pyro/infer/reparam/discrete_cosine.py | 38 ++-------- pyro/infer/reparam/haar.py | 30 ++++++++ pyro/infer/reparam/unit_jacobian.py | 49 +++++++++++++ tests/contrib/epidemiology/test_seir.py | 4 +- tests/contrib/epidemiology/test_sir.py | 9 ++- tests/distributions/test_haar.py | 3 + tests/infer/reparam/test_haar.py | 85 ++++++++++++++++++++++ tests/infer/reparam/test_unit_jacobian.py | 55 ++++++++++++++ tests/test_examples.py | 8 +- 16 files changed, 296 insertions(+), 62 deletions(-) create mode 100644 pyro/infer/reparam/haar.py create mode 100644 pyro/infer/reparam/unit_jacobian.py create mode 100644 tests/infer/reparam/test_haar.py create mode 100644 tests/infer/reparam/test_unit_jacobian.py diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 3bcc822d55..b65cb735bb 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -340,6 +340,13 @@ ELUTransform :undoc-members: :show-inheritance: +HaarTransform +------------- +.. autoclass:: pyro.distributions.transforms.HaarTransform + :members: + :undoc-members: + :show-inheritance: + LeakyReLUTransform ------------------ .. autoclass:: pyro.distributions.transforms.LeakyReLUTransform diff --git a/docs/source/infer.reparam.rst b/docs/source/infer.reparam.rst index 36bdc44d48..1b46eda249 100644 --- a/docs/source/infer.reparam.rst +++ b/docs/source/infer.reparam.rst @@ -50,6 +50,24 @@ Discrete Cosine Transform :special-members: __call__ :show-inheritance: +Haar Transform +-------------- +.. automodule:: pyro.infer.reparam.haar + :members: + :undoc-members: + :member-order: bysource + :special-members: __call__ + :show-inheritance: + +Unit Jacobian Transforms +------------------------ +.. automodule:: pyro.infer.reparam.unit_jacobian + :members: + :undoc-members: + :member-order: bysource + :special-members: __call__ + :show-inheritance: + StudentT Distributions ---------------------- .. automodule:: pyro.infer.reparam.studentt diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index e50c321cf8..9a67db44bc 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -57,6 +57,7 @@ def hook_fn(kernel, *unused): num_samples=args.num_samples, max_tree_depth=args.max_tree_depth, num_quant_bins=args.num_bins, + haar=args.haar, hook_fn=hook_fn) mcmc.summary() @@ -135,6 +136,7 @@ def main(args): parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float) parser.add_argument("-tau", "--recovery-time", default=7.0, type=float) parser.add_argument("-rho", "--response-rate", default=0.5, type=float) + parser.add_argument("--haar", action="store_true") parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 2ffb2a7982..ba6a4a4034 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -73,7 +73,7 @@ def hook_fn(kernel, *unused): max_tree_depth=args.max_tree_depth, arrowhead_mass=args.arrowhead_mass, num_quant_bins=args.num_bins, - dct=args.dct, + haar=args.haar, hook_fn=hook_fn) mcmc.summary() @@ -204,8 +204,7 @@ def main(args): parser.add_argument("-k", "--concentration", default=math.inf, type=float, help="If finite, use a superspreader model.") parser.add_argument("-rho", "--response-rate", default=0.5, type=float) - parser.add_argument("--dct", type=float, - help="smoothing for discrete cosine reparameterizer") + parser.add_argument("--haar", action="store_true") parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 202896a35a..0befaef121 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -15,11 +15,11 @@ import pyro.distributions as dist import pyro.distributions.hmm import pyro.poutine as poutine -from pyro.distributions.transforms import DiscreteCosineTransform +from pyro.distributions.transforms import HaarTransform from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete from pyro.infer.autoguide import init_to_generated, init_to_value from pyro.infer.mcmc import ArrowheadMassMatrix -from pyro.infer.reparam import DiscreteCosineReparam +from pyro.infer.reparam import HaarReparam from pyro.infer.smcfilter import SMCFailed from pyro.util import warn_if_nan @@ -300,8 +300,7 @@ def fit(self, **options): :param int num_quant_bins: The number of quantization bins to use. Note that computational cost is exponential in `num_quant_bins`. Defaults to 4. - :param float dct: If provided, use a discrete cosine reparameterizer - with this value as smoothness. + :param bool haar: Whether to use a Haar wavelet reparameterizer. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. @@ -309,9 +308,7 @@ def fit(self, **options): """ # Save these options for .predict(). self.num_quant_bins = options.pop("num_quant_bins", 4) - self._dct = options.pop("dct", None) - if self._dct is not None and self.is_regional: - raise NotImplementedError("regional models do not support DiscreteCosineReparam") + haar = options.pop("haar", False) # Heuristically initialze to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) @@ -324,12 +321,12 @@ def heuristic(): assert isinstance(init_values, dict) assert "auxiliary" in init_values, \ ".heuristic() did not define auxiliary value" - if self._dct is not None: - # Also initialize DCT transformed coordinates. + if haar: + # Also initialize Haar transformed coordinates. x = init_values["auxiliary"] x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) - x = DiscreteCosineTransform(smooth=self._dct)(x) - init_values["auxiliary_dct"] = x + x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True)(x) + init_values["auxiliary_haar"] = x logger.info("Heuristic init: {}".format(", ".join( "{}={:0.3g}".format(k, v.item()) for k, v in init_values.items() @@ -341,8 +338,8 @@ def heuristic(): max_tree_depth = options.pop("max_tree_depth", 5) full_mass = options.pop("full_mass", self.full_mass) model = self._vectorized_model - if self._dct is not None: - rep = DiscreteCosineReparam(smooth=self._dct) + if haar: + rep = HaarReparam(dim=-2 if self.is_regional else -1, flip=True) model = poutine.reparam(model, {"auxiliary": rep}) kernel = NUTS(model, full_mass=full_mass, @@ -355,6 +352,13 @@ def heuristic(): mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() + if haar: + # Transform back from Haar coordinates. + x = self.samples.pop("auxiliary_haar") + x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True).inv(x) + x = biject_to(constraints.interval(-0.5, self.population + 0.5))(x) + self.samples["auxiliary"] = x + # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). self.samples = align_samples(self.samples, model, @@ -391,10 +395,6 @@ def predict(self, forecast=0): model = self._sequential_model model = poutine.condition(model, samples) model = particle_plate(model) - if self._dct is not None: - # Apply the same reparameterizer as during inference. - rep = DiscreteCosineReparam(smooth=self._dct) - model = poutine.reparam(model, {"auxiliary": rep}) model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting) trace = poutine.trace(model).get_trace() samples = OrderedDict((name, site["value"]) diff --git a/pyro/distributions/transforms/haar.py b/pyro/distributions/transforms/haar.py index ab8838c498..36d89dcde3 100644 --- a/pyro/distributions/transforms/haar.py +++ b/pyro/distributions/transforms/haar.py @@ -61,3 +61,8 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): return x.new_zeros(x.shape[:-self.event_dim]) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return HaarTransform(-self.event_dim, flip=self.flip, cache_size=cache_size) diff --git a/pyro/infer/reparam/__init__.py b/pyro/infer/reparam/__init__.py index 03b0b38c31..cf4e6bff22 100644 --- a/pyro/infer/reparam/__init__.py +++ b/pyro/infer/reparam/__init__.py @@ -1,18 +1,21 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from .conjugate import ConjugateReparam from .discrete_cosine import DiscreteCosineReparam +from .haar import HaarReparam from .hmm import LinearHMMReparam from .loc_scale import LocScaleReparam from .neutra import NeuTraReparam -from .conjugate import ConjugateReparam from .stable import LatentStableReparam, StableReparam, SymmetricStableReparam from .studentt import StudentTReparam from .transform import TransformReparam +from .unit_jacobian import UnitJacobianReparam __all__ = [ "ConjugateReparam", "DiscreteCosineReparam", + "HaarReparam", "LatentStableReparam", "LinearHMMReparam", "LocScaleReparam", @@ -21,4 +24,5 @@ "StudentTReparam", "SymmetricStableReparam", "TransformReparam", + "UnitJacobianReparam", ] diff --git a/pyro/infer/reparam/discrete_cosine.py b/pyro/infer/reparam/discrete_cosine.py index 023f874364..64796d8266 100644 --- a/pyro/infer/reparam/discrete_cosine.py +++ b/pyro/infer/reparam/discrete_cosine.py @@ -1,24 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from torch.distributions import biject_to -from torch.distributions.transforms import ComposeTransform - -import pyro -import pyro.distributions as dist from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform -from .reparam import Reparam +from .unit_jacobian import UnitJacobianReparam -class DiscreteCosineReparam(Reparam): +class DiscreteCosineReparam(UnitJacobianReparam): """ - Discrete Cosine reparamterizer, using a + Discrete Cosine reparameterizer, using a :class:`~pyro.distributions.transforms.DiscreteCosineTransform` . This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This - reparameterizes to a frequency-domain represetation where posterior + reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC. @@ -37,26 +32,5 @@ class DiscreteCosineReparam(Reparam): real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise. """ def __init__(self, dim=-1, smooth=0.): - assert isinstance(dim, int) and dim < 0 - self.dim = dim - self.smooth = float(smooth) - - def __call__(self, name, fn, obs): - assert obs is None, "TransformReparam does not support observe statements" - assert fn.event_dim >= -self.dim, ("Cannot transform along batch dimension; " - "try converting a batch dimension to an event dimension") - - # Draw noise from the base distribution. - # TODO Use biject_to(fn.support).inv.with_cache(1) once the following merges: - # https://github.com/probtorch/pytorch/pull/153 - dct = DiscreteCosineTransform(dim=self.dim, smooth=self.smooth, cache_size=1) - transform = ComposeTransform([biject_to(fn.support).inv, dct]) - x_dct = pyro.sample("{}_dct".format(name), - dist.TransformedDistribution(fn, transform)) - - # Differentiably transform. - x = transform.inv(x_dct) # should be free due to transform cache - - # Simulate a pyro.deterministic() site. - new_fn = dist.Delta(x, event_dim=fn.event_dim) - return new_fn, x + transform = DiscreteCosineTransform(dim=dim, smooth=smooth, cache_size=1) + super().__init__(transform, suffix="dct") diff --git a/pyro/infer/reparam/haar.py b/pyro/infer/reparam/haar.py new file mode 100644 index 0000000000..6a17fe9064 --- /dev/null +++ b/pyro/infer/reparam/haar.py @@ -0,0 +1,30 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from pyro.distributions.transforms.haar import HaarTransform + +from .unit_jacobian import UnitJacobianReparam + + +class HaarReparam(UnitJacobianReparam): + """ + Haar wavelet reparameterizer, using a + :class:`~pyro.distributions.transforms.HaarTransform`. + + This is useful for sequential models where coupling along a time-like axis + (e.g. a banded precision matrix) introduces long-range correlation. This + reparameterizes to a frequency-domain representation where posterior + covariance should be closer to diagonal, thereby improving the accuracy of + diagonal guides in SVI and improving the effectiveness of a diagonal mass + matrix in HMC. + + This reparameterization works only for latent variables, not likelihoods. + + :param int dim: Dimension along which to transform. Must be negative. + This is an absolute dim counting from the right. + :param bool flip: Whether to flip the time axis before applying the + Haar transform. Defaults to false. + """ + def __init__(self, dim=-1, flip=False): + transform = HaarTransform(dim=dim, flip=flip, cache_size=1) + super().__init__(transform, suffix="haar") diff --git a/pyro/infer/reparam/unit_jacobian.py b/pyro/infer/reparam/unit_jacobian.py new file mode 100644 index 0000000000..c6dc262caf --- /dev/null +++ b/pyro/infer/reparam/unit_jacobian.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from torch.distributions import biject_to +from torch.distributions.transforms import ComposeTransform + +import pyro +import pyro.distributions as dist + +from .reparam import Reparam + + +# TODO Replace with .with_cache() once the following is released: +# https://github.com/probtorch/pytorch/pull/153 +def _with_cache(t): + return t.with_cache() if hasattr(t, "with_cache") else t + + +class UnitJacobianReparam(Reparam): + """ + Reparameterizer for :class:`~torch.distributions.transforms.Transform` + objects whose Jacobian determinant is one. + + :param transform: A transform whose Jacobian has determinant 1. + :type transform: ~torch.distributions.transforms.Transform + :param str suffix: A suffix to append to the transformed site. + """ + def __init__(self, transform, suffix="transformed"): + self.transform = _with_cache(transform) + self.suffix = suffix + + def __call__(self, name, fn, obs): + assert obs is None, "TransformReparam does not support observe statements" + assert fn.event_dim >= self.transform.event_dim, ( + "Cannot transform along batch dimension; " + "try converting a batch dimension to an event dimension") + + # Draw noise from the base distribution. + transform = ComposeTransform([_with_cache(biject_to(fn.support).inv), + self.transform]) + x_trans = pyro.sample("{}_{}".format(name, self.suffix), + dist.TransformedDistribution(fn, transform)) + + # Differentiably transform. + x = transform.inv(x_trans) # should be free due to transform cache + + # Simulate a pyro.deterministic() site. + new_fn = dist.Delta(x, event_dim=fn.event_dim) + return new_fn, x diff --git a/tests/contrib/epidemiology/test_seir.py b/tests/contrib/epidemiology/test_seir.py index 68747ba681..c0f231110b 100644 --- a/tests/contrib/epidemiology/test_seir.py +++ b/tests/contrib/epidemiology/test_seir.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_simple_smoke(duration, forecast, options): @@ -46,7 +46,7 @@ def test_simple_smoke(duration, forecast, options): @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_overdispersed_smoke(duration, forecast, options): diff --git a/tests/contrib/epidemiology/test_sir.py b/tests/contrib/epidemiology/test_sir.py index 5a99d018c8..5b5381f21a 100644 --- a/tests/contrib/epidemiology/test_sir.py +++ b/tests/contrib/epidemiology/test_sir.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, {"num_quant_bins": 12}, {"num_quant_bins": 16}, @@ -50,7 +50,7 @@ def test_simple_smoke(duration, forecast, options): @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_overdispersed_smoke(duration, forecast, options): @@ -80,7 +80,7 @@ def test_overdispersed_smoke(duration, forecast, options): @pytest.mark.parametrize("forecast", [7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_sparse_smoke(duration, forecast, options): @@ -121,7 +121,7 @@ def test_sparse_smoke(duration, forecast, options): @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, - {"dct": 1.}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): @@ -165,6 +165,7 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ {}, + {"haar": True}, {"num_quant_bins": 8}, ], ids=str) def test_regional_smoke(duration, forecast, options): diff --git a/tests/distributions/test_haar.py b/tests/distributions/test_haar.py index 441b6852b7..63f3daea57 100644 --- a/tests/distributions/test_haar.py +++ b/tests/distributions/test_haar.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import torch import pytest diff --git a/tests/infer/reparam/test_haar.py b/tests/infer/reparam/test_haar.py new file mode 100644 index 0000000000..b010e079d6 --- /dev/null +++ b/tests/infer/reparam/test_haar.py @@ -0,0 +1,85 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch.autograd import grad + +import pyro +import pyro.distributions as dist +from pyro import poutine +from pyro.infer.reparam import HaarReparam +from tests.common import assert_close + + +# Test helper to extract central moments from samples. +def get_moments(x): + n = x.size(0) + x = x.reshape(n, -1) + mean = x.mean(0) + x = x - mean + std = (x * x).mean(0).sqrt() + x = x / std + corr = (x.unsqueeze(-1) * x.unsqueeze(-2)).mean(0).reshape(-1) + return torch.cat([mean, std, corr]) + + +@pytest.mark.parametrize("flip", [False, True]) +@pytest.mark.parametrize("shape,dim", [ + ((6,), -1), + ((2, 5,), -1), + ((4, 2), -2), + ((2, 3, 1), -2), +], ids=str) +def test_normal(shape, dim, flip): + loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() + + def model(): + with pyro.plate_stack("plates", shape[:dim]): + with pyro.plate("particles", 10000): + pyro.sample("x", dist.Normal(loc, scale).expand(shape).to_event(-dim)) + + value = poutine.trace(model).get_trace().nodes["x"]["value"] + expected_probe = get_moments(value) + + rep = HaarReparam(dim=dim, flip=flip) + reparam_model = poutine.reparam(model, {"x": rep}) + trace = poutine.trace(reparam_model).get_trace() + assert isinstance(trace.nodes["x_haar"]["fn"], dist.TransformedDistribution) + assert isinstance(trace.nodes["x"]["fn"], dist.Delta) + value = trace.nodes["x"]["value"] + actual_probe = get_moments(value) + assert_close(actual_probe, expected_probe, atol=0.1) + + for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]): + expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) + actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) + assert_close(actual_grads[0], expected_grads[0], atol=0.05) + assert_close(actual_grads[1], expected_grads[1], atol=0.05) + + +@pytest.mark.parametrize("flip", [False, True]) +@pytest.mark.parametrize("shape,dim", [ + ((6,), -1), + ((2, 5,), -1), + ((4, 2), -2), + ((2, 3, 1), -2), +], ids=str) +def test_uniform(shape, dim, flip): + + def model(): + with pyro.plate_stack("plates", shape[:dim]): + with pyro.plate("particles", 10000): + pyro.sample("x", dist.Uniform(0, 1).expand(shape).to_event(-dim)) + + value = poutine.trace(model).get_trace().nodes["x"]["value"] + expected_probe = get_moments(value) + + reparam_model = poutine.reparam(model, {"x": HaarReparam(dim=dim, flip=flip)}) + trace = poutine.trace(reparam_model).get_trace() + assert isinstance(trace.nodes["x_haar"]["fn"], dist.TransformedDistribution) + assert isinstance(trace.nodes["x"]["fn"], dist.Delta) + value = trace.nodes["x"]["value"] + actual_probe = get_moments(value) + assert_close(actual_probe, expected_probe, atol=0.1) diff --git a/tests/infer/reparam/test_unit_jacobian.py b/tests/infer/reparam/test_unit_jacobian.py new file mode 100644 index 0000000000..da5d07150f --- /dev/null +++ b/tests/infer/reparam/test_unit_jacobian.py @@ -0,0 +1,55 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch.autograd import grad + +import pyro +import pyro.distributions as dist +from pyro import poutine +from pyro.distributions.transforms import Permute +from pyro.infer.reparam import UnitJacobianReparam +from tests.common import assert_close + + +# Test helper to extract central moments from samples. +def get_moments(x): + n = x.size(0) + x = x.reshape(n, -1) + mean = x.mean(0) + x = x - mean + std = (x * x).mean(0).sqrt() + x = x / std + corr = (x.unsqueeze(-1) * x.unsqueeze(-2)).mean(0).reshape(-1) + return torch.cat([mean, std, corr]) + + +@pytest.mark.parametrize("shape", [(6,), (4, 5), (2, 1, 3)], ids=str) +def test_normal(shape): + loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() + + def model(): + with pyro.plate_stack("plates", shape[:-1]): + with pyro.plate("particles", 10000): + pyro.sample("x", dist.Normal(loc, scale).expand(shape).to_event(1)) + + value = poutine.trace(model).get_trace().nodes["x"]["value"] + expected_probe = get_moments(value) + + transform = Permute(torch.randperm(shape[-1])) + rep = UnitJacobianReparam(transform) + reparam_model = poutine.reparam(model, {"x": rep}) + trace = poutine.trace(reparam_model).get_trace() + assert isinstance(trace.nodes["x_transformed"]["fn"], dist.TransformedDistribution) + assert isinstance(trace.nodes["x"]["fn"], dist.Delta) + value = trace.nodes["x"]["value"] + actual_probe = get_moments(value) + assert_close(actual_probe, expected_probe, atol=0.1) + + for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]): + expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) + actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) + assert_close(actual_grads[0], expected_grads[0], atol=0.05) + assert_close(actual_grads[1], expected_grads[1], atol=0.05) diff --git a/tests/test_examples.py b/tests/test_examples.py index 6fd5c90b7f..fbebabd1d1 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -37,11 +37,12 @@ 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --haar', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a', 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', @@ -106,8 +107,9 @@ 'contrib/cevae/synthetic.py --num-epochs=1 --cuda', 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1 --cuda', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', + 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm/dmm.py --num-epochs=1 --cuda', From e652eedf149c20a886869c478f9e0eda091b37b5 Mon Sep 17 00:00:00 2001 From: martinjankowiak Date: Wed, 20 May 2020 06:44:41 -0700 Subject: [PATCH 6/9] add TruncatedPolyaGamma distribution (#2491) --- docs/source/distributions.rst | 7 +++ pyro/distributions/__init__.py | 2 + pyro/distributions/polya_gamma.py | 66 +++++++++++++++++++++++++ tests/distributions/test_polya_gamma.py | 23 +++++++++ 4 files changed, 98 insertions(+) create mode 100644 pyro/distributions/polya_gamma.py create mode 100644 tests/distributions/test_polya_gamma.py diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index b65cb735bb..de31a3c9f1 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -281,6 +281,13 @@ Stable :undoc-members: :show-inheritance: +TruncatedPolyaGamma +------------------- +.. autoclass:: pyro.distributions.TruncatedPolyaGamma + :members: + :undoc-members: + :show-inheritance: + Unit ---- .. autoclass:: pyro.distributions.Unit diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 9070ea08df..53e6d1aa64 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -21,6 +21,7 @@ from pyro.distributions.mixture import MaskedMixture from pyro.distributions.multivariate_studentt import MultivariateStudentT from pyro.distributions.omt_mvn import OMTMultivariateNormal +from pyro.distributions.polya_gamma import TruncatedPolyaGamma from pyro.distributions.rejector import Rejector from pyro.distributions.relaxed_straight_through import (RelaxedBernoulliStraightThrough, RelaxedOneHotCategoricalStraightThrough) @@ -78,6 +79,7 @@ "Stable", "TorchDistribution", "TransformModule", + "TruncatedPolyaGamma", "Unit", "VonMises3D", "ZeroInflatedPoisson", diff --git a/pyro/distributions/polya_gamma.py b/pyro/distributions/polya_gamma.py new file mode 100644 index 0000000000..78df8116d8 --- /dev/null +++ b/pyro/distributions/polya_gamma.py @@ -0,0 +1,66 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +from torch.distributions import constraints + +from pyro.distributions.torch import Exponential +from pyro.distributions.torch_distribution import TorchDistribution + + +class TruncatedPolyaGamma(TorchDistribution): + """ + This is a PolyaGamma(1, 0) distribution truncated to have finite support in + the interval (0, 2.5). See [1] for details. As a consequence of the truncation + the `log_prob` method is only accurate to about six decimal places. In + addition the provided sampler is a rough approximation that is only meant to + be used in contexts where sample accuracy is not important (e.g. in initialization). + Broadly, this implementation is only intended for usage in cases where good + approximations of the `log_prob` are sufficient, as is the case e.g. in HMC. + + :param tensor prototype: A prototype tensor of arbitrary shape used to determine + the `dtype` and `device` returned by `sample` and `log_prob`. + + References + + [1] 'Bayesian inference for logistic models using Polya-Gamma latent variables' + Nicholas G. Polson, James G. Scott, Jesse Windle. + """ + truncation_point = 2.5 + num_log_prob_terms = 7 + num_gamma_variates = 8 + assert num_log_prob_terms % 2 == 1 + + arg_constraints = {} + support = constraints.interval(0.0, truncation_point) + has_rsample = False + + def __init__(self, prototype, validate_args=None): + self.prototype = prototype + super(TruncatedPolyaGamma, self).__init__(batch_shape=(), event_shape=(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(TruncatedPolyaGamma, _instance) + super(TruncatedPolyaGamma, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self.__dict__.get("_validate_args") + new.prototype = self.prototype + return new + + def sample(self, sample_shape=()): + denom = torch.arange(0.5, self.num_gamma_variates, device=self.prototype.device).pow(2.0) + ones = self.prototype.new_ones((self.num_gamma_variates)) + x = Exponential(ones).sample(self.batch_shape + sample_shape) + x = (x / denom).sum(-1) + return torch.clamp(x * (0.5 / math.pi ** 2), max=self.truncation_point) + + def log_prob(self, value): + value = value.unsqueeze(-1) + two_n_plus_one = 2.0 * torch.arange(0, self.num_log_prob_terms, device=self.prototype.device) + 1.0 + log_terms = two_n_plus_one.log() - 1.5 * value.log() - 0.125 * two_n_plus_one.pow(2.0) / value + even_terms = log_terms[..., ::2] + odd_terms = log_terms[..., 1::2] + sum_even = torch.logsumexp(even_terms, dim=-1).exp() + sum_odd = torch.logsumexp(odd_terms, dim=-1).exp() + return (sum_even - sum_odd).log() - 0.5 * math.log(2.0 * math.pi) diff --git a/tests/distributions/test_polya_gamma.py b/tests/distributions/test_polya_gamma.py new file mode 100644 index 0000000000..6abcf77de6 --- /dev/null +++ b/tests/distributions/test_polya_gamma.py @@ -0,0 +1,23 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.distributions import TruncatedPolyaGamma +from tests.common import assert_close + + +@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 1)]) +def test_polya_gamma(batch_shape, num_points=20000): + d = TruncatedPolyaGamma(prototype=torch.ones(1)).expand(batch_shape) + + # test density approximately normalized + x = torch.linspace(1.0e-6, d.truncation_point, num_points).expand(batch_shape + (num_points,)) + prob = (d.truncation_point / num_points) * torch.logsumexp(d.log_prob(x), dim=-1).exp() + assert_close(prob, torch.tensor(1.0).expand(batch_shape), rtol=1.0e-4) + + # test mean of approximate sampler + z = d.sample(sample_shape=(3000,)) + mean = z.mean(-1) + assert_close(mean, torch.tensor(0.25).expand(batch_shape), rtol=0.07) From ce8dd261f8bb1a4c61618b535dae5aa2d5136ede Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 May 2020 12:18:20 -0700 Subject: [PATCH 7/9] Fix reraising logic (#2494) --- pyro/poutine/trace_messenger.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 917dab6b2c..da9a4f60b9 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -166,7 +166,9 @@ def __call__(self, *args, **kwargs): except (ValueError, RuntimeError): exc_type, exc_value, traceback = sys.exc_info() shapes = self.msngr.trace.format_shapes() - raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback) + exc = exc_type(u"{}\n{}".format(exc_value, shapes)) + exc = exc.with_traceback(traceback) + raise exc from None self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret) return ret From 8cc51fb0e0da9009c0711c9d786966ffce5e5e0a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 May 2020 15:45:27 -0700 Subject: [PATCH 8/9] Implement a SplitReparam and use it in contrib.epidemiology (#2495) * Implement SplitReparam and use in contrib.epidemiology * Add -hfm arg to example scripts * Fix spelling * Add unit tests * Rebalance tests * Rebalance tests * Address review comments --- docs/source/infer.reparam.rst | 9 +++ examples/contrib/epidemiology/regional.py | 2 + examples/contrib/epidemiology/sir.py | 2 + pyro/contrib/epidemiology/compartmental.py | 40 ++++++++-- pyro/contrib/epidemiology/util.py | 1 - pyro/infer/mcmc/util.py | 13 +++- pyro/infer/reparam/__init__.py | 3 + pyro/infer/reparam/split.py | 90 ++++++++++++++++++++++ tests/contrib/conftest.py | 2 +- tests/contrib/epidemiology/test_seir.py | 2 + tests/contrib/epidemiology/test_sir.py | 5 ++ tests/infer/reparam/test_split.py | 60 +++++++++++++++ tests/test_examples.py | 3 +- 13 files changed, 220 insertions(+), 12 deletions(-) create mode 100644 pyro/infer/reparam/split.py create mode 100644 tests/infer/reparam/test_split.py diff --git a/docs/source/infer.reparam.rst b/docs/source/infer.reparam.rst index 1b46eda249..f60c4a30ac 100644 --- a/docs/source/infer.reparam.rst +++ b/docs/source/infer.reparam.rst @@ -95,6 +95,15 @@ Hidden Markov Models :special-members: __call__ :show-inheritance: +Site Splitting +-------------- +.. automodule:: pyro.infer.reparam.split + :members: + :undoc-members: + :member-order: bysource + :special-members: __call__ + :show-inheritance: + Neural Transport ---------------- .. automodule:: pyro.infer.reparam.neutra diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index 9a67db44bc..02e85d7e26 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -58,6 +58,7 @@ def hook_fn(kernel, *unused): max_tree_depth=args.max_tree_depth, num_quant_bins=args.num_bins, haar=args.haar, + haar_full_mass=args.haar_full_mass, hook_fn=hook_fn) mcmc.summary() @@ -137,6 +138,7 @@ def main(args): parser.add_argument("-tau", "--recovery-time", default=7.0, type=float) parser.add_argument("-rho", "--response-rate", default=0.5, type=float) parser.add_argument("--haar", action="store_true") + parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int) parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index ba6a4a4034..49d5653d4d 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -74,6 +74,7 @@ def hook_fn(kernel, *unused): arrowhead_mass=args.arrowhead_mass, num_quant_bins=args.num_bins, haar=args.haar, + haar_full_mass=args.haar_full_mass, hook_fn=hook_fn) mcmc.summary() @@ -205,6 +206,7 @@ def main(args): help="If finite, use a superspreader model.") parser.add_argument("-rho", "--response-rate", default=0.5, type=float) parser.add_argument("--haar", action="store_true") + parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int) parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 0befaef121..29a89617bd 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -19,7 +19,7 @@ from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete from pyro.infer.autoguide import init_to_generated, init_to_value from pyro.infer.mcmc import ArrowheadMassMatrix -from pyro.infer.reparam import HaarReparam +from pyro.infer.reparam import HaarReparam, SplitReparam from pyro.infer.smcfilter import SMCFailed from pyro.util import warn_if_nan @@ -301,16 +301,25 @@ def fit(self, **options): that computational cost is exponential in `num_quant_bins`. Defaults to 4. :param bool haar: Whether to use a Haar wavelet reparameterizer. + :param int haar_full_mass: Number of low frequency Haar components to + include in the full mass matrix. If nonzero this implies + ``haar=True``. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ - # Save these options for .predict(). + # Parse options, saving some for use in .predict(). self.num_quant_bins = options.pop("num_quant_bins", 4) haar = options.pop("haar", False) - - # Heuristically initialze to feasible latents. + assert isinstance(haar, bool) + haar_full_mass = options.pop("haar_full_mass", 0) + assert isinstance(haar_full_mass, int) + assert haar_full_mass >= 0 + haar_full_mass = min(haar_full_mass, self.duration) + haar = haar or (haar_full_mass > 0) + + # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} @@ -327,6 +336,13 @@ def heuristic(): x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True)(x) init_values["auxiliary_haar"] = x + if haar_full_mass: + # Also split into low- and high-frequency parts. + x0, x1 = init_values["auxiliary_haar"].split( + [haar_full_mass, self.duration - haar_full_mass], + dim=-2 if self.is_regional else -1) + init_values["auxiliary_haar_split_0"] = x0 + init_values["auxiliary_haar_split_1"] = x1 logger.info("Heuristic init: {}".format(", ".join( "{}={:0.3g}".format(k, v.item()) for k, v in init_values.items() @@ -341,9 +357,17 @@ def heuristic(): if haar: rep = HaarReparam(dim=-2 if self.is_regional else -1, flip=True) model = poutine.reparam(model, {"auxiliary": rep}) + if haar_full_mass: + assert full_mass and isinstance(full_mass, list) + full_mass = full_mass[:] + full_mass[0] = full_mass[0] + ("auxiliary_haar_split_0",) + rep = SplitReparam([haar_full_mass, self.duration - haar_full_mass], + dim=-2 if self.is_regional else -1) + model = poutine.reparam(model, {"auxiliary_haar": rep}) kernel = NUTS(model, full_mass=full_mass, init_strategy=init_to_generated(generate=heuristic), + max_plate_nesting=self.max_plate_nesting, max_tree_depth=max_tree_depth) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() @@ -352,6 +376,12 @@ def heuristic(): mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() + if haar_full_mass: + # Transform back from SplitReparam coordinates. + self.samples["auxiliary_haar"] = torch.cat([ + self.samples.pop("auxiliary_haar_split_0"), + self.samples.pop("auxiliary_haar_split_1"), + ], dim=-2 if self.is_regional else -1) if haar: # Transform back from Haar coordinates. x = self.samples.pop("auxiliary_haar") @@ -361,7 +391,7 @@ def heuristic(): # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). - self.samples = align_samples(self.samples, model, + self.samples = align_samples(self.samples, self._vectorized_model, particle_dim=-1 - self.max_plate_nesting) return mcmc # E.g. so user can run mcmc.summary(). diff --git a/pyro/contrib/epidemiology/util.py b/pyro/contrib/epidemiology/util.py index 004613e58f..b64f441ae6 100644 --- a/pyro/contrib/epidemiology/util.py +++ b/pyro/contrib/epidemiology/util.py @@ -72,7 +72,6 @@ def align_samples(samples, model, particle_dim): raise ValueError("Cannot align samples, try moving particle_dim left") if pad > 0: shape = value.shape[:1] + (1,) * pad + value.shape[1:] - print("DEBUG reshaping {} : {} -> {}".format(name, value.shape, shape)) samples[name] = value.reshape(shape) return samples diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 8a976c910d..b8d6d0a821 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -302,7 +302,7 @@ def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_option def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params=100, num_chains=1, - init_strategy=init_to_uniform): + init_strategy=init_to_uniform, trace=None): params = prototype_params # For empty models, exit early @@ -313,7 +313,8 @@ def _find_valid_initial_params(model, model_args, model_kwargs, transforms, pote num_found = 0 model = InitMessenger(init_strategy)(model) for attempt in range(num_chains * max_tries_initial_params): - trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) + if trace is None: + trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) samples = {name: trace.nodes[name]["value"].detach() for name in params} params = {k: transforms[k](v) for k, v in samples.items()} pe_grad, pe = potential_grad(potential_fn, params) @@ -327,6 +328,7 @@ def _find_valid_initial_params(model, model_args, model_kwargs, transforms, pote return {k: v[0] for k, v in params_per_chain.items()} else: return {k: torch.stack(v) for k, v in params_per_chain.items()} + trace = None raise ValueError("Model specification seems incorrect - cannot find valid initial params.") @@ -382,7 +384,8 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max # No-op if model does not have any discrete latents. model = poutine.enum(config_enumerate(model), first_available_dim=-1 - max_plate_nesting) - model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) + prototype_model = poutine.trace(InitMessenger(init_strategy)(model)) + model_trace = prototype_model.get_trace(*model_args, **model_kwargs) has_enumerable_sites = False prototype_samples = {} for name, node in model_trace.iter_stochastic_nodes(): @@ -411,9 +414,11 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()} # Note that we deliberately do not exercise jit compilation here so as to # enable potential_fn to be picklable (a torch._C.Function cannot be pickled). + # We pass model_trace merely for computational savings. initial_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms, pe_maker.get_potential_fn(), prototype_params, - num_chains=num_chains, init_strategy=init_strategy) + num_chains=num_chains, init_strategy=init_strategy, + trace=model_trace) potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options) return initial_params, potential_fn, transforms, model_trace diff --git a/pyro/infer/reparam/__init__.py b/pyro/infer/reparam/__init__.py index cf4e6bff22..2d3debd95a 100644 --- a/pyro/infer/reparam/__init__.py +++ b/pyro/infer/reparam/__init__.py @@ -1,3 +1,4 @@ +# Copyright Contributors to the Pyro project. # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 @@ -7,6 +8,7 @@ from .hmm import LinearHMMReparam from .loc_scale import LocScaleReparam from .neutra import NeuTraReparam +from .split import SplitReparam from .stable import LatentStableReparam, StableReparam, SymmetricStableReparam from .studentt import StudentTReparam from .transform import TransformReparam @@ -20,6 +22,7 @@ "LinearHMMReparam", "LocScaleReparam", "NeuTraReparam", + "SplitReparam", "StableReparam", "StudentTReparam", "SymmetricStableReparam", diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py new file mode 100644 index 0000000000..5a12fc9ba2 --- /dev/null +++ b/pyro/infer/reparam/split.py @@ -0,0 +1,90 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.distributions import constraints + +import pyro +import pyro.distributions as dist +from pyro.distributions.util import broadcast_shape + +from .reparam import Reparam + + +class _ImproperUniform(dist.TorchDistribution): + """ + Internal helper distribution with zero :meth:`log_prob` and undefined + :meth:`sample`. + """ + arg_constraints = {} + + def __init__(self, support, batch_shape, event_shape): + self._support = support + super().__init__(batch_shape, event_shape) + + @constraints.dependent_property + def support(self): + return self._support + + def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) + new = self._get_checked_instance(_ImproperUniform, _instance) + new._support = self._support + super(_ImproperUniform, new).__init__(batch_shape, self.event_shape) + return new + + def log_prob(self, value): + batch_shape = value.shape[:value.dim() - self.event_dim] + batch_shape = broadcast_shape(batch_shape, self.batch_shape) + return torch.zeros(()).expand(batch_shape) + + def sample(self, sample_shape=torch.Size()): + raise NotImplementedError("SplitReparam does not support sampling") + + +class SplitReparam(Reparam): + """ + Reparameterizer to split a random variable along a dimension, similar to + :func:`torch.split`. + + This is useful for treating different parts of a tensor with different + reparameterizers or inference methods. For example when performing HMC + inference on a time series, you can first apply + :class:`~pyro.infer.reparam.discrete_cosine.DiscreteCosineReparam` or + :class:`~pyro.infer.reparam.haar.HaarReparam`, then apply + :class:`SplitReparam` to split into low-frequency and high-frequency + components, and finally add the low-frequency components to the + ``full_mass`` matrix together with globals. + + :param sections: Size of a single chunk or list of sizes for + each chunk. + :type: list(int) + :param int dim: Dimension along which to split. Defaults to -1. + """ + def __init__(self, sections, dim): + assert isinstance(dim, int) and dim < 0 + assert isinstance(sections, list) + assert all(isinstance(size, int) for size in sections) + self.event_dim = -dim + self.sections = sections + + def __call__(self, name, fn, obs): + assert fn.event_dim >= self.event_dim + assert obs is None, "SplitReparam does not support observe statements" + + # Draw independent parts. + dim = fn.event_dim - self.event_dim + left_shape = fn.event_shape[:dim] + right_shape = fn.event_shape[1 + dim:] + parts = [] + for i, size in enumerate(self.sections): + event_shape = left_shape + (size,) + right_shape + parts.append(pyro.sample( + "{}_split_{}".format(name, i), + _ImproperUniform(fn.support, fn.batch_shape, event_shape))) + value = torch.cat(parts, dim=-self.event_dim) + + # Combine parts. + log_prob = fn.log_prob(value) + new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_prob) + return new_fn, value diff --git a/tests/contrib/conftest.py b/tests/contrib/conftest.py index 2b6f5c8045..5878446177 100644 --- a/tests/contrib/conftest.py +++ b/tests/contrib/conftest.py @@ -8,6 +8,6 @@ def pytest_collection_modifyitems(items): for item in items: if item.nodeid.startswith("tests/contrib"): if "stage" not in item.keywords: - item.add_marker(pytest.mark.stage("unit")) + item.add_marker(pytest.mark.stage("integration_batch_1")) if "init" not in item.keywords: item.add_marker(pytest.mark.init(rng_seed=123)) diff --git a/tests/contrib/epidemiology/test_seir.py b/tests/contrib/epidemiology/test_seir.py index c0f231110b..ace9f7f228 100644 --- a/tests/contrib/epidemiology/test_seir.py +++ b/tests/contrib/epidemiology/test_seir.py @@ -13,6 +13,7 @@ @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 2}, {"num_quant_bins": 8}, ], ids=str) def test_simple_smoke(duration, forecast, options): @@ -47,6 +48,7 @@ def test_simple_smoke(duration, forecast, options): @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 2}, {"num_quant_bins": 8}, ], ids=str) def test_overdispersed_smoke(duration, forecast, options): diff --git a/tests/contrib/epidemiology/test_sir.py b/tests/contrib/epidemiology/test_sir.py index 5b5381f21a..adfe921e11 100644 --- a/tests/contrib/epidemiology/test_sir.py +++ b/tests/contrib/epidemiology/test_sir.py @@ -18,6 +18,7 @@ @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 2}, {"num_quant_bins": 8}, {"num_quant_bins": 12}, {"num_quant_bins": 16}, @@ -51,6 +52,7 @@ def test_simple_smoke(duration, forecast, options): @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 2}, {"num_quant_bins": 8}, ], ids=str) def test_overdispersed_smoke(duration, forecast, options): @@ -81,6 +83,7 @@ def test_overdispersed_smoke(duration, forecast, options): @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 3}, {"num_quant_bins": 8}, ], ids=str) def test_sparse_smoke(duration, forecast, options): @@ -122,6 +125,7 @@ def test_sparse_smoke(duration, forecast, options): @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 4}, {"num_quant_bins": 8}, ], ids=str) def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): @@ -166,6 +170,7 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): @pytest.mark.parametrize("options", [ {}, {"haar": True}, + {"haar_full_mass": 2}, {"num_quant_bins": 8}, ], ids=str) def test_regional_smoke(duration, forecast, options): diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py new file mode 100644 index 0000000000..1cb050182e --- /dev/null +++ b/tests/infer/reparam/test_split.py @@ -0,0 +1,60 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch.autograd import grad + +import pyro +import pyro.distributions as dist +from pyro import poutine +from pyro.infer.reparam import SplitReparam +from tests.common import assert_close + + +@pytest.mark.parametrize("event_shape,splits,dim", [ + ((6,), [2, 1, 3], -1), + ((2, 5,), [2, 3], -1), + ((4, 2), [1, 3], -2), + ((2, 3, 1), [1, 2], -2), +], ids=str) +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +def test_normal(batch_shape, event_shape, splits, dim): + shape = batch_shape + event_shape + loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() + scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() + + def model(): + with pyro.plate_stack("plates", batch_shape): + pyro.sample("x", + dist.Normal(loc, scale) + .to_event(len(event_shape))) + + # Run without reparam. + trace = poutine.trace(model).get_trace() + expected_value = trace.nodes["x"]["value"] + expected_log_prob = trace.log_prob_sum() + expected_grads = grad(expected_log_prob, [loc, scale], create_graph=True) + + # Run with reparam. + split_values = { + "x_split_{}".format(i): xi + for i, xi in enumerate(expected_value.split(splits, dim))} + rep = SplitReparam(splits, dim) + reparam_model = poutine.reparam(model, {"x": rep}) + reparam_model = poutine.condition(reparam_model, split_values) + trace = poutine.trace(reparam_model).get_trace() + assert all(name in trace.nodes for name in split_values) + assert isinstance(trace.nodes["x"]["fn"], dist.Delta) + assert trace.nodes["x"]["fn"].batch_shape == batch_shape + assert trace.nodes["x"]["fn"].event_shape == event_shape + + # Check values. + actual_value = trace.nodes["x"]["value"] + assert_close(actual_value, expected_value, atol=0.1) + + # Check log prob. + actual_log_prob = trace.log_prob_sum() + assert_close(actual_log_prob, expected_log_prob) + actual_grads = grad(actual_log_prob, [loc, scale], create_graph=True) + assert_close(actual_grads, expected_grads) diff --git a/tests/test_examples.py b/tests/test_examples.py index fbebabd1d1..39c78e3c74 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -39,10 +39,11 @@ 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --haar', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a', 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar', + 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', From a11e170b8cb0347ce22fdcd3bb68de6ac8a217de Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 May 2020 11:47:19 -0700 Subject: [PATCH 9/9] Add pearson correlation plot to examples/.../sir.py (#2497) --- examples/contrib/epidemiology/sir.py | 37 ++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 49d5653d4d..ae2112f9d9 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -10,6 +10,7 @@ import math import torch +from torch.distributions import biject_to, constraints import pyro from pyro.contrib.epidemiology import OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel @@ -90,7 +91,7 @@ def hook_fn(kernel, *unused): return model.samples -def evaluate(args, samples): +def evaluate(args, model, samples): # Print estimated values. names = {"basic_reproduction_number": "R0", "response_rate": "rho"} @@ -107,6 +108,7 @@ def evaluate(args, samples): import matplotlib.pyplot as plt import seaborn as sns + # Plot individual histograms. fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names))) axes[0].set_title("Posterior parameter estimates") for ax, (name, key) in zip(axes, names.items()): @@ -118,6 +120,7 @@ def evaluate(args, samples): ax.legend(loc="best") plt.tight_layout() + # Plot pairwise joint distributions for selected variables. covariates = [(name, samples[name]) for name in names.values()] for i, aux in enumerate(samples["auxiliary"].unbind(-2)): covariates.append(("aux[{},0]".format(i), aux[:, 0])) @@ -137,6 +140,36 @@ def evaluate(args, samples): plt.tight_layout() plt.subplots_adjust(wspace=0, hspace=0) + # Plot Pearson correlation for every pair of unconstrained variables. + def unconstrain(constraint, value): + value = biject_to(constraint).inv(value) + return value.reshape(args.num_samples, -1) + + covariates = [ + ("R1", unconstrain(constraints.positive, samples["R0"])), + ("rho", unconstrain(constraints.unit_interval, samples["rho"]))] + if "k" in samples: + covariates.append( + ("k", unconstrain(constraints.positive, samples["k"]))) + constraint = constraints.interval(-0.5, model.population + 0.5) + for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)): + covariates.append((name, unconstrain(constraint, aux))) + x = torch.cat([v for _, v in covariates], dim=-1) + x -= x.mean(0) + x /= x.std(0) + x = x.t().matmul(x) + x /= args.num_samples + x.clamp_(min=-1, max=1) + plt.figure(figsize=(8, 8)) + plt.imshow(x, cmap="bwr") + ticks = torch.tensor([0] + [v.size(-1) for _, v in covariates]).cumsum(0) + ticks = (ticks[1:] + ticks[:-1]) / 2 + plt.yticks(ticks, [name for name, _ in covariates]) + plt.xticks(()) + plt.tick_params(length=0) + plt.title("Pearson correlation (unconstrained coordinates)") + plt.tight_layout() + def predict(args, model, truth): samples = model.predict(forecast=args.forecast) @@ -183,7 +216,7 @@ def main(args): samples = infer(args, model) # Evaluate fit. - evaluate(args, samples) + evaluate(args, model, samples) # Predict latent time series. if args.forecast: