Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default TorchDistribution.expand method #2209

Merged
merged 2 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import warnings
from collections import OrderedDict

import torch
from torch.distributions import constraints
from torch.distributions.kl import kl_divergence, register_kl

import pyro.distributions.torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import broadcast_shape, scale_and_mask


Expand Down Expand Up @@ -63,6 +65,19 @@ def shape(self, sample_shape=torch.Size()):
"""
return sample_shape + self.batch_shape + self.event_shape

def expand(self, batch_shape, _instance=None):
"""
Returns a new :class:`ExpandedDistribution` instance with batch
dimensions expanded to `batch_shape`.

:param tuple batch_shape: batch shape to expand to.
:param _instance: unused argument for compatibility with
:meth:`torch.distributions.Distribution.expand`
:return: an instance of `ExpandedDistribution`.
:rtype: :class:`ExpandedDistribution`
"""
return ExpandedDistribution(self, batch_shape)

def expand_by(self, sample_shape):
"""
Expands a distribution by adding ``sample_shape`` to the left side of
Expand All @@ -74,9 +89,13 @@ def expand_by(self, sample_shape):
:param torch.Size sample_shape: The size of the iid batch to be drawn
from the distribution.
:return: An expanded version of this distribution.
:rtype: :class:`ReshapedDistribution`
:rtype: :class:`ExpandedDistribution`
"""
return self.expand(torch.Size(sample_shape) + self.batch_shape)
try:
expanded_dist = self.expand(torch.Size(sample_shape) + self.batch_shape)
except NotImplementedError:
expanded_dist = TorchDistributionMixin.expand(self, torch.Size(sample_shape) + self.batch_shape)
return expanded_dist

def reshape(self, sample_shape=None, extra_event_dims=None):
raise Exception('''
Expand Down Expand Up @@ -197,7 +216,9 @@ class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin
method to improve gradient estimates and set
``.has_enumerate_support = True``.
"""
pass
# Provides a default `.expand` method for Pyro distributions which overrides
# torch.distributions.Distribution.expand (throws a NotImplementedError).
expand = TorchDistributionMixin.expand


class MaskedDistribution(TorchDistribution):
Expand Down Expand Up @@ -280,6 +301,110 @@ def variance(self):
return self.base_dist.variance


class ExpandedDistribution(TorchDistribution):
arg_constraints = {}

def __init__(self, base_dist, batch_shape=torch.Size()):
self.base_dist = base_dist
super(ExpandedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
# adjust batch shape
self.expand(batch_shape)

def expand(self, batch_shape, _instance=None):
# Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
# Record interstitial and expanded dims/sizes w.r.t. the base distribution
new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
new_shape)
self._batch_shape = new_shape
self._expanded_sizes = expanded_sizes
self._interstitial_sizes = interstitial_sizes
return self

@staticmethod
def _broadcast_shape(existing_shape, new_shape):
if len(new_shape) < len(existing_shape):
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
reversed_shape = list(reversed(existing_shape))
expanded_sizes, interstitial_sizes = [], []
for i, size in enumerate(reversed(new_shape)):
if i >= len(reversed_shape):
reversed_shape.append(size)
expanded_sizes.append((-i - 1, size))
elif reversed_shape[i] == 1:
if size != 1:
reversed_shape[i] = size
interstitial_sizes.append((-i - 1, size))
elif reversed_shape[i] != size:
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
return tuple(reversed(reversed_shape)), OrderedDict(expanded_sizes), OrderedDict(interstitial_sizes)

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support

@constraints.dependent_property
def support(self):
return self.base_dist.support

def _sample(self, sample_fn, sample_shape):
interstitial_dims = tuple(self._interstitial_sizes.keys())
interstitial_dims = tuple(i - self.event_dim for i in interstitial_dims)
interstitial_sizes = tuple(self._interstitial_sizes.values())
expanded_sizes = tuple(self._expanded_sizes.values())
batch_shape = expanded_sizes + interstitial_sizes
samples = sample_fn(sample_shape + batch_shape)
interstitial_idx = len(sample_shape) + len(expanded_sizes)
interstitial_sample_dims = tuple(range(interstitial_idx, interstitial_idx + len(interstitial_sizes)))
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
samples = samples.transpose(dim1, dim2)
return samples.reshape(sample_shape + self.batch_shape + self.event_shape)

def sample(self, sample_shape=torch.Size()):
return self._sample(self.base_dist.sample, sample_shape)

def rsample(self, sample_shape=torch.Size()):
return self._sample(self.base_dist.rsample, sample_shape)

def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
log_prob = self.base_dist.log_prob(value)
return log_prob.expand(shape)

def score_parts(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
if self.batch_shape != self.base_dist.batch_shape:
log_prob = log_prob.expand(shape)
if isinstance(score_function, torch.Tensor):
score_function = score_function.expand(shape)
if isinstance(score_function, torch.Tensor):
entropy_term = entropy_term.expand(shape)
return ScoreParts(log_prob, score_function, entropy_term)

def enumerate_support(self, expand=True):
samples = self.base_dist.enumerate_support(expand=expand)
enum_shape = samples.shape[:1]
samples = samples.reshape(enum_shape + (1,) * len(self.batch_shape))
if expand:
samples = samples.expand(enum_shape + self.batch_shape)
return samples

@property
def mean(self):
return self.base_dist.mean.expand(self.batch_shape + self.event_shape)

@property
def variance(self):
return self.base_dist.variance.expand(self.batch_shape + self.event_shape)


@register_kl(MaskedDistribution, MaskedDistribution)
def _kl_masked_masked(p, q):
if p._mask is False or q._mask is False:
Expand Down
86 changes: 56 additions & 30 deletions tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyro
import pyro.distributions as dist
from pyro.distributions import TorchDistribution
from pyro.distributions.util import broadcast_shape
from tests.common import assert_equal, xfail_if_not_implemented

Expand Down Expand Up @@ -143,25 +144,33 @@ def check_sample_shapes(small, large):
def test_expand_by(dist, sample_shape, shape_type):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand_by(shape_type(sample_shape))
assert large.batch_shape == sample_shape + small.batch_shape
check_sample_shapes(small, large)
large = small.expand_by(shape_type(sample_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)])
@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list])
def test_expand_new_dim(dist, sample_shape, shape_type):
@pytest.mark.parametrize('default', [False, True])
def test_expand_new_dim(dist, sample_shape, shape_type, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand(shape_type(sample_shape + small.batch_shape))
assert large.batch_shape == sample_shape + small.batch_shape
check_sample_shapes(small, large)
if default:
large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape))
else:
with xfail_if_not_implemented():
large = small.expand(shape_type(sample_shape + small.batch_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list])
def test_expand_existing_dim(dist, shape_type):
@pytest.mark.parametrize('default', [False, True])
def test_expand_existing_dim(dist, shape_type, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
for dim, size in enumerate(small.batch_shape):
Expand All @@ -170,24 +179,33 @@ def test_expand_existing_dim(dist, shape_type):
batch_shape = list(small.batch_shape)
batch_shape[dim] = 5
batch_shape = torch.Size(batch_shape)
with xfail_if_not_implemented():
large = small.expand(shape_type(batch_shape))
assert large.batch_shape == batch_shape
check_sample_shapes(small, large)
if default:
large = TorchDistribution.expand(small, shape_type(batch_shape))
else:
with xfail_if_not_implemented():
large = small.expand(shape_type(batch_shape))
assert large.batch_shape == batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize("sample_shapes", [
[(2, 1), (2, 3)],
[(2, 1, 1), (2, 1, 3), (2, 5, 3)],
])
def test_subsequent_expands_ok(dist, sample_shapes):
@pytest.mark.parametrize('default', [False, True])
def test_subsequent_expands_ok(dist, sample_shapes, default):
for idx in range(dist.get_num_test_data()):
d = dist.pyro_dist(**dist.get_dist_params(idx))
original_batch_shape = d.batch_shape
for shape in sample_shapes:
proposed_batch_shape = torch.Size(shape) + original_batch_shape
with xfail_if_not_implemented():
n = d.expand(proposed_batch_shape)
if default:
n = TorchDistribution.expand(d, proposed_batch_shape)
else:
with xfail_if_not_implemented():
n = d.expand(proposed_batch_shape)
assert n.batch_shape == proposed_batch_shape
with xfail_if_not_implemented():
check_sample_shapes(d, n)
Expand All @@ -199,17 +217,21 @@ def test_subsequent_expands_ok(dist, sample_shapes):
[(2, 4), (2, 2, 1)],
[(1, 2, 1), (2, 1)],
])
def test_expand_error(dist, initial_shape, proposed_shape):
@pytest.mark.parametrize("default", [False, True])
def test_expand_error(dist, initial_shape, proposed_shape, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand(torch.Size(initial_shape) + small.batch_shape)
proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
if dist.get_test_distribution_name() == 'LKJCorrCholesky':
pytest.skip('LKJCorrCholesky can expand to a shape not' +
'broadcastable with its original batch_shape.')
with pytest.raises(RuntimeError):
large.expand(proposed_batch_shape)
if default:
large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
else:
with xfail_if_not_implemented():
large = small.expand(torch.Size(initial_shape) + small.batch_shape)
proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
if dist.get_test_distribution_name() == 'LKJCorrCholesky':
pytest.skip('LKJCorrCholesky can expand to a shape not' +
'broadcastable with its original batch_shape.')
with pytest.raises((RuntimeError, ValueError)):
large.expand(proposed_batch_shape)


@pytest.mark.parametrize("extra_event_dims,expand_shape", [
Expand All @@ -218,11 +240,15 @@ def test_expand_error(dist, initial_shape, proposed_shape):
(1, [5, 4, 3, 2]),
(2, [5, 4, 3]),
])
def test_expand_reshaped_distribution(extra_event_dims, expand_shape):
@pytest.mark.parametrize('default', [False, True])
def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default):
probs = torch.ones(1, 6) / 6
d = dist.OneHotCategorical(probs)
reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims)
full_shape = torch.Size([4, 1, 1, 1, 6])
if default:
reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event(extra_event_dims)
else:
reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims)
cut = 4 - extra_event_dims
batch_shape, event_shape = full_shape[:cut], full_shape[cut:]
assert reshaped_dist.batch_shape == batch_shape
Expand All @@ -232,9 +258,9 @@ def test_expand_reshaped_distribution(extra_event_dims, expand_shape):
assert large.event_shape == torch.Size(event_shape)

# Throws error when batch shape cannot be broadcasted
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, ValueError)):
reshaped_dist.expand(expand_shape + [3])

# Throws error when trying to shrink existing batch shape
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, ValueError)):
large.expand(expand_shape[1:])