diff --git a/docs/source/inference_algos.rst b/docs/source/inference_algos.rst index 03493f6ba2..f1f58fbc27 100644 --- a/docs/source/inference_algos.rst +++ b/docs/source/inference_algos.rst @@ -13,21 +13,25 @@ ELBO :members: :undoc-members: :show-inheritance: + :member-order: bysource .. automodule:: pyro.infer.trace_elbo :members: :undoc-members: :show-inheritance: + :member-order: bysource .. automodule:: pyro.infer.tracegraph_elbo :members: :undoc-members: :show-inheritance: + :member-order: bysource .. automodule:: pyro.infer.traceenum_elbo :members: :undoc-members: :show-inheritance: + :member-order: bysource Importance ---------- diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index fac99bc882..544c89e3ce 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -5,9 +5,9 @@ from pyro.infer.enum import config_enumerate from pyro.infer.importance import Importance from pyro.infer.svi import SVI -from pyro.infer.trace_elbo import Trace_ELBO -from pyro.infer.traceenum_elbo import TraceEnum_ELBO -from pyro.infer.tracegraph_elbo import TraceGraph_ELBO +from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO +from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO +from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO from pyro.infer.util import enable_validation, is_validation_enabled __all__ = [ @@ -15,12 +15,15 @@ "enable_validation", "is_validation_enabled", "ELBO", - "Importance", "EmpiricalMarginal", - "TracePredictive", + "Importance", + "JitTraceEnum_ELBO", + "JitTraceGraph_ELBO", + "JitTrace_ELBO", "SVI", - "TracePosterior", - "Trace_ELBO", "TraceEnum_ELBO", "TraceGraph_ELBO", + "TracePosterior", + "TracePredictive", + "Trace_ELBO", ] diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index 0f6fd4e1fd..486c3ff5fe 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -1,7 +1,11 @@ from __future__ import absolute_import, division, print_function import warnings +import weakref +import torch + +import pyro import pyro.poutine as poutine from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO @@ -142,3 +146,81 @@ def loss_and_grads(self, model, guide, *args, **kwargs): if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss + + +class JitTrace_ELBO(Trace_ELBO): + """ + Like :class:`Trace_ELBO` but uses :func:`torch.jit.compile` to compile + :meth:`loss_and_grads`. + + This works only for a limited set of models: + + - Models must have static structure. + - Models must not depend on any global data (except the param store). + - All model inputs that are tensors must be passed in via ``*args``. + - All model inputs that are *not* tensors must be passed in via + ``*kwargs``, and these will be fixed to their values on the first + call to :meth:`jit_loss_and_grads`. + + .. warning:: Experimental. Interface subject to change. + """ + def loss_and_grads(self, model, guide, *args, **kwargs): + if getattr(self, '_loss_and_surrogate_loss', None) is None: + # populate param store + with poutine.block(): + with poutine.trace(param_only=True) as param_capture: + for _ in self._get_traces(model, guide, *args, **kwargs): + pass + self._param_names = list(param_capture.trace.nodes.keys()) + + # build a closure for loss_and_surrogate_loss + weakself = weakref.ref(self) + + @torch.jit.compile(nderivs=1) + def loss_and_surrogate_loss(args_list, param_list): + self = weakself() + loss = 0.0 + surrogate_loss = 0.0 + for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs): + elbo_particle = 0 + surrogate_elbo_particle = 0 + log_r = None + + # compute elbo and surrogate elbo + for name, site in model_trace.nodes.items(): + if site["type"] == "sample": + elbo_particle = elbo_particle + site["log_prob_sum"] + surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"] + + for name, site in guide_trace.nodes.items(): + if site["type"] == "sample": + log_prob, score_function_term, entropy_term = site["score_parts"] + + elbo_particle = elbo_particle - site["log_prob_sum"] + + if not is_identically_zero(entropy_term): + surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum() + + if not is_identically_zero(score_function_term): + if log_r is None: + log_r = _compute_log_r(model_trace, guide_trace) + site = log_r.sum_to(site["cond_indep_stack"]) + surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum() + + loss = loss - elbo_particle / self.num_particles + surrogate_loss = surrogate_loss - surrogate_elbo_particle / self.num_particles + + return loss, surrogate_loss + + self._loss_and_surrogate_loss = loss_and_surrogate_loss + + # invoke _loss_and_surrogate_loss + args_list = list(args) + param_list = [pyro.param(name).unconstrained() for name in self._param_names] + loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list) + surrogate_loss.backward() # this line triggers jit compilation + loss = loss.item() + + if torch_isnan(loss): + warnings.warn('Encountered NAN loss') + return loss diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 117f180862..fe69002bdd 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -1,7 +1,11 @@ from __future__ import absolute_import, division, print_function import warnings +import weakref +import torch + +import pyro import pyro.poutine as poutine from pyro.distributions.util import is_identically_zero from pyro.infer.elbo import ELBO @@ -147,3 +151,53 @@ def loss_and_grads(self, model, guide, *args, **kwargs): if torch_isnan(loss): warnings.warn('Encountered NAN loss') return loss + + +class JitTraceEnum_ELBO(TraceEnum_ELBO): + """ + Like :class:`TraceEnum_ELBO` but uses :func:`torch.jit.compile` to + compile :meth:`loss_and_grads`. + + This works only for a limited set of models: + + - Models must have static structure. + - Models must not depend on any global data (except the param store). + - All model inputs that are tensors must be passed in via ``*args``. + - All model inputs that are *not* tensors must be passed in via + ``*kwargs``, and these will be fixed to their values on the first + call to :meth:`jit_loss_and_grads`. + + .. warning:: Experimental. Interface subject to change. + """ + def loss_and_grads(self, model, guide, *args, **kwargs): + if getattr(self, '_differentiable_loss', None) is None: + # populate param store + with poutine.block(): + with poutine.trace(param_only=True) as param_capture: + for _ in self._get_traces(model, guide, *args, **kwargs): + pass + self._param_names = list(param_capture.trace.nodes.keys()) + + # build a closure for differentiable_loss + weakself = weakref.ref(self) + + @torch.jit.compile(nderivs=1) + def differentiable_loss(args_list, param_list): + self = weakself() + elbo = 0.0 + for model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs): + elbo += _compute_dice_elbo(model_trace, guide_trace) + return elbo * (-1.0 / self.num_particles) + + self._differentiable_loss = differentiable_loss + + # invoke _differentiable_loss + args_list = list(args) + param_list = [pyro.param(name).unconstrained() for name in self._param_names] + differentiable_loss = self._differentiable_loss(args_list, param_list) + differentiable_loss.backward() # this line triggers jit compilation + loss = differentiable_loss.item() + + if torch_isnan(loss): + warnings.warn('Encountered NAN loss') + return loss diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 2249d9d224..ad8cb68f18 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import warnings +import weakref from operator import itemgetter import networkx @@ -99,13 +100,13 @@ def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes): # deal with log p(z|...) terms for name, site in model_trace.nodes.items(): if site["type"] == "sample": - elbo += torch_item(site["log_prob_sum"]) + elbo += site["log_prob_sum"] surrogate_elbo += site["log_prob_sum"] # deal with log q(z|...) terms for name, site in guide_trace.nodes.items(): if site["type"] == "sample": - elbo -= torch_item(site["log_prob_sum"]) + elbo -= site["log_prob_sum"] entropy_term = site["score_parts"].entropy_term if not is_identically_zero(entropy_term): surrogate_elbo -= entropy_term.sum() @@ -277,7 +278,84 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace): surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) - loss = -elbo + loss = -torch_item(elbo) if torch_isnan(loss): warnings.warn('Encountered NAN loss') return weight * loss + + +class JitTraceGraph_ELBO(TraceGraph_ELBO): + """ + Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.compile` to + compile :meth:`loss_and_grads`. + + This works only for a limited set of models: + + - Models must have static structure. + - Models must not depend on any global data (except the param store). + - All model inputs that are tensors must be passed in via ``*args``. + - All model inputs that are *not* tensors must be passed in via + ``*kwargs``, and these will be fixed to their values on the first + call to :meth:`loss_and_grads`. + + .. warning:: Experimental. Interface subject to change. + """ + + def loss_and_grads(self, model, guide, *args, **kwargs): + if getattr(self, '_loss_and_surrogate_loss', None) is None: + # populate param store + with poutine.block(): + with poutine.trace(param_only=True) as param_capture: + for _ in self._get_traces(model, guide, *args, **kwargs): + pass + self._param_names = list(param_capture.trace.nodes.keys()) + + # build a closure for loss_and_surrogate_loss + weakself = weakref.ref(self) + + @torch.jit.compile(nderivs=1) + def loss_and_surrogate_loss(args_list, param_list): + self = weakself() + loss = 0.0 + surrogate_loss = 0.0 + for weight, model_trace, guide_trace in self._get_traces(model, guide, *args_list, **kwargs): + model_trace.compute_log_prob() + guide_trace.compute_score_parts() + if is_validation_enabled(): + for site in model_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + for site in guide_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + + # compute elbo for reparameterized nodes + non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) + elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes) + + # the following computations are only necessary if we have non-reparameterizable nodes + baseline_loss = 0.0 + if non_reparam_nodes: + downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) + surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace, + non_reparam_nodes, + downstream_costs) + surrogate_elbo += surrogate_elbo_term + + loss = loss - weight * elbo + surrogate_loss = surrogate_loss - weight * surrogate_elbo + + return loss, surrogate_loss + + self._loss_and_surrogate_loss = loss_and_surrogate_loss + + # invoke _loss_and_surrogate_loss + args_list = list(args) + param_list = [pyro.param(name).unconstrained() for name in self._param_names] + loss, surrogate_loss = self._loss_and_surrogate_loss(args_list, param_list) + surrogate_loss.backward() # this line triggers jit compilation + loss = loss.item() + + if torch_isnan(loss): + warnings.warn('Encountered NAN loss') + return loss diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index df68be325e..5ce1d23700 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -10,9 +10,10 @@ import pyro import pyro.distributions as dist from pyro.distributions.testing import fakes -from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO +from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, + TraceGraph_ELBO) from pyro.optim import Adam -from tests.common import assert_equal +from tests.common import assert_equal, xfail_param logger = logging.getLogger(__name__) @@ -112,7 +113,17 @@ def guide(): @pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) -@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +@pytest.mark.parametrize("Elbo", [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + xfail_param(JitTrace_ELBO, + reason="jit RuntimeError: Unsupported op descriptor: index-2"), + xfail_param(JitTraceGraph_ELBO, + reason="jit RuntimeError: Unsupported op descriptor: index-2"), + xfail_param(JitTraceEnum_ELBO, + reason="jit RuntimeError: Unsupported op descriptor: index-2"), +]) def test_subsample_gradient_sequential(Elbo, reparameterized, subsample): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -134,11 +145,15 @@ def guide(): pyro.sample("z", Normal(loc[ind], scale)) optim = Adam({"lr": 0.1}) - elbo = Elbo(num_particles=num_particles, strict_enumeration_warning=False) - inference = SVI(model, guide, optim, loss=elbo) - inference.loss_and_grads(model, guide) + elbo = Elbo(num_particles=10, strict_enumeration_warning=False) + inference = SVI(model, guide, optim, elbo) + iters = num_particles // 10 + for _ in range(iters): + inference.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) - actual_grads = {name: param.grad.detach().cpu().numpy() for name, param in params.items()} + actual_grads = {name: param.grad.detach().cpu().numpy() / iters + for name, param in params.items()} expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} for name in sorted(params): diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 2355be40f2..fda0e7a4f0 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -12,8 +12,9 @@ import pyro.optim as optim from pyro.distributions.testing import fakes from pyro.distributions.testing.rejection_gamma import ShapeAugmentedGamma -from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO -from tests.common import assert_equal +from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, + TraceGraph_ELBO) +from tests.common import assert_equal, xfail_param def param_mse(name, target): @@ -205,7 +206,14 @@ def guide(): @pytest.mark.stage("integration", "integration_batch_1") -@pytest.mark.parametrize('elbo_impl', [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +@pytest.mark.parametrize('elbo_impl', [ + xfail_param(JitTrace_ELBO, reason="incorrect gradients", run=False), + xfail_param(JitTraceGraph_ELBO, reason="incorrect gradients", run=False), + xfail_param(JitTraceEnum_ELBO, reason="incorrect gradients", run=False), + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, +]) @pytest.mark.parametrize('gamma_dist,n_steps', [ (dist.Gamma, 5000), (fakes.NonreparameterizedGamma, 10000), @@ -223,13 +231,13 @@ def test_exponential_gamma(gamma_dist, n_steps, elbo_impl): alpha_n = alpha0 + torch.tensor(float(n_data)) # posterior alpha beta_n = beta0 + torch.sum(data) # posterior beta - def model(): + def model(alpha0, beta0, alpha_n, beta_n): lambda_latent = pyro.sample("lambda_latent", gamma_dist(alpha0, beta0)) with pyro.iarange("data", n_data): pyro.sample("obs", dist.Exponential(lambda_latent), obs=data) return lambda_latent - def guide(): + def guide(alpha0, beta0, alpha_n, beta_n): alpha_q = pyro.param("alpha_q", alpha_n * math.exp(0.17), constraint=constraints.positive) beta_q = pyro.param("beta_q", beta_n / math.exp(0.143), constraint=constraints.positive) pyro.sample("lambda_latent", gamma_dist(alpha_q, beta_q)) @@ -239,7 +247,7 @@ def guide(): svi = SVI(model, guide, adam, loss=elbo, max_iarange_nesting=1) for k in range(n_steps): - svi.step() + svi.step(alpha0, beta0, alpha_n, beta_n) assert_equal(pyro.param("alpha_q"), alpha_n, prec=0.15, msg='{} vs {}'.format( pyro.param("alpha_q").detach().cpu().numpy(), alpha_n.detach().cpu().numpy())) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py new file mode 100644 index 0000000000..d4fba32aa9 --- /dev/null +++ b/tests/infer/test_jit.py @@ -0,0 +1,234 @@ +from __future__ import absolute_import, division, print_function + +import pytest +import torch +from torch.autograd import grad +from torch.distributions import constraints, kl_divergence + +import pyro +import pyro.distributions as dist +from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, + TraceGraph_ELBO) +from pyro.optim import Adam +from tests.common import assert_equal, xfail_param + + +def test_simple(): + y = torch.ones(2) + + @torch.jit.compile(nderivs=0) + def f(x): + print('Inside f') + assert x is y + return y + 1.0 + + print('Calling f(y)') + assert_equal(f(y), y.new_tensor([2, 2])) + print('Calling f(y)') + assert_equal(f(y), y.new_tensor([2, 2])) + print('Calling f(torch.zeros(2))') + assert_equal(f(torch.zeros(2)), y.new_tensor([1, 1])) + with pytest.raises(AssertionError): + assert_equal(f(torch.ones(5)), y.new_tensor([2, 2, 2, 2, 2])) + + +def test_backward(): + y = torch.ones(2, requires_grad=True) + + @torch.jit.compile(nderivs=1) + def f(x): + print('Inside f') + assert x is y + return (y + 1.0).sum() + + print('Calling f(y)') + f(y).backward() + print('Calling f(y)') + f(y) + print('Calling f(torch.zeros(2))') + f(torch.zeros(2, requires_grad=True)) + with pytest.raises(AssertionError): + f(torch.ones(5, requires_grad=True)) + + +def test_grad(): + + @torch.jit.compile(nderivs=0) + def f(x, y): + print('Inside f') + loss = (x - y).pow(2).sum() + return torch.autograd.grad(loss, [x, y], allow_unused=True) + + print('Invoking f') + f(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) + print('Invoking f') + f(torch.zeros(2, requires_grad=True), torch.zeros(2, requires_grad=True)) + + +@pytest.mark.xfail(reason='RuntimeError: ' + 'saved_variables() needed but not implemented in ExpandBackward') +def test_grad_expand(): + + @torch.jit.compile(nderivs=0) + def f(x, y): + print('Inside f') + loss = (x - y).pow(2).sum() + return torch.autograd.grad(loss, [x, y], allow_unused=True) + + print('Invoking f') + f(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) + print('Invoking f') + f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) + + +@pytest.mark.parametrize('num_particles', [1, 10]) +@pytest.mark.parametrize('Elbo', [ + Trace_ELBO, + JitTrace_ELBO, + TraceGraph_ELBO, + JitTraceGraph_ELBO, + TraceEnum_ELBO, + JitTraceEnum_ELBO, +]) +def test_svi(Elbo, num_particles): + pyro.clear_param_store() + data = torch.arange(10) + + def model(data): + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) + pyro.sample("x", dist.Normal(loc, scale).expand_by(data.shape).independent(1), obs=data) + + def guide(data): + pass + + elbo = Elbo(num_particles=num_particles, strict_enumeration_warning=False) + inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) + for i in range(100): + inference.step(data) + + +@pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) +@pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) +@pytest.mark.parametrize("irange_dim", [1, 2]) +@pytest.mark.parametrize('Elbo', [ + Trace_ELBO, + JitTrace_ELBO, + TraceGraph_ELBO, + JitTraceGraph_ELBO, + TraceEnum_ELBO, + JitTraceEnum_ELBO, +]) +def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): + pyro.clear_param_store() + num_particles = 10 + q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval) + p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 + + def model(): + pyro.sample("x", dist.Bernoulli(p)) + for i in pyro.irange("irange", irange_dim): + pyro.sample("y_{}".format(i), dist.Bernoulli(p)) + + def guide(): + q = pyro.param("q") + pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1}) + for i in pyro.irange("irange", irange_dim): + pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2}) + + kl = (1 + irange_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) + expected_loss = kl.item() + expected_grad = grad(kl, [q.unconstrained()])[0] + + inner_particles = 2 + outer_particles = num_particles // inner_particles + elbo = TraceEnum_ELBO(max_iarange_nesting=0, + strict_enumeration_warning=any([enumerate1, enumerate2]), + num_particles=inner_particles) + actual_loss = sum(elbo.loss_and_grads(model, guide) + for i in range(outer_particles)) / outer_particles + actual_grad = q.unconstrained().grad / outer_particles + + assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([ + "\nexpected loss = {}".format(expected_loss), + "\n actual loss = {}".format(actual_loss), + ])) + assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([ + "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), + "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), + ])) + + +@pytest.mark.parametrize('vectorized', [False, True]) +@pytest.mark.parametrize('Elbo', [ + TraceEnum_ELBO, + xfail_param(JitTraceEnum_ELBO, + reason="jit RuntimeError: Unsupported op descriptor: stack-2-dim_i"), +]) +def test_beta_bernoulli(Elbo, vectorized): + pyro.clear_param_store() + data = torch.tensor([1.0] * 6 + [0.0] * 4) + + def model1(data): + alpha0 = torch.tensor(10.0) + beta0 = torch.tensor(10.0) + f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) + for i in pyro.irange("irange", len(data)): + pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) + + def model2(data): + alpha0 = torch.tensor(10.0) + beta0 = torch.tensor(10.0) + f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) + pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), + obs=data) + + model = model2 if vectorized else model1 + + def guide(data): + alpha_q = pyro.param("alpha_q", torch.tensor(15.0), + constraint=constraints.positive) + beta_q = pyro.param("beta_q", torch.tensor(15.0), + constraint=constraints.positive) + pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) + + elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) + svi = SVI(model, guide, optim, elbo) + for step in range(40): + svi.step(data) + + +@pytest.mark.parametrize('vectorized', [False, True]) +@pytest.mark.parametrize('Elbo', [ + TraceEnum_ELBO, + xfail_param(JitTraceEnum_ELBO, reason="jit RuntimeError in Dirichlet.rsample"), +]) +def test_dirichlet_bernoulli(Elbo, vectorized): + pyro.clear_param_store() + data = torch.tensor([1.0] * 6 + [0.0] * 4) + + def model1(data): + concentration0 = torch.tensor([10.0, 10.0]) + f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] + for i in pyro.irange("irange", len(data)): + pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) + + def model2(data): + concentration0 = torch.tensor([10.0, 10.0]) + f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] + pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), + obs=data) + + model = model2 if vectorized else model1 + + def guide(data): + concentration_q = pyro.param("concentration_q", torch.tensor([15.0, 15.0]), + constraint=constraints.positive) + pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) + + elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) + svi = SVI(model, guide, optim, elbo) + for step in range(40): + svi.step(data) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index 29adc05e5b..c2f311eb4d 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -152,7 +152,10 @@ def array_to_string(y): pyro.clear_param_store() adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)}) - svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO()) + elbo = TraceGraph_ELBO() + loss_and_grads = elbo.loss_and_grads + # loss_and_grads = elbo.jit_loss_and_grads # This fails. + svi = SVI(self.model, self.guide, adam, loss=elbo.loss, loss_and_grads=loss_and_grads) for step in range(n_steps): t0 = time.time()