diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 7e8dda4979..7f86c413d7 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -1,6 +1,14 @@ from __future__ import absolute_import, division, print_function +from abc import abstractmethod, ABCMeta +from six import add_metaclass + +import pyro +import pyro.poutine as poutine + + +@add_metaclass(ABCMeta) class ELBO(object): """ :class:`ELBO` is the top-level interface for stochastic variational @@ -15,6 +23,12 @@ class ELBO(object): :func:`pyro.iarange` contexts. This is only required to enumerate over sample sites in parallel, e.g. if a site sets ``infer={"enumerate": "parallel"}``. + :param bool vectorize_particles: Whether to vectorize the ELBO computation + over `num_particles`. Defaults to False. This requires static structure + in model and guide. In addition, this wraps the model and guide inside a + :class:`~pyro.poutine.broadcast` poutine for automatic broadcasting of + sample site batch shapes, and requires specifying a finite value for + `max_iarange_nesting`. :param bool strict_enumeration_warning: Whether to warn about possible misuse of enumeration, i.e. that :class:`pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there @@ -32,7 +46,63 @@ class ELBO(object): def __init__(self, num_particles=1, max_iarange_nesting=float('inf'), + vectorize_particles=False, strict_enumeration_warning=True): self.num_particles = num_particles self.max_iarange_nesting = max_iarange_nesting + self.vectorize_particles = vectorize_particles + if self.vectorize_particles: + if self.num_particles > 1: + if self.max_iarange_nesting == float('inf'): + raise ValueError("Automatic vectorization over num_particles requires " + + "a finite value for `max_iarange_nesting` arg.") + self.max_iarange_nesting += 1 self.strict_enumeration_warning = strict_enumeration_warning + + def _vectorized_num_particles(self, fn): + """ + Wraps a callable inside an outermost :class:`~pyro.iarange` to parallelize + ELBO computation over `num_particles`, and a :class:`~pyro.poutine.broadcast` + poutine to broadcast batch shapes of sample site functions in accordance + with the `~pyro.iarange` contexts within which they are embedded. + + :param fn: arbitrary callable containing Pyro primitives. + :return: wrapped callable. + """ + + def wrapped_fn(*args, **kwargs): + if self.num_particles == 1: + return fn(*args, **kwargs) + with pyro.iarange("num_particles_vectorized", self.num_particles, dim=-self.max_iarange_nesting): + return fn(*args, **kwargs) + + return poutine.broadcast(wrapped_fn) + + def _get_vectorized_trace(self, model, guide, *args, **kwargs): + """ + Wraps the model and guide to vectorize ELBO computation over + ``num_particles``, and returns a single trace from the wrapped model + and guide. + """ + return self._get_trace(self._vectorized_num_particles(model), + self._vectorized_num_particles(guide), + *args, **kwargs) + + @abstractmethod + def _get_trace(self, model, guide, *args, **kwargs): + """ + Returns a single trace from the guide, and the model that is run + against it. + """ + raise NotImplementedError + + def _get_traces(self, model, guide, *args, **kwargs): + """ + Runs the guide and runs the model against the guide with + the result packaged as a trace generator. + """ + if self.vectorize_particles: + yield self._get_vectorized_trace(model, guide, *args, **kwargs) + else: + for i in range(self.num_particles): + yield self._get_trace(model, guide, *args, **kwargs) diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index a064f12849..1857c2e84e 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -45,37 +45,36 @@ class Trace_ELBO(ELBO): Rajesh Ranganath, Sean Gerrish, David M. Blei """ - def _get_traces(self, model, guide, *args, **kwargs): + def _get_trace(self, model, guide, *args, **kwargs): """ - runs the guide and runs the model against the guide with - the result packaged as a trace generator + Returns a single trace from the guide, and the model that is run + against it. """ - for i in range(self.num_particles): - guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) - model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) - if is_validation_enabled(): - check_model_guide_match(model_trace, guide_trace) - enumerated_sites = [name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and site["infer"].get("enumerate")] - if enumerated_sites: - warnings.warn('\n'.join([ - 'Trace_ELBO found sample sites configured for enumeration:' - ', '.join(enumerated_sites), - 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) - guide_trace = prune_subsample_sites(guide_trace) - model_trace = prune_subsample_sites(model_trace) - - model_trace.compute_log_prob() - guide_trace.compute_score_parts() - if is_validation_enabled(): - for site in model_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, self.max_iarange_nesting) - for site in guide_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, self.max_iarange_nesting) - - yield model_trace, guide_trace + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) + if is_validation_enabled(): + check_model_guide_match(model_trace, guide_trace) + enumerated_sites = [name for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and site["infer"].get("enumerate")] + if enumerated_sites: + warnings.warn('\n'.join([ + 'Trace_ELBO found sample sites configured for enumeration:' + ', '.join(enumerated_sites), + 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) + guide_trace = prune_subsample_sites(guide_trace) + model_trace = prune_subsample_sites(model_trace) + + model_trace.compute_log_prob() + guide_trace.compute_score_parts() + if is_validation_enabled(): + for site in model_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + for site in guide_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + + return model_trace, guide_trace def loss(self, model, guide, *args, **kwargs): """ diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 9d316f5493..d1cb0f0da5 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -67,45 +67,60 @@ class TraceEnum_ELBO(ELBO): variables inside that :class:`~pyro.iarange`. """ - def _get_traces(self, model, guide, *args, **kwargs): + def _get_trace(self, model, guide, *args, **kwargs): """ - runs the guide and runs the model against the guide with - the result packaged as a trace generator + Returns a single trace from the guide, and the model that is run + against it. """ - # enable parallel enumeration - guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) - - for i in range(self.num_particles): - for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): - model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), - graph_type="flat").get_trace(*args, **kwargs) - - if is_validation_enabled(): - check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) - guide_trace = prune_subsample_sites(guide_trace) - model_trace = prune_subsample_sites(model_trace) - if is_validation_enabled(): - check_traceenum_requirements(model_trace, guide_trace) - - model_trace.compute_log_prob() - guide_trace.compute_score_parts() - if is_validation_enabled(): - for site in model_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, self.max_iarange_nesting) - any_enumerated = False - for site in guide_trace.nodes.values(): - if site["type"] == "sample": - check_site_shape(site, self.max_iarange_nesting) - if site["infer"].get("enumerate"): - any_enumerated = True - if self.strict_enumeration_warning and not any_enumerated: - warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' - 'If you want to enumerate sites, you need to @config_enumerate or set ' - 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' - 'If you do not want to enumerate, consider using Trace_ELBO instead.') + for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): + model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), + graph_type="flat").get_trace(*args, **kwargs) + + if is_validation_enabled(): + check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) + guide_trace = prune_subsample_sites(guide_trace) + model_trace = prune_subsample_sites(model_trace) + if is_validation_enabled(): + check_traceenum_requirements(model_trace, guide_trace) + + model_trace.compute_log_prob() + guide_trace.compute_score_parts() + if is_validation_enabled(): + for site in model_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + any_enumerated = False + for site in guide_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_iarange_nesting) + if site["infer"].get("enumerate"): + any_enumerated = True + if self.strict_enumeration_warning and not any_enumerated: + warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' + 'If you want to enumerate sites, you need to @config_enumerate or set ' + 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' + 'If you do not want to enumerate, consider using Trace_ELBO instead.') + + yield model_trace, guide_trace + def _get_traces(self, model, guide, *args, **kwargs): + """ + Runs the guide and runs the model against the guide with + the result packaged as a trace generator. + """ + if self.vectorize_particles: + # enable parallel enumeration over the vectorized guide. + guide = poutine.enum(self._vectorized_num_particles(guide), + first_available_dim=self.max_iarange_nesting) + model = self._vectorized_num_particles(model) + for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs): yield model_trace, guide_trace + else: + # enable parallel enumeration. + guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) + for i in range(self.num_particles): + for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs): + yield model_trace, guide_trace def loss(self, model, guide, *args, **kwargs): """ diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index ff91c90b2e..518ab3a587 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -187,32 +187,30 @@ class TraceGraph_ELBO(ELBO): Andriy Mnih, Karol Gregor """ - def _get_traces(self, model, guide, *args, **kwargs): + def _get_trace(self, model, guide, *args, **kwargs): """ - runs the guide and runs the model against the guide with - the result packaged as a tracegraph generator + Returns a single trace from the guide, and the model that is run + against it. """ - - for i in range(self.num_particles): - guide_trace = poutine.trace(guide, - graph_type="dense").get_trace(*args, **kwargs) - model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), - graph_type="dense").get_trace(*args, **kwargs) - if is_validation_enabled(): - check_model_guide_match(model_trace, guide_trace) - enumerated_sites = [name for name, site in guide_trace.nodes.items() - if site["type"] == "sample" and site["infer"].get("enumerate")] - if enumerated_sites: - warnings.warn('\n'.join([ - 'TraceGraph_ELBO found sample sites configured for enumeration:' - ', '.join(enumerated_sites), - 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) - - guide_trace = prune_subsample_sites(guide_trace) - model_trace = prune_subsample_sites(model_trace) - - weight = 1.0 / self.num_particles - yield weight, model_trace, guide_trace + guide_trace = poutine.trace(guide, + graph_type="dense").get_trace(*args, **kwargs) + model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), + graph_type="dense").get_trace(*args, **kwargs) + if is_validation_enabled(): + check_model_guide_match(model_trace, guide_trace) + enumerated_sites = [name for name, site in guide_trace.nodes.items() + if site["type"] == "sample" and site["infer"].get("enumerate")] + if enumerated_sites: + warnings.warn('\n'.join([ + 'TraceGraph_ELBO found sample sites configured for enumeration:' + ', '.join(enumerated_sites), + 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) + + guide_trace = prune_subsample_sites(guide_trace) + model_trace = prune_subsample_sites(model_trace) + + weight = 1.0 / self.num_particles + return weight, model_trace, guide_trace def loss(self, model, guide, *args, **kwargs): """ diff --git a/tests/infer/test_conjugate_gradients.py b/tests/infer/test_conjugate_gradients.py index 8ce0771c82..4a4f70ce8b 100644 --- a/tests/infer/test_conjugate_gradients.py +++ b/tests/infer/test_conjugate_gradients.py @@ -5,7 +5,6 @@ from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain -# TODO increase precision and number of particles once latter is parallelized properly class ConjugateChainGradientTests(GaussianChain): def test_gradients(self): @@ -17,7 +16,7 @@ def do_test_gradients(self, N, reparameterized): pyro.clear_param_store() self.setup_chain(N) - elbo = TraceGraph_ELBO(num_particles=1000) + elbo = TraceGraph_ELBO(num_particles=10000, vectorize_particles=True, max_iarange_nesting=1) elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized) for i in range(1, N + 1): @@ -25,7 +24,7 @@ def do_test_gradients(self, N, reparameterized): if i == N and param_prefix == 'kappa_q_%d': continue actual_grad = pyro.param(param_prefix % i).grad - assert_equal(actual_grad, 0.0 * actual_grad, prec=0.20, msg="".join([ + assert_equal(actual_grad, 0.0 * actual_grad, prec=0.10, msg="".join([ "parameter %s%d" % (param_prefix[:-2], i), "\nexpected = zero vector", "\n actual = {}".format(actual_grad.detach().cpu().numpy())])) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 325101925d..d4e70b75c0 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -288,26 +288,26 @@ def test_elbo_berns(enumerate1, enumerate2, enumerate3): q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) def model(): - with pyro.iarange("particles", num_particles): - pyro.sample("x1", dist.Bernoulli(0.1).expand_by([num_particles])) - pyro.sample("x2", dist.Bernoulli(0.2).expand_by([num_particles])) - pyro.sample("x3", dist.Bernoulli(0.3).expand_by([num_particles])) + pyro.sample("x1", dist.Bernoulli(0.1)) + pyro.sample("x2", dist.Bernoulli(0.2)) + pyro.sample("x3", dist.Bernoulli(0.3)) def guide(): q = pyro.param("q") - with pyro.iarange("particles", num_particles): - pyro.sample("x1", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1}) - pyro.sample("x2", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate2}) - pyro.sample("x3", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate3}) + pyro.sample("x1", dist.Bernoulli(q), infer={"enumerate": enumerate1}) + pyro.sample("x2", dist.Bernoulli(q), infer={"enumerate": enumerate2}) + pyro.sample("x3", dist.Bernoulli(q), infer={"enumerate": enumerate3}) kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3]) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] - elbo = TraceEnum_ELBO(max_iarange_nesting=1, + elbo = TraceEnum_ELBO(max_iarange_nesting=0, + num_particles=num_particles, + vectorize_particles=True, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) - actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = q.grad / num_particles + actual_loss = elbo.loss_and_grads(model, guide) + actual_grad = q.grad assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ "\nexpected loss = {}".format(expected_loss), @@ -462,31 +462,31 @@ def test_elbo_iarange_iarange(outer_dim, inner_dim, enumerate1, enumerate2, enum q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 + @poutine.broadcast def model(): d = dist.Bernoulli(p) - with pyro.iarange("particles", num_particles): - context1 = pyro.iarange("outer", outer_dim, dim=-2) - context2 = pyro.iarange("inner", inner_dim, dim=-3) - pyro.sample("w", d.expand_by([num_particles])) - with context1: - pyro.sample("x", d.expand_by([outer_dim, num_particles])) - with context2: - pyro.sample("y", d.expand_by([inner_dim, 1, num_particles])) - with context1, context2: - pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles])) + context1 = pyro.iarange("outer", outer_dim, dim=-1) + context2 = pyro.iarange("inner", inner_dim, dim=-2) + pyro.sample("w", d) + with context1: + pyro.sample("x", d) + with context2: + pyro.sample("y", d) + with context1, context2: + pyro.sample("z", d) + @poutine.broadcast def guide(): d = dist.Bernoulli(pyro.param("q")) - with pyro.iarange("particles", num_particles): - context1 = pyro.iarange("outer", outer_dim, dim=-2) - context2 = pyro.iarange("inner", inner_dim, dim=-3) - pyro.sample("w", d.expand_by([num_particles]), infer={"enumerate": enumerate1}) - with context1: - pyro.sample("x", d.expand_by([outer_dim, num_particles]), infer={"enumerate": enumerate2}) - with context2: - pyro.sample("y", d.expand_by([inner_dim, 1, num_particles]), infer={"enumerate": enumerate3}) - with context1, context2: - pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles]), infer={"enumerate": enumerate4}) + context1 = pyro.iarange("outer", outer_dim, dim=-1) + context2 = pyro.iarange("inner", inner_dim, dim=-2) + pyro.sample("w", d, infer={"enumerate": enumerate1}) + with context1: + pyro.sample("x", d, infer={"enumerate": enumerate2}) + with context2: + pyro.sample("y", d, infer={"enumerate": enumerate3}) + with context1, context2: + pyro.sample("z", d, infer={"enumerate": enumerate4}) kl_node = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node @@ -494,9 +494,11 @@ def guide(): expected_grad = grad(kl, [q])[0] elbo = TraceEnum_ELBO(max_iarange_nesting=3, + num_particles=num_particles, + vectorize_particles=True, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) - actual_loss = elbo.loss_and_grads(model, guide) / num_particles - actual_grad = pyro.param('q').grad / num_particles + actual_loss = elbo.loss_and_grads(model, guide) + actual_grad = pyro.param('q').grad assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ "\nexpected loss = {}".format(expected_loss), diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index cd24451771..011665a559 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -26,27 +26,29 @@ def test_subsample_gradient(Elbo, reparameterized, subsample): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) - num_particles = 50000 precision = 0.06 Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + @poutine.broadcast def model(subsample): - with pyro.iarange("particles", num_particles): - with pyro.iarange("data", len(data), subsample_size, subsample) as ind: - x = data[ind].unsqueeze(-1).expand(-1, num_particles) - z = pyro.sample("z", Normal(0, 1).expand_by(x.shape)) - pyro.sample("x", Normal(z, 1), obs=x) + with pyro.iarange("data", len(data), subsample_size, subsample) as ind: + x = data[ind] + z = pyro.sample("z", Normal(0, 1)) + pyro.sample("x", Normal(z, 1), obs=x) + @poutine.broadcast def guide(subsample): loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True)) scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True)) - with pyro.iarange("particles", num_particles): - with pyro.iarange("data", len(data), subsample_size, subsample) as ind: - loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles) - pyro.sample("z", Normal(loc_ind, scale)) + with pyro.iarange("data", len(data), subsample_size, subsample) as ind: + loc_ind = loc[ind] + pyro.sample("z", Normal(loc_ind, scale)) optim = Adam({"lr": 0.1}) - elbo = Elbo(strict_enumeration_warning=False) + elbo = Elbo(max_iarange_nesting=1, + num_particles=50000, + vectorize_particles=True, + strict_enumeration_warning=False) inference = SVI(model, guide, optim, loss=elbo) if subsample_size == 1: inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0])) @@ -54,7 +56,7 @@ def guide(subsample): else: inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0, 1])) params = dict(pyro.get_param_store().named_parameters()) - normalizer = 2 * num_particles / subsample_size + normalizer = 2 if subsample else 1 actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()} expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} @@ -112,6 +114,56 @@ def guide(): assert_equal(actual_grads, expected_grads, prec=precision) +@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) +@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +def test_iarange_elbo_vectorized_particles(Elbo, reparameterized): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + num_particles = 20000 + precision = 0.06 + Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + + @poutine.broadcast + def model(): + data_iarange = pyro.iarange("data", len(data)) + + pyro.sample("nuisance_a", Normal(0, 1)) + with data_iarange: + z = pyro.sample("z", Normal(0, 1)) + pyro.sample("nuisance_b", Normal(2, 3)) + with data_iarange: + pyro.sample("x", Normal(z, 1), obs=data) + pyro.sample("nuisance_c", Normal(4, 5)) + + @poutine.broadcast + def guide(): + loc = pyro.param("loc", torch.zeros(len(data))) + scale = pyro.param("scale", torch.tensor([1.])) + + pyro.sample("nuisance_c", Normal(4, 5)) + with pyro.iarange("data", len(data)): + pyro.sample("z", Normal(loc, scale)) + pyro.sample("nuisance_b", Normal(2, 3)) + pyro.sample("nuisance_a", Normal(0, 1)) + + optim = Adam({"lr": 0.1}) + loss = Elbo(max_iarange_nesting=1, + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False) + inference = SVI(model, guide, optim, loss=loss) + inference.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach().cpu().numpy() + for name, param in params.items()} + + expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} + for name in sorted(params): + logger.info('expected {} = {}'.format(name, expected_grads[name])) + logger.info('actual {} = {}'.format(name, actual_grads[name])) + assert_equal(actual_grads, expected_grads, prec=precision) + + @pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"]) @pytest.mark.parametrize("Elbo", [ diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index d17fc2d406..01c10578c7 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -926,3 +926,59 @@ def model(): guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model assert_error(model, guide, Elbo()) + + +@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) +def test_vectorized_num_particles(Elbo): + data = torch.ones(1000, 2) + + @poutine.broadcast + def model(): + with pyro.iarange("components", 2): + p = pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) + assert p.shape == torch.Size((10, 1, 2)) + with pyro.iarange("data", data.shape[0]): + pyro.sample("obs", dist.Bernoulli(p), obs=data) + + @poutine.broadcast + def guide(): + with pyro.iarange("components", 2): + pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) + + pyro.clear_param_store() + guide = config_enumerate(guide) if Elbo is TraceEnum_ELBO else guide + assert_ok(model, guide, Elbo(num_particles=10, + vectorize_particles=True, + max_iarange_nesting=2, + strict_enumeration_warning=False)) + + +@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) +@pytest.mark.parametrize('num_particles', [1, 50]) +def test_enum_discrete_vectorized_num_particles(enumerate_, num_particles): + + @poutine.broadcast + @config_enumerate(default=enumerate_) + def model(): + x_iarange = pyro.iarange("x_iarange", 10, 5, dim=-1) + y_iarange = pyro.iarange("y_iarange", 11, 6, dim=-2) + with x_iarange: + b = pyro.sample("b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) + assert b.shape == torch.Size((num_particles, 1, 5) if num_particles > 1 else (5,)) + with y_iarange: + c = pyro.sample("c", dist.Bernoulli(0.5)) + if enumerate_ == "parallel": + assert c.shape == torch.Size((2, num_particles, 6, 1) if num_particles > 1 else (2, 6, 1)) + else: + assert c.shape == torch.Size((num_particles, 6, 1) if num_particles > 1 else (6, 1)) + with x_iarange, y_iarange: + d = pyro.sample("d", dist.Bernoulli(b)) + if enumerate_ == "parallel": + assert d.shape == torch.Size((2, 1, num_particles, 6, 5) if num_particles > 1 else (2, 1, 6, 5)) + else: + assert d.shape == torch.Size((num_particles, 6, 5) if num_particles > 1 else (6, 5)) + + assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=2, + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=(enumerate_ == "parallel"))) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index c2f311eb4d..c846bd999f 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -75,8 +75,8 @@ def model(self, reparameterized, difficulty=0.0): loc_N = next_mean with pyro.iarange("data", self.data.size(0)): - pyro.sample("obs", dist.Normal(loc_N.expand_as(self.data), - torch.pow(self.lambdas[self.N], -0.5).expand_as(self.data)), obs=self.data) + pyro.sample("obs", dist.Normal(loc_N, + torch.pow(self.lambdas[self.N], -0.5)), obs=self.data) return loc_N def guide(self, reparameterized, difficulty=0.0):