Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more batch distributions in HaarReparam #2731

Merged
merged 7 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions pyro/contrib/forecast/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numbers
from functools import singledispatch

import torch
from torch.distributions import transform_to, transforms

import pyro.distributions as dist
from pyro.infer.reparam import HaarReparam, DiscreteCosineReparam
from pyro.infer.reparam import DiscreteCosineReparam, HaarReparam
from pyro.poutine.messenger import Messenger
from pyro.poutine.util import site_is_subsample
from pyro.primitives import get_param_store
Expand Down Expand Up @@ -190,6 +191,7 @@ def _pyro_sample(self, msg):
dist.Poisson: ("rate",),
dist.Stable: ("stability", "skew", "scale", "loc"),
dist.StudentT: ("df", "loc", "scale"),
dist.Uniform: ("low", "high"),
dist.ZeroInflatedPoisson: ("rate", "gate"),
dist.ZeroInflatedNegativeBinomial: ("total_count", "logits", "gate"),
}
Expand Down Expand Up @@ -241,6 +243,14 @@ def _(d, data):
return dist.FoldedDistribution(base_dist)


@prefix_condition.register(dist.TransformedDistribution)
def _(d, data):
for t in reversed(d.transforms):
data = t.inv(data)
base_dist = prefix_condition(d.base_dist, data)
return dist.TransformedDistribution(base_dist, d.transforms)


