diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 3cfc8a3990..fa3394a813 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -459,6 +459,20 @@ Permute :undoc-members: :show-inheritance: +SoftplusLowerCholeskyTransform +------------------------------ +.. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform + :members: + :undoc-members: + :show-inheritance: + +SoftplusTransform +----------------- +.. autoclass:: pyro.distributions.transforms.SoftplusTransform + :members: + :undoc-members: + :show-inheritance: + TransformModules ~~~~~~~~~~~~~~~~ diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index d8c4fb6253..60db57f9b6 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -8,7 +8,12 @@ import torch from torch.distributions.constraints import Constraint from torch.distributions.constraints import __all__ as torch_constraints -from torch.distributions.constraints import independent, positive, positive_definite +from torch.distributions.constraints import ( + independent, + lower_cholesky, + positive, + positive_definite, +) # TODO move this upstream to torch.distributions @@ -81,11 +86,22 @@ def check(self, value): return ordered_vector.check(value) & independent(positive, 1).check(value) +class _SoftplusPositive(type(positive)): + def __init__(self): + super().__init__(lower_bound=0.0) + + +class _SoftplusLowerCholesky(type(lower_cholesky)): + pass + + corr_matrix = _CorrMatrix() integer = _Integer() ordered_vector = _OrderedVector() positive_ordered_vector = _PositiveOrderedVector() sphere = _Sphere() +softplus_positive = _SoftplusPositive() +softplus_lower_cholesky = _SoftplusLowerCholesky() corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ @@ -94,6 +110,8 @@ def check(self, value): 'integer', 'ordered_vector', 'positive_ordered_vector', + 'softplus_lower_cholesky', + 'softplus_positive', 'sphere', ] diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index a56f189756..398403532e 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -68,6 +68,7 @@ from .planar import ConditionalPlanar, Planar, conditional_planar, planar from .polynomial import Polynomial, polynomial from .radial import ConditionalRadial, Radial, conditional_radial, radial +from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import ( ConditionalSplineAutoregressive, @@ -117,6 +118,17 @@ def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv]) +@biject_to.register(constraints.softplus_positive) +@transform_to.register(constraints.softplus_positive) +def _transform_to_softplus_positive(constraint): + return SoftplusTransform() + + +@transform_to.register(constraints.softplus_lower_cholesky) +def _transform_to_softplus_lower_cholesky(constraint): + return SoftplusLowerCholeskyTransform() + + def iterated(repeats, base_fn, *args, **kwargs): """ Helper function to compose a sequence of bijective transforms with potentially @@ -167,6 +179,8 @@ def iterated(repeats, base_fn, *args, **kwargs): 'Planar', 'Polynomial', 'Radial', + 'SoftplusLowerCholeskyTransform', + 'SoftplusTransform', 'Spline', 'SplineAutoregressive', 'SplineCoupling', diff --git a/pyro/distributions/transforms/softplus.py b/pyro/distributions/transforms/softplus.py new file mode 100644 index 0000000000..b58a889612 --- /dev/null +++ b/pyro/distributions/transforms/softplus.py @@ -0,0 +1,60 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from torch.distributions import constraints +from torch.distributions.transforms import Transform +from torch.nn.functional import softplus + + +def softplus_inv(y): + return y + y.neg().expm1().neg().log() + + +# Backport of https://github.com/pytorch/pytorch/pull/52300 +class SoftplusTransform(Transform): + r""" + Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SoftplusTransform) + + def _call(self, x): + return softplus(x) + + def _inverse(self, y): + return softplus_inv(y) + + def log_abs_det_jacobian(self, x, y): + return -softplus(-x) + + +class SoftplusLowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + nonnegative diagonal entries. This is useful for parameterizing positive + definite matrices in terms of their Cholesky factorization. + """ + domain = constraints.independent(constraints.real, 2) + codomain = constraints.lower_cholesky + + def __eq__(self, other): + return isinstance(other, SoftplusLowerCholeskyTransform) + + def _call(self, x): + diag = softplus(x.diagonal(dim1=-2, dim2=-1)) + return x.tril(-1) + diag.diag_embed() + + def _inverse(self, y): + diag = softplus_inv(y.diagonal(dim1=-2, dim2=-1)) + return y.tril(-1) + diag.diag_embed() + + +__all__ = [ + 'SoftplusTransform', + 'SoftplusLowerCholeskyTransform', +] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 618e53baa6..741b0e801b 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -9,7 +9,7 @@ def model(): ... - guide = AutoDiagonalNormal(model) # a mean field guide + guide = AutoNormal(model) # a mean field guide svi = SVI(model, guide, Adam({'lr': 1e-3}), Trace_ELBO()) Automatic guides can also be combined using :func:`pyro.poutine.block` and @@ -23,11 +23,12 @@ def model(): import torch from torch import nn -from torch.distributions import biject_to, constraints +from torch.distributions import biject_to import pyro import pyro.distributions as dist import pyro.poutine as poutine +from pyro.distributions import constraints from pyro.distributions.transforms import affine_autoregressive, iterated from pyro.distributions.util import broadcast_shape, eye_like, sum_rightmost from pyro.infer.autoguide.initialization import ( @@ -431,6 +432,9 @@ class AutoNormal(AutoGuide): or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ + + scale_constraint = constraints.softplus_positive + def __init__(self, model, *, init_loc_fn=init_to_feasible, init_scale=0.1, @@ -472,7 +476,8 @@ def _setup_prototype(self, *args, **kwargs): init_scale = torch.full_like(init_loc, self._init_scale) _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) - _deep_setattr(self.scales, name, PyroParam(init_scale, constraints.positive, event_dim)) + _deep_setattr(self.scales, name, + PyroParam(init_scale, self.scale_constraint, event_dim)) def _get_loc_and_scale(self, name): site_loc = _deep_getattr(self.locs, name) @@ -813,6 +818,8 @@ class AutoMultivariateNormal(AutoContinuous): (unconstrained transformed) latent variable. """ + scale_tril_constraint = constraints.softplus_lower_cholesky + def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -824,7 +831,7 @@ def _setup_prototype(self, *args, **kwargs): # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim) * self._init_scale, - constraints.lower_cholesky) + self.scale_tril_constraint) def get_base_dist(self): return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1) @@ -863,6 +870,8 @@ class AutoDiagonalNormal(AutoContinuous): (unconstrained transformed) latent variable. """ + scale_constraint = constraints.softplus_positive + def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -874,7 +883,7 @@ def _setup_prototype(self, *args, **kwargs): # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale = PyroParam(self.loc.new_full((self.latent_dim,), self._init_scale), - constraints.positive) + self.scale_constraint) def get_base_dist(self): return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1) @@ -918,6 +927,8 @@ class AutoLowRankMultivariateNormal(AutoContinuous): deviation of each (unconstrained transformed) latent variable. """ + scale_constraint = constraints.softplus_positive + def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1, rank=None): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -935,7 +946,7 @@ def _setup_prototype(self, *args, **kwargs): self.rank = int(round(self.latent_dim ** 0.5)) self.scale = PyroParam( self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale), - constraint=constraints.positive) + constraint=self.scale_constraint) self.cov_factor = nn.Parameter( self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5)) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 98de4ca43e..3059349a5d 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -85,7 +85,7 @@ def do_test_auto(self, N, reparameterized, n_steps): AutoLowRankMultivariateNormal, AutoLaplaceApproximation]) @pytest.mark.parametrize('Elbo', [Trace_ELBO, TraceMeanField_ELBO]) def test_auto_diagonal_gaussians(auto_class, Elbo): - n_steps = 3501 if auto_class == AutoDiagonalNormal else 6001 + n_steps = 3001 def model(): pyro.sample("x", dist.Normal(-0.2, 1.2)) @@ -95,7 +95,8 @@ def model(): guide = auto_class(model, rank=1) else: guide = auto_class(model) - adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) + adam = optim.ClippedAdam({"lr": .01, "betas": (0.95, 0.999), + "lrd": 0.1 ** (1 / n_steps)}) svi = SVI(model, guide, adam, loss=Elbo()) for k in range(n_steps): diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 47c3b1d315..e326aec6ca 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -345,13 +345,18 @@ def test_sylvester(self): self._test(T.sylvester, inverse=False) def test_normalize_transform(self): - for p in [1., 2.]: - self._test(lambda p: T.Normalize(p=p), autodiff=False) + self._test(lambda p: T.Normalize(p=p), autodiff=False) + + def test_softplus(self): + self._test(lambda _: T.SoftplusTransform(), autodiff=False) @pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)]) @pytest.mark.parametrize('dim', [2, 3, 5]) -@pytest.mark.parametrize('transform', [T.CholeskyTransform(), T.CorrMatrixCholeskyTransform()]) +@pytest.mark.parametrize('transform', [ + T.CholeskyTransform(), + T.CorrMatrixCholeskyTransform(), +], ids=lambda t: type(t).__name__) def test_cholesky_transform(batch_shape, dim, transform): arange = torch.arange(dim) domain = transform.domain @@ -385,3 +390,21 @@ def transform_to_vec(x_vec): assert log_det.shape == batch_shape assert_close(y, x_mat.cholesky()) assert_close(transform.inv(y), x_mat) + + +@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)]) +@pytest.mark.parametrize('dim', [2, 3, 5]) +@pytest.mark.parametrize('transform', [ + T.LowerCholeskyTransform(), + T.SoftplusLowerCholeskyTransform(), +], ids=lambda t: type(t).__name__) +def test_lower_cholesky_transform(transform, batch_shape, dim): + shape = batch_shape + (dim, dim) + x = torch.randn(shape) + y = transform(x) + assert y.shape == shape + x2 = transform.inv(y) + assert x2.shape == shape + y2 = transform(x2) + assert y2.shape == shape + assert_close(y, y2) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 6ea9384faa..085b1be688 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -240,7 +240,7 @@ def model(): pyro.sample("z", dist.Beta(2.0, 2.0)) guide = auto_class(model) - optim = Adam({'lr': 0.05, 'betas': (0.8, 0.99)}) + optim = Adam({'lr': 0.02, 'betas': (0.8, 0.99)}) elbo = Elbo(strict_enumeration_warning=False, num_particles=100, vectorize_particles=True) infer = SVI(model, guide, optim, elbo)