Skip to content

Commit

Permalink
add UnitLowerCholeskyTransform; change default parameterization for A…
Browse files Browse the repository at this point in the history
…utoMultivariateNormal (#2972)
  • Loading branch information
martinjankowiak authored Nov 28, 2021
1 parent 3c6a591 commit f9b7d3d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 1 deletion.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,13 @@ SoftplusTransform
:undoc-members:
:show-inheritance:

UnitLowerCholeskyTransform
--------------------------
.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform
:members:
:undoc-members:
:show-inheritance:

TransformModules
~~~~~~~~~~~~~~~~

Expand Down
19 changes: 19 additions & 0 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,31 @@ class _SoftplusLowerCholesky(type(lower_cholesky)):
pass


class _UnitLowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with all ones diagonals.
"""

event_dim = 2

def check(self, value):
value_tril = value.tril()
lower_triangular = (
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
)

ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0]
return lower_triangular & ones_diagonal


corr_matrix = _CorrMatrix()
integer = _Integer()
ordered_vector = _OrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
sphere = _Sphere()
softplus_positive = _SoftplusPositive()
softplus_lower_cholesky = _SoftplusLowerCholesky()
unit_lower_cholesky = _UnitLowerCholesky()
corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED

__all__ = [
Expand All @@ -119,6 +137,7 @@ class _SoftplusLowerCholesky(type(lower_cholesky)):
"softplus_lower_cholesky",
"softplus_positive",
"sphere",
"unit_lower_cholesky",
]

__all__.extend(torch_constraints)
Expand Down
6 changes: 6 additions & 0 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
)
from .spline_coupling import SplineCoupling, spline_coupling
from .sylvester import Sylvester, sylvester
from .unit_cholesky import UnitLowerCholeskyTransform

########################################
# register transforms
Expand Down Expand Up @@ -132,6 +133,11 @@ def _transform_to_softplus_lower_cholesky(constraint):
return SoftplusLowerCholeskyTransform()


@transform_to.register(constraints.unit_lower_cholesky)
def _transform_to_unit_lower_cholesky(constraint):
return UnitLowerCholeskyTransform()


def iterated(repeats, base_fn, *args, **kwargs):
"""
Helper function to compose a sequence of bijective transforms with potentially
Expand Down
32 changes: 32 additions & 0 deletions pyro/distributions/transforms/unit_cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.transforms import Transform

from pyro.distributions.constraints import unit_lower_cholesky


class UnitLowerCholeskyTransform(Transform):
"""
Transform from unconstrained matrices to lower-triangular matrices with
all ones diagonals.
"""

domain = constraints.independent(constraints.real, 2)
codomain = unit_lower_cholesky

def __eq__(self, other):
return isinstance(other, UnitLowerCholeskyTransform)

def _call(self, x):
return x.tril(-1) + torch.eye(x.size(-1), device=x.device, dtype=x.dtype)

def _inverse(self, y):
return y


__all__ = [
"UnitLowerCholeskyTransform",
]
2 changes: 1 addition & 1 deletion pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ class AutoMultivariateNormal(AutoContinuous):
"""

scale_constraint = constraints.softplus_positive
scale_tril_constraint = constraints.softplus_lower_cholesky
scale_tril_constraint = constraints.unit_lower_cholesky

def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
if not isinstance(init_scale, float) or not (init_scale > 0):
Expand Down
1 change: 1 addition & 0 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def transform_to_vec(x_vec):
[
T.LowerCholeskyTransform(),
T.SoftplusLowerCholeskyTransform(),
T.UnitLowerCholeskyTransform(),
],
ids=lambda t: type(t).__name__,
)
Expand Down

0 comments on commit f9b7d3d

Please sign in to comment.