You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Pyro currently provides an .expand_by() method to aid/burden users when reshaping distributions. However .expand_by() is limited to expanding only by adding new dimensions on the left. It appears that to automate some of this expansion (#238, #1115), we'll need to additionally support broadcasting at internal dimension, as is done by torch.Tensor.expand().
Here is an example demonstrating the necessity of .expand() (regardless of dim ordering):
First .expand() will operate only on batch_shape. Note that .expand_by() is agnostic about which shape is being expanded (batch_shape of the .log_prob() results -versus- batch_shape+event_shape of .sample() results). In contrast, .expand() will need to decide which of the two shapes is being expanded. Here are three reasons why batch_shape is more sensible:
It is unclear how to expand event_shape in general (e.g. how do you expand a normalized set of probabilities?)
.expand() can only change dimension of size 1 to dimensions of size > 1. Yet none of our event shapes have meaningful dimensions of size 1. (e.g. a OneHotCategorical of size 1 is degenerate; MultivariateNormal of size 1 is just a Normal).
The current use cases only need to expand batch_shape.
Mechanically, .expand() can attempt to apply .expand_by() under the hood. As a last resort, it can combine .expand_by() with a PermutedDistribution that continues to use .sample(sample_shape) to add batching on the left, but then permutes that batching into internal dims. Since PermutedDistribution may be expensive and lead to non-contiguous tensors, we should implement overrides in wrappers for the most commonly used distributions, e.g.
# in pyro/distributions/torch.pyclassNormal(torch.distributions.Normal, TorchDistributionMixin):
defexpand(self, batch_shape):
loc=self.loc.expand(batch_shape)
scale=self.scale.expand(batch_shape)
returnNormal(loc, scale)
classMultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin):
defexpand(self, batch_shape):
batch_shape=torch.Size(batch_shape)
loc=self.loc.expand(batch_shape+self.event_shape)
# we'll want to check for covariance_matrix and precision_matrix, but roughly:scale_tril=self.scale_tril.expand(batch_shape+2*self.event_shape)
returnMultivariateNormal(loc, scale_tril=scale_tril)
The text was updated successfully, but these errors were encountered:
Why?
Pyro currently provides an
.expand_by()
method to aid/burden users when reshaping distributions. However.expand_by()
is limited to expanding only by adding new dimensions on the left. It appears that to automate some of this expansion (#238, #1115), we'll need to additionally support broadcasting at internal dimension, as is done bytorch.Tensor.expand()
.Here is an example demonstrating the necessity of
.expand()
(regardless of dim ordering):How?
First
.expand()
will operate only onbatch_shape
. Note that.expand_by()
is agnostic about which shape is being expanded (batch_shape
of the.log_prob()
results -versus-batch_shape+event_shape
of.sample()
results). In contrast,.expand()
will need to decide which of the two shapes is being expanded. Here are three reasons whybatch_shape
is more sensible:event_shape
in general (e.g. how do you expand a normalized set of probabilities?).expand()
can only change dimension of size 1 to dimensions of size > 1. Yet none of our event shapes have meaningful dimensions of size 1. (e.g. aOneHotCategorical
of size 1 is degenerate;MultivariateNormal
of size 1 is just aNormal
).batch_shape
.Mechanically,
.expand()
can attempt to apply.expand_by()
under the hood. As a last resort, it can combine.expand_by()
with aPermutedDistribution
that continues to use.sample(sample_shape)
to add batching on the left, but then permutes that batching into internal dims. SincePermutedDistribution
may be expensive and lead to non-contiguous tensors, we should implement overrides in wrappers for the most commonly used distributions, e.g.The text was updated successfully, but these errors were encountered: