Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a SplitReparam and use it in contrib.epidemiology #2495

Merged
merged 7 commits into from
May 20, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented May 20, 2020

Addresses #2426

This adds a SplitReparam reparameterizer to split a sample site tensor into multiple other tensors, as suggested by @fehiepsi. The motivating use case (also implemented in this PR) is splitting Haar-reparameterized sample sites into low- and high-frequency parts, then adding the low frequency parts to the full_mass matrix in HMC.

Note this reparameterizer is quite limited: it cannot generate samples because there is no standard way to split a distribution into multiple independent distribution. However in HMC and SVI require only .log_prob() to be implemented, not .sample(). Actually this PR needed to change HMC internals to avoid calling .sample() inadvertently during prototype tracing.

This PR also rebalances some test since unit was taking 37 minutes vs integration_batch_1 taking 15 minutes.

Tested

  • unit tests
  • added epidemiology tests
  • added to epidemiology examples and tests

Comment on lines +387 to +388
prototype_model = poutine.trace(InitMessenger(init_strategy)(model))
model_trace = prototype_model.get_trace(*model_args, **model_kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi could you please review changes in this file? The changes are needed to

  1. avoid calling .sample() in constructing the prototype model_trace, and
  2. avoid duplicating expensive initialization work by reusing that model_trace as the first trace in _find_valid_initial_params().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, then with e.g. init_to_uniform strategy, we still call sample but won't for some other strategies? If so, the change looks great to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, init_to_uniform still calls sample. Thanks for reviewing!

@fritzo fritzo removed the WIP label May 20, 2020
Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! (didn't review mcmc refactor)

pyro/contrib/epidemiology/compartmental.py Outdated Show resolved Hide resolved
pyro/infer/reparam/split.py Show resolved Hide resolved
return torch.zeros(()).expand(batch_shape)

def sample(self, sample_shape=torch.Size()):
raise NotImplementedError("SplitReparam does not support sampling")
Copy link
Member

@fehiepsi fehiepsi May 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _ImproperUniform does not support sampling

I think we can move this to the main distributions module. Probably with an sample_fn arg to generate prototype samples for HMC. We might add a warning to the docs mention that sample_fn does not actually generate uniform samples in the support, but only to generate protype values for the inference. WDYT?

Copy link
Member Author

@fritzo fritzo May 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _ImproperUniform does not support sampling

This user-facing exception mentions the user-facing class; _ImproperUniform is an implementation detail.

I think we can move this to the main distributions module.

I feel like this is too bespoke for general use. Until we find another use case I'd prefer to keep it private. We could implement this other ways, e.g. Delta(nan).mask(False) would raise a Nan error rather than a NotImplementedError; I found this way helpful because it raised an exception early.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with an sample_fn arg to generate prototype samples for HMC.

I believe this functionality is already cleanly accomplished by InitMessenger. Do you have another use case in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I think dist.Foo(...).mask(False) can serve the same purpose and initial values can be drawn directly from Foo distribution. This class is useful when we don't know the support before-hand.

@martinjankowiak martinjankowiak merged commit 8cc51fb into dev May 20, 2020
@fritzo fritzo mentioned this pull request Jun 2, 2020
2 tasks
@fritzo fritzo deleted the split-reparam branch June 5, 2020 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants