Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize ELBO computation over num particles #1176

Merged
merged 7 commits into from
Jun 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from __future__ import absolute_import, division, print_function

from abc import abstractmethod, ABCMeta

from six import add_metaclass

import pyro
import pyro.poutine as poutine


@add_metaclass(ABCMeta)
class ELBO(object):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand All @@ -15,6 +23,12 @@ class ELBO(object):
:func:`pyro.iarange` contexts. This is only required to enumerate over
sample sites in parallel, e.g. if a site sets
``infer={"enumerate": "parallel"}``.
:param bool vectorize_particles: Whether to vectorize the ELBO computation
over `num_particles`. Defaults to False. This requires static structure
in model and guide. In addition, this wraps the model and guide inside a
:class:`~pyro.poutine.broadcast` poutine for automatic broadcasting of
sample site batch shapes, and requires specifying a finite value for
`max_iarange_nesting`.
:param bool strict_enumeration_warning: Whether to warn about possible
misuse of enumeration, i.e. that
:class:`pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there
Expand All @@ -32,7 +46,63 @@ class ELBO(object):
def __init__(self,
num_particles=1,
max_iarange_nesting=float('inf'),
vectorize_particles=False,
strict_enumeration_warning=True):
self.num_particles = num_particles
self.max_iarange_nesting = max_iarange_nesting
self.vectorize_particles = vectorize_particles
if self.vectorize_particles:
if self.num_particles > 1:
if self.max_iarange_nesting == float('inf'):
raise ValueError("Automatic vectorization over num_particles requires " +
"a finite value for `max_iarange_nesting` arg.")
self.max_iarange_nesting += 1
self.strict_enumeration_warning = strict_enumeration_warning

def _vectorized_num_particles(self, fn):
"""
Wraps a callable inside an outermost :class:`~pyro.iarange` to parallelize
ELBO computation over `num_particles`, and a :class:`~pyro.poutine.broadcast`
poutine to broadcast batch shapes of sample site functions in accordance
with the `~pyro.iarange` contexts within which they are embedded.
:param fn: arbitrary callable containing Pyro primitives.
:return: wrapped callable.
"""

def wrapped_fn(*args, **kwargs):
if self.num_particles == 1:
return fn(*args, **kwargs)
with pyro.iarange("num_particles_vectorized", self.num_particles, dim=-self.max_iarange_nesting):
return fn(*args, **kwargs)

return poutine.broadcast(wrapped_fn)
Copy link
Member

@fritzo fritzo Jun 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth checking whether fn is already wrapped in poutine.broadcast and avoid wrapping a second time? Alternatively, this avoidance check could be done inside poutine.broadcast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I'll check it in poutine.broadcast itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add fn.__broadcasted__ as an attribute to mark a function that has already been broadcasted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't broadcast be idempotent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes broadcast should be idempotent, I'm asking only whether it is worth avoiding double wrapping for efficiency purposes. If it costs more than 5 lines of code, then it's probably not worth avoiding the double wrap.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll remove this.


def _get_vectorized_trace(self, model, guide, *args, **kwargs):
"""
Wraps the model and guide to vectorize ELBO computation over
``num_particles``, and returns a single trace from the wrapped model
and guide.
"""
return self._get_trace(self._vectorized_num_particles(model),
self._vectorized_num_particles(guide),
*args, **kwargs)

@abstractmethod
def _get_trace(self, model, guide, *args, **kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
raise NotImplementedError

def _get_traces(self, model, guide, *args, **kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if self.vectorize_particles:
yield self._get_vectorized_trace(model, guide, *args, **kwargs)
else:
for i in range(self.num_particles):
yield self._get_trace(model, guide, *args, **kwargs)
57 changes: 28 additions & 29 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,36 @@ class Trace_ELBO(ELBO):
Rajesh Ranganath, Sean Gerrish, David M. Blei
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
Returns a single trace from the guide, and the model that is run
against it.
"""
for i in range(self.num_particles):
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'Trace_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

yield model_trace, guide_trace
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this method identical sans the for loop over num_particles? the diff shows the whole method has changed but from a cursory glance it looks the same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are identical. have you tried the old "cross your eyes at the side-by-side diff" trick?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do the "cmd-f a randomly sampled substring and see if it appears on both sides". 90% of the time, it works every time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's right! Mostly Trace_ELBO, TraceGraph_ELBO and TraceEnum_ELBO will only differ in how they generate a single trace. This is reused when we want to generate traces vectorized over num_particles.

model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'Trace_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

return model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
85 changes: 50 additions & 35 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,60 @@ class TraceEnum_ELBO(ELBO):
variables inside that :class:`~pyro.iarange`.
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
Returns a single trace from the guide, and the model that is run
against it.
"""
# enable parallel enumeration
guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

for i in range(self.num_particles):
for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
any_enumerated = False
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
if site["infer"].get("enumerate"):
any_enumerated = True
if self.strict_enumeration_warning and not any_enumerated:
warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
'If you want to enumerate sites, you need to @config_enumerate or set '
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
'If you do not want to enumerate, consider using Trace_ELBO instead.')
for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)

model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
any_enumerated = False
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
if site["infer"].get("enumerate"):
any_enumerated = True
if self.strict_enumeration_warning and not any_enumerated:
warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
'If you want to enumerate sites, you need to @config_enumerate or set '
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
'If you do not want to enumerate, consider using Trace_ELBO instead.')

yield model_trace, guide_trace

def _get_traces(self, model, guide, *args, **kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if self.vectorize_particles:
# enable parallel enumeration over the vectorized guide.
guide = poutine.enum(self._vectorized_num_particles(guide),
first_available_dim=self.max_iarange_nesting)
model = self._vectorized_num_particles(model)
for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs):
yield model_trace, guide_trace
else:
# enable parallel enumeration.
guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)
for i in range(self.num_particles):
for model_trace, guide_trace in self._get_trace(model, guide, *args, **kwargs):
yield model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
46 changes: 22 additions & 24 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,32 +187,30 @@ class TraceGraph_ELBO(ELBO):
Andriy Mnih, Karol Gregor
"""

def _get_traces(self, model, guide, *args, **kwargs):
def _get_trace(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a tracegraph generator
Returns a single trace from the guide, and the model that is run
against it.
"""

for i in range(self.num_particles):
guide_trace = poutine.trace(guide,
graph_type="dense").get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="dense").get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'TraceGraph_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

weight = 1.0 / self.num_particles
yield weight, model_trace, guide_trace
guide_trace = poutine.trace(guide,
graph_type="dense").get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
graph_type="dense").get_trace(*args, **kwargs)
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace)
enumerated_sites = [name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate")]
if enumerated_sites:
warnings.warn('\n'.join([
'TraceGraph_ELBO found sample sites configured for enumeration:'
', '.join(enumerated_sites),
'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

weight = 1.0 / self.num_particles
return weight, model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/infer/test_conjugate_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain


# TODO increase precision and number of particles once latter is parallelized properly
class ConjugateChainGradientTests(GaussianChain):

def test_gradients(self):
Expand All @@ -17,15 +16,15 @@ def do_test_gradients(self, N, reparameterized):
pyro.clear_param_store()
self.setup_chain(N)

elbo = TraceGraph_ELBO(num_particles=1000)
elbo = TraceGraph_ELBO(num_particles=10000, vectorize_particles=True, max_iarange_nesting=1)
elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized)

for i in range(1, N + 1):
for param_prefix in ["loc_q_%d", "log_sig_q_%d", "kappa_q_%d"]:
if i == N and param_prefix == 'kappa_q_%d':
continue
actual_grad = pyro.param(param_prefix % i).grad
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.20, msg="".join([
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.10, msg="".join([
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

"parameter %s%d" % (param_prefix[:-2], i),
"\nexpected = zero vector",
"\n actual = {}".format(actual_grad.detach().cpu().numpy())]))
Loading