-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate escape and enumerate into continuation, centralize _get_traces/_traces logic #950
Conversation
…ppoutine and poutine.indep
Looks very clean so far. |
Any progress on this? It would be nice to be able to use this, and I'll end up pulling pieces out of this PR until it merges. |
I guess this PR can be merged mostly as-is after the merge conflicts are fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...still reviewing continuation logic...
:param num_particles: number of non-parallel importance samples to take | ||
:param graph_type: the type of the graph, e.g. "flat" or "dense". | ||
:param max_iarange_nesting: maximum depth of iaranges | ||
:returns: An iterator over model/guide trace pairs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this returns a callable that returns a generator?
check_site_shape(site, self.max_iarange_nesting) | ||
|
||
yield model_trace, guide_trace | ||
return iter_importance_traces(num_particles=self.num_particles, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it might clean up style throughout uses of iter_importance_trace
if we consistently introduced a temporary get_traces
:
get_traces = iter_importance_traces(num_particles=self.num_particles, graph_type="flat")
return get_traces(model, guide, *args, **kwargs)
return ContinuationPoutine(fn, escape_fn, util.escape_cont_fn) | ||
|
||
|
||
def enum(fn, first_available_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really want to expose this? I'm tempted to move all the enumeration-related helpers to pyro.infer.enum. Since that library is undocumented (i.e. absent from sphinx), we can reduce guarantees of backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for API compatibility during the refactor, will remove later
@@ -126,6 +143,25 @@ def condition(fn, data): | |||
return ConditionPoutine(fn, data=data) | |||
|
|||
|
|||
def continuation(fn, escape_fn, continuation_fn, first_available_dim=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really want to expose this? It seems weird exposing a detail like first_available_dim
in the poutine interface, and I can't think of examples beyond EnumeratePoutine
Also nit: argument should be named cont_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just following the convention of the rest of poutine/__init__.py
in exposing all constructor args in the alias, will remove it and rename continuation_fn
.
@@ -128,3 +128,40 @@ def all_escape(trace, msg): | |||
return (msg["type"] == "sample") and \ | |||
(not msg["is_observed"]) and \ | |||
(msg["name"] not in trace) | |||
|
|||
|
|||
def broadcast_enum_filter(msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like these helpers would naturally live in pyro.infer.enum, where most of the enumeration logic lives.
return msg["infer"].get("enumerate") == "parallel" | ||
|
||
|
||
def broadcast_enum_cont(msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
""" | ||
Increments the next available expansion dimension. | ||
""" | ||
if "next_available_dim" in msg["infer"] and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pretty specific for default behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's necessary in this iteration of the refactor for escape
and enum
to play nicely together, I think
|
||
Generalizes and replaces EscapeMessenger and EnumerateMessenger. | ||
""" | ||
def __init__(self, escape_fn, cont_fn, first_available_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like it would be cleaner to split this spit this into two classes: a base class ContinuationMessenger
to handle escape_fn,cont_fn
and a derived class EnumerateMessenger
to handle first_available_dim
and next_available_dim
. It seems conflated to manage that dimension logic in e.g. poutine.escape
. Also it will be easier to refactor dimension logic if we isolate that to the messenger that needs it, i.e. EnumerateMessenger
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept it this way because this is mid-refactor; in the next version (if I ever figure out dependency tracking...) ContinuationMessenger
will maintain an internal stack of IndepMessenger
s to mark multiple invocations of the rest of the program independent. These will be used for both escape
and enumeration.
Basically, there's no reason for ContinuationMessenger
to exist if it's not handling this logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's fine if it's mid-refactor. The main thing I'm trying to avoid is the leaking of specific details (like first_available_dim
) into general interfaces. I found this to be awkward in e.g. poutine.queue
which takes args num_samples
and max_tries
. I'd like to make a general effort to keep specific details of derived classes from leaking into base classes, so as to avoid all derived classes having to know about details of their siblings (and substitute oo with fp or your favorite abstraction mechanism).
also btw it sure seems like we're getting lots of circular dependencies in this PR, e.g. importing pyro.infer from within submodules of pyro.infer. Feel free to move things from some_module/init.py to some_module/util.py to break these dependencies. |
Part of refactoring mentioned in design doc... Will update this description with more details today