Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AutoGaussian and tests #2948

Merged
merged 20 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyro/distributions/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, log_factor, *, has_rsample=None, validate_args=None):
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(Unit, _instance)
new.log_factor = self.log_factor.expand(batch_shape)
if "has_rsample" in self.__dict__:
Expand Down
91 changes: 69 additions & 22 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from types import SimpleNamespace
Expand All @@ -14,7 +15,7 @@
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.infer.inspect import get_dependencies
from pyro.infer.inspect import get_dependencies, is_sample_site
from pyro.nn.module import PyroModule, PyroParam
from pyro.poutine.runtime import am_i_wrapped, get_plates
from pyro.poutine.util import site_is_subsample
Expand All @@ -30,7 +31,7 @@
# AutoGaussianDense(model)
# The intent is to avoid proliferation of subclasses and docstrings,
# and provide a single interface AutoGaussian(...).
class AutoGaussianMeta(type(AutoGuide)):
class AutoGaussianMeta(type(AutoGuide), ABCMeta):
backends = {}
default_backend = "dense"

Expand All @@ -41,8 +42,9 @@ def __init__(cls, *args, **kwargs):
cls.backends[key] = cls

def __call__(cls, *args, **kwargs):
backend = kwargs.pop("backend", cls.default_backend)
cls = cls.backends[backend]
if cls is AutoGaussian:
backend = kwargs.pop("backend", cls.default_backend)
cls = cls.backends[backend]
return super(AutoGaussianMeta, cls).__call__(*args, **kwargs)


Expand Down Expand Up @@ -117,6 +119,12 @@ def __init__(
model = InitMessenger(init_loc_fn)(model)
super().__init__(model)

@staticmethod
def _prototype_hide_fn(msg):
# In contrast to the AutoGuide base class, this includes observation
# sites and excludes deterministic sites.
return not is_sample_site(msg)

def _setup_prototype(self, *args, **kwargs) -> None:
super()._setup_prototype(*args, **kwargs)

Expand All @@ -135,6 +143,12 @@ def _setup_prototype(self, *args, **kwargs) -> None:
"prior_dependencies"
]

# Eliminate observations with no upstream latents.
for d, upstreams in list(self.dependencies.items()):
if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams):
del self.dependencies[d]
del self.prototype_trace.nodes[d]

# Collect factors and plates.
for d, site in self.prototype_trace.nodes.items():
# Prune non-essential parts of the trace to save memory.
Expand All @@ -153,14 +167,19 @@ def _setup_prototype(self, *args, **kwargs) -> None:
"Are you missing a pyro.plate() or .to_event()?"
)
if site["is_observed"]:
# Eagerly eliminate irrelevant observation plates.
plates &= frozenset.union(
# Break irrelevant observation plates.
plates &= frozenset().union(
*(self._plates[u] for u in self.dependencies[d] if u != d)
)
self._plates[d] = plates

# Create location-scale parameters, one per latent variable.
if site["is_observed"]:
# This may slightly overestimate, e.g. for Multinomial.
self._event_numel[d] = site["fn"].event_shape.numel()
# Account for broken irrelevant observation plates.
for f in set(site["cond_indep_stack"]) - plates:
self._event_numel[d] *= f.size
continue
with helpful_support_errors(site):
init_loc = biject_to(site["fn"].support).inv(site["value"]).detach()
Expand All @@ -184,16 +203,23 @@ def _setup_prototype(self, *args, **kwargs) -> None:
for d, site in self._factors.items():
u_size = 0
for u in self.dependencies[d]:
broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
u_size += broken_shape.numel() * self._event_numel[u]
if not self._factors[u]["is_observed"]:
broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
u_size += broken_shape.numel() * self._event_numel[u]
d_size = self._event_numel[d]
if site["is_observed"]:
d_size = min(d_size, u_size) # just an optimization
batch_shape = _plates_to_shape(self._plates[d])

# Create a square root parameter (full, not lower triangular).
sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size))
if d in self.dependencies[d]:
# We initialize with noise to avoid singular gradient.
sqrt = torch.rand(
batch_shape + (u_size, d_size),
dtype=init_loc.dtype,
device=init_loc.device,
)
sqrt.sub_(0.5).mul_(self._init_scale)
if not site["is_observed"]:
# Initialize the [d,d] block to the identity matrix.
sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2))
Expand Down Expand Up @@ -223,6 +249,8 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
# Replay via Pyro primitives.
plates = self._create_plates(*args, **kwargs)
for name, site in self._factors.items():
if site["is_observed"]:
continue
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
stack.enter_context(plates[frame.name])
Expand Down Expand Up @@ -253,6 +281,8 @@ def _transform_values(
log_densities = defaultdict(float)
compute_density = am_i_wrapped() and poutine.get_mask() is not False
for name, site in self._factors.items():
if site["is_observed"]:
continue
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
unconstrained = aux_values[name] * scale + loc
Expand All @@ -268,6 +298,7 @@ def _transform_values(

return values, log_densities

@abstractmethod
def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError

Expand Down Expand Up @@ -305,15 +336,18 @@ def _setup_prototype(self, *args, **kwargs):
index = torch.zeros(precision_shape, dtype=torch.long)

# Collect local offsets.
upstreams = [
u for u in self.dependencies[d] if not self._factors[u]["is_observed"]
]
local_offsets = {}
pos = 0
for u in self.dependencies[d]:
for u in upstreams:
local_offsets[u] = pos
broken_plates = self._plates[u] - self._plates[d]
pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel()

# Create indices blockwise.
for u, v in itertools.product(self.dependencies[d], self.dependencies[d]):
for u, v in itertools.product(upstreams, upstreams):
u_index = global_indices[u]
v_index = global_indices[v]

Expand All @@ -333,17 +367,17 @@ def _setup_prototype(self, *args, **kwargs):
self._dense_scatter[d] = index.reshape(-1)

def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
# Sample from a dense joint Gaussian over flattened variables.
precision = self._get_precision()
loc = precision.new_zeros(self._dense_size)
flat_samples = pyro.sample(
f"_{self._pyro_name}",
dist.MultivariateNormal(loc, precision_matrix=precision),
f"_{self._pyro_name}_latent",
self._dense_get_mvn(),
infer={"is_auxiliary": True},
)
sample_shape = flat_samples.shape[:-1]
samples = self._dense_unflatten(flat_samples)
return samples

# Convert flat to shaped tensors.
def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, torch.Tensor]:
# Convert a single flattened sample to a dict of shaped samples.
sample_shape = flat_samples.shape[:-1]
samples = {}
pos = 0
for d, (batch_shape, event_shape) in self._dense_shapes.items():
Expand All @@ -356,14 +390,25 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
)
return samples

def _get_precision(self):
def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
# Convert a dict of shaped samples single flattened sample.
flat_samples = []
for d, (batch_shape, event_shape) in self._dense_shapes.items():
shape = samples[d].shape
sample_shape = shape[: len(shape) - len(batch_shape) - len(event_shape)]
flat_samples.append(samples[d].reshape(sample_shape + (-1,)))
return torch.cat(flat_samples, dim=-1)

def _dense_get_mvn(self):
# Create a dense joint Gaussian over flattened variables.
flat_precision = torch.zeros(self._dense_size ** 2)
for d, index in self._dense_scatter.items():
sqrt = deep_getattr(self.factors, d)
precision = sqrt @ sqrt.transpose(-1, -2)
flat_precision.scatter_add_(0, index, precision.reshape(-1))
precision = flat_precision.reshape(self._dense_size, self._dense_size)
return precision
loc = precision.new_zeros(self._dense_size)
return dist.MultivariateNormal(loc, precision_matrix=precision)


class AutoGaussianFunsor(AutoGaussian):
Expand Down Expand Up @@ -403,11 +448,13 @@ def _setup_prototype(self, *args, **kwargs):
plate_to_dim: Dict[str, int] = {}
for d, site in self._factors.items():
inputs = OrderedDict()
for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim):
for f in sorted(self._plates[d], key=lambda f: f.dim):
plate_to_dim[f.name] = f.dim
inputs[f.name] = funsor.Bint[f.size]
eliminate.add(f.name)
for u in self.dependencies[d]:
if self._factors[u]["is_observed"]:
continue
inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]]
eliminate.add(u)
factor_inputs[d] = inputs
Expand Down
10 changes: 5 additions & 5 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def _create_plates(self, *args, **kwargs):
self.plates = self.master().plates
return self.plates

_prototype_hide_fn = staticmethod(prototype_hide_fn)

def _setup_prototype(self, *args, **kwargs):
# run the model so we can inspect its structure
model = poutine.block(self.model, prototype_hide_fn)
model = poutine.block(self.model, self._prototype_hide_fn)
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
*args, **kwargs
)
Expand Down Expand Up @@ -193,9 +195,7 @@ class AutoGuideList(AutoGuide, nn.ModuleList):
"""

def _check_prototype(self, part_trace):
for name, part_site in part_trace.nodes.items():
if part_site["type"] != "sample":
continue
for name, part_site in part_trace.iter_stochastic_nodes():
self_site = self.prototype_trace.nodes[name]
assert part_site["fn"].batch_shape == self_site["fn"].batch_shape
assert part_site["fn"].event_shape == self_site["fn"].event_shape
Expand Down Expand Up @@ -1187,7 +1187,7 @@ class AutoDiscreteParallel(AutoGuide):

def _setup_prototype(self, *args, **kwargs):
# run the model so we can inspect its structure
model = poutine.block(config_enumerate(self.model), prototype_hide_fn)
model = poutine.block(config_enumerate(self.model), self._prototype_hide_fn)
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
*args, **kwargs
)
Expand Down
Loading