Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ProjectedNormal distribution and reparametrizer #2736

Merged
merged 17 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ OrderedLogistic
:undoc-members:
:show-inheritance:

ProjectedNormal
---------------
.. autoclass:: pyro.distributions.ProjectedNormal
:members:
:undoc-members:
:show-inheritance:

RelaxedBernoulliStraightThrough
-------------------------------
.. autoclass:: pyro.distributions.RelaxedBernoulliStraightThrough
Expand Down
9 changes: 9 additions & 0 deletions docs/source/infer.reparam.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ Stable Distributions
:special-members: __call__
:show-inheritance:

Projected Normal Distributions
------------------------------
.. automodule:: pyro.infer.reparam.projected_normal
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:

Hidden Markov Models
--------------------
.. automodule:: pyro.infer.reparam.hmm
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pyro.distributions.one_two_matching import OneTwoMatching
from pyro.distributions.ordered_logistic import OrderedLogistic
from pyro.distributions.polya_gamma import TruncatedPolyaGamma
from pyro.distributions.projected_normal import ProjectedNormal
from pyro.distributions.rejector import Rejector
from pyro.distributions.relaxed_straight_through import (
RelaxedBernoulliStraightThrough,
Expand Down Expand Up @@ -113,6 +114,7 @@
"OneOneMatching",
"OneTwoMatching",
"OrderedLogistic",
"ProjectedNormal",
"Rejector",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategoricalStraightThrough",
Expand Down
24 changes: 23 additions & 1 deletion pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ def __repr__(self):
return self.__class__.__name__[1:]


class _Sphere(Constraint):
"""
Constrain to the Euclidean sphere of any dimension.
"""
reltol = 10. # Relative to finfo.eps.

def check(self, value):
eps = torch.finfo(value.dtype).eps
try:
norm = torch.linalg.norm(value, dim=-1) # torch 1.7+
except AttributeError:
norm = value.norm(dim=-1) # torch 1.6
error = (norm - 1).abs()
return error < self.reltol * eps * value.size(-1) ** 0.5

def __repr__(self):
return self.__class__.__name__[1:]


class _CorrCholesky(Constraint):
"""
Constrains to lower-triangular square matrices with positive diagonals and
Expand Down Expand Up @@ -73,12 +92,14 @@ def check(self, value):
corr_cholesky_constraint = _CorrCholesky()
integer = _Integer()
ordered_vector = _OrderedVector()
sphere = _Sphere()

__all__ = [
'IndependentConstraint',
'corr_cholesky_constraint',
'integer',
'ordered_vector',
'sphere',
]

__all__.extend(torch_constraints)
Expand All @@ -100,7 +121,8 @@ def check(self, value):
_name,
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch") else
".. autoclass:: {}".format(_name)
".. autoclass:: {}".format(_name if type(globals()[_name]) is type else
type(globals()[_name]).__name__)
)
for _name in sorted(__all__)
])
158 changes: 158 additions & 0 deletions pyro/distributions/projected_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch

from . import constraints
from .torch_distribution import TorchDistribution


def safe_project(x):
"""
Safely project a vector onto the sphere. This avoid the singularity at zero
by mapping to the vector ``[1, 0, 0, ..., 0]``.

:param Tensor x: A vector
:returns: A normalized version ``x / ||x||_2``.
:rtype: Tensor
"""
try:
norm = torch.linalg.norm(x, dim=-1, keepdim=True) # torch 1.7+
except AttributeError:
norm = x.norm(dim=-1, keepdim=True) # torch 1.6
x = x / norm.clamp(min=torch.finfo(x.dtype).tiny)
x.data[..., 0][x.data.eq(0).all(dim=-1)] = 1 # Avoid the singularity.
return x


class ProjectedNormal(TorchDistribution):
"""
Projected isotropic normal distribution of arbitrary dimension.

This distribution over directional data is qualitatively similar to the von
Mises and von Mises-Fisher distributions, but permits tractable variational
inference via reparametrized gradients.

To use this distribution with autoguides, use ``poutine.reparam`` with a
:class:`~pyro.infer.reparam.projected_normal.ProjectedNormalReparam`
reparametrizer in the model, e.g.::

@poutine.reparam(config={"direction": ProjectedNormalReparam()})
def model():
direction = pyro.sample("direction",
ProjectedNormal(torch.zeros(3)))
...

.. note:: This implements :meth:`log_prob` only for dimensions {2,3}.

[1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
"The General Projected Normal Distribution of Arbitrary Dimension:
Modeling and Bayesian Inference"
https://projecteuclid.org/euclid.ba/1453211962
"""
arg_constraints = {"concentration": constraints.real_vector}
support = constraints.sphere
has_rsample = True
_log_prob_impls = {} # maps dim -> function(concentration, value)

def __init__(self, concentration, *, validate_args=None):
assert concentration.dim() >= 1
self.concentration = concentration
batch_shape = concentration.shape[:-1]
event_shape = concentration.shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(ProjectedNormal, _instance)
new.concentration = self.concentration.expand(batch_shape + (-1,))
super(ProjectedNormal, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self.__dict__.get('_validate_args')
return new

@property
def mean(self):
"""
Note this is the mean in the sense of a centroid in the submanifold
that minimizes expected squared geodesic distance.
"""
return safe_project(self.concentration)

@property
def mode(self):
return safe_project(self.concentration)

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
x = self.concentration.new_empty(shape).normal_()
x = x + self.concentration
x = safe_project(x)
return x

def log_prob(self, value):
if self._validate_args:
event_shape = value.shape[-1:]
if event_shape != self.event_shape:
raise ValueError(f"Expected event shape {self.event_shape}, "
f"but got {event_shape}")
self._validate_sample(value)
dim = int(self.concentration.size(-1))
try:
impl = self._log_prob_impls[dim]
except KeyError:
msg = f"ProjectedNormal.log_prob() is not implemented for dim = {dim}."
if value.requires_grad: # For latent variables but not observations.
msg += " Consider using poutine.reparam with ProjectedNormalReparam."
raise NotImplementedError(msg)
return impl(self.concentration, value)

@classmethod
def _register_log_prob(cls, dim, fn=None):
if fn is None:
return lambda fn: cls._register_log_prob(dim, fn)
cls._log_prob_impls[dim] = fn
return fn


def _dot(x, y):
return (x[..., None, :] @ y[..., None])[..., 0, 0]


@ProjectedNormal._register_log_prob(dim=2)
def _log_prob_2(concentration, value):
# We integrate along a ray, factorizing the integrand as a product of:
# a truncated normal distribution over coordinate t parallel to the ray, and
# a univariate normal distribution over coordinate r perpendicular to the ray.
t = _dot(concentration, value)
t2 = t.square()
r2 = _dot(concentration, concentration) - t2
perp_part = r2.mul(-0.5) - 0.5 * math.log(2 * math.pi)

# This is the log of a definite integral, computed by mathematica:
# Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2
para_part = (t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5)
+ t * (1 + (t * 0.5 ** 0.5).erf())).mul(0.5).log()

return para_part + perp_part


@ProjectedNormal._register_log_prob(dim=3)
def _log_prob_3(concentration, value):
# We integrate along a ray, factorizing the integrand as a product of:
# a truncated normal distribution over coordinate t parallel to the ray, and
# a bivariate normal distribution over coordinate r perpendicular to the ray.
t = _dot(concentration, value)
t2 = t.square()
r2 = _dot(concentration, concentration) - t2
perp_part = r2.mul(-0.5) - math.log(2 * math.pi)

# This is the log of a definite integral, computed by mathematica:
# Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2
para_part = (t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5
+ (1 + t2) * (1 + (t * 0.5 ** 0.5).erf()) / 2).log()

return para_part + perp_part
9 changes: 5 additions & 4 deletions pyro/distributions/von_mises_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import math

import torch
from torch.distributions import constraints

from pyro.distributions import TorchDistribution
from . import constraints
from .torch_distribution import TorchDistribution


class VonMises3D(TorchDistribution):
Expand All @@ -19,7 +19,8 @@ class VonMises3D(TorchDistribution):
must be a normalized 3-vector that lies on the 2-sphere.

See :class:`~pyro.distributions.VonMises` for a 2D polar coordinate cousin
of this distribution.
of this distribution. See :class:`~pyro.distributions.projected_normal` for
a qualitatively similar distribution but implementing more functionality.

Currently only :meth:`log_prob` is implemented.

Expand All @@ -28,7 +29,7 @@ class VonMises3D(TorchDistribution):
magnitude is the concentration.
"""
arg_constraints = {'concentration': constraints.real}
support = constraints.real # TODO implement constraints.sphere or similar
support = constraints.sphere

def __init__(self, concentration, validate_args=None):
if concentration.dim() < 1 or concentration.shape[-1] != 3:
Expand Down
12 changes: 8 additions & 4 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def model():
from pyro.distributions.transforms import affine_autoregressive, iterated
from pyro.distributions.util import broadcast_shape, eye_like, sum_rightmost
from pyro.infer.autoguide.initialization import InitMessenger, init_to_feasible, init_to_median
from pyro.infer.autoguide.utils import _product
from pyro.infer.enum import config_enumerate
from pyro.nn import PyroModule, PyroParam
from pyro.ops.hessian import hessian
from pyro.ops.tensor_utils import periodic_repeat
from pyro.poutine.util import site_is_subsample

from .utils import _product, helpful_support_errors


def _deep_setattr(obj, key, val):
"""
Expand Down Expand Up @@ -355,7 +356,8 @@ def _setup_prototype(self, *args, **kwargs):
value = periodic_repeat(value, full_size, dim).contiguous()

value = PyroParam(value, site["fn"].support, event_dim)
_deep_setattr(self, name, value)
with helpful_support_errors(site):
_deep_setattr(self, name, value)

def forward(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -447,7 +449,8 @@ def _setup_prototype(self, *args, **kwargs):
# Initialize guide params
for name, site in self.prototype_trace.iter_stochastic_nodes():
# Collect unconstrained event_dims, which may differ from constrained event_dims.
init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
with helpful_support_errors(site):
init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
self._event_dims[name] = event_dim

Expand Down Expand Up @@ -591,7 +594,8 @@ def _setup_prototype(self, *args, **kwargs):
for name, site in self.prototype_trace.iter_stochastic_nodes():
# Collect the shapes of unconstrained values.
# These may differ from the shapes of constrained values.
self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape
with helpful_support_errors(site):
self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape

# Collect independence contexts.
self._cond_indep_stacks[name] = site["cond_indep_stack"]
Expand Down
4 changes: 3 additions & 1 deletion pyro/infer/autoguide/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from pyro.poutine.messenger import Messenger
from pyro.util import torch_isnan

from .utils import helpful_support_errors

# TODO: move this file out of `autoguide` in a minor release


def _is_multivariate(d):
while isinstance(d, (Independent, MaskedDistribution)):
d = d.base_dist
Expand Down Expand Up @@ -180,7 +182,7 @@ def __init__(self, init_fn):
def _pyro_sample(self, msg):
if msg["done"] or msg["is_observed"] or type(msg["fn"]).__name__ == "_Subsample":
return
with torch.no_grad():
with torch.no_grad(), helpful_support_errors(msg):
value = self.init_fn(msg)
if is_validation_enabled() and msg["value"] is not None:
if not isinstance(value, type(msg["value"])):
Expand Down
26 changes: 26 additions & 0 deletions pyro/infer/autoguide/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager

from pyro import poutine


Expand Down Expand Up @@ -33,3 +35,27 @@ def mean_field_entropy(model, args, whitelist=None):
if whitelist is None or name in whitelist:
entropy += site["fn"].entropy()
return entropy


@contextmanager
def helpful_support_errors(site):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

try:
yield
except NotImplementedError as e:
support_name = repr(site["fn"].support).lower()
if "integer" in support_name or "boolean" in support_name:
name = site["name"]
raise ValueError(
f"Continuous inference cannot handle discrete sample site '{name}'. "
"Consider enumerating that variable as documented in "
"https://pyro.ai/examples/enumeration.html . "
"If you are already enumerating, take care to hide this site when "
"constructing an autoguide, e.g. "
f"guide = AutoNormal(poutine.block(model, hide=['{name}'])).")
if "sphere" in support_name:
name = site["name"]
raise ValueError(
f"Continuous inference cannot handle spherical sample site '{name}'. "
"Consider using ProjectedNormal distribution together with "
"poutine.reparam and a ProjectedNormalReparam reparametrizer.")
raise e from None
Loading