Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to stop PyroModules from sharing parameters #3149

Merged
merged 18 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"condition",
"deterministic",
"do",
"enable_module_local_param",
"enable_validation",
"factor",
"get_param_store",
Expand All @@ -51,6 +52,7 @@
"log",
"markov",
"module",
"module_local_param_enabled",
"param",
"plate",
"plate",
Expand Down
54 changes: 54 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
import warnings
from abc import ABCMeta, abstractmethod

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import is_validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape


class ELBOModule(torch.nn.Module):
def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"):
super().__init__()
self.model = model
self.guide = guide
self.elbo = elbo

def forward(self, *args, **kwargs):
return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved


class ELBO(object, metaclass=ABCMeta):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand All @@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta):
:class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or
:class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`.

.. note:: Derived classes now provide a more idiomatic PyTorch interface via
:meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s,
which is useful for integrating Pyro's variational inference tooling with
standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s
and the large ecosystem of libraries like PyTorch Lightning
and the PyTorch JIT that work with these interfaces::

model = Model()
guide = pyro.infer.autoguide.AutoNormal(model)

elbo_ = pyro.infer.Trace_ELBO(num_particles=10)

# Fix the model/guide pair
elbo = elbo_(model, guide)

# perform any data-dependent initialization
elbo(data)

optim = torch.optim.Adam(elbo.parameters(), lr=0.001)

for _ in range(100):
optim.zero_grad()
loss = elbo(data)
loss.backward()
optim.step()

Note that Pyro's global parameter store may cause this new interface to
behave unexpectedly relative to standard PyTorch when working with
:class:`~pyro.nn.PyroModule` s.

Users are therefore strongly encouraged to use this interface in conjunction
with :func:`~pyro.enable_module_local_param` which will override the default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: override -> disable or avoid?

implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances.

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param int max_plate_nesting: Optional bound on max number of nested
Expand Down Expand Up @@ -86,6 +133,13 @@ def __init__(
self.jit_options = jit_options
self.tail_adaptive_beta = tail_adaptive_beta

def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ELBOModule:
"""
Given a model and guide, returns a :class:`~torch.nn.Module` which
computes the ELBO loss when called with arguments to the model and guide.
"""
return ELBOModule(model, guide, self)

def _guess_max_plate_nesting(self, model, guide, args, kwargs):
"""
Guesses max_plate_nesting by running the (model,guide) pair once
Expand Down
42 changes: 41 additions & 1 deletion pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE

_MODULE_LOCAL_PARAMS: bool = False


@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS")
def _validate_module_local_params(value: bool) -> None:
assert isinstance(value, bool)


def _is_module_local_param_enabled() -> bool:
return pyro.settings.get("module_local_params")


class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))):
"""
Expand Down Expand Up @@ -178,15 +189,23 @@ def __init__(self):
self.active = 0
self.cache = {}
self.used = False
if _is_module_local_param_enabled():
self.param_state = {"params": {}, "constraints": {}}

def __enter__(self):
if not self.active and _is_module_local_param_enabled():
self._param_ctx = pyro.get_param_store().scope(state=self.param_state)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Persisting self.param_state like this (and in _pyro_set_supermodule below) seems to be a reasonable solution for the behavior of vanilla pyro.param statements. Values of these parameters are now local to the outermost PyroModule in a nested PyroModule instance.

self.param_state = self._param_ctx.__enter__()
self.active += 1
self.used = True

def __exit__(self, type, value, traceback):
self.active -= 1
if not self.active:
self.cache.clear()
if _is_module_local_param_enabled():
self._param_ctx.__exit__(type, value, traceback)
del self._param_ctx

def get(self, name):
if self.active:
Expand Down Expand Up @@ -409,6 +428,8 @@ def named_pyro_params(self, prefix="", recurse=True):
yield elem

def _pyro_set_supermodule(self, name, context):
if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"):
self._check_module_local_param_usage()
self._pyro_name = name
self._pyro_context = context
for key, value in self._modules.items():
Expand All @@ -424,7 +445,26 @@ def _pyro_get_fullname(self, name):

def __call__(self, *args, **kwargs):
with self._pyro_context:
return super().__call__(*args, **kwargs)
result = super().__call__(*args, **kwargs)
if (
pyro.settings.get("validate_poutine")
and not self._pyro_context.active
and _is_module_local_param_enabled()
):
self._check_module_local_param_usage()
return result

def _check_module_local_param_usage(self) -> None:
self_nn_params = set(id(p) for p in self.parameters())
self_pyro_params = set(
id(p if not hasattr(p, "unconstrained") else p.unconstrained())
for p in self._pyro_context.param_state["params"].values()
)
if not self_pyro_params <= self_nn_params:
raise NotImplementedError(
"Support for global pyro.param statements in PyroModules "
"with local param mode enabled is not yet implemented."
)

def __getattr__(self, name):
# PyroParams trigger pyro.param statements.
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def pytest_runtest_setup(item):
if test_initialize_marker:
rng_seed = test_initialize_marker.kwargs["rng_seed"]
pyro.set_rng_seed(rng_seed)
pyro.settings.set(module_local_params=False)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved


def pytest_addoption(parser):
Expand Down
Loading