diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index ec58a174a3..b8ed1c6d93 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -529,6 +529,13 @@ SoftplusTransform :undoc-members: :show-inheritance: +UnitLowerCholeskyTransform +-------------------------- +.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform + :members: + :undoc-members: + :show-inheritance: + TransformModules ~~~~~~~~~~~~~~~~ diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index d6bd399459..711b295cc2 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -101,6 +101,23 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): pass +class _UnitLowerCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with all ones diagonals. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) + + ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0] + return lower_triangular & ones_diagonal + + corr_matrix = _CorrMatrix() integer = _Integer() ordered_vector = _OrderedVector() @@ -108,6 +125,7 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): sphere = _Sphere() softplus_positive = _SoftplusPositive() softplus_lower_cholesky = _SoftplusLowerCholesky() +unit_lower_cholesky = _UnitLowerCholesky() corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ @@ -119,6 +137,7 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): "softplus_lower_cholesky", "softplus_positive", "sphere", + "unit_lower_cholesky", ] __all__.extend(torch_constraints) diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 298c9c2ade..917da4cffb 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -79,6 +79,7 @@ ) from .spline_coupling import SplineCoupling, spline_coupling from .sylvester import Sylvester, sylvester +from .unit_cholesky import UnitLowerCholeskyTransform ######################################## # register transforms @@ -132,6 +133,11 @@ def _transform_to_softplus_lower_cholesky(constraint): return SoftplusLowerCholeskyTransform() +@transform_to.register(constraints.unit_lower_cholesky) +def _transform_to_unit_lower_cholesky(constraint): + return UnitLowerCholeskyTransform() + + def iterated(repeats, base_fn, *args, **kwargs): """ Helper function to compose a sequence of bijective transforms with potentially diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py new file mode 100644 index 0000000000..8262f858fb --- /dev/null +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -0,0 +1,32 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.distributions import constraints +from torch.distributions.transforms import Transform + +from pyro.distributions.constraints import unit_lower_cholesky + + +class UnitLowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + all ones diagonals. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = unit_lower_cholesky + + def __eq__(self, other): + return isinstance(other, UnitLowerCholeskyTransform) + + def _call(self, x): + return x.tril(-1) + torch.eye(x.size(-1), device=x.device, dtype=x.dtype) + + def _inverse(self, y): + return y + + +__all__ = [ + "UnitLowerCholeskyTransform", +] diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 467e9eb003..292e3bcd19 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -863,7 +863,7 @@ class AutoMultivariateNormal(AutoContinuous): """ scale_constraint = constraints.softplus_positive - scale_tril_constraint = constraints.softplus_lower_cholesky + scale_tril_constraint = constraints.unit_lower_cholesky def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 4344eb56d2..560d3a189e 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -456,6 +456,7 @@ def transform_to_vec(x_vec): [ T.LowerCholeskyTransform(), T.SoftplusLowerCholeskyTransform(), + T.UnitLowerCholeskyTransform(), ], ids=lambda t: type(t).__name__, )