def _prefix_condition_univariate(d, data):
t = data.size(-2)
params = {name: getattr(d, name)[..., t:, :]
Expand Down Expand Up @@ -281,7 +291,9 @@ def reshape_batch(d, batch_shape):

@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
mask = d._mask.reshape(batch_shape)
mask = d._mask
if not isinstance(d._mask, numbers.Number):
mask = mask.reshape(batch_shape)
base_dist = reshape_batch(d.base_dist, batch_shape)
return base_dist.mask(mask)

Expand All @@ -306,6 +318,16 @@ def _(d, batch_shape):
return dist.FoldedDistribution(base_dist)


@reshape_batch.register(dist.TransformedDistribution)
def _(d, batch_shape):
base_dist = reshape_batch(d.base_dist, batch_shape)
old_shape = d.base_dist.shape()
new_shape = base_dist.shape()
transforms = [reshape_transform_batch(t, old_shape, new_shape)
for t in d.transforms]
return dist.TransformedDistribution(base_dist, transforms)


def _reshape_batch_univariate(d, batch_shape):
batch_shape = batch_shape + (-1,) * d.event_dim
params = {name: getattr(d, name).reshape(batch_shape)
Expand Down Expand Up @@ -412,12 +434,29 @@ def reshape_transform_batch(t, old_shape, new_shape):


def _reshape_batch_univariate_transform(t, old_shape, new_shape):
params = {name: getattr(t, name).expand(old_shape).reshape(new_shape)
for name in UNIVARIATE_TRANSFORMS[type(t)]}
params = {}
for name in UNIVARIATE_TRANSFORMS[type(t)]:
value = getattr(t, name)
if not isinstance(value, numbers.Number):
value = value.expand(old_shape).reshape(new_shape)
params[name] = value
params["cache_size"] = t._cache_size
return type(t)(**params)


@reshape_transform_batch.register(torch.distributions.transforms._InverseTransform)
def _(t, old_shape, new_shape):
return reshape_transform_batch(t.inv, old_shape, new_shape).inv


@reshape_transform_batch.register(dist.transforms.ComposeTransform)
def _(t, old_shape, new_shape):
return dist.transforms.ComposeTransform([
reshape_transform_batch(part, old_shape, new_shape)
for part in t.parts
])


for _type in UNIVARIATE_TRANSFORMS:
reshape_transform_batch.register(_type)(_reshape_batch_univariate_transform)

Expand Down
3 changes: 3 additions & 0 deletions pyro/infer/reparam/discrete_cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class DiscreteCosineReparam(UnitJacobianReparam):
noise to white noise; when 1 this transforms Brownian noise to to white
noise; when -1 this transforms violet noise to white noise; etc. Any
real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise.
:param bool experimental_allow_batch: EXPERIMENTAL allow coupling across a
batch dimension. The targeted batch dimension and all batch dimensions
to the right will be converted to event dimensions. Defaults to False.
"""
def __init__(self, dim=-1, smooth=0., *,
experimental_allow_batch=False):
Expand Down
3 changes: 3 additions & 0 deletions pyro/infer/reparam/haar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class HaarReparam(UnitJacobianReparam):
This is an absolute dim counting from the right.
:param bool flip: Whether to flip the time axis before applying the
Haar transform. Defaults to false.
:param bool experimental_allow_batch: EXPERIMENTAL allow coupling across a
batch dimension. The targeted batch dimension and all batch dimensions
to the right will be converted to event dimensions. Defaults to False.
"""
def __init__(self, dim=-1, flip=False, *,
experimental_allow_batch=False):
Expand Down
22 changes: 10 additions & 12 deletions pyro/infer/reparam/unit_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
from .reparam import Reparam


# TODO Replace with .with_cache() once the following is released:
# https://github.com/probtorch/pytorch/pull/153
def _with_cache(t):
return t.with_cache() if hasattr(t, "with_cache") else t


class UnitJacobianReparam(Reparam):
"""
Reparameterizer for :class:`~torch.distributions.transforms.Transform`
Expand All @@ -27,10 +21,13 @@ class UnitJacobianReparam(Reparam):
:param transform: A transform whose Jacobian has determinant 1.
:type transform: ~torch.distributions.transforms.Transform
:param str suffix: A suffix to append to the transformed site.
:param bool experimental_allow_batch: EXPERIMENTAL allow coupling across a
batch dimension. The targeted batch dimension and all batch dimensions
to the right will be converted to event dimensions. Defaults to False.
"""
def __init__(self, transform, suffix="transformed", *,
experimental_allow_batch=False):
self.transform = _with_cache(transform)
self.transform = transform.with_cache()
self.suffix = suffix
self.experimental_allow_batch = experimental_allow_batch

Expand All @@ -41,9 +38,10 @@ def __call__(self, name, fn, obs):
with ExitStack() as stack:
shift = max(0, transform.event_dim - event_dim)
if shift:
assert self.experimental_allow_batch, (
"Cannot transform along batch dimension; "
"try converting a batch dimension to an event dimension")
if not self.experimental_allow_batch:
raise ValueError("Cannot transform along batch dimension; try either"
"converting a batch dimension to an event dimension, or "
"setting experimental_allow_batch=True.")

# Reshape and mute plates using block_plate.
from pyro.contrib.forecast.util import reshape_batch, reshape_transform_batch
Expand All @@ -54,10 +52,10 @@ def __call__(self, name, fn, obs):
old_shape + fn.event_shape,
new_shape + fn.event_shape)
for dim in range(-shift, 0):
stack.enter_context(block_plate(dim=dim))
stack.enter_context(block_plate(dim=dim, strict=False))

# Draw noise from the base distribution.
transform = ComposeTransform([_with_cache(biject_to(fn.support).inv),
transform = ComposeTransform([biject_to(fn.support).inv.with_cache(),
self.transform])
x_trans = pyro.sample("{}_{}".format(name, self.suffix),
dist.TransformedDistribution(fn, transform))
Expand Down
18 changes: 16 additions & 2 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)


def as_complex(x):
"""
Similar to :func:`torch.view_as_complex` but copies data in case strides
are not multiples of two.
"""
if any(stride % 2 for stride in x.stride()[:-1]):
# First try to normalize strides.
x = x.squeeze().reshape(x.shape)
if any(stride % 2 for stride in x.stride()[:-1]):
# Fall back to copying data.
x = x.clone()
return torch.view_as_complex(x)


def block_diag_embed(mat):
"""
Takes a tensor of shape (..., B, M, N) and returns a block diagonal tensor
Expand Down Expand Up @@ -278,7 +292,7 @@ def dct(x, dim=-1):
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
M = Y.size(-1)
coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1)
X = torch.view_as_complex(coef) * Y
X = as_complex(coef) * Y
# NB: if we use the full-length version Y_full = fft(y, n=N), then
# the real part of the later half of X will be the flip
# of the negative of the imaginary part of the first half
Expand Down Expand Up @@ -320,7 +334,7 @@ def idct(x, dim=-1):
X = torch.stack([x[..., :M], xi], dim=-1)
coef_real = torch.cos(torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device))
coef = torch.stack([coef_real[:M], coef_real[-M:].flip(-1)], dim=-1)
Y = torch.view_as_complex(coef) * torch.view_as_complex(X)
Y = as_complex(coef) * as_complex(X)
# Step 2
y = irfft(Y, n=N)
# Step 3
Expand Down
12 changes: 8 additions & 4 deletions pyro/poutine/plate_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __enter__(self):


@contextmanager
def block_plate(name=None, dim=None):
def block_plate(name=None, dim=None, *, strict=True):
"""
EXPERIMENTAL Context manager to temporarily block a single enclosing plate.

Expand All @@ -50,7 +50,9 @@ def model_2(data):

:param str name: Optional name of plate to match.
:param int dim: Optional dim of plate to match. Must be negative.
:raises: ValueError if no enclosing plate was found.
:param bool strict: Whether to error if no matching plate is found.
Defaults to True.
:raises: ValueError if no enclosing plate was found and ``strict=True``.
"""
if (name is not None) == (dim is not None):
raise ValueError("Exactly one of name,dim must be specified")
Expand All @@ -69,6 +71,8 @@ def predicate(messenger):
return messenger.dim == dim

with block_messengers(predicate) as matches:
if len(matches) != 1:
raise ValueError("block_plate matched {} messengers".format(len(matches)))
if strict and len(matches) != 1:
raise ValueError(f"block_plate matched {len(matches)} messengers. "
"Try either removing the block_plate or "
"setting strict=False.")
yield
18 changes: 17 additions & 1 deletion tests/contrib/forecast/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
dist.NegativeBinomial,
dist.Normal,
dist.StudentT,
dist.TransformedDistribution,
dist.Uniform,
dist.ZeroInflatedPoisson,
dist.ZeroInflatedNegativeBinomial,
]
Expand All @@ -40,10 +42,20 @@
def random_dist(Dist, shape, transform=None):
if Dist is dist.FoldedDistribution:
return Dist(random_dist(dist.Normal, shape))
if Dist is dist.MaskedDistribution:
elif Dist is dist.MaskedDistribution:
base_dist = random_dist(dist.Normal, shape)
mask = torch.empty(shape, dtype=torch.bool).bernoulli_(0.5)
return base_dist.mask(mask)
elif Dist is dist.TransformedDistribution:
base_dist = random_dist(dist.Normal, shape)
transforms = [
dist.transforms.ExpTransform(),
dist.transforms.ComposeTransform([
dist.transforms.AffineTransform(1, 1),
dist.transforms.ExpTransform().inv,
]),
]
return dist.TransformedDistribution(base_dist, transforms)
elif Dist in (dist.GaussianHMM, dist.LinearHMM):
batch_shape, duration, obs_dim = shape[:-2], shape[-2], shape[-1]
hidden_dim = obs_dim + 1
Expand All @@ -63,6 +75,10 @@ def random_dist(Dist, shape, transform=None):
return Dist(base_dist)
elif Dist is dist.MultivariateNormal:
return random_mvn(shape[:-1], shape[-1])
elif Dist is dist.Uniform:
low = torch.randn(shape)
high = low + torch.randn(shape).exp()
return Dist(low, high)
else:
params = {
name: transform_to(Dist.arg_constraints[name])(torch.rand(shape) - 0.5)
Expand Down