Skip to content

Commit

Permalink
Make TransformReparam compatible with .to_event() (#886)
Browse files Browse the repository at this point in the history
* Support .to_event() in TransformReparam

* Strengthen test

* Address review comment
  • Loading branch information
fritzo authored Jan 24, 2021
1 parent 90b74eb commit 7892f2b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
45 changes: 23 additions & 22 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax import lax

import numpyro
import numpyro.distributions as dist
Expand All @@ -28,31 +29,31 @@ def __call__(self, name, fn, obs):

def _unwrap(self, fn):
"""
Unwrap Independent(...) distributions.
Unwrap Independent(...) and ExpandedDistribution(...) distributions.
"""
batch_shape = fn.batch_shape
event_dim = fn.event_dim
while isinstance(fn, dist.Independent):
while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)):
fn = fn.base_dist
return fn, event_dim
return fn, batch_shape, event_dim

def _wrap(self, fn, event_dim):
def _wrap(self, fn, batch_shape, event_dim):
"""
Wrap in Independent distributions.
Wrap in Independent and ExpandedDistribution distributions.
"""
# Match batch_shape.
assert fn.event_dim <= event_dim
fn_batch_shape = batch_shape + (1,) * (event_dim - fn.event_dim)
fn_batch_shape = lax.broadcast_shapes(fn_batch_shape, fn.batch_shape)
if fn.batch_shape != fn_batch_shape:
fn = fn.expand(fn_batch_shape)

# Match event_dim.
if fn.event_dim < event_dim:
fn = fn.to_event(event_dim - fn.event_dim)
assert fn.event_dim == event_dim
return fn

def _unexpand(self, fn):
"""
Unexpand ExpandedDistribution(...) distributions.
"""
batch_shape = fn.batch_shape
if isinstance(fn, dist.ExpandedDistribution):
fn = fn.base_dist
return fn, batch_shape


class LocScaleReparam(Reparam):
"""
Expand Down Expand Up @@ -89,8 +90,7 @@ def __call__(self, name, fn, obs):
if is_identically_one(centered):
return name, fn, obs
event_shape = fn.event_shape
fn, event_dim = self._unwrap(fn)
fn, batch_shape = self._unexpand(fn)
fn, batch_shape, event_dim = self._unwrap(fn)

# Apply a partial decentering transform.
params = {key: getattr(fn, key) for key in self.shape_params}
Expand All @@ -100,11 +100,11 @@ def __call__(self, name, fn, obs):
constraint=constraints.unit_interval)
params["loc"] = fn.loc * centered
params["scale"] = fn.scale ** centered
decentered_fn = type(fn)(**params).expand(batch_shape)
decentered_fn = self._wrap(type(fn)(**params), batch_shape, event_dim)

# Draw decentered noise.
decentered_value = numpyro.sample("{}_decentered".format(name),
self._wrap(decentered_fn, event_dim))
decentered_fn)

# Differentiably transform.
delta = decentered_value - centered * fn.loc
Expand All @@ -127,14 +127,15 @@ class TransformReparam(Reparam):
"""
def __call__(self, name, fn, obs):
assert obs is None, "TransformReparam does not support observe statements"
fn, batch_shape = self._unexpand(fn)
fn, batch_shape, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.TransformedDistribution)

# Draw noise from the base distribution.
# We need to make sure that we have the same batch_shape
reinterpreted_batch_ndims = fn.event_dim - fn.base_dist.event_dim
base_event_dim = event_dim
for t in reversed(fn.transforms):
base_event_dim += t.domain.event_dim - t.codomain.event_dim
x = numpyro.sample("{}_base".format(name),
fn.base_dist.to_event(reinterpreted_batch_ndims).expand(batch_shape))
self._wrap(fn.base_dist, batch_shape, base_event_dim))

# Differentiably transform.
for t in fn.transforms:
Expand Down
22 changes: 12 additions & 10 deletions test/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,22 @@ def get_moments(x):
return jnp.stack([m1, m2, m3, m4])


@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str)
def test_log_normal(shape):
@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize("event_shape", [(), (5,)], ids=str)
def test_log_normal(batch_shape, event_shape):
shape = batch_shape + event_shape
loc = np.random.rand(*shape) * 2 - 1
scale = np.random.rand(*shape) + 0.5

def model():
with numpyro.plate_stack("plates", shape):
fn = dist.TransformedDistribution(
dist.Normal(jnp.zeros_like(loc), jnp.ones_like(scale)),
[AffineTransform(loc, scale), ExpTransform()])
if event_shape:
fn = fn.to_event(len(event_shape)).expand_by([100000])
with numpyro.plate_stack("plates", batch_shape):
with numpyro.plate("particles", 100000):
return numpyro.sample("x",
dist.TransformedDistribution(
dist.Normal(jnp.zeros_like(loc),
jnp.ones_like(scale)),
[AffineTransform(loc, scale),
ExpTransform()]).expand_by([100000]))
return numpyro.sample("x", fn)

with handlers.trace() as tr:
value = handlers.seed(model, 0)()
Expand All @@ -56,7 +58,7 @@ def model():
value = handlers.seed(model, 0)()
assert tr["x"]["type"] == "deterministic"
actual_moments = get_moments(jnp.log(value))
assert_allclose(actual_moments, expected_moments, atol=0.05)
assert_allclose(actual_moments, expected_moments, atol=0.05, rtol=0.01)


def neals_funnel(dim):
Expand Down

0 comments on commit 7892f2b

Please sign in to comment.