Skip to content

Commit

Permalink
Vectorize ELBO computation over num particles (#1176)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Jun 8, 2018
1 parent 75854fa commit 865bff6
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 138 deletions.
70 changes: 70 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from __future__ import absolute_import, division, print_function

from abc import abstractmethod, ABCMeta

from six import add_metaclass

import pyro
import pyro.poutine as poutine


@add_metaclass(ABCMeta)
class ELBO(object):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand All @@ -15,6 +23,12 @@ class ELBO(object):
:func:`pyro.iarange` contexts. This is only required to enumerate over
sample sites in parallel, e.g. if a site sets
``infer={"enumerate": "parallel"}``.
:param bool vectorize_particles: Whether to vectorize the ELBO computation
over `num_particles`. Defaults to False. This requires static structure
in model and guide. In addition, this wraps the model and guide inside a
:class:`~pyro.poutine.broadcast` poutine for automatic broadcasting of
sample site batch shapes, and requires specifying a finite value for
`max_iarange_nesting`.
:param bool strict_enumeration_warning: Whether to warn about possible
misuse of enumeration, i.e. that
:class:`pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there
Expand All @@ -32,7 +46,63 @@ class ELBO(object):
def __init__(self,
num_particles=1,
max_iarange_nesting=float('inf'),
vectorize_particles=False,
strict_enumeration_warning=True):
self.num_particles = num_particles
self.max_iarange_nesting = max_iarange_nesting
self.vectorize_particles = vectorize_particles
if self.vectorize_particles:
if self.num_particles > 1:
if self.max_iarange_nesting == float('inf'):
raise ValueError("Automatic vectorization over num_particles requires " +
"a finite value for `max_iarange_nesting` arg.")
self.max_iarange_nesting += 1
self.strict_enumeration_warning = strict_enumeration_warning

def _vectorized_num_particles(self, fn):
"""
Wraps a callable inside an outermost :class:`~pyro.iarange` to parallelize
ELBO computation over `num_particles`, and a :class:`~pyro.poutine.broadcast`
poutine to broadcast batch shapes of sample site functions in accordance
with the `~pyro.iarange` contexts within which they are embedded.
:param fn: arbitrary callable containing Pyro primitives.
:return: wrapped callable.
"""

def wrapped_fn(*args, **kwargs):
if self.num_particles == 1:
return fn(*args, **kwargs)
with pyro.iarange("num_particles_vectorized", self.num_particles, dim=-self.max_iarange_nesting):
return fn(*args, **kwargs)

return poutine.broadcast(wrapped_fn)

def _get_vectorized_trace(self, model, guide, *args, **kwargs):
"""
Wraps the model and guide to vectorize ELBO computation over
``num_particles``, and returns a single trace from the wrapped model
and guide.
"""
return self._get_trace(self._vectorized_num_particles(model),
self._vectorized_num_particles(guide),
*args, **kwargs)

@abstractmethod
def _get_trace(self, model, guide, *args, **kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
raise NotImplementedError

def _get_traces(self, model, guide, *args, **kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if self.vectorize_particles:
yield self._get_vectorized_trace(model, guide, *args, **kwargs)
else:
for i in range(self.num_particles):
yield self._get_trace(model, guide, *args, **kwargs)
57 changes: 28 additions & 29 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,36 @@ class Trace_ELBO(ELBO):
Rajesh Ranganath, Sean Gerrish, David M. Blei
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
Returns a single trace from the guide, and the model that is run
against it.
"""
for i in range(self.num_particles):
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'Trace_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

yield model_trace, guide_trace
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'Trace_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

return model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
85 changes: 50 additions & 35 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,60 @@ class TraceEnum_ELBO(ELBO):
variables inside that :class:`~pyro.iarange`.
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
Returns a single trace from the guide, and the model that is run
against it.
"""
# enable parallel enumeration
guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

for i in range(self.num_particles):
for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
any_enumerated = False
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
if site["infer"].get("enumerate"):
any_enumerated = True
if self.strict_enumeration_warning and not any_enumerated:
warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
'If you want to enumerate sites, you need to @config_enumerate or set '
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
'If you do not want to enumerate, consider using Trace_ELBO instead.')
for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
any_enumerated = False
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
if site["infer"].get("enumerate"):
any_enumerated = True
if self.strict_enumeration_warning and not any_enumerated:
warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
'If you want to enumerate sites, you need to @config_enumerate or set '
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
'If you do not want to enumerate, consider using Trace_ELBO instead.')

yield model_trace, guide_trace

def _get_traces(self, model, guide, *args, **kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if self.vectorize_particles:
# enable parallel enumeration over the vectorized guide.
guide = poutine.enum(self._vectorized_num_particles(guide),
first_available_dim=self.max_iarange_nesting)
model = self._vectorized_num_particles(model)
for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs):
yield model_trace, guide_trace
else:
# enable parallel enumeration.
guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)
for i in range(self.num_particles):
for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs):
yield model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
46 changes: 22 additions & 24 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,32 +187,30 @@ class TraceGraph_ELBO(ELBO):
Andriy Mnih, Karol Gregor
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a tracegraph generator
Returns a single trace from the guide, and the model that is run
against it.
"""

for i in range(self.num_particles):
guide_trace = poutine.trace(guide,
graph_type="dense").get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="dense").get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'TraceGraph_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

weight = 1.0 / self.num_particles
yield weight, model_trace, guide_trace
guide_trace = poutine.trace(guide,
graph_type="dense").get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="dense").get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'TraceGraph_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

weight = 1.0 / self.num_particles
return weight, model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/infer/test_conjugate_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain


# TODO increase precision and number of particles once latter is parallelized properly
class ConjugateChainGradientTests(GaussianChain):

def test_gradients(self):
Expand All @@ -17,15 +16,15 @@ def do_test_gradients(self, N, reparameterized):
pyro.clear_param_store()
self.setup_chain(N)

elbo = TraceGraph_ELBO(num_particles=1000)
elbo = TraceGraph_ELBO(num_particles=10000, vectorize_particles=True, max_iarange_nesting=1)
elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized)

for i in range(1, N + 1):
for param_prefix in ["loc_q_%d", "log_sig_q_%d", "kappa_q_%d"]:
if i == N and param_prefix == 'kappa_q_%d':
continue
actual_grad = pyro.param(param_prefix % i).grad
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.20, msg="".join([
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.10, msg="".join([
"parameter %s%d" % (param_prefix[:-2], i),
"\nexpected = zero vector",
"\n actual = {}".format(actual_grad.detach().cpu().numpy())]))
Loading

0 comments on commit 865bff6

Please sign in to comment.