Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize autoguide scale parameters via SoftplusTransform #2767

Merged
merged 25 commits into from
Feb 28, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.distributions.constraints import * # noqa F403
from torch.distributions.constraints import Constraint
from torch.distributions.constraints import __all__ as torch_constraints
from torch.distributions.constraints import independent, positive, positive_definite
from torch.distributions.constraints import independent, lower_cholesky, positive, positive_definite


# TODO move this upstream to torch.distributions
Expand Down Expand Up @@ -78,11 +78,22 @@ def check(self, value):
return ordered_vector.check(value) & independent(positive, 1).check(value)


class _SoftplusPositive(type(positive)):
def __init__(self):
super().__init__(lower_bound=0.0)


class _SoftplusLowerCholesky(type(lower_cholesky)):
pass


corr_matrix = _CorrMatrix()
integer = _Integer()
ordered_vector = _OrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
sphere = _Sphere()
softplus_positive = _SoftplusPositive()
softplus_lower_cholesky = _SoftplusLowerCholesky()
corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED

__all__ = [
Expand All @@ -91,6 +102,8 @@ def check(self, value):
'integer',
'ordered_vector',
'positive_ordered_vector',
'softplus_lower_cholesky',
'softplus_positive',
'sphere',
]

Expand Down
14 changes: 14 additions & 0 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .planar import ConditionalPlanar, Planar, conditional_planar, planar
from .polynomial import Polynomial, polynomial
from .radial import ConditionalRadial, Radial, conditional_radial, radial
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform
from .spline import ConditionalSpline, Spline, conditional_spline, spline
from .spline_autoregressive import (ConditionalSplineAutoregressive, SplineAutoregressive,
conditional_spline_autoregressive, spline_autoregressive)
Expand Down Expand Up @@ -76,6 +77,17 @@ def _transform_to_positive_definite(constraint):
return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])


@biject_to.register(constraints.softplus_positive)
@transform_to.register(constraints.softplus_positive)
def _transform_to_softplus_positive(constraint):
return SoftplusTransform()


@transform_to.register(constraints.softplus_lower_cholesky)
def _transform_to_softplus_lower_cholesky(constraint):
return SoftplusLowerCholeskyTransform()


def iterated(repeats, base_fn, *args, **kwargs):
"""
Helper function to compose a sequence of bijective transforms with potentially
Expand Down Expand Up @@ -126,6 +138,8 @@ def iterated(repeats, base_fn, *args, **kwargs):
'Planar',
'Polynomial',
'Radial',
'SoftplusLowerCholeskyTransform',
'SoftplusTransform',
'Spline',
'SplineAutoregressive',
'SplineCoupling',
Expand Down
60 changes: 60 additions & 0 deletions pyro/distributions/transforms/softplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.nn.functional import softplus


def softplus_inv(y):
return y + y.neg().expm1().neg().log()


# Backport of https://github.com/pytorch/pytorch/pull/52300
class SoftplusTransform(Transform):
r"""
Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
"""
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, SoftplusTransform)

def _call(self, x):
return softplus(x)

def _inverse(self, y):
return softplus_inv(y)

def log_abs_det_jacobian(self, x, y):
return -softplus(-x)


class SoftplusLowerCholeskyTransform(Transform):
"""
Transform from unconstrained matrices to lower-triangular matrices with
nonnegative diagonal entries. This is useful for parameterizing positive
definite matrices in terms of their Cholesky factorization.
"""
domain = constraints.independent(constraints.real, 2)
codomain = constraints.lower_cholesky

def __eq__(self, other):
return isinstance(other, SoftplusLowerCholeskyTransform)

def _call(self, x):
diag = softplus(x.diagonal(dim1=-2, dim2=-1))
return x.tril(-1) + diag.diag_embed()

def _inverse(self, y):
diag = softplus_inv(y.diagonal(dim1=-2, dim2=-1))
return y.tril(-1) + diag.diag_embed()


__all__ = [
'SoftplusTransform',
'SoftplusLowerCholeskyTransform',
]
23 changes: 17 additions & 6 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def model():
...

guide = AutoDiagonalNormal(model) # a mean field guide
guide = AutoNormal(model) # a mean field guide
svi = SVI(model, guide, Adam({'lr': 1e-3}), Trace_ELBO())

Automatic guides can also be combined using :func:`pyro.poutine.block` and
Expand All @@ -23,11 +23,12 @@ def model():

import torch
from torch import nn
from torch.distributions import biject_to, constraints
from torch.distributions import biject_to

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.distributions.transforms import affine_autoregressive, iterated
from pyro.distributions.util import broadcast_shape, eye_like, sum_rightmost
from pyro.infer.autoguide.initialization import InitMessenger, init_to_feasible, init_to_median
Expand Down Expand Up @@ -427,6 +428,9 @@ class AutoNormal(AutoGuide):
or iterable of plates. Plates not returned will be created
automatically as usual. This is useful for data subsampling.
"""

