Skip to content

Commit

Permalink
Port implementation of SimplexToOrderedTransform from numpyro (#3320)
Browse files Browse the repository at this point in the history
  • Loading branch information
peblair committed Feb 7, 2024
1 parent 6d2a56f commit a52338c
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 9 deletions.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,13 @@ PositivePowerTransform
:undoc-members:
:show-inheritance:

SimplexToOrderedTransform
-------------------------
.. autoclass:: pyro.distributions.transforms.SimplexToOrderedTransform
:members:
:undoc-members:
:show-inheritance:

SoftplusLowerCholeskyTransform
------------------------------
.. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .polynomial import Polynomial, polynomial
from .power import PositivePowerTransform
from .radial import ConditionalRadial, Radial, conditional_radial, radial
from .simplex_to_ordered import SimplexToOrderedTransform
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform
from .spline import ConditionalSpline, Spline, conditional_spline, spline
from .spline_autoregressive import (
Expand Down Expand Up @@ -184,6 +185,7 @@ def iterated(repeats, base_fn, *args, **kwargs):
"Polynomial",
"PositivePowerTransform",
"Radial",
"SimplexToOrderedTransform",
"SoftplusLowerCholeskyTransform",
"SoftplusTransform",
"Spline",
Expand Down
70 changes: 70 additions & 0 deletions pyro/distributions/transforms/simplex_to_ordered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions.transforms import Transform
from torch.special import expit, logit

from .. import constraints


# This class is a port of https://num.pyro.ai/en/stable/_modules/numpyro/distributions/transforms.html#SimplexToOrderedTransform
class SimplexToOrderedTransform(Transform):
"""
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints)
Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.
:param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform.
For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1].
For more details please refer to Section 2.2 in [1]
**References:**
1. *Ordinal Regression Case Study, section 2.2*,
M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
"""

domain = constraints.simplex
codomain = constraints.ordered_vector

def __init__(self, anchor_point=None):
super().__init__()
self.anchor_point = (
anchor_point if anchor_point is not None else torch.tensor(0.0)
)

def _call(self, x):
s = torch.cumsum(x[..., :-1], axis=-1)
y = logit(s) + torch.unsqueeze(self.anchor_point, -1)
return y

def _inverse(self, y):
y = y - torch.unsqueeze(self.anchor_point, -1)
s = expit(y)
# x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1]
# add two boundary points 0 and 1
s = torch.concat(
[torch.zeros_like(s)[..., :1], s, torch.ones_like(s)[..., :1]], dim=-1
)
x = s[..., 1:] - s[..., :-1]
return x

def log_abs_det_jacobian(self, x, y):
# |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y))
# we know log derivative of expit(y) is `-softplus(y) - softplus(-y)`
J_logdet = (
torch.nn.functional.softplus(y) + torch.nn.functional.softplus(-y)
).sum(-1)
return J_logdet

def __eq__(self, other):
if not isinstance(other, SimplexToOrderedTransform):
return False
return torch.all(torch.equal(self.anchor_point, other.anchor_point))

def forward_shape(self, shape):
return shape[:-1] + (shape[-1] - 1,)

def inverse_shape(self, shape):
return shape[:-1] + (shape[-1] + 1,)
56 changes: 47 additions & 9 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,15 @@ def nonzero(x):
assert diag_sum == float(input_dim)
assert lower_sum == float(0.0)

def _test_inverse(self, shape, transform):
def _test_inverse(self, shape, transform, base_dist_type="normal"):
# Test g^{-1}(g(x)) = x
# NOTE: Calling _call and _inverse directly bypasses caching
base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape))
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(shape) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
x_true = base_dist.sample(torch.Size([10]))
y = transform._call(x_true)
x_calculated = transform._inverse(y)
Expand All @@ -132,8 +137,13 @@ def _test_inverse(self, shape, transform):
J_2 = transform.log_abs_det_jacobian(x_calculated, y)
assert (J_1 - J_2).abs().max().item() < self.delta

def _test_shape(self, base_shape, transform):
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
def _test_shape(self, base_shape, transform, base_dist_type="normal"):
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(base_shape) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
sample = dist.TransformedDistribution(base_dist, [transform]).sample()
assert sample.shape == base_shape

Expand All @@ -148,7 +158,9 @@ def _test_shape(self, base_shape, transform):
assert transform.inverse_shape(output_event_shape) == input_event_shape
assert transform.inverse_shape(output_shape) == base_shape

def _test_autodiff(self, input_dim, transform, inverse=False):
def _test_autodiff(
self, input_dim, transform, inverse=False, base_dist_type="normal"
):
"""
This method essentially tests whether autodiff will not throw any errors
when you're doing maximum-likelihood learning with the transform. Many
Expand All @@ -159,7 +171,12 @@ def _test_autodiff(self, input_dim, transform, inverse=False):
if inverse:
transform = transform.inv

base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
if base_dist_type == "dirichlet":
base_dist = dist.Dirichlet(torch.ones(input_dim) * 10.0)
elif base_dist_type == "normal":
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
else:
raise ValueError(f"Unknown base distribution type: {base_dist_type}")
flow_dist = dist.TransformedDistribution(base_dist, [transform])
optimizer = torch.optim.Adam(temp_transform.parameters())
x = torch.rand(1, input_dim)
Expand All @@ -177,6 +194,7 @@ def _test(
inverse=True,
autodiff=True,
event_dim=1,
base_dist_type="normal",
):
for event_shape in [(2,), (5,)]:
if event_dim > 1:
Expand All @@ -186,19 +204,26 @@ def _test(
)

if inverse:
self._test_inverse(event_shape, transform)
self._test_inverse(
event_shape, transform, base_dist_type=base_dist_type
)
if shape:
for shape in [(3,), (3, 4), (3, 4, 5)]:
base_shape = shape + event_shape
self._test_shape(base_shape, transform)
self._test_shape(
base_shape, transform, base_dist_type=base_dist_type
)
if jacobian and transform.bijective:
if event_dim > 1:
transform = Flatten(transform, event_shape)
self._test_jacobian(reduce(operator.mul, event_shape, 1), transform)
if isinstance(transform, dist.TransformModule) and autodiff:
# If the function doesn't have an explicit inverse, then use the forward op for autodiff
self._test_autodiff(
reduce(operator.mul, event_shape, 1), transform, inverse=not inverse
reduce(operator.mul, event_shape, 1),
transform,
inverse=not inverse,
base_dist_type=base_dist_type,
)

def _test_conditional(
Expand Down Expand Up @@ -376,6 +401,19 @@ def test_polynomial(self):
def test_radial(self):
self._test(T.radial, inverse=False)

def test_simplex_to_ordered(self):
self._test(
lambda event_shape: T.SimplexToOrderedTransform(),
shape=False,
autodiff=False,
base_dist_type="dirichlet",
)
# Unique shape behavior:
assert T.SimplexToOrderedTransform().forward_shape((4, 3, 3)) == (4, 3, 2)
assert T.SimplexToOrderedTransform().forward_shape((2,)) == (1,)
assert T.SimplexToOrderedTransform().inverse_shape((2,)) == (3,)
assert T.SimplexToOrderedTransform().inverse_shape((4, 3, 3)) == (4, 3, 4)

def test_spline(self):
for order in ["linear", "quadratic"]:
self._test(partial(T.spline, order=order))
Expand Down

0 comments on commit a52338c

Please sign in to comment.