From a7dea88a40062795ba76d179bd39adf4a333bc01 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 22 Apr 2023 19:05:04 +0000 Subject: [PATCH 1/6] deprecate CorrLCholeskyTransform --- docs/source/distributions.rst | 7 -- pyro/distributions/transforms/__init__.py | 14 +--- pyro/distributions/transforms/cholesky.py | 83 ++--------------------- tests/distributions/test_lkj.py | 2 +- tests/distributions/test_transforms.py | 2 +- 5 files changed, 11 insertions(+), 97 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index ca9b0b5d94..ccb51b088d 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -466,13 +466,6 @@ CholeskyTransform :undoc-members: :show-inheritance: -CorrLCholeskyTransform ----------------------- -.. autoclass:: pyro.distributions.transforms.CorrLCholeskyTransform - :members: - :undoc-members: - :show-inheritance: - CorrMatrixCholeskyTransform --------------------------- .. autoclass:: pyro.distributions.transforms.CorrMatrixCholeskyTransform diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 917da4cffb..ba2b50bafc 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -30,11 +30,7 @@ from .basic import ELUTransform, LeakyReLUTransform, elu, leaky_relu from .batchnorm import BatchNorm, batchnorm from .block_autoregressive import BlockAutoregressive, block_autoregressive -from .cholesky import ( - CholeskyTransform, - CorrLCholeskyTransform, - CorrMatrixCholeskyTransform, -) +from .cholesky import CholeskyTransform, CorrMatrixCholeskyTransform from .discrete_cosine import DiscreteCosineTransform from .generalized_channel_permute import ( ConditionalGeneralizedChannelPermute, @@ -90,17 +86,11 @@ def _transform_to_sphere(constraint): return Normalize() -@biject_to.register(constraints.corr_cholesky) -@transform_to.register(constraints.corr_cholesky) -def _transform_to_corr_cholesky(constraint): - return CorrLCholeskyTransform() - - @biject_to.register(constraints.corr_matrix) @transform_to.register(constraints.corr_matrix) def _transform_to_corr_matrix(constraint): return ComposeTransform( - [CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv] + [CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv] ) diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index d2e4d22684..d3532baa69 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -4,87 +4,18 @@ import math import torch -from torch.distributions.transforms import Transform +from torch.distributions.transforms import CorrCholeskyTransform, Transform from .. import constraints -def _vector_to_l_cholesky(z): - D = (1.0 + math.sqrt(1.0 + 8.0 * z.shape[-1])) / 2.0 - if D % 1 != 0: - raise ValueError("Correlation matrix transformation requires d choose 2 inputs") - D = int(D) - x = torch.zeros(z.shape[:-1] + (D, D), dtype=z.dtype, device=z.device) - - x[..., 0, 0] = 1 - x[..., 1:, 0] = z[..., : (D - 1)] - i = D - 1 - last_squared_x = torch.zeros(z.shape[:-1] + (D,), dtype=z.dtype, device=z.device) - for j in range(1, D): - distance_to_copy = D - 1 - j - last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone() ** 2 - x[..., j, j] = (1 - last_squared_x[..., 0]).sqrt() - x[..., (j + 1) :, j] = ( - z[..., i : (i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt() - ) - i += distance_to_copy - return x - - -class CorrLCholeskyTransform(Transform): - """ - Transforms a vector into the cholesky factor of a correlation matrix. - - The input should have shape `[batch_shape] + [d * (d-1)/2]`. The output will - have shape `[batch_shape] + [d, d]`. - - References: - - [1] Cholesky Factors of Correlation Matrices. Stan Reference Manual v2.18, - Section 10.12. - - """ - - domain = constraints.real_vector - codomain = constraints.corr_cholesky - bijective = True - - def __eq__(self, other): - return isinstance(other, CorrLCholeskyTransform) - - def _call(self, x): - z = x.tanh() - return _vector_to_l_cholesky(z) - - def _inverse(self, y): - if y.shape[-2] != y.shape[-1]: - raise ValueError( - "A matrix that isn't square can't be a Cholesky factor of a correlation matrix" - ) - D = y.shape[-1] - - z_tri = torch.zeros( - y.shape[:-2] + (D - 2, D - 2), dtype=y.dtype, device=y.device - ) - z_stack = [y[..., 1:, 0]] - - for i in range(2, D): - z_tri[..., i - 2, 0 : (i - 1)] = ( - y[..., i, 1:i] / (1 - y[..., i, 0 : (i - 1)].pow(2).cumsum(-1)).sqrt() - ) - for j in range(D - 2): - z_stack.append(z_tri[..., j:, j]) - - z = torch.cat(z_stack, -1) - return torch.log1p((2 * z) / (1 - z)) / 2 - - def log_abs_det_jacobian(self, x, y): - # Note dependence on pytorch 1.0.1 for batched tril - tanpart = x.cosh().log().sum(-1).mul(-2) - matpart = ( - (1 - y.pow(2).cumsum(-1).tril(diagonal=-2)).log().div(2).sum(-1).sum(-1) +class CorrLCholeskyTransform(CorrCholeskyTransform): # DEPRECATED + def __init__(self): + warnings.warn( + "class CorrLCholeskyTransform is deprecated in favor of CorrCholeskyTransform.", + FutureWarning, ) - return tanpart + matpart + super().__init__() class CholeskyTransform(Transform): diff --git a/tests/distributions/test_lkj.py b/tests/distributions/test_lkj.py index 92216fb47c..ffbbe02781 100644 --- a/tests/distributions/test_lkj.py +++ b/tests/distributions/test_lkj.py @@ -39,7 +39,7 @@ def _autograd_log_det(ys, x): @pytest.mark.parametrize("y_shape", [(1,), (3, 1), (6,), (1, 6), (2, 6)]) def test_unconstrained_to_corr_cholesky_transform(y_shape): - transform = transforms.CorrLCholeskyTransform() + transform = transforms.CorrCholeskyTransform() y = torch.empty(y_shape).uniform_(-4, 4).requires_grad_() x = transform(y) diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index 30803d1472..a339f7ea79 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -418,7 +418,7 @@ def test_cholesky_transform(batch_shape, dim, transform): tril_mask = arange < arange.view(-1, 1) else: tril_mask = arange < arange.view(-1, 1) + 1 - x = transform.inv(T.CorrLCholeskyTransform()(z)) # creates corr_matrix + x = transform.inv(T.CorrCholeskyTransform()(z)) # creates corr_matrix def vec_to_mat(x_vec): x_mat = x_vec.new_zeros(batch_shape + (dim, dim)) From 4153551c84a46fd8a30f40e75714aa849b5ca075 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 22 Apr 2023 19:11:00 +0000 Subject: [PATCH 2/6] cache_size --- pyro/distributions/transforms/cholesky.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index d3532baa69..54bf2d8874 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -10,12 +10,12 @@ class CorrLCholeskyTransform(CorrCholeskyTransform): # DEPRECATED - def __init__(self): + def __init__(self, cache_size=0): warnings.warn( "class CorrLCholeskyTransform is deprecated in favor of CorrCholeskyTransform.", FutureWarning, ) - super().__init__() + super().__init__(cache_size=cache_size) class CholeskyTransform(Transform): From 0742bb4fe6375b1668df0075f05fc9e1b6ffecf2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 22 Apr 2023 19:15:00 +0000 Subject: [PATCH 3/6] import warnings --- pyro/distributions/transforms/cholesky.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index 54bf2d8874..39321d1145 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import warnings import torch from torch.distributions.transforms import CorrCholeskyTransform, Transform From a4603d7316d73f960eb83d57ed3bf6c1b7e22556 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 22 Apr 2023 19:17:39 +0000 Subject: [PATCH 4/6] lint --- pyro/distributions/transforms/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index ba2b50bafc..8453d3e460 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -161,7 +161,6 @@ def iterated(repeats, base_fn, *args, **kwargs): "ConditionalRadial", "ConditionalSpline", "ConditionalSplineAutoregressive", - "CorrLCholeskyTransform", "CorrMatrixCholeskyTransform", "DiscreteCosineTransform", "ELUTransform", From 48cac2114c25d1a7c6dcb28b73d527691a0db372 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 22 Apr 2023 19:22:56 +0000 Subject: [PATCH 5/6] lint --- pyro/distributions/transforms/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 8453d3e460..884b4f7b23 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -30,7 +30,11 @@ from .basic import ELUTransform, LeakyReLUTransform, elu, leaky_relu from .batchnorm import BatchNorm, batchnorm from .block_autoregressive import BlockAutoregressive, block_autoregressive -from .cholesky import CholeskyTransform, CorrMatrixCholeskyTransform +from .cholesky import ( + CholeskyTransform, + CorrCholeskyTransform, + CorrMatrixCholeskyTransform, +) from .discrete_cosine import DiscreteCosineTransform from .generalized_channel_permute import ( ConditionalGeneralizedChannelPermute, From 68b6cad05e91bc4edc2323d3730122ba01142ff0 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 24 Apr 2023 02:02:43 +0000 Subject: [PATCH 6/6] address comments --- pyro/distributions/transforms/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 884b4f7b23..a5847ad31d 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -33,6 +33,7 @@ from .cholesky import ( CholeskyTransform, CorrCholeskyTransform, + CorrLCholeskyTransform, CorrMatrixCholeskyTransform, ) from .discrete_cosine import DiscreteCosineTransform @@ -165,6 +166,7 @@ def iterated(repeats, base_fn, *args, **kwargs): "ConditionalRadial", "ConditionalSpline", "ConditionalSplineAutoregressive", + "CorrLCholeskyTransform", "CorrMatrixCholeskyTransform", "DiscreteCosineTransform", "ELUTransform",