scale_constraint = constraints.softplus_positive

def __init__(self, model, *,
init_loc_fn=init_to_feasible,
init_scale=0.1,
Expand Down Expand Up @@ -468,7 +472,8 @@ def _setup_prototype(self, *args, **kwargs):
init_scale = torch.full_like(init_loc, self._init_scale)

_deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim))
_deep_setattr(self.scales, name, PyroParam(init_scale, constraints.positive, event_dim))
_deep_setattr(self.scales, name,
PyroParam(init_scale, self.scale_constraint, event_dim))

def _get_loc_and_scale(self, name):
site_loc = _deep_getattr(self.locs, name)
Expand Down Expand Up @@ -809,6 +814,8 @@ class AutoMultivariateNormal(AutoContinuous):
(unconstrained transformed) latent variable.
"""

scale_tril_constraint = constraints.softplus_lower_cholesky

def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand All @@ -820,7 +827,7 @@ def _setup_prototype(self, *args, **kwargs):
# Initialize guide params
self.loc = nn.Parameter(self._init_loc())
self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim) * self._init_scale,
constraints.lower_cholesky)
self.scale_tril_constraint)

def get_base_dist(self):
return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1)
Expand Down Expand Up @@ -859,6 +866,8 @@ class AutoDiagonalNormal(AutoContinuous):
(unconstrained transformed) latent variable.
"""

scale_constraint = constraints.softplus_positive

def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand All @@ -870,7 +879,7 @@ def _setup_prototype(self, *args, **kwargs):
# Initialize guide params
self.loc = nn.Parameter(self._init_loc())
self.scale = PyroParam(self.loc.new_full((self.latent_dim,), self._init_scale),
constraints.positive)
self.scale_constraint)

def get_base_dist(self):
return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1)
Expand Down Expand Up @@ -914,6 +923,8 @@ class AutoLowRankMultivariateNormal(AutoContinuous):
deviation of each (unconstrained transformed) latent variable.
"""

scale_constraint = constraints.softplus_positive

def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1, rank=None):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand All @@ -931,7 +942,7 @@ def _setup_prototype(self, *args, **kwargs):
self.rank = int(round(self.latent_dim ** 0.5))
self.scale = PyroParam(
self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale),
constraint=constraints.positive)
constraint=self.scale_constraint)
self.cov_factor = nn.Parameter(
self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5))

Expand Down
29 changes: 26 additions & 3 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,18 @@ def test_sylvester(self):
self._test(T.sylvester, inverse=False)

def test_normalize_transform(self):
for p in [1., 2.]:
self._test(lambda p: T.Normalize(p=p), autodiff=False)
self._test(lambda p: T.Normalize(p=p), autodiff=False)

def test_softplus(self):
self._test(lambda _: T.SoftplusTransform(), autodiff=False)


@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)])
@pytest.mark.parametrize('dim', [2, 3, 5])
@pytest.mark.parametrize('transform', [T.CholeskyTransform(), T.CorrMatrixCholeskyTransform()])
@pytest.mark.parametrize('transform', [
T.CholeskyTransform(),
T.CorrMatrixCholeskyTransform(),
], ids=lambda t: type(t).__name__)
def test_cholesky_transform(batch_shape, dim, transform):
arange = torch.arange(dim)
domain = transform.domain
Expand Down Expand Up @@ -385,3 +390,21 @@ def transform_to_vec(x_vec):
assert log_det.shape == batch_shape
assert_close(y, x_mat.cholesky())
assert_close(transform.inv(y), x_mat)


@pytest.mark.parametrize('batch_shape', [(), (7,), (6, 5)])
@pytest.mark.parametrize('dim', [2, 3, 5])
@pytest.mark.parametrize('transform', [
T.LowerCholeskyTransform(),
T.SoftplusLowerCholeskyTransform(),
], ids=lambda t: type(t).__name__)
def test_lower_cholesky_transform(transform, batch_shape, dim):
shape = batch_shape + (dim, dim)
x = torch.randn(shape)
y = transform(x)
assert y.shape == shape
x2 = transform.inv(y)
assert x2.shape == shape
y2 = transform(x2)
assert y2.shape == shape
assert_close(y, y2)
2 changes: 1 addition & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def model():
pyro.sample("z", dist.Beta(2.0, 2.0))

guide = auto_class(model)
optim = Adam({'lr': 0.05, 'betas': (0.8, 0.99)})
optim = Adam({'lr': 0.02, 'betas': (0.8, 0.99)})
elbo = Elbo(strict_enumeration_warning=False,
num_particles=100, vectorize_particles=True)
infer = SVI(model, guide, optim, elbo)
Expand Down