Skip to content

Commit

Permalink
Add default TorchDistribution.expand method (#2209)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Dec 5, 2019
1 parent 5058d39 commit ab6a88b
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 33 deletions.
131 changes: 128 additions & 3 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import warnings
from collections import OrderedDict

import torch
from torch.distributions import constraints
from torch.distributions.kl import kl_divergence, register_kl

import pyro.distributions.torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import broadcast_shape, scale_and_mask


Expand Down Expand Up @@ -63,6 +65,19 @@ def shape(self, sample_shape=torch.Size()):
"""
return sample_shape + self.batch_shape + self.event_shape

def expand(self, batch_shape, _instance=None):
"""
Returns a new :class:`ExpandedDistribution` instance with batch
dimensions expanded to `batch_shape`.
:param tuple batch_shape: batch shape to expand to.
:param _instance: unused argument for compatibility with
:meth:`torch.distributions.Distribution.expand`
:return: an instance of `ExpandedDistribution`.
:rtype: :class:`ExpandedDistribution`
"""
return ExpandedDistribution(self, batch_shape)

def expand_by(self, sample_shape):
"""
Expands a distribution by adding ``sample_shape`` to the left side of
Expand All @@ -74,9 +89,13 @@ def expand_by(self, sample_shape):
:param torch.Size sample_shape: The size of the iid batch to be drawn
from the distribution.
:return: An expanded version of this distribution.
:rtype: :class:`ReshapedDistribution`
:rtype: :class:`ExpandedDistribution`
"""
return self.expand(torch.Size(sample_shape) + self.batch_shape)
try:
expanded_dist = self.expand(torch.Size(sample_shape) + self.batch_shape)
except NotImplementedError:
expanded_dist = TorchDistributionMixin.expand(self, torch.Size(sample_shape) + self.batch_shape)
return expanded_dist

def reshape(self, sample_shape=None, extra_event_dims=None):
raise Exception('''
Expand Down Expand Up @@ -197,7 +216,9 @@ class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin
method to improve gradient estimates and set
``.has_enumerate_support = True``.
"""
pass
# Provides a default `.expand` method for Pyro distributions which overrides
# torch.distributions.Distribution.expand (throws a NotImplementedError).
expand = TorchDistributionMixin.expand


class MaskedDistribution(TorchDistribution):
Expand Down Expand Up @@ -280,6 +301,110 @@ def variance(self):
return self.base_dist.variance


class ExpandedDistribution(TorchDistribution):
arg_constraints = {}

def __init__(self, base_dist, batch_shape=torch.Size()):
self.base_dist = base_dist
super(ExpandedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
# adjust batch shape
self.expand(batch_shape)

def expand(self, batch_shape, _instance=None):
# Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
# Record interstitial and expanded dims/sizes w.r.t. the base distribution
new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
new_shape)
self._batch_shape = new_shape
self._expanded_sizes = expanded_sizes
self._interstitial_sizes = interstitial_sizes
return self

@staticmethod
def _broadcast_shape(existing_shape, new_shape):
if len(new_shape) < len(existing_shape):
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
reversed_shape = list(reversed(existing_shape))
expanded_sizes, interstitial_sizes = [], []
for i, size in enumerate(reversed(new_shape)):
if i >= len(reversed_shape):
reversed_shape.append(size)
expanded_sizes.append((-i - 1, size))
elif reversed_shape[i] == 1:
if size != 1:
reversed_shape[i] = size
interstitial_sizes.append((-i - 1, size))
elif reversed_shape[i] != size:
raise ValueError("Cannot broadcast distribution of shape {} to shape {}"
.format(existing_shape, new_shape))
return tuple(reversed(reversed_shape)), OrderedDict(expanded_sizes), OrderedDict(interstitial_sizes)

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support

@constraints.dependent_property
def support(self):
return self.base_dist.support

def _sample(self, sample_fn, sample_shape):
interstitial_dims = tuple(self._interstitial_sizes.keys())
interstitial_dims = tuple(i - self.event_dim for i in interstitial_dims)
interstitial_sizes = tuple(self._interstitial_sizes.values())
expanded_sizes = tuple(self._expanded_sizes.values())
batch_shape = expanded_sizes + interstitial_sizes
samples = sample_fn(sample_shape + batch_shape)
interstitial_idx = len(sample_shape) + len(expanded_sizes)
interstitial_sample_dims = tuple(range(interstitial_idx, interstitial_idx + len(interstitial_sizes)))
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
samples = samples.transpose(dim1, dim2)
return samples.reshape(sample_shape + self.batch_shape + self.event_shape)

def sample(self, sample_shape=torch.Size()):
return self._sample(self.base_dist.sample, sample_shape)

def rsample(self, sample_shape=torch.Size()):
return self._sample(self.base_dist.rsample, sample_shape)

def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
log_prob = self.base_dist.log_prob(value)
return log_prob.expand(shape)

def score_parts(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
if self.batch_shape != self.base_dist.batch_shape:
log_prob = log_prob.expand(shape)
if isinstance(score_function, torch.Tensor):
score_function = score_function.expand(shape)
if isinstance(score_function, torch.Tensor):
entropy_term = entropy_term.expand(shape)
return ScoreParts(log_prob, score_function, entropy_term)

def enumerate_support(self, expand=True):
samples = self.base_dist.enumerate_support(expand=expand)
enum_shape = samples.shape[:1]
samples = samples.reshape(enum_shape + (1,) * len(self.batch_shape))
if expand:
samples = samples.expand(enum_shape + self.batch_shape)
return samples

@property
def mean(self):
return self.base_dist.mean.expand(self.batch_shape + self.event_shape)

@property
def variance(self):
return self.base_dist.variance.expand(self.batch_shape + self.event_shape)


@register_kl(MaskedDistribution, MaskedDistribution)
def _kl_masked_masked(p, q):
if p._mask is False or q._mask is False:
Expand Down
86 changes: 56 additions & 30 deletions tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyro
import pyro.distributions as dist
from pyro.distributions import TorchDistribution
from pyro.distributions.util import broadcast_shape
from tests.common import assert_equal, xfail_if_not_implemented

Expand Down Expand Up @@ -143,25 +144,33 @@ def check_sample_shapes(small, large):
def test_expand_by(dist, sample_shape, shape_type):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand_by(shape_type(sample_shape))
assert large.batch_shape == sample_shape + small.batch_shape
check_sample_shapes(small, large)
large = small.expand_by(shape_type(sample_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)])
@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list])
def test_expand_new_dim(dist, sample_shape, shape_type):
@pytest.mark.parametrize('default', [False, True])
def test_expand_new_dim(dist, sample_shape, shape_type, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand(shape_type(sample_shape + small.batch_shape))
assert large.batch_shape == sample_shape + small.batch_shape
check_sample_shapes(small, large)
if default:
large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape))
else:
with xfail_if_not_implemented():
large = small.expand(shape_type(sample_shape + small.batch_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize('shape_type', [torch.Size, tuple, list])
def test_expand_existing_dim(dist, shape_type):
@pytest.mark.parametrize('default', [False, True])
def test_expand_existing_dim(dist, shape_type, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
for dim, size in enumerate(small.batch_shape):
Expand All @@ -170,24 +179,33 @@ def test_expand_existing_dim(dist, shape_type):
batch_shape = list(small.batch_shape)
batch_shape[dim] = 5
batch_shape = torch.Size(batch_shape)
with xfail_if_not_implemented():
large = small.expand(shape_type(batch_shape))
assert large.batch_shape == batch_shape
check_sample_shapes(small, large)
if default:
large = TorchDistribution.expand(small, shape_type(batch_shape))
else:
with xfail_if_not_implemented():
large = small.expand(shape_type(batch_shape))
assert large.batch_shape == batch_shape
if dist.get_test_distribution_name() == 'Stable':
pytest.skip('Stable does not implement a log_prob method.')
check_sample_shapes(small, large)


@pytest.mark.parametrize("sample_shapes", [
[(2, 1), (2, 3)],
[(2, 1, 1), (2, 1, 3), (2, 5, 3)],
])
def test_subsequent_expands_ok(dist, sample_shapes):
@pytest.mark.parametrize('default', [False, True])
def test_subsequent_expands_ok(dist, sample_shapes, default):
for idx in range(dist.get_num_test_data()):
d = dist.pyro_dist(**dist.get_dist_params(idx))
original_batch_shape = d.batch_shape
for shape in sample_shapes:
proposed_batch_shape = torch.Size(shape) + original_batch_shape
with xfail_if_not_implemented():
n = d.expand(proposed_batch_shape)
if default:
n = TorchDistribution.expand(d, proposed_batch_shape)
else:
with xfail_if_not_implemented():
n = d.expand(proposed_batch_shape)
assert n.batch_shape == proposed_batch_shape
with xfail_if_not_implemented():
check_sample_shapes(d, n)
Expand All @@ -199,17 +217,21 @@ def test_subsequent_expands_ok(dist, sample_shapes):
[(2, 4), (2, 2, 1)],
[(1, 2, 1), (2, 1)],
])
def test_expand_error(dist, initial_shape, proposed_shape):
@pytest.mark.parametrize("default", [False, True])
def test_expand_error(dist, initial_shape, proposed_shape, default):
for idx in range(dist.get_num_test_data()):
small = dist.pyro_dist(**dist.get_dist_params(idx))
with xfail_if_not_implemented():
large = small.expand(torch.Size(initial_shape) + small.batch_shape)
proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
if dist.get_test_distribution_name() == 'LKJCorrCholesky':
pytest.skip('LKJCorrCholesky can expand to a shape not' +
'broadcastable with its original batch_shape.')
with pytest.raises(RuntimeError):
large.expand(proposed_batch_shape)
if default:
large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
else:
with xfail_if_not_implemented():
large = small.expand(torch.Size(initial_shape) + small.batch_shape)
proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
if dist.get_test_distribution_name() == 'LKJCorrCholesky':
pytest.skip('LKJCorrCholesky can expand to a shape not' +
'broadcastable with its original batch_shape.')
with pytest.raises((RuntimeError, ValueError)):
large.expand(proposed_batch_shape)


@pytest.mark.parametrize("extra_event_dims,expand_shape", [
Expand All @@ -218,11 +240,15 @@ def test_expand_error(dist, initial_shape, proposed_shape):
(1, [5, 4, 3, 2]),
(2, [5, 4, 3]),
])
def test_expand_reshaped_distribution(extra_event_dims, expand_shape):
@pytest.mark.parametrize('default', [False, True])
def test_expand_reshaped_distribution(extra_event_dims, expand_shape, default):
probs = torch.ones(1, 6) / 6
d = dist.OneHotCategorical(probs)
reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims)
full_shape = torch.Size([4, 1, 1, 1, 6])
if default:
reshaped_dist = TorchDistribution.expand(d, [4, 1, 1, 1]).to_event(extra_event_dims)
else:
reshaped_dist = d.expand_by([4, 1, 1]).to_event(extra_event_dims)
cut = 4 - extra_event_dims
batch_shape, event_shape = full_shape[:cut], full_shape[cut:]
assert reshaped_dist.batch_shape == batch_shape
Expand All @@ -232,9 +258,9 @@ def test_expand_reshaped_distribution(extra_event_dims, expand_shape):
assert large.event_shape == torch.Size(event_shape)

# Throws error when batch shape cannot be broadcasted
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, ValueError)):
reshaped_dist.expand(expand_shape + [3])

# Throws error when trying to shrink existing batch shape
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, ValueError)):
large.expand(expand_shape[1:])

0 comments on commit ab6a88b

Please sign in to comment.