Skip to content

Commit

Permalink
Move Predictive handlers to observational module (#550)
Browse files Browse the repository at this point in the history
* Move cut posterior handlers to robust module

* moving nmc and predictive functionality from robust to observational

* format and lint

* docstring ref

* restore nmc

* inline bind_leftmost_dim in nmc

* format

* revert format

* revert format

* Update utils.py

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>
  • Loading branch information
eb8680 and SamWitty authored Jul 16, 2024
1 parent 8ba2f2c commit 54e338b
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 248 deletions.
24 changes: 24 additions & 0 deletions chirho/indexed/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyro
import torch
from typing_extensions import ParamSpec

from chirho.indexed.internals import (
_LazyPlateMessenger,
Expand All @@ -12,6 +13,8 @@
)
from chirho.indexed.ops import union

P = ParamSpec("P")


class IndexPlatesMessenger(pyro.poutine.messenger.Messenger):
plates: Dict[Hashable, pyro.poutine.indep_messenger.IndepMessenger]
Expand Down Expand Up @@ -137,3 +140,24 @@ def _pyro_sample(self, msg: Dict[str, Any]) -> None:
msg["fn"] = msg["fn"].expand(
torch.broadcast_shapes(msg["fn"].batch_shape, mask.shape)
)


@pyro.poutine.block()
@pyro.validation_enabled(False)
@torch.no_grad()
def guess_max_plate_nesting(
model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> int:
"""
Guesses the maximum plate nesting level by running `pyro.infer.Trace_ELBO`
:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
:type guide: Callable[P, Any]
:return: maximum plate nesting level
:rtype: int
"""
elbo = pyro.infer.Trace_ELBO()
elbo._guess_max_plate_nesting(model, guide, args, kwargs)
return elbo.max_plate_nesting
Original file line number Diff line number Diff line change
@@ -1,18 +1,101 @@
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Any, Callable, Generic, Mapping, Optional, TypeVar

import pyro
import torch
from typing_extensions import ParamSpec

from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.robust.internals.nmc import BatchedLatents
from chirho.robust.internals.utils import bind_leftmost_dim
from chirho.robust.ops import Point
from chirho.indexed.ops import indices_of
from chirho.observational.handlers.condition import Observations
from chirho.observational.internals import (
bind_leftmost_dim,
site_is_delta,
unbind_leftmost_dim,
)
from chirho.observational.ops import Observation

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")

Point = Mapping[str, Observation[T]]


class BatchedLatents(pyro.poutine.messenger.Messenger):
"""
Effect handler that adds a fresh batch dimension to all latent ``sample`` sites.
Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery
in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension
based on the ``name`` argument to ``BatchedLatents`` .
.. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
:param int num_particles: Number of particles to use for parallelization.
:param str name: Name of the fresh batch dimension.
"""

num_particles: int
name: str

def __init__(self, num_particles: int, *, name: str = "__particles_mc"):
assert num_particles > 0
assert len(name) > 0
self.num_particles = num_particles
self.name = name
super().__init__()

def _pyro_sample(self, msg: dict) -> None:
if (
self.num_particles > 1
and msg["value"] is None
and not pyro.poutine.util.site_is_factor(msg)
and not pyro.poutine.util.site_is_subsample(msg)
and not site_is_delta(msg)
and self.name not in indices_of(msg["fn"])
):
msg["fn"] = unbind_leftmost_dim(
msg["fn"].expand((1,) + msg["fn"].batch_shape),
self.name,
size=self.num_particles,
)


class BatchedObservations(Generic[T], Observations[T]):
"""
Effect handler that takes a dictionary of observation values for ``sample`` sites
that are assumed to be batched along their leftmost dimension, adds a fresh named
dimension using the machinery in ``chirho.indexed``, and reshapes the observation
values so that the new ``chirho.observational.observe`` sites are batched along
the fresh named dimension.
Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary
of values whose leftmost dimension is a batch dimension over independent samples.
.. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
:param Point[T] data: Dictionary of observation values.
:param str name: Name of the fresh batch dimension.
"""

name: str

def __init__(self, data: Point[T], *, name: str = "__particles_data"):
assert len(name) > 0
self.name = name
super().__init__(data)

def _pyro_observe(self, msg: dict) -> None:
super()._pyro_observe(msg)
if msg["kwargs"]["name"] in self.data:
rv, obs = msg["args"]
event_dim = (
len(rv.event_shape)
if hasattr(rv, "event_shape")
else msg["kwargs"].get("event_dim", 0)
)
batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim)
msg["args"] = (rv, batch_obs)


class PredictiveModel(Generic[P, T], torch.nn.Module):
"""
Expand Down Expand Up @@ -76,7 +159,7 @@ class PredictiveFunctional(Generic[P, T], torch.nn.Module):
the returned values are batched along their leftmost positional dimension.
Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)``
when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct
when :class:`~chirho.observational.handlers.predictive.PredictiveModel` is used to construct
the ``model`` argument and infer the ``sample`` sites whose values should be returned,
and uses :class:`~BatchedLatents` to parallelize over samples from the model.
Expand Down
89 changes: 89 additions & 0 deletions chirho/observational/internals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import functools
from typing import Mapping, Optional, TypeVar

import pyro
import pyro.distributions
import torch
from typing_extensions import ParamSpec

from chirho.indexed.handlers import add_indices
from chirho.indexed.ops import IndexSet, get_index_plates, indices_of
from chirho.observational.ops import AtomicObservation, observe

P = ParamSpec("P")
K = TypeVar("K")
T = TypeVar("T")

Expand Down Expand Up @@ -67,3 +72,87 @@ class ObserveNameMessenger(pyro.poutine.messenger.Messenger):
def _pyro_observe(self, msg):
if "name" not in msg["kwargs"]:
msg["kwargs"]["name"] = msg["name"]


def site_is_delta(msg: dict) -> bool:
d = msg["fn"]
while hasattr(d, "base_dist"):
d = d.base_dist
return isinstance(d, pyro.distributions.Delta)


@functools.singledispatch
def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs):
"""
Helper function to move the leftmost dimension of a ``torch.Tensor``
or ``pyro.distributions.Distribution`` or other batched value
into a fresh named dimension using the machinery in ``chirho.indexed`` ,
allocating a new dimension with the given name if necessary
via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
.. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
:param v: Batched value.
:param name: Name of the fresh dimension.
:param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` .
"""
raise NotImplementedError


@unbind_leftmost_dim.register
def _unbind_leftmost_dim_tensor(
v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0
) -> torch.Tensor:
size = max(size, v.shape[0])
v = v.expand((size,) + v.shape[1:])

if name not in get_index_plates():
add_indices(IndexSet(**{name: set(range(size))}))

new_dim: int = get_index_plates()[name].dim
orig_shape = v.shape
while new_dim - event_dim < -len(v.shape):
v = v[None]
if v.shape[0] == 1 and orig_shape[0] != 1:
v = torch.transpose(v, -len(orig_shape), new_dim - event_dim)
return v


@unbind_leftmost_dim.register
def _unbind_leftmost_dim_distribution(
v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs
) -> pyro.distributions.Distribution:
size = max(size, v.batch_shape[0])
if v.batch_shape[0] != 1:
raise NotImplementedError("Cannot freely reshape distribution")

if name not in get_index_plates():
add_indices(IndexSet(**{name: set(range(size))}))

new_dim: int = get_index_plates()[name].dim
orig_shape = v.batch_shape

new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:]
return v.expand(new_shape)


@functools.singledispatch
def bind_leftmost_dim(v, name: str, **kwargs):
"""
Helper function to move a named dimension managed by ``chirho.indexed``
into a new unnamed dimension to the left of all named dimensions in the value.
.. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
"""
raise NotImplementedError


@bind_leftmost_dim.register
def _bind_leftmost_dim_tensor(
v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs
) -> torch.Tensor:
if name not in indices_of(v, event_dim=event_dim):
return v
return torch.transpose(
v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim
)
2 changes: 1 addition & 1 deletion chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def linearize(
import pyro.distributions as dist
import torch
from chirho.robust.handlers.predictive import PredictiveModel
from chirho.observational.handlers.predictive import PredictiveModel
from chirho.robust.internals.linearize import linearize
pyro.settings.set(module_local_params=True)
Expand Down
Loading

0 comments on commit 54e338b

Please sign in to comment.