Skip to content

Commit

Permalink
Add tests, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 3, 2021
1 parent 670fdda commit 5db1109
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
18 changes: 10 additions & 8 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from typing import Callable, Dict, Set, Tuple, Union
from typing import Callable, Dict, Optional, Set, Tuple, Union

import torch
from torch.distributions import biject_to
Expand Down Expand Up @@ -54,12 +54,14 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta):
the model [1]. Depending on model structure, this can have asymptotically
better statistical efficiency than :class:`AutoMultivariateNormal` .
The default "dense" backend should have similar computational complexity to
:class:`AutoMultivariateNormal` . The experimental "funsor" backend can be
asymptotically cheaper in terms of time and space (using Gaussian tensor
variable elimination [2,3]), but incurs large constant overhead. The
"funsor" backend requires `funsor <https://funsor.pyro.ai>`_ which can be
installed via ``pip install pyro-ppl[funsor]``.
This guide implements multiple backends for computation. All backends use
the same statistically optimal parametrization. The default "dense" backend
has computational complexity similar to :class:`AutoMultivariateNormal` .
The experimental "funsor" backend can be asymptotically cheaper in terms of
time and space (using Gaussian tensor variable elimination [2,3]), but
incurs large constant overhead. The "funsor" backend requires `funsor
<https://funsor.pyro.ai>`_ which can be installed via ``pip install
pyro-ppl[funsor]``.
The guide currently does not depend on the model's ``*args, **kwargs``.
Expand Down Expand Up @@ -105,7 +107,7 @@ def __init__(
*,
init_loc_fn: Callable = init_to_feasible,
init_scale: float = 0.1,
backend=None,
backend: Optional[str] = None, # used only by metaclass
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError(f"Expected init_scale > 0. but got {init_scale}")
Expand Down
38 changes: 36 additions & 2 deletions tests/infer/autoguide/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyro.infer.autoguide.gaussian import _break_plates
from pyro.infer.reparam import LocScaleReparam
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_equal, xfail_if_not_implemented

BACKENDS = [
"dense",
Expand Down Expand Up @@ -180,6 +180,40 @@ def model():
check_structure(model, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_broken_plates_smoke(backend):
def model():
with pyro.plate("i", 2):
x = pyro.sample("x", dist.Normal(0, 1))
pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
guide()
predictive = Predictive(model, guide=guide, num_samples=2)
predictive()


@pytest.mark.parametrize("backend", BACKENDS)
def test_intractable_smoke(backend):
def model():
with pyro.plate("i", 2):
x = pyro.sample("x", dist.Normal(0, 1))
pyro.sample("y", dist.Normal(x.mean(-1), 1), obs=torch.tensor(0.0))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
guide()
predictive = Predictive(model, guide=guide, num_samples=2)
predictive()


# Simplified from https://github.com/pyro-cov/tree/master/pyrocov/mutrans.py
def pyrocov_model(dataset):
# Tensor shapes are commented at the end of some lines.
Expand Down Expand Up @@ -486,7 +520,7 @@ def test_profile(backend, n=1, num_steps=1):
"""
Helper function for profiling.
"""
model = pyrocov_model_plated
model = pyrocov_model_poisson
T, P, S, F = 2 * n, 3 * n, 4 * n, 5 * n
dataset = {
"features": torch.randn(S, F),
Expand Down

0 comments on commit 5db1109

Please sign in to comment.