Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sine Skewed Toridial distribution #2826

Merged
merged 42 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
585beb9
Bump to version 1.5.2 (#2755)
fritzo Feb 1, 2021
260d05b
Merge branch 'dev'
fritzo Mar 4, 2021
09bcbc0
Added sine skewed distribution and tests.
OlaRonning Apr 27, 2021
b7ae4d1
Added repr.
OlaRonning Apr 27, 2021
85e352c
Fixed shape tests and minor fixes to docstring.
OlaRonning Apr 27, 2021
7ee6643
Fixed lint.
OlaRonning Apr 27, 2021
789f550
Updated docstring with uniform prior.
OlaRonning Apr 28, 2021
be91d9a
Fixed skewness shape assertion.
OlaRonning Apr 28, 2021
2a200d3
ensure `SineSkewed` is on the torus.
OlaRonning Apr 30, 2021
e7a1a74
Reverted `infer_shapes` in `sine_skewed` and `# isort: split` in `dis…
OlaRonning May 1, 2021
1b2e1ca
Merge branch 'feature/ss_dist' of github.com:aleatory-science/pyro in…
OlaRonning May 1, 2021
e802aa4
Sketched `SineSkewed.expand`
OlaRonning May 1, 2021
d1801b9
Fixed `SineSkewed.log_prob`.
OlaRonning May 2, 2021
3b44ebe
Added pep exception to `distributions.__init__`
OlaRonning May 2, 2021
84ac72e
Fixed `SineSkewed` on cuda.
OlaRonning May 3, 2021
c92ef62
Restricted `event_dim=2`
OlaRonning May 5, 2021
906211a
Fixed doc_string and updated tests.
OlaRonning May 5, 2021
e237531
fixed linting
OlaRonning May 5, 2021
5e4020a
fixed arg_constraints
OlaRonning May 5, 2021
bd93a2b
cleaned __repr__
OlaRonning May 5, 2021
4646ce6
Fixed comments.
OlaRonning May 5, 2021
ffe50e6
Fixed `n_dim=1` and updated `test_sine_skewed`; missing updated fixtu…
OlaRonning May 7, 2021
d935f74
Added fixture.
OlaRonning May 7, 2021
9ca8a45
Fixed tests.
OlaRonning May 9, 2021
e98bb2b
Merge branch 'feature/ss_fix_dim' into feature/ss_dist
OlaRonning May 9, 2021
51f6365
Merge branch 'dev' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning May 9, 2021
dd461fd
removed deprecated add_stylesheet
OlaRonning May 9, 2021
5cfee34
reverted to `add_stylesheet`
OlaRonning May 9, 2021
6d79eb3
Removed raise from sine_skewed.py
OlaRonning May 10, 2021
ff26ce9
Added equation references.
OlaRonning May 10, 2021
4427e37
Fixed sampling bound in `SineSkewed`.
OlaRonning May 10, 2021
a7bf5fe
Fixed prior on `SineSkewed` to avoid `AffineTransform`.
OlaRonning May 10, 2021
8d5d684
Merge remote-tracking branch 'origin/feature/ss_dist' into feature/ss…
OlaRonning May 10, 2021
c9ede43
Merged origin.
OlaRonning May 10, 2021
b1e1408
Merge branch 'master' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning Jun 2, 2021
fac5864
removed import all pyro distributions
OlaRonning Jun 2, 2021
21af6e5
Merged upstream and fixed docstring for `SineSkewed`.
OlaRonning Jun 2, 2021
9fc36ff
Fixed tests for SineSkewed with wrapper class.
OlaRonning Jun 2, 2021
9b78a7a
Removed unused import from conftest.py
OlaRonning Jun 2, 2021
0b5af73
Removed xfail int test_cuda for `SineSkewed`
OlaRonning Jun 2, 2021
4d44aa0
Fixed DocString example.
OlaRonning Jun 4, 2021
3143b45
Fixed psi_phi name in docstring.
OlaRonning Jun 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ Rejector
:undoc-members:
:show-inheritance:

SineSkewed
----------
.. autoclass:: pyro.distributions.SineSkewed
:members:
:undoc-members:
:show-inheritance:

SoftLaplace
-------------
.. autoclass:: pyro.distributions.SoftLaplace
Expand Down
3 changes: 3 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@
RelaxedBernoulliStraightThrough,
RelaxedOneHotCategoricalStraightThrough,
)
from pyro.distributions.sine_skewed import SineSkewed
from pyro.distributions.softlaplace import SoftLaplace
from pyro.distributions.spanning_tree import SpanningTree
from pyro.distributions.stable import Stable
from pyro.distributions.torch import * # noqa F403
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
from pyro.distributions.torch import __all__ as torch_dists
from pyro.distributions.torch_distribution import (
ExpandedDistribution,
Expand Down Expand Up @@ -128,6 +130,7 @@
"Rejector",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategoricalStraightThrough",
"SineSkewed",
"SoftLaplace",
"SpanningTree",
"Stable",
Expand Down
96 changes: 96 additions & 0 deletions pyro/distributions/sine_skewed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import warnings
from math import pi

import torch
from torch import broadcast_shapes
from torch.distributions import Uniform

from pyro.distributions import constraints

from .torch_distribution import TorchDistribution


class SineSkewed(TorchDistribution):
"""The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over
the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the
1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile).

The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
For example, the following will produce a uniform prior over skewness for the 2-torus,::

def model(...):
...
skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.))
psi_bound = 1 - skewness_phi.abs()
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.))
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0)
...

In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a
latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist
cannot be reparameterized.

.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
must be less than or equal to one. See eq. 2.1 in [1].

** References: **
1. Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)

:param base_dist: base density on a d-dimensional torus.
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
:param skewness: skewness of the distribution.
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
"""
arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)}
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

support = constraints.independent(constraints.real, 1)

def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None):
if (skewness.abs().sum(-1) > 1.).any():
warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning)

OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
event_shape = skewness.shape[-1:]
self.skewness = skewness.broadcast_to(batch_shape + event_shape)
self.base_dist = base_dist.expand(batch_shape)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

if self._validate_args and base_dist.mean.device != skewness.device:
raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed "
f"must be on same device.")

def __repr__(self):
args_string = ', '.join(['{}: {}'.format(p, getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size()) for p in self.arg_constraints.keys()])
return self.__class__.__name__ + '(' + f'base_density: {str(self.base_dist)}, ' + args_string + ')'

def sample(self, sample_shape=torch.Size()):
bd = self.base_dist
ys = bd.sample(sample_shape)
u = Uniform(0., torch.ones(torch.Size([]), device=self.skewness.device)).sample(sample_shape + self.batch_shape)
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

# Section 2.3 step 3 in [1]
mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1)
mask = mask[..., None]
samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
return samples

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Eq. 2.1 in [1]
skew_prob = torch.log(1 + (self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1))
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
return self.base_dist.log_prob(value) + skew_prob

def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(SineSkewed, _instance)
base_dist = self.base_dist.expand(batch_shape)
new.base_dist = base_dist
new.skewness = self.skewness.expand(batch_shape + (-1,))
super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
36 changes: 23 additions & 13 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from math import pi

import numpy as np
import pytest
Expand All @@ -15,7 +16,7 @@
ShapeAugmentedDirichlet,
ShapeAugmentedGamma,
)
from tests.distributions.dist_fixture import Fixture
from tests.distributions.dist_fixture import Fixture, tensor_wrap


class FoldedNormal(dist.FoldedDistribution):
Expand Down Expand Up @@ -187,7 +188,7 @@ def __init__(self, rate, *, validate_args=None):
],
# This hack seems to be the best option right now, as 'scale' is not handled well by get_scipy_batch_logpdf
scipy_arg_fn=lambda loc, covariance_matrix=None:
((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}),
((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}),
prec=0.01,
min_samples=500000),
Fixture(pyro_dist=dist.LowRankMultivariateNormal,
Expand All @@ -197,7 +198,7 @@ def __init__(self, rate, *, validate_args=None):
'test_data': [[2.0, 1.0], [9.0, 3.4]]},
],
scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None:
((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}),
((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}),
prec=0.01,
min_samples=500000),
Fixture(pyro_dist=FoldedNormal,
Expand Down Expand Up @@ -280,12 +281,12 @@ def __init__(self, rate, *, validate_args=None):
Fixture(pyro_dist=dist.LKJ,
examples=[
{'dim': 3, 'concentration': 1., 'test_data':
[[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]],
[[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]],
[[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]],
[[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]],
[[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]},
]),
[[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]],
[[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]],
[[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]],
[[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]],
[[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]},
]),
Fixture(pyro_dist=dist.LKJCholesky,
examples=[
{
Expand All @@ -305,19 +306,19 @@ def __init__(self, rate, *, validate_args=None):
examples=[
{'stability': [1.5], 'skew': 0.1, 'test_data': [-10.]},
{'stability': [1.5], 'skew': 0.1, 'scale': 2.0, 'loc': -2.0, 'test_data': [10.]},
]),
]),
Fixture(pyro_dist=dist.MultivariateStudentT,
examples=[
{'df': 1.5, 'loc': [0.2, 0.3], 'scale_tril': [[0.8, 0.0], [1.3, 0.4]],
'test_data': [-3., 2]},
]),
]),
Fixture(pyro_dist=dist.ProjectedNormal,
examples=[
{'concentration': [0., 0.], 'test_data': [1., 0.]},
{'concentration': [2., 3.], 'test_data': [0., 1.]},
{'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]},
{'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]},
]),
]),
Fixture(pyro_dist=dist.SoftLaplace,
examples=[
{'loc': [2.0], 'scale': [4.0],
Expand All @@ -328,7 +329,16 @@ def __init__(self, rate, *, validate_args=None):
'test_data': [[[2.0]]]},
{'loc': [2.0, 50.0], 'scale': [4.0, 100.0],
'test_data': [[2.0, 50.0], [2.0, 50.0]]},
]),
]),
Fixture(pyro_dist=dist.SineSkewed,
examples=[
{'base_dist': dist.VonMises(*tensor_wrap([0.], [1.])).to_event(1),
'skewness': [.342355], 'test_data': [.1]},
{'base_dist': dist.Uniform(*tensor_wrap([-pi, -pi], [pi, pi])).to_event(1),
'skewness': [-pi / 4, .1], 'test_data': [pi / 2, -2 * pi / 3]},
{'base_dist': dist.VonMises(*tensor_wrap([0., -1.234], [1., 10.])).to_event(1),
'skewness': [[.342355, -.0001], [.91, 0.09]], 'test_data': [[.1, -3.2], [-2., 0.]]},
])
]

discrete_dists = [
Expand Down
4 changes: 4 additions & 0 deletions tests/distributions/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

@requires_cuda
def test_sample(dist):
if dist.pyro_dist.__name__ == 'SineSkewed':
pytest.xfail(reason="Fixture with distribution param not handled.")
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
for idx in range(len(dist.dist_params)):

# Compute CPU value.
Expand Down Expand Up @@ -77,6 +79,8 @@ def test_rsample(dist):

@requires_cuda
def test_log_prob(dist):
if dist.pyro_dist.__name__ == 'SineSkewed':
pytest.xfail(reason="Fixture with distribution param not handled.")
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
for idx in range(len(dist.dist_params)):

# Compute CPU value.
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_support_shape(dist):


def test_infer_shapes(dist):
if "LKJ" in dist.pyro_dist.__name__:
if "LKJ" in dist.pyro_dist.__name__ or "SineSkewed" == dist.pyro_dist.__name__:
pytest.xfail(reason="cannot statically compute shape")
for idx in range(dist.get_num_test_data()):
dist_params = dist.get_dist_params(idx)
Expand Down
79 changes: 79 additions & 0 deletions tests/distributions/test_sine_skewed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from math import pi

import pytest
import torch

import pyro
from pyro.distributions import Normal, SineSkewed, Uniform, VonMises, constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tests.common import assert_equal

BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))]


def _skewness(event_shape):
skewness = torch.zeros(event_shape.numel())
done = False
while not done:
for i in range(event_shape.numel()):
max_ = 1. - skewness.abs().sum(-1)
if torch.any(max_ < 1e-15):
break
skewness[i] = Uniform(-max_, max_).sample()
done = not torch.any(max_ < 1e-15)

if event_shape == tuple():
skewness = skewness.reshape(event_shape)
else:
skewness = skewness.view(event_shape)
return skewness


@pytest.mark.parametrize('expand_shape',
[(1,), (2,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)])
@pytest.mark.parametrize('dist', BASE_DISTS)
def test_ss_multidim_log_prob(expand_shape, dist):
base_dist = dist[0](*(torch.tensor(param).expand(expand_shape) for param in dist[1])).to_event(1)

loc = base_dist.sample((10,)) + Normal(0., 1e-3).sample()

base_prob = base_dist.log_prob(loc)
skewness = _skewness(base_dist.event_shape)

ss = SineSkewed(base_dist, skewness)
assert_equal(base_prob.shape, ss.log_prob(loc).shape)
assert_equal(ss.sample().shape, torch.Size(expand_shape))


@pytest.mark.parametrize('dist', BASE_DISTS)
@pytest.mark.parametrize('dim', [1, 2])
def test_ss_mle(dim, dist):
base_dist = dist[0](*(torch.tensor(param).expand((dim,)) for param in dist[1])).to_event(1)

skewness_tar = _skewness(base_dist.event_shape)
data = SineSkewed(base_dist, skewness_tar).sample((1000,))

def model(data, batch_shape):
skews = []
for i in range(dim):
skews.append(pyro.param(f'skew{i}', .5 * torch.ones(batch_shape), constraint=constraints.interval(-1, 1)))

skewness = torch.stack(skews, dim=-1)
with pyro.plate("data", data.size(-len(data.size()))):
pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data)

def guide(data, batch_shape):
pass

pyro.clear_param_store()
adam = Adam({"lr": .1})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

losses = []
steps = 80
for step in range(steps):
losses.append(svi.step(data, base_dist.batch_shape))

act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1)
assert_equal(act_skewness, skewness_tar, 1e-1)