From 4c1fa6862bee395c4ca64749c6114257df048476 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Fri, 26 Nov 2021 15:43:04 -0800 Subject: [PATCH 1/7] initial commit --- .../distributions/transforms/unit_cholesky.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 pyro/distributions/transforms/unit_cholesky.py diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py new file mode 100644 index 0000000000..717a612adf --- /dev/null +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.distributions import constraints +from torch.distributions.transforms import Transform + + +class _UnitLowerCholesky(torch.distributions.Constraint): + """ + Constrain to lower-triangular square matrices with all ones diagonals. + """ + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + + ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0] + return lower_triangular & ones_diagonal + + +class UnitLowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + all ones diagonals. + """ + domain = constraints.independent(constraints.real, 2) + codomain = constraints.unit_lower_cholesky + + def __eq__(self, other): + return isinstance(other, UnitLowerCholeskyTransform) + + def _call(self, x): + return x.tril(-1) + torch.eye(x.size(-1), device=x.device, dtype=x.dtype) + + def _inverse(self, y): + return y.tril(-1) + + +__all__ = [ + "_UnitLowerCholesky", + "UnitLowerCholeskyTransform", +] From e73186c46a8956546e5ad8d95f90232e0cf1574a Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Fri, 26 Nov 2021 15:52:20 -0800 Subject: [PATCH 2/7] add constraint --- pyro/distributions/constraints.py | 15 +++++++++++++++ pyro/distributions/transforms/__init__.py | 6 ++++++ pyro/distributions/transforms/unit_cholesky.py | 15 --------------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index d6bd399459..67bcfe2e16 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -101,6 +101,20 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): pass +class _UnitLowerCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with all ones diagonals. + """ + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + + ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0] + return lower_triangular & ones_diagonal + + corr_matrix = _CorrMatrix() integer = _Integer() ordered_vector = _OrderedVector() @@ -108,6 +122,7 @@ class _SoftplusLowerCholesky(type(lower_cholesky)): sphere = _Sphere() softplus_positive = _SoftplusPositive() softplus_lower_cholesky = _SoftplusLowerCholesky() +unit_lower_cholesky = _UnitLowerCholesky() corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 298c9c2ade..a2939c1039 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -70,6 +70,7 @@ from .power import PositivePowerTransform from .radial import ConditionalRadial, Radial, conditional_radial, radial from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform +from .unit_cholesky import UnitLowerCholeskyTransform from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import ( ConditionalSplineAutoregressive, @@ -132,6 +133,11 @@ def _transform_to_softplus_lower_cholesky(constraint): return SoftplusLowerCholeskyTransform() +@transform_to.register(constraints.unit_lower_cholesky) +def _transform_to_unit_lower_cholesky(constraint): + return UnitLowerCholeskyTransform() + + def iterated(repeats, base_fn, *args, **kwargs): """ Helper function to compose a sequence of bijective transforms with potentially diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py index 717a612adf..edd33a1ab1 100644 --- a/pyro/distributions/transforms/unit_cholesky.py +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -6,20 +6,6 @@ from torch.distributions.transforms import Transform -class _UnitLowerCholesky(torch.distributions.Constraint): - """ - Constrain to lower-triangular square matrices with all ones diagonals. - """ - event_dim = 2 - - def check(self, value): - value_tril = value.tril() - lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] - - ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0] - return lower_triangular & ones_diagonal - - class UnitLowerCholeskyTransform(Transform): """ Transform from unconstrained matrices to lower-triangular matrices with @@ -39,6 +25,5 @@ def _inverse(self, y): __all__ = [ - "_UnitLowerCholesky", "UnitLowerCholeskyTransform", ] From c54309041717b4ccca0af1111d509edd9b5df089 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Fri, 26 Nov 2021 16:00:45 -0800 Subject: [PATCH 3/7] add test --- pyro/distributions/transforms/unit_cholesky.py | 4 +++- pyro/infer/autoguide/guides.py | 2 +- tests/distributions/test_transforms.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py index edd33a1ab1..b5102129cc 100644 --- a/pyro/distributions/transforms/unit_cholesky.py +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -5,6 +5,8 @@ from torch.distributions import constraints from torch.distributions.transforms import Transform +from pyro.distributions.constraints import unit_lower_cholesky + class UnitLowerCholeskyTransform(Transform): """ @@ -12,7 +14,7 @@ class UnitLowerCholeskyTransform(Transform): all ones diagonals. """ domain = constraints.independent(constraints.real, 2) - codomain = constraints.unit_lower_cholesky + codomain = unit_lower_cholesky def __eq__(self, other): return isinstance(other, UnitLowerCholeskyTransform) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 467e9eb003..292e3bcd19 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -863,7 +863,7 @@ class AutoMultivariateNormal(AutoContinuous): """ scale_constraint = constraints.softplus_positive - scale_tril_constraint = constraints.softplus_lower_cholesky + scale_tril_constraint = constraints.unit_lower_cholesky def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 4344eb56d2..560d3a189e 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -456,6 +456,7 @@ def transform_to_vec(x_vec): [ T.LowerCholeskyTransform(), T.SoftplusLowerCholeskyTransform(), + T.UnitLowerCholeskyTransform(), ], ids=lambda t: type(t).__name__, ) From a88eb885ec1b04ade47bdccf2b02cb0a7d8a0bab Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Fri, 26 Nov 2021 16:01:31 -0800 Subject: [PATCH 4/7] make format --- pyro/distributions/constraints.py | 5 ++++- pyro/distributions/transforms/__init__.py | 2 +- pyro/distributions/transforms/unit_cholesky.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index 67bcfe2e16..c6136f4af3 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -105,11 +105,14 @@ class _UnitLowerCholesky(Constraint): """ Constrain to lower-triangular square matrices with all ones diagonals. """ + event_dim = 2 def check(self, value): value_tril = value.tril() - lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0] return lower_triangular & ones_diagonal diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index a2939c1039..917da4cffb 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -70,7 +70,6 @@ from .power import PositivePowerTransform from .radial import ConditionalRadial, Radial, conditional_radial, radial from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform -from .unit_cholesky import UnitLowerCholeskyTransform from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import ( ConditionalSplineAutoregressive, @@ -80,6 +79,7 @@ ) from .spline_coupling import SplineCoupling, spline_coupling from .sylvester import Sylvester, sylvester +from .unit_cholesky import UnitLowerCholeskyTransform ######################################## # register transforms diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py index b5102129cc..5f6e6fa2f8 100644 --- a/pyro/distributions/transforms/unit_cholesky.py +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -13,6 +13,7 @@ class UnitLowerCholeskyTransform(Transform): Transform from unconstrained matrices to lower-triangular matrices with all ones diagonals. """ + domain = constraints.independent(constraints.real, 2) codomain = unit_lower_cholesky From ff9827786898b909816159d8be3e7c4e3e88519c Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Fri, 26 Nov 2021 16:04:01 -0800 Subject: [PATCH 5/7] add docstring --- docs/source/distributions.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index ec58a174a3..b3d1008673 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -522,6 +522,13 @@ SoftplusLowerCholeskyTransform :undoc-members: :show-inheritance: +UnitLowerCholeskyTransform +-------------------------- +.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform + :members: + :undoc-members: + :show-inheritance: + SoftplusTransform ----------------- .. autoclass:: pyro.distributions.transforms.SoftplusTransform From a288dd431f06fbd663881bb16d732fba0e28a48e Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 28 Nov 2021 09:18:54 -0800 Subject: [PATCH 6/7] address comments --- docs/source/distributions.rst | 12 ++++++------ pyro/distributions/transforms/unit_cholesky.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index b3d1008673..b8ed1c6d93 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -522,16 +522,16 @@ SoftplusLowerCholeskyTransform :undoc-members: :show-inheritance: -UnitLowerCholeskyTransform --------------------------- -.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform +SoftplusTransform +----------------- +.. autoclass:: pyro.distributions.transforms.SoftplusTransform :members: :undoc-members: :show-inheritance: -SoftplusTransform ------------------ -.. autoclass:: pyro.distributions.transforms.SoftplusTransform +UnitLowerCholeskyTransform +-------------------------- +.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform :members: :undoc-members: :show-inheritance: diff --git a/pyro/distributions/transforms/unit_cholesky.py b/pyro/distributions/transforms/unit_cholesky.py index 5f6e6fa2f8..8262f858fb 100644 --- a/pyro/distributions/transforms/unit_cholesky.py +++ b/pyro/distributions/transforms/unit_cholesky.py @@ -24,7 +24,7 @@ def _call(self, x): return x.tril(-1) + torch.eye(x.size(-1), device=x.device, dtype=x.dtype) def _inverse(self, y): - return y.tril(-1) + return y __all__ = [ From 15cf6b278f4a85eaad2f85eea7986ad538716bb1 Mon Sep 17 00:00:00 2001 From: martin jankowiak Date: Sun, 28 Nov 2021 13:00:02 -0800 Subject: [PATCH 7/7] fixdocs --- pyro/distributions/constraints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index c6136f4af3..711b295cc2 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -137,6 +137,7 @@ def check(self, value): "softplus_lower_cholesky", "softplus_positive", "sphere", + "unit_lower_cholesky", ] __all__.extend(torch_constraints)