Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TorchDistribution.expand() #1119

Closed
fritzo opened this issue May 3, 2018 · 0 comments · Fixed by #1120
Closed

Implement TorchDistribution.expand() #1119

fritzo opened this issue May 3, 2018 · 0 comments · Fixed by #1120
Assignees

Comments

@fritzo
Copy link
Member

fritzo commented May 3, 2018

Why?

Pyro currently provides an .expand_by() method to aid/burden users when reshaping distributions. However .expand_by() is limited to expanding only by adding new dimensions on the left. It appears that to automate some of this expansion (#238, #1115), we'll need to additionally support broadcasting at internal dimension, as is done by torch.Tensor.expand().

Here is an example demonstrating the necessity of .expand() (regardless of dim ordering):

X, Y = 320, 200
x_axis = pyro.iarange("x_axis", X, dim=-2)
y_axis = pyro.iarange("y_axis", Y, dim=-1)
with x_axis:
    x = pyro.sample("x", dist.Normal(0, 1).expand_by([X, 1])      # .expand_by() suffices
with y_axis:
    y = pyro.sample("y", dist.Normal(0, 1).expand_by([Y])         # .expand_by() suffices
with x_axis, y_axis:
    yx = pyro.sample("yx", dist.Normal(y_noise, 1).expand_by([X]) # .expand_by() suffices
    xy = pyro.sample("xy", dist.Normal(x_noise, 1).expand([X, Y]) # .expand() is needed
    ...

How?

First .expand() will operate only on batch_shape. Note that .expand_by() is agnostic about which shape is being expanded (batch_shape of the .log_prob() results -versus- batch_shape+event_shape of .sample() results). In contrast, .expand() will need to decide which of the two shapes is being expanded. Here are three reasons why batch_shape is more sensible:

  1. It is unclear how to expand event_shape in general (e.g. how do you expand a normalized set of probabilities?)
  2. .expand() can only change dimension of size 1 to dimensions of size > 1. Yet none of our event shapes have meaningful dimensions of size 1. (e.g. a OneHotCategorical of size 1 is degenerate; MultivariateNormal of size 1 is just a Normal).
  3. The current use cases only need to expand batch_shape.

Mechanically, .expand() can attempt to apply .expand_by() under the hood. As a last resort, it can combine .expand_by() with a PermutedDistribution that continues to use .sample(sample_shape) to add batching on the left, but then permutes that batching into internal dims. Since PermutedDistribution may be expensive and lead to non-contiguous tensors, we should implement overrides in wrappers for the most commonly used distributions, e.g.

# in pyro/distributions/torch.py

class Normal(torch.distributions.Normal, TorchDistributionMixin):
    def expand(self, batch_shape):
        loc = self.loc.expand(batch_shape)
        scale = self.scale.expand(batch_shape)
        return Normal(loc, scale)

class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin):
    def expand(self, batch_shape):
        batch_shape = torch.Size(batch_shape)
        loc = self.loc.expand(batch_shape + self.event_shape)
        # we'll want to check for covariance_matrix and precision_matrix, but roughly:
        scale_tril = self.scale_tril.expand(batch_shape + 2 * self.event_shape)
        return MultivariateNormal(loc, scale_tril=scale_tril)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant