Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate escape and enumerate into continuation, centralize _get_traces/_traces logic #950

Closed
wants to merge 26 commits into from

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Mar 29, 2018

Part of refactoring mentioned in design doc... Will update this description with more details today

@eb8680 eb8680 added the WIP label Mar 29, 2018
@fritzo
Copy link
Member

fritzo commented Mar 29, 2018

Looks very clean so far.

@fritzo
Copy link
Member

fritzo commented Apr 4, 2018

Any progress on this? It would be nice to be able to use this, and I'll end up pulling pieces out of this PR until it merges.

@eb8680
Copy link
Member Author

eb8680 commented Apr 4, 2018

Any progress on this? It would be nice to be able to use this, and I'll end up pulling pieces out of this PR until it merges.

I guess this PR can be merged mostly as-is after the merge conflicts are fixed.

@eb8680 eb8680 requested a review from fritzo April 5, 2018 01:35
@eb8680 eb8680 self-assigned this Apr 5, 2018
@eb8680 eb8680 added awaiting review and removed WIP labels Apr 5, 2018
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...still reviewing continuation logic...

:param num_particles: number of non-parallel importance samples to take
:param graph_type: the type of the graph, e.g. "flat" or "dense".
:param max_iarange_nesting: maximum depth of iaranges
:returns: An iterator over model/guide trace pairs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this returns a callable that returns a generator?

check_site_shape(site, self.max_iarange_nesting)

yield model_trace, guide_trace
return iter_importance_traces(num_particles=self.num_particles,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might clean up style throughout uses of iter_importance_trace if we consistently introduced a temporary get_traces:

get_traces = iter_importance_traces(num_particles=self.num_particles, graph_type="flat")
return get_traces(model, guide, *args, **kwargs)

return ContinuationPoutine(fn, escape_fn, util.escape_cont_fn)


def enum(fn, first_available_dim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to expose this? I'm tempted to move all the enumeration-related helpers to pyro.infer.enum. Since that library is undocumented (i.e. absent from sphinx), we can reduce guarantees of backwards compatibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for API compatibility during the refactor, will remove later

@@ -126,6 +143,25 @@ def condition(fn, data):
return ConditionPoutine(fn, data=data)


def continuation(fn, escape_fn, continuation_fn, first_available_dim=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to expose this? It seems weird exposing a detail like first_available_dim in the poutine interface, and I can't think of examples beyond EnumeratePoutine

Also nit: argument should be named cont_fn

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just following the convention of the rest of poutine/__init__.py in exposing all constructor args in the alias, will remove it and rename continuation_fn.

@@ -128,3 +128,40 @@ def all_escape(trace, msg):
return (msg["type"] == "sample") and \
(not msg["is_observed"]) and \
(msg["name"] not in trace)


def broadcast_enum_filter(msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like these helpers would naturally live in pyro.infer.enum, where most of the enumeration logic lives.

return msg["infer"].get("enumerate") == "parallel"


def broadcast_enum_cont(msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""
Increments the next available expansion dimension.
"""
if "next_available_dim" in msg["infer"] and \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty specific for default behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's necessary in this iteration of the refactor for escape and enum to play nicely together, I think


Generalizes and replaces EscapeMessenger and EnumerateMessenger.
"""
def __init__(self, escape_fn, cont_fn, first_available_dim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would be cleaner to split this spit this into two classes: a base class ContinuationMessenger to handle escape_fn,cont_fn and a derived class EnumerateMessenger to handle first_available_dim and next_available_dim. It seems conflated to manage that dimension logic in e.g. poutine.escape. Also it will be easier to refactor dimension logic if we isolate that to the messenger that needs it, i.e. EnumerateMessenger.

Copy link
Member Author

@eb8680 eb8680 Apr 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it this way because this is mid-refactor; in the next version (if I ever figure out dependency tracking...) ContinuationMessenger will maintain an internal stack of IndepMessengers to mark multiple invocations of the rest of the program independent. These will be used for both escape and enumeration.

Basically, there's no reason for ContinuationMessenger to exist if it's not handling this logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fine if it's mid-refactor. The main thing I'm trying to avoid is the leaking of specific details (like first_available_dim) into general interfaces. I found this to be awkward in e.g. poutine.queue which takes args num_samples and max_tries. I'd like to make a general effort to keep specific details of derived classes from leaking into base classes, so as to avoid all derived classes having to know about details of their siblings (and substitute oo with fp or your favorite abstraction mechanism).

@fritzo
Copy link
Member

fritzo commented Apr 5, 2018

also btw it sure seems like we're getting lots of circular dependencies in this PR, e.g. importing pyro.infer from within submodules of pyro.infer. Feel free to move things from some_module/init.py to some_module/util.py to break these dependencies.

@eb8680 eb8680 added the WIP label Apr 5, 2018
@eb8680 eb8680 closed this Apr 11, 2018
@jpchen jpchen deleted the only-continuation-poutine branch December 9, 2018 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants