Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port implementation of SimplexToOrderedTransform from numpyro #3320

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,13 @@ PositivePowerTransform
:undoc-members:
:show-inheritance:

SimplexToOrderedTransform
-------------------------
.. autoclass:: pyro.distributions.transforms.SimplexToOrderedTransform
:members:
:undoc-members:
:show-inheritance:

SoftplusLowerCholeskyTransform
------------------------------
.. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .polynomial import Polynomial, polynomial
from .power import PositivePowerTransform
from .radial import ConditionalRadial, Radial, conditional_radial, radial
from .simplex_to_ordered import SimplexToOrderedTransform
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform
from .spline import ConditionalSpline, Spline, conditional_spline, spline
from .spline_autoregressive import (
Expand Down Expand Up @@ -184,6 +185,7 @@ def iterated(repeats, base_fn, *args, **kwargs):
"Polynomial",
"PositivePowerTransform",
"Radial",
"SimplexToOrderedTransform",
"SoftplusLowerCholeskyTransform",
"SoftplusTransform",
"Spline",
Expand Down
70 changes: 70 additions & 0 deletions pyro/distributions/transforms/simplex_to_ordered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions.transforms import Transform
from torch.special import expit, logit

from .. import constraints


# This class is a port of https://num.pyro.ai/en/stable/_modules/numpyro/distributions/transforms.html#SimplexToOrderedTransform
class SimplexToOrderedTransform(Transform):
"""
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints)
Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

:param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform.
For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1].
For more details please refer to Section 2.2 in [1]

**References:**

1. *Ordinal Regression Case Study, section 2.2*,
M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

"""

domain = constraints.simplex
codomain = constraints.ordered_vector

def __init__(self, anchor_point=None):
super().__init__()
self.anchor_point = (
anchor_point if anchor_point is not None else torch.tensor(0.0)
)

def _call(self, x):
s = torch.cumsum(x[..., :-1], axis=-1)
y = logit(s) + torch.unsqueeze(self.anchor_point, -1)
return y

def _inverse(self, y):
y = y - torch.unsqueeze(self.anchor_point, -1)
s = expit(y)
# x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1]
# add two boundary points 0 and 1
s = torch.concat(
[torch.zeros_like(s)[..., :1], s, torch.ones_like(s)[..., :1]], dim=-1
)
x = s[..., 1:] - s[..., :-1]
return x

def log_abs_det_jacobian(self, x, y):
# |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y))
# we know log derivative of expit(y) is `-softplus(y) - softplus(-y)`
J_logdet = (
torch.nn.functional.softplus(y) + torch.nn.functional.softplus(-y)
).sum(-1)
return J_logdet

def __eq__(self, other):
if not isinstance(other, SimplexToOrderedTransform):
return False
return torch.all(torch.equal(self.anchor_point, other.anchor_point))

def forward_shape(self, shape):
return shape[:-1] + (shape[-1] - 1,)

def inverse_shape(self, shape):
return shape[:-1] + (shape[-1] + 1,)
56 changes: 47 additions & 9 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,15 @@ def nonzero(x):
assert diag_sum == float(input_dim)
assert lower_sum == float(0.0)

def _test_inverse(self, shape, transform):
def _test_inverse(self, shape, transform, base_dist_type="normal"):
# Test g^{-1}(g(x)) = x
# NOTE: Calling _call and _inverse directly bypasses caching
base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape))
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(shape) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
x_true = base_dist.sample(torch.Size([10]))
y = transform._call(x_true)
x_calculated = transform._inverse(y)
Expand All @@ -132,8 +137,13 @@ def _test_inverse(self, shape, transform):
J_2 = transform.log_abs_det_jacobian(x_calculated, y)
assert (J_1 - J_2).abs().max().item() < self.delta

def _test_shape(self, base_shape, transform):
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
def _test_shape(self, base_shape, transform, base_dist_type="normal"):
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(base_shape) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
sample = dist.TransformedDistribution(base_dist, [transform]).sample()
assert sample.shape == base_shape

Expand All @@ -148,7 +158,9 @@ def _test_shape(self, base_shape, transform):
assert transform.inverse_shape(output_event_shape) == input_event_shape
assert transform.inverse_shape(output_shape) == base_shape

def _test_autodiff(self, input_dim, transform, inverse=False):
def _test_autodiff(
self, input_dim, transform, inverse=False, base_dist_type="normal"
):
"""
This method essentially tests whether autodiff will not throw any errors
when you're doing maximum-likelihood learning with the transform. Many
Expand All @@ -159,7 +171,12 @@ def _test_autodiff(self, input_dim, transform, inverse=False):
if inverse:
transform = transform.inv

base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(input_dim) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
flow_dist = dist.TransformedDistribution(base_dist, [transform])
optimizer = torch.optim.Adam(temp_transform.parameters())
x = torch.rand(1, input_dim)
Expand All @@ -177,6 +194,7 @@ def _test(
inverse=True,
autodiff=True,
event_dim=1,
base_dist_type="normal",
):
for event_shape in [(2,), (5,)]:
if event_dim > 1:
Expand All @@ -186,19 +204,26 @@ def _test(
)

if inverse:
self._test_inverse(event_shape, transform)
self._test_inverse(
event_shape, transform, base_dist_type=base_dist_type
)
if shape:
for shape in [(3,), (3, 4), (3, 4, 5)]:
base_shape = shape + event_shape
self._test_shape(base_shape, transform)
self._test_shape(
base_shape, transform, base_dist_type=base_dist_type
)
if jacobian and transform.bijective:
if event_dim > 1:
transform = Flatten(transform, event_shape)
self._test_jacobian(reduce(operator.mul, event_shape, 1), transform)
if isinstance(transform, dist.TransformModule) and autodiff:
# If the function doesn't have an explicit inverse, then use the forward op for autodiff
self._test_autodiff(
reduce(operator.mul, event_shape, 1), transform, inverse=not inverse
reduce(operator.mul, event_shape, 1),
transform,
inverse=not inverse,
base_dist_type=base_dist_type,
)

def _test_conditional(
Expand Down Expand Up @@ -376,6 +401,19 @@ def test_polynomial(self):
def test_radial(self):
self._test(T.radial, inverse=False)

def test_simplex_to_ordered(self):
self._test(
lambda event_shape: T.SimplexToOrderedTransform(),
shape=False,
autodiff=False,
base_dist_type="dirichlet",
)
# Unique shape behavior:
assert T.SimplexToOrderedTransform().forward_shape((4, 3, 3)) == (4, 3, 2)
assert T.SimplexToOrderedTransform().forward_shape((2,)) == (1,)
assert T.SimplexToOrderedTransform().inverse_shape((2,)) == (3,)
assert T.SimplexToOrderedTransform().inverse_shape((4, 3, 3)) == (4, 3, 4)

def test_spline(self):
for order in ["linear", "quadratic"]:
self._test(partial(T.spline, order=order))
Expand Down
Loading