diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index ccb51b088d..aee80f6cc4 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -536,6 +536,13 @@ PositivePowerTransform :undoc-members: :show-inheritance: +SimplexToOrderedTransform +------------------------- +.. autoclass:: pyro.distributions.transforms.SimplexToOrderedTransform + :members: + :undoc-members: + :show-inheritance: + SoftplusLowerCholeskyTransform ------------------------------ .. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index a5847ad31d..d2a2382974 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -70,6 +70,7 @@ from .polynomial import Polynomial, polynomial from .power import PositivePowerTransform from .radial import ConditionalRadial, Radial, conditional_radial, radial +from .simplex_to_ordered import SimplexToOrderedTransform from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import ( @@ -184,6 +185,7 @@ def iterated(repeats, base_fn, *args, **kwargs): "Polynomial", "PositivePowerTransform", "Radial", + "SimplexToOrderedTransform", "SoftplusLowerCholeskyTransform", "SoftplusTransform", "Spline", diff --git a/pyro/distributions/transforms/simplex_to_ordered.py b/pyro/distributions/transforms/simplex_to_ordered.py new file mode 100644 index 0000000000..26791f7ff9 --- /dev/null +++ b/pyro/distributions/transforms/simplex_to_ordered.py @@ -0,0 +1,70 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.distributions.transforms import Transform +from torch.special import expit, logit + +from .. import constraints + + +# This class is a port of https://num.pyro.ai/en/stable/_modules/numpyro/distributions/transforms.html#SimplexToOrderedTransform +class SimplexToOrderedTransform(Transform): + """ + Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) + Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities. + + :param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform. + For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. + For more details please refer to Section 2.2 in [1] + + **References:** + + 1. *Ordinal Regression Case Study, section 2.2*, + M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html + + """ + + domain = constraints.simplex + codomain = constraints.ordered_vector + + def __init__(self, anchor_point=None): + super().__init__() + self.anchor_point = ( + anchor_point if anchor_point is not None else torch.tensor(0.0) + ) + + def _call(self, x): + s = torch.cumsum(x[..., :-1], axis=-1) + y = logit(s) + torch.unsqueeze(self.anchor_point, -1) + return y + + def _inverse(self, y): + y = y - torch.unsqueeze(self.anchor_point, -1) + s = expit(y) + # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] + # add two boundary points 0 and 1 + s = torch.concat( + [torch.zeros_like(s)[..., :1], s, torch.ones_like(s)[..., :1]], dim=-1 + ) + x = s[..., 1:] - s[..., :-1] + return x + + def log_abs_det_jacobian(self, x, y): + # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) + # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` + J_logdet = ( + torch.nn.functional.softplus(y) + torch.nn.functional.softplus(-y) + ).sum(-1) + return J_logdet + + def __eq__(self, other): + if not isinstance(other, SimplexToOrderedTransform): + return False + return torch.all(torch.equal(self.anchor_point, other.anchor_point)) + + def forward_shape(self, shape): + return shape[:-1] + (shape[-1] - 1,) + + def inverse_shape(self, shape): + return shape[:-1] + (shape[-1] + 1,) diff --git a/tests/distributions/test_transforms.py b/tests/distributions/test_transforms.py index a339f7ea79..67f4c5167b 100644 --- a/tests/distributions/test_transforms.py +++ b/tests/distributions/test_transforms.py @@ -112,10 +112,15 @@ def nonzero(x): assert diag_sum == float(input_dim) assert lower_sum == float(0.0) - def _test_inverse(self, shape, transform): + def _test_inverse(self, shape, transform, base_dist_type="normal"): # Test g^{-1}(g(x)) = x # NOTE: Calling _call and _inverse directly bypasses caching - base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape)) + if base_dist_type == "dirichlet": + base_dist = dist.Dirichlet(torch.ones(shape) * 10.0) + elif base_dist_type == "normal": + base_dist = dist.Normal(torch.zeros(shape), torch.ones(shape)) + else: + raise ValueError(f"Unknown base distribution type: {base_dist_type}") x_true = base_dist.sample(torch.Size([10])) y = transform._call(x_true) x_calculated = transform._inverse(y) @@ -132,8 +137,13 @@ def _test_inverse(self, shape, transform): J_2 = transform.log_abs_det_jacobian(x_calculated, y) assert (J_1 - J_2).abs().max().item() < self.delta - def _test_shape(self, base_shape, transform): - base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) + def _test_shape(self, base_shape, transform, base_dist_type="normal"): + if base_dist_type == "dirichlet": + base_dist = dist.Dirichlet(torch.ones(base_shape) * 10.0) + elif base_dist_type == "normal": + base_dist = dist.Normal(torch.zeros(base_shape), torch.ones(base_shape)) + else: + raise ValueError(f"Unknown base distribution type: {base_dist_type}") sample = dist.TransformedDistribution(base_dist, [transform]).sample() assert sample.shape == base_shape @@ -148,7 +158,9 @@ def _test_shape(self, base_shape, transform): assert transform.inverse_shape(output_event_shape) == input_event_shape assert transform.inverse_shape(output_shape) == base_shape - def _test_autodiff(self, input_dim, transform, inverse=False): + def _test_autodiff( + self, input_dim, transform, inverse=False, base_dist_type="normal" + ): """ This method essentially tests whether autodiff will not throw any errors when you're doing maximum-likelihood learning with the transform. Many @@ -159,7 +171,12 @@ def _test_autodiff(self, input_dim, transform, inverse=False): if inverse: transform = transform.inv - base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) + if base_dist_type == "dirichlet": + base_dist = dist.Dirichlet(torch.ones(input_dim) * 10.0) + elif base_dist_type == "normal": + base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) + else: + raise ValueError(f"Unknown base distribution type: {base_dist_type}") flow_dist = dist.TransformedDistribution(base_dist, [transform]) optimizer = torch.optim.Adam(temp_transform.parameters()) x = torch.rand(1, input_dim) @@ -177,6 +194,7 @@ def _test( inverse=True, autodiff=True, event_dim=1, + base_dist_type="normal", ): for event_shape in [(2,), (5,)]: if event_dim > 1: @@ -186,11 +204,15 @@ def _test( ) if inverse: - self._test_inverse(event_shape, transform) + self._test_inverse( + event_shape, transform, base_dist_type=base_dist_type + ) if shape: for shape in [(3,), (3, 4), (3, 4, 5)]: base_shape = shape + event_shape - self._test_shape(base_shape, transform) + self._test_shape( + base_shape, transform, base_dist_type=base_dist_type + ) if jacobian and transform.bijective: if event_dim > 1: transform = Flatten(transform, event_shape) @@ -198,7 +220,10 @@ def _test( if isinstance(transform, dist.TransformModule) and autodiff: # If the function doesn't have an explicit inverse, then use the forward op for autodiff self._test_autodiff( - reduce(operator.mul, event_shape, 1), transform, inverse=not inverse + reduce(operator.mul, event_shape, 1), + transform, + inverse=not inverse, + base_dist_type=base_dist_type, ) def _test_conditional( @@ -376,6 +401,19 @@ def test_polynomial(self): def test_radial(self): self._test(T.radial, inverse=False) + def test_simplex_to_ordered(self): + self._test( + lambda event_shape: T.SimplexToOrderedTransform(), + shape=False, + autodiff=False, + base_dist_type="dirichlet", + ) + # Unique shape behavior: + assert T.SimplexToOrderedTransform().forward_shape((4, 3, 3)) == (4, 3, 2) + assert T.SimplexToOrderedTransform().forward_shape((2,)) == (1,) + assert T.SimplexToOrderedTransform().inverse_shape((2,)) == (3,) + assert T.SimplexToOrderedTransform().inverse_shape((4, 3, 3)) == (4, 3, 4) + def test_spline(self): for order in ["linear", "quadratic"]: self._test(partial(T.spline, order=order))