-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement parallel enumeration over discrete sample sites #776
Conversation
d352788
to
c3c0c80
Compare
@eb8680 Could you take a look at the new interface using |
It mostly looks fine, but maybe we should change the name of |
626c59b
to
37b6ac9
Compare
tests/infer/test_enum.py
Outdated
with xfail_if_not_implemented(): | ||
inference.step(data) | ||
|
||
|
||
@pytest.mark.parametrize("enum_discrete", [None, "sequential", "parallel"]) | ||
@pytest.mark.parametrize("trace_graph", [False, True], ids=["dense", "flat"]) | ||
def test_bern_elbo_gradient(enum_discrete, trace_graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the following test use analytic grad(kl_divergence(-, -), -)
to test correctness of ELBO gradients.
@@ -190,9 +190,6 @@ def _get_traces(self, model, guide, *args, **kwargs): | |||
""" | |||
|
|||
for i in range(self.num_particles): | |||
if self.enum_discrete: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enumeration is now controlled by hints in the site['infer']
dict. If an inference algorithm does not implement enumeration, it can safely ignore those hints and sample rather than enumerate.
pyro/infer/enum.py
Outdated
q_fn = poutine.queue(fn, queue=queue) | ||
full_trace = poutine.trace( | ||
q_fn, graph_type=graph_type).get_trace(*args, **kwargs) | ||
q_fn = poutine.queue(fn, queue=queue, escape_fn=_iter_discrete_escape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be outside the while loop replacing q_fn
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks right to me. Done.
pyro/distributions/util.py
Outdated
@@ -90,40 +90,58 @@ def sum_rightmost(value, dim): | |||
""" | |||
Sum out ``dim`` many rightmost dimensions of a given tensor. | |||
|
|||
If ``dim`` is 0, no dimensions are summed out. | |||
If ``dim`` is ``float('inf')``, then all dimensions are summed out. | |||
If ``dim`` is 1, the leftmost 1 dimension is summed out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftmost
-> rightmost
and reverse below.
|
||
|
||
def site_is_discrete(name, site): | ||
return getattr(site["fn"], "enumerable", False) | ||
def _iter_discrete_filter(name, msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the name from the function args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't remove this arg without changes to Trace.compute_batch_log_pdf()
. Let's clean this up later when we refactor the Trace.compute_()
methods.
pyro/poutine/enumerate_poutine.py
Outdated
from .poutine import Messenger, Poutine | ||
|
||
|
||
def _iter_discrete_filter(name, msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Can we remove name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Woo hoo, tests finally pass! |
@@ -25,9 +29,9 @@ class ELBO(object): | |||
|
|||
def __init__(self, | |||
num_particles=1, | |||
enum_discrete=False): | |||
max_iarange_nesting=float('inf')): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it weird to mix floats and ints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems pretty safe to me. We're taking care to consume this number only through comparisons (in functions like sum_rightmost()
and check_sites()
in #806 ) and we only read its value if it is less than some other finite number. This float('inf')
solution seems cleaner to me than the alternatives like INT_MAX
or -1
for which we would need logic that is even less readable.
|
||
def config_enumerate(guide=None, default="sequential"): | ||
""" | ||
Configures each enumerable site a guide to enumerate with given method, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the docstring mention that it doesn't override?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
weight.dim() > 0 and \ | ||
weight.size(0) > 1 | ||
# iterate over a bag of traces, one trace per particle | ||
for scale, guide_trace in iter_discrete_traces("flat", self.max_iarange_nesting, guide, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so beautiful
pyro/poutine/enumerate_poutine.py
Outdated
from .poutine import Messenger, Poutine | ||
|
||
|
||
def _iter_discrete_filter(msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is defined in two places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this = _iter_discrete_filter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They differ, but I'll rename them to make that clear...
pyro/poutine/enumerate_poutine.py
Outdated
|
||
# Ensure enumeration happens at an available tensor dimension. | ||
event_dim = len(msg["fn"].event_shape) | ||
actual_dim = value.dim() - event_dim - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have one or two comments here i find this confusing
actual_dim = value.dim() - event_dim - 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've simplified and added some comments.
@@ -49,6 +50,7 @@ def model(): | |||
def test_iter_discrete_traces_vector(graph_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the deal with test_iter_discrete_traces_vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed and removed the @xfail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
great work!!!
Finally ready to merge! |
See Design Doc | Closes #742 | Replaces #227
This implements parallel enumeration for discrete sample sites. The new notation is
This allows finer-grained control over each discrete sample site, e.g. we can freely mix monte carlo sampling, sequential enumeration, and parallel enumeration:
Additionally we can use the
@config_enumerate
decorator to annotate an entire guide:We can also set parallel enumeration by default
Tasks
max_iarange_nesting
batch_log_pdf.shape
errors