Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate CorrLCholeskyTransform in favor of upstream CorrCholeskyTransform #3199

Merged
merged 6 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,6 @@ CholeskyTransform
:undoc-members:
:show-inheritance:

CorrLCholeskyTransform
----------------------
.. autoclass:: pyro.distributions.transforms.CorrLCholeskyTransform
:members:
:undoc-members:
:show-inheritance:

CorrMatrixCholeskyTransform
---------------------------
.. autoclass:: pyro.distributions.transforms.CorrMatrixCholeskyTransform
Expand Down
9 changes: 2 additions & 7 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .block_autoregressive import BlockAutoregressive, block_autoregressive
from .cholesky import (
CholeskyTransform,
CorrCholeskyTransform,
CorrLCholeskyTransform,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd lean towards preserving the old import so as to avoid breaking existing code, but I'll defer to your judgement if you think we should delete both. (I think of "deprecation" as meaning "it still works but we no longer recommend it", whereas removing from __init__.py would be a truly breaking change)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it makes sense to leave the old import.

CorrMatrixCholeskyTransform,
)
Expand Down Expand Up @@ -90,17 +91,11 @@ def _transform_to_sphere(constraint):
return Normalize()


@biject_to.register(constraints.corr_cholesky)
@transform_to.register(constraints.corr_cholesky)
def _transform_to_corr_cholesky(constraint):
return CorrLCholeskyTransform()


@biject_to.register(constraints.corr_matrix)
@transform_to.register(constraints.corr_matrix)
def _transform_to_corr_matrix(constraint):
return ComposeTransform(
[CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv]
[CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv]
)


Expand Down
84 changes: 8 additions & 76 deletions pyro/distributions/transforms/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,21 @@
# SPDX-License-Identifier: Apache-2.0

import math
import warnings

import torch
from torch.distributions.transforms import Transform
from torch.distributions.transforms import CorrCholeskyTransform, Transform

from .. import constraints


def _vector_to_l_cholesky(z):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woo hoo less math to maintain 📉

D = (1.0 + math.sqrt(1.0 + 8.0 * z.shape[-1])) / 2.0
if D % 1 != 0:
raise ValueError("Correlation matrix transformation requires d choose 2 inputs")
D = int(D)
x = torch.zeros(z.shape[:-1] + (D, D), dtype=z.dtype, device=z.device)

x[..., 0, 0] = 1
x[..., 1:, 0] = z[..., : (D - 1)]
i = D - 1
last_squared_x = torch.zeros(z.shape[:-1] + (D,), dtype=z.dtype, device=z.device)
for j in range(1, D):
distance_to_copy = D - 1 - j
last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone() ** 2
x[..., j, j] = (1 - last_squared_x[..., 0]).sqrt()
x[..., (j + 1) :, j] = (
z[..., i : (i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt()
)
i += distance_to_copy
return x


class CorrLCholeskyTransform(Transform):
"""
Transforms a vector into the cholesky factor of a correlation matrix.

The input should have shape `[batch_shape] + [d * (d-1)/2]`. The output will
have shape `[batch_shape] + [d, d]`.

References:

[1] Cholesky Factors of Correlation Matrices. Stan Reference Manual v2.18,
Section 10.12.

"""

domain = constraints.real_vector
codomain = constraints.corr_cholesky
bijective = True

def __eq__(self, other):
return isinstance(other, CorrLCholeskyTransform)

def _call(self, x):
z = x.tanh()
return _vector_to_l_cholesky(z)

def _inverse(self, y):
if y.shape[-2] != y.shape[-1]:
raise ValueError(
"A matrix that isn't square can't be a Cholesky factor of a correlation matrix"
)
D = y.shape[-1]

z_tri = torch.zeros(
y.shape[:-2] + (D - 2, D - 2), dtype=y.dtype, device=y.device
)
z_stack = [y[..., 1:, 0]]

for i in range(2, D):
z_tri[..., i - 2, 0 : (i - 1)] = (
y[..., i, 1:i] / (1 - y[..., i, 0 : (i - 1)].pow(2).cumsum(-1)).sqrt()
)
for j in range(D - 2):
z_stack.append(z_tri[..., j:, j])

z = torch.cat(z_stack, -1)
return torch.log1p((2 * z) / (1 - z)) / 2

def log_abs_det_jacobian(self, x, y):
# Note dependence on pytorch 1.0.1 for batched tril
tanpart = x.cosh().log().sum(-1).mul(-2)
matpart = (
(1 - y.pow(2).cumsum(-1).tril(diagonal=-2)).log().div(2).sum(-1).sum(-1)
class CorrLCholeskyTransform(CorrCholeskyTransform): # DEPRECATED
def __init__(self, cache_size=0):
warnings.warn(
"class CorrLCholeskyTransform is deprecated in favor of CorrCholeskyTransform.",
FutureWarning,
)
return tanpart + matpart
super().__init__(cache_size=cache_size)


class CholeskyTransform(Transform):
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _autograd_log_det(ys, x):

@pytest.mark.parametrize("y_shape", [(1,), (3, 1), (6,), (1, 6), (2, 6)])
def test_unconstrained_to_corr_cholesky_transform(y_shape):
transform = transforms.CorrLCholeskyTransform()
transform = transforms.CorrCholeskyTransform()
y = torch.empty(y_shape).uniform_(-4, 4).requires_grad_()
x = transform(y)

Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_cholesky_transform(batch_shape, dim, transform):
tril_mask = arange < arange.view(-1, 1)
else:
tril_mask = arange < arange.view(-1, 1) + 1
x = transform.inv(T.CorrLCholeskyTransform()(z)) # creates corr_matrix
x = transform.inv(T.CorrCholeskyTransform()(z)) # creates corr_matrix

def vec_to_mat(x_vec):
x_mat = x_vec.new_zeros(batch_shape + (dim, dim))
Expand Down