Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear .unconstrained weakrefs before pickling; rebuild them more often #3212

Merged
merged 3 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/contrib/easyguide/easyguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(self, guide, sites):
self.event_shape = torch.Size([sum(self._site_sizes.values())])

def __getstate__(self):
state = self.__dict__.copy()
state = getattr(super(), "__getstate__", self.__dict__.copy)()
state["_guide"] = state["_guide"]() # weakref -> ref
return state

Expand Down
20 changes: 20 additions & 0 deletions pyro/contrib/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import OrderedDict

import torch
Expand Down Expand Up @@ -75,3 +76,22 @@ def iter_plates_to_shape(shape):
# Go backwards (right to left)
for i, s in enumerate(shape[::-1]):
yield pyro.plate("plate_" + str(i), s)


def check_no_weakref(obj, path="", avoid_ids=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a helper to debug similar issues in the future.

"""Attempts to check that an object has no weakrefs."""
if avoid_ids is None:
avoid_ids = {id(obj)}

if isinstance(obj, weakref.ref):
raise ValueError(f"Weakref found at {path}")
elif isinstance(obj, dict):
for k, v in obj.items():
if id(v) not in avoid_ids:
check_no_weakref(v, path + f"[{k}]")
elif isinstance(obj, (list, tuple)):
for i, v in enumerate(obj):
if id(v) not in avoid_ids:
check_no_weakref(v, path + f"[{i}]")
elif hasattr(obj, "__dict__"):
check_no_weakref(obj.__dict__, path)
11 changes: 5 additions & 6 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@ def decorator(new_fn):
# TODO: Move upstream to allow for pickle serialization of transforms
@patch_dependency("torch.distributions.transforms.Transform.__getstate__")
def _Transform__getstate__(self):
attrs = {}
for k, v in self.__dict__.items():
super_ = super(torch.distributions.transforms.Transform, self)
state = getattr(super_, "__getstate__", self.__dict__.copy)()
for k, v in state.items():
if isinstance(v, weakref.ref):
attrs[k] = None
else:
attrs[k] = v
return attrs
state[k] = None
return state


# TODO move upstream
Expand Down
7 changes: 4 additions & 3 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class ELBO(object, metaclass=ABCMeta):
behave unexpectedly relative to standard PyTorch when working with
:class:`~pyro.nn.PyroModule` s.

Users are therefore strongly encouraged to use this interface in conjunction
with :func:`~pyro.enable_module_local_param` which will override the default
implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances.
Users are therefore strongly encouraged to use this interface in
conjunction with ``pyro.settings.set(module_local_params=True)`` which
will override the default implicit sharing of parameters across
:class:`~pyro.nn.PyroModule` instances.

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
Expand Down
11 changes: 10 additions & 1 deletion pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
import functools
import inspect
import weakref
from collections import OrderedDict, namedtuple

import torch
Expand Down Expand Up @@ -503,7 +504,9 @@ def __getattr__(self, name):
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
else: # Cannot determine supermodule and hence cannot compute fullname.
return transform_to(constraint)(unconstrained_value)
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return constrained_value

# PyroSample trigger pyro.sample statements.
if "_pyro_samples" in self.__dict__:
Expand Down Expand Up @@ -661,6 +664,12 @@ def __delattr__(self, name):

super().__delattr__(name)

def __getstate__(self):
# Remove weakrefs in preparation for pickling.
for param in self.parameters(recurse=True):
param.__dict__.pop("unconstrained", None)
return getattr(super(), "__getstate__", self.__dict__.copy)()


def pyro_method(fn):
"""
Expand Down
9 changes: 5 additions & 4 deletions pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,11 @@ def get_state(self) -> dict:
"""
Get the ParamStore state.
"""
state = {
"params": self._params.copy(),
"constraints": self._constraints.copy(),
}
params = self._params.copy()
# Remove weakrefs in preparation for pickling.
for param in params.values():
param.__dict__.pop("unconstrained", None)
state = {"params": params, "constraints": self._constraints.copy()}
return state

def set_state(self, state: dict):
Expand Down
6 changes: 6 additions & 0 deletions pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def __init__(self, model: Callable):
def model(self):
return self._model[0]

def __getstate__(self):
# Avoid pickling the trace.
state = super().__getstate__()
state.pop("trace")
return state

def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
"""
Draws posterior samples from the guide and replays the model against
Expand Down
4 changes: 3 additions & 1 deletion tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ def forward(self):

# mutate
for _, x in m.named_pyro_params():
x.unconstrained().data += torch.randn(())
if hasattr(x, "unconstrained"):
x = x.unconstrained()
x.data += torch.randn(x.shape)
state1 = m()
for x, y in zip(state0, state1):
assert not (x == y).all()
Expand Down