diff --git a/pyro/__init__.py b/pyro/__init__.py index b92471913e..bd3b143c58 100644 --- a/pyro/__init__.py +++ b/pyro/__init__.py @@ -43,6 +43,7 @@ "condition", "deterministic", "do", + "enable_module_local_param", "enable_validation", "factor", "get_param_store", @@ -51,6 +52,7 @@ "log", "markov", "module", + "module_local_param_enabled", "param", "plate", "plate", diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 3abe07d748..05ebd6f2a9 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -5,6 +5,8 @@ import warnings from abc import ABCMeta, abstractmethod +import torch + import pyro import pyro.poutine as poutine from pyro.infer.util import is_validation_enabled @@ -12,6 +14,17 @@ from pyro.util import check_site_shape +class ELBOModule(torch.nn.Module): + def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"): + super().__init__() + self.model = model + self.guide = guide + self.elbo = elbo + + def forward(self, *args, **kwargs): + return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs) + + class ELBO(object, metaclass=ABCMeta): """ :class:`ELBO` is the top-level interface for stochastic variational @@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta): :class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`. + .. note:: Derived classes now provide a more idiomatic PyTorch interface via + :meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s, + which is useful for integrating Pyro's variational inference tooling with + standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s + and the large ecosystem of libraries like PyTorch Lightning + and the PyTorch JIT that work with these interfaces:: + + model = Model() + guide = pyro.infer.autoguide.AutoNormal(model) + + elbo_ = pyro.infer.Trace_ELBO(num_particles=10) + + # Fix the model/guide pair + elbo = elbo_(model, guide) + + # perform any data-dependent initialization + elbo(data) + + optim = torch.optim.Adam(elbo.parameters(), lr=0.001) + + for _ in range(100): + optim.zero_grad() + loss = elbo(data) + loss.backward() + optim.step() + + Note that Pyro's global parameter store may cause this new interface to + behave unexpectedly relative to standard PyTorch when working with + :class:`~pyro.nn.PyroModule` s. + + Users are therefore strongly encouraged to use this interface in conjunction + with :func:`~pyro.enable_module_local_param` which will override the default + implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances. + :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. :param int max_plate_nesting: Optional bound on max number of nested @@ -86,6 +133,13 @@ def __init__( self.jit_options = jit_options self.tail_adaptive_beta = tail_adaptive_beta + def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ELBOModule: + """ + Given a model and guide, returns a :class:`~torch.nn.Module` which + computes the ELBO loss when called with arguments to the model and guide. + """ + return ELBOModule(model, guide, self) + def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 7a87631e2e..8168461f0c 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -23,6 +23,17 @@ from pyro.ops.provenance import detach_provenance from pyro.poutine.runtime import _PYRO_PARAM_STORE +_MODULE_LOCAL_PARAMS: bool = False + + +@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS") +def _validate_module_local_params(value: bool) -> None: + assert isinstance(value, bool) + + +def _is_module_local_param_enabled() -> bool: + return pyro.settings.get("module_local_params") + class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): """ @@ -178,8 +189,13 @@ def __init__(self): self.active = 0 self.cache = {} self.used = False + if _is_module_local_param_enabled(): + self.param_state = {"params": {}, "constraints": {}} def __enter__(self): + if not self.active and _is_module_local_param_enabled(): + self._param_ctx = pyro.get_param_store().scope(state=self.param_state) + self.param_state = self._param_ctx.__enter__() self.active += 1 self.used = True @@ -187,6 +203,9 @@ def __exit__(self, type, value, traceback): self.active -= 1 if not self.active: self.cache.clear() + if _is_module_local_param_enabled(): + self._param_ctx.__exit__(type, value, traceback) + del self._param_ctx def get(self, name): if self.active: @@ -409,6 +428,8 @@ def named_pyro_params(self, prefix="", recurse=True): yield elem def _pyro_set_supermodule(self, name, context): + if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"): + self._check_module_local_param_usage() self._pyro_name = name self._pyro_context = context for key, value in self._modules.items(): @@ -424,7 +445,26 @@ def _pyro_get_fullname(self, name): def __call__(self, *args, **kwargs): with self._pyro_context: - return super().__call__(*args, **kwargs) + result = super().__call__(*args, **kwargs) + if ( + pyro.settings.get("validate_poutine") + and not self._pyro_context.active + and _is_module_local_param_enabled() + ): + self._check_module_local_param_usage() + return result + + def _check_module_local_param_usage(self) -> None: + self_nn_params = set(id(p) for p in self.parameters()) + self_pyro_params = set( + id(p if not hasattr(p, "unconstrained") else p.unconstrained()) + for p in self._pyro_context.param_state["params"].values() + ) + if not self_pyro_params <= self_nn_params: + raise NotImplementedError( + "Support for global pyro.param statements in PyroModules " + "with local param mode enabled is not yet implemented." + ) def __getattr__(self, name): # PyroParams trigger pyro.param statements. diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 41198b9f2a..90b8af0eec 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -66,7 +66,97 @@ def forward(self, *args, **kwargs): svi.step(data) -def test_names(): +@pytest.mark.parametrize("local_params", [True, False]) +@pytest.mark.parametrize("num_particles", [1, 2]) +@pytest.mark.parametrize("vectorize_particles", [True, False]) +@pytest.mark.parametrize("Autoguide", [pyro.infer.autoguide.AutoNormal]) +def test_svi_elbomodule_interface( + local_params, num_particles, vectorize_particles, Autoguide +): + class Model(PyroModule): + def __init__(self): + super().__init__() + self.loc = nn.Parameter(torch.zeros(2)) + self.scale = PyroParam(torch.ones(2), constraint=constraints.positive) + self.z = PyroSample( + lambda self: dist.Normal(self.loc, self.scale).to_event(1) + ) + + def forward(self, data): + loc, log_scale = self.z.unbind(-1) + with pyro.plate("data"): + pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), obs=data) + + with pyro.settings.context(module_local_params=local_params): + data = torch.randn(5) + model = Model() + model(data) # initialize + + guide = Autoguide(model) + guide(data) # initialize + + elbo = Trace_ELBO( + vectorize_particles=vectorize_particles, num_particles=num_particles + ) + elbo = elbo(model, guide) + assert isinstance(elbo, torch.nn.Module) + assert set( + k[: -len("_unconstrained")] if k.endswith("_unconstrained") else k + for k, v in elbo.named_parameters() + ) == set("model." + k for k, v in model.named_pyro_params()) | set( + "guide." + k for k, v in guide.named_pyro_params() + ) + + adam = torch.optim.Adam(elbo.parameters(), lr=0.0001) + for _ in range(3): + adam.zero_grad() + loss = elbo(data) + loss.backward() + adam.step() + + guide2 = Autoguide(model) + guide2(data) # initialize + if local_params: + assert set(pyro.get_param_store().keys()) == set() + for (name, p), (name2, p2) in zip( + guide.named_parameters(), guide2.named_parameters() + ): + assert name == name2 + assert not torch.allclose(p, p2) + else: + assert set(pyro.get_param_store().keys()) != set() + for (name, p), (name2, p2) in zip( + guide.named_parameters(), guide2.named_parameters() + ): + assert name == name2 + assert torch.allclose(p, p2) + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_local_param_global_behavior_fails(local_params): + class Model(PyroModule): + def __init__(self): + super().__init__() + self.global_nn_param = nn.Parameter(torch.zeros(2)) + + def forward(self): + global_param = pyro.param("_global_param", lambda: torch.randn(2)) + global_nn_param = pyro.param("global_nn_param", self.global_nn_param) + return global_param, global_nn_param + + with pyro.settings.context(module_local_params=local_params): + model = Model() + if local_params: + assert pyro.settings.get("module_local_params") + with pytest.raises(NotImplementedError): + model() + else: + assert not pyro.settings.get("module_local_params") + model() + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_names(local_params): class Model(PyroModule): def __init__(self): super().__init__() @@ -86,34 +176,39 @@ def forward(self): self.p.v self.p.w - model = Model() - - # Check named_parameters. - expected = { - "x", - "y_unconstrained", - "m.u", - "p.v", - "p.w_unconstrained", - } - actual = set(name for name, _ in model.named_parameters()) - assert actual == expected - - # Check pyro.param names. - expected = {"x", "y", "m$$$u", "p.v", "p.w"} - with poutine.trace(param_only=True) as param_capture: - model() - actual = { - name - for name, site in param_capture.trace.nodes.items() - if site["type"] == "param" - } - assert actual == expected - - # Check pyro_parameters method - expected = {"x", "y", "m.u", "p.v", "p.w"} - actual = set(k for k, v in model.named_pyro_params()) - assert actual == expected + with pyro.settings.context(module_local_params=local_params): + model = Model() + + # Check named_parameters. + expected = { + "x", + "y_unconstrained", + "m.u", + "p.v", + "p.w_unconstrained", + } + actual = set(name for name, _ in model.named_parameters()) + assert actual == expected + + # Check pyro.param names. + expected = {"x", "y", "m$$$u", "p.v", "p.w"} + with poutine.trace(param_only=True) as param_capture: + model() + actual = { + name + for name, site in param_capture.trace.nodes.items() + if site["type"] == "param" + } + assert actual == expected + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == expected + + # Check pyro_parameters method + expected = {"x", "y", "m.u", "p.v", "p.w"} + actual = set(k for k, v in model.named_pyro_params()) + assert actual == expected def test_delete(): @@ -258,7 +353,8 @@ def test_constraints(shape, constraint_): assert not hasattr(module, "x_unconstrained") -def test_clear(): +@pytest.mark.parametrize("local_params", [True, False]) +def test_clear(local_params): class Model(PyroModule): def __init__(self): super().__init__() @@ -272,28 +368,43 @@ def __init__(self): def forward(self): return [x.clone() for x in [self.x, self.m.weight, self.m.bias, self.p.x]] - assert set(pyro.get_param_store().keys()) == set() - m = Model() - state0 = m() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - - # mutate - for x in pyro.get_param_store().values(): - x.unconstrained().data += torch.randn(()) - state1 = m() - for x, y in zip(state0, state1): - assert not (x == y).all() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - - clear(m) - del m - assert set(pyro.get_param_store().keys()) == set() - - m = Model() - state2 = m() - assert set(pyro.get_param_store().keys()) == {"x", "m$$$weight", "m$$$bias", "p.x"} - for actual, expected in zip(state2, state0): - assert_equal(actual, expected) + with pyro.settings.context(module_local_params=local_params): + m = Model() + state0 = m() + + # mutate + for _, x in m.named_pyro_params(): + x.unconstrained().data += torch.randn(()) + state1 = m() + for x, y in zip(state0, state1): + assert not (x == y).all() + + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == { + "x", + "m$$$weight", + "m$$$bias", + "p.x", + } + clear(m) + del m + assert set(pyro.get_param_store().keys()) == set() + + m = Model() + state2 = m() + if local_params: + assert set(pyro.get_param_store().keys()) == set() + else: + assert set(pyro.get_param_store().keys()) == { + "x", + "m$$$weight", + "m$$$bias", + "p.x", + } + for actual, expected in zip(state2, state0): + assert_equal(actual, expected) def test_sample(): @@ -532,48 +643,65 @@ def randomize(model): assert_identical(actual, expected) -def test_torch_serialize_attributes(): - module = PyroModule() - module.x = PyroParam(torch.tensor(1.234), constraints.positive) - module.y = nn.Parameter(torch.randn(3)) - assert isinstance(module.x, torch.Tensor) - - # Work around https://github.com/pytorch/pytorch/issues/27972 - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - f = io.BytesIO() - torch.save(module, f) - pyro.clear_param_store() - f.seek(0) - actual = torch.load(f) - - assert_equal(actual.x, module.x) - actual_names = {name for name, _ in actual.named_parameters()} - expected_names = {name for name, _ in module.named_parameters()} - assert actual_names == expected_names - - -def test_torch_serialize_decorators(): - module = DecoratorModel(3) - module() # initialize - - # Work around https://github.com/pytorch/pytorch/issues/27972 - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - f = io.BytesIO() - torch.save(module, f) - pyro.clear_param_store() - f.seek(0) - actual = torch.load(f) - - assert_equal(actual.x, module.x) - assert_equal(actual.y, module.y) - assert_equal(actual.z, module.z) - assert actual.s.shape == module.s.shape - assert actual.t.shape == module.t.shape - actual_names = {name for name, _ in actual.named_parameters()} - expected_names = {name for name, _ in module.named_parameters()} - assert actual_names == expected_names +@pytest.mark.parametrize("local_params", [True, False]) +def test_torch_serialize_attributes(local_params): + with pyro.settings.context(module_local_params=local_params): + module = PyroModule() + module.x = PyroParam(torch.tensor(1.234), constraints.positive) + module.y = nn.Parameter(torch.randn(3)) + assert isinstance(module.x, torch.Tensor) + + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(module, f) + pyro.clear_param_store() + f.seek(0) + actual = torch.load(f) + + assert_equal(actual.x, module.x) + actual_names = {name for name, _ in actual.named_parameters()} + expected_names = {name for name, _ in module.named_parameters()} + assert actual_names == expected_names + + +@pytest.mark.parametrize("local_params", [True, False]) +def test_torch_serialize_decorators(local_params): + with pyro.settings.context(module_local_params=local_params): + module = DecoratorModel(3) + module() # initialize + + module2 = DecoratorModel(3) + module2() # initialize + + # Work around https://github.com/pytorch/pytorch/issues/27972 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + f = io.BytesIO() + torch.save(module, f) + pyro.clear_param_store() + f.seek(0) + actual = torch.load(f) + + assert_equal(actual.x, module.x) + assert_equal(actual.y, module.y) + assert_equal(actual.z, module.z) + assert actual.s.shape == module.s.shape + assert actual.t.shape == module.t.shape + actual_names = {name for name, _ in actual.named_parameters()} + expected_names = {name for name, _ in module.named_parameters()} + assert actual_names == expected_names + + actual() + if local_params: + assert len(set(pyro.get_param_store().keys())) == 0 + assert torch.all(module.y != module2.y) + assert torch.all(actual.y != module2.y) + else: + assert len(set(pyro.get_param_store().keys())) > 0 + assert_equal(module.y, module2.y) + assert_equal(actual.y, module2.y) def test_pyro_serialize():