Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to stop PyroModules from sharing parameters #3149

Merged
merged 18 commits into from
Nov 9, 2022

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Oct 23, 2022

Replaces #2996

This PR adds two small related features for easier Pyro-PyTorch integration:

  1. A __call__ method for the base pyro.infer.elbo.ELBO that binds ELBO instances to specific nn.Module model/guide pairs in a Module that exposes their PyTorch parameters
  2. A global setting (off by default for full backwards compatibility) that prevents PyroModule instances from sharing parameter values with one another through the global Pyro parameter store, and a primitive and context manager for toggling it. One context where this is useful is for workflows that involve multiple models and autoguides with overlapping parameter names.

An edge case I haven't handled here is the behavior under the new local parameter setting of regular pyro.param statements (as opposed to PyroParam) within a PyroModule that don't have their data associated with any underlying nn.Module. I've raised an error rather than attempt to get this working, since I think it's usually a PyroModule programming anti-pattern to mix global and local parameter states in this way.

I am also hopeful that these changes will simplify the use of Pyro with the PyTorch JIT and other PyTorch compilers, but I have left testing this for future work, since I suspect it will require additional engineering that is out of scope for this PR.

Tasks:

Tested:

  • Added test cases to some existing PyroModule tests in tests/nn/test_module.py
  • Added a new test that checks the behavior of these two features together
  • Add a test demonstrating expected failing behavior with global pyro.param

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice rethinking towards more idiomatic PyTorch!

pyro/infer/elbo.py Outdated Show resolved Hide resolved
pyro/infer/elbo.py Show resolved Hide resolved
pyro/infer/elbo.py Outdated Show resolved Hide resolved

def __enter__(self):
if not self.active and _is_module_local_param_enabled():
self._param_ctx = pyro.get_param_store().scope(state=self.param_state)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Persisting self.param_state like this (and in _pyro_set_supermodule below) seems to be a reasonable solution for the behavior of vanilla pyro.param statements. Values of these parameters are now local to the outermost PyroModule in a nested PyroModule instance.

pyro/nn/module.py Outdated Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey thanks for your patience in reviewing this subtle PR. The ELBOModule changes look clean. I'm still working through understanding the module_local_param changes...

@@ -562,3 +562,35 @@ def validation_enabled(is_validate=True):
dist.enable_validation(distribution_validation_status)
infer.enable_validation(infer_validation_status)
poutine.enable_validation(poutine_validation_status)


def enable_module_local_param(is_enabled: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to make it super clear that users can now decide between (i) a global param store or (ii) local nn.Module style parameters. Like maybe

with pyro.param_storage("local"): ...
with pyro.param_storage("global"): ...

or pyro.disable_param_store(True) or pyro.enable_param_store(False). Whatever we call it I think it would be good in the first docstring sentence to mention the phrase "param store" and the word "nn.Module".

if _is_module_local_param_enabled():
with pyro.get_param_store().scope(
state=self._pyro_context.param_state
) as vanilla_param_state:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would another word for "vanilla" be "global" or "global-only" or "raw" or "nonmodule" or something? We might want to avoid "vanilla" because PyTorch users new to Pyro might think of "vanilla" as "an nn.Param attribute of an nn.Module".

:class:`~pyro.nn.PyroModule` s.

Users are therefore strongly encouraged to use this interface in conjunction
with :func:`~pyro.enable_module_local_param` which will override the default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: override -> disable or avoid?

pyro/nn/module.py Outdated Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after minor comment on .set() vs .context() in tests

tests/conftest.py Outdated Show resolved Hide resolved
tests/nn/test_module.py Outdated Show resolved Hide resolved
@fritzo fritzo merged commit 8b7e564 into dev Nov 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants