Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel enumeration over discrete sample sites #776

Merged
merged 28 commits into from
Feb 25, 2018

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Feb 15, 2018

See Design Doc | Closes #742 | Replaces #227

This implements parallel enumeration for discrete sample sites. The new notation is

- SVI(model, guide, optim, "ELBO", enum_discrete=True)
+ SVI(model, config_enumerate(guide), optim, "ELBO")

This allows finer-grained control over each discrete sample site, e.g. we can freely mix monte carlo sampling, sequential enumeration, and parallel enumeration:

def guide():
    pyro.sample("x", Categorical(p), infer={"enumerate": "sequential"})
    pyro.sample("y", Categorical(p), infer={"enumerate": "parallel"})
    pyro.sample("z", Categorical(p))  # monte carlo

Additionally we can use the @config_enumerate decorator to annotate an entire guide:

@config_enumerate
def guide():
    pyro.sample("x", Categorical(p))  # sequential
    pyro.sample("y", Categorical(p), infer={"enumerate": "parallel"})
    pyro.sample("z", Categorical(p), infer={"enumerate": None})  # monte carlo

We can also set parallel enumeration by default

@config_enumerate(default="parallel")
def guide():
    ...

Tasks

  • Add tests for max_iarange_nesting
  • Fix batch_log_pdf.shape errors
  • Add tests for nested enumeration (mixing sequential with parallel)
  • Add tests for correctness of gradients

@fritzo fritzo force-pushed the enumerate-parallel branch from d352788 to c3c0c80 Compare February 15, 2018 17:55
@fritzo fritzo mentioned this pull request Feb 20, 2018
@fritzo
Copy link
Member Author

fritzo commented Feb 21, 2018

@eb8680 Could you take a look at the new interface using enumerate_discrete(guide) and your Poutine rather than the enum_discrete=True kwarg? (there are still some test failures, but the interface is mostly complete).

@eb8680
Copy link
Member

eb8680 commented Feb 21, 2018

Could you take a look at the new interface using enumerate_discrete(guide) and your Poutine rather than the enum_discrete=True kwarg?

It mostly looks fine, but maybe we should change the name of enumerate_discrete to config_enumerate or something similar to make it clear that it does not modify the semantics of the guide on its own (i.e. calling enumerate_discrete(guide)(...) with no enclosing EnumerateMessenger context will sample the latent variables rather than enumerating).

with xfail_if_not_implemented():
inference.step(data)


@pytest.mark.parametrize("enum_discrete", [None, "sequential", "parallel"])
@pytest.mark.parametrize("trace_graph", [False, True], ids=["dense", "flat"])
def test_bern_elbo_gradient(enum_discrete, trace_graph):
Copy link
Member Author

@fritzo fritzo Feb 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the following test use analytic grad(kl_divergence(-, -), -) to test correctness of ELBO gradients.

@@ -190,9 +190,6 @@ def _get_traces(self, model, guide, *args, **kwargs):
"""

for i in range(self.num_particles):
if self.enum_discrete:
Copy link
Member Author

@fritzo fritzo Feb 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enumeration is now controlled by hints in the site['infer'] dict. If an inference algorithm does not implement enumeration, it can safely ignore those hints and sample rather than enumerate.

q_fn = poutine.queue(fn, queue=queue)
full_trace = poutine.trace(
q_fn, graph_type=graph_type).get_trace(*args, **kwargs)
q_fn = poutine.queue(fn, queue=queue, escape_fn=_iter_discrete_escape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be outside the while loop replacing q_fn above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks right to me. Done.

@@ -90,40 +90,58 @@ def sum_rightmost(value, dim):
"""
Sum out ``dim`` many rightmost dimensions of a given tensor.

If ``dim`` is 0, no dimensions are summed out.
If ``dim`` is ``float('inf')``, then all dimensions are summed out.
If ``dim`` is 1, the leftmost 1 dimension is summed out.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftmost -> rightmost and reverse below.



def site_is_discrete(name, site):
return getattr(site["fn"], "enumerable", False)
def _iter_discrete_filter(name, msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the name from the function args?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't remove this arg without changes to Trace.compute_batch_log_pdf(). Let's clean this up later when we refactor the Trace.compute_() methods.

from .poutine import Messenger, Poutine


def _iter_discrete_filter(name, msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Can we remove name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@fritzo
Copy link
Member Author

fritzo commented Feb 24, 2018

Woo hoo, tests finally pass!

@@ -25,9 +29,9 @@ class ELBO(object):

def __init__(self,
num_particles=1,
enum_discrete=False):
max_iarange_nesting=float('inf')):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it weird to mix floats and ints?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems pretty safe to me. We're taking care to consume this number only through comparisons (in functions like sum_rightmost() and check_sites() in #806 ) and we only read its value if it is less than some other finite number. This float('inf') solution seems cleaner to me than the alternatives like INT_MAX or -1 for which we would need logic that is even less readable.


def config_enumerate(guide=None, default="sequential"):
"""
Configures each enumerable site a guide to enumerate with given method,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the docstring mention that it doesn't override?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

weight.dim() > 0 and \
weight.size(0) > 1
# iterate over a bag of traces, one trace per particle
for scale, guide_trace in iter_discrete_traces("flat", self.max_iarange_nesting, guide, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so beautiful

from .poutine import Messenger, Poutine


def _iter_discrete_filter(msg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is defined in two places

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this = _iter_discrete_filter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They differ, but I'll rename them to make that clear...


# Ensure enumeration happens at an available tensor dimension.
event_dim = len(msg["fn"].event_shape)
actual_dim = value.dim() - event_dim - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have one or two comments here i find this confusing

actual_dim = value.dim() - event_dim - 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've simplified and added some comments.

@@ -49,6 +50,7 @@ def model():
def test_iter_discrete_traces_vector(graph_type):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the deal with test_iter_discrete_traces_vector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and removed the @xfail

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

great work!!!

@fritzo
Copy link
Member Author

fritzo commented Feb 25, 2018

Finally ready to merge!

@martinjankowiak martinjankowiak merged commit 1bbfb48 into dev Feb 25, 2018
@fritzo fritzo mentioned this pull request Feb 28, 2018
4 tasks
@fritzo fritzo deleted the enumerate-parallel branch March 6, 2018 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants