From 54e338b705b2b6632c314569a0ef1d3b866bf372 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 16 Jul 2024 13:08:18 -0400 Subject: [PATCH] Move Predictive handlers to observational module (#550) * Move cut posterior handlers to robust module * moving nmc and predictive functionality from robust to observational * format and lint * docstring ref * restore nmc * inline bind_leftmost_dim in nmc * format * revert format * revert format * Update utils.py --------- Co-authored-by: Sam Witty --- chirho/indexed/handlers.py | 24 +++ .../handlers/predictive.py | 93 ++++++++++- chirho/observational/internals.py | 89 ++++++++++ chirho/robust/internals/linearize.py | 2 +- chirho/robust/internals/nmc.py | 117 +++++--------- chirho/robust/internals/utils.py | 152 +----------------- chirho/robust/ops.py | 2 +- tests/robust/test_handlers/test_MC_EIF.py | 5 +- tests/robust/test_handlers/test_estimators.py | 5 +- tests/robust/test_internals_compositions.py | 12 +- tests/robust/test_internals_linearize.py | 2 +- tests/robust/test_performance.py | 6 +- 12 files changed, 261 insertions(+), 248 deletions(-) rename chirho/{robust => observational}/handlers/predictive.py (58%) diff --git a/chirho/indexed/handlers.py b/chirho/indexed/handlers.py index decfcdfec..2ab71a34c 100644 --- a/chirho/indexed/handlers.py +++ b/chirho/indexed/handlers.py @@ -4,6 +4,7 @@ import pyro import torch +from typing_extensions import ParamSpec from chirho.indexed.internals import ( _LazyPlateMessenger, @@ -12,6 +13,8 @@ ) from chirho.indexed.ops import union +P = ParamSpec("P") + class IndexPlatesMessenger(pyro.poutine.messenger.Messenger): plates: Dict[Hashable, pyro.poutine.indep_messenger.IndepMessenger] @@ -137,3 +140,24 @@ def _pyro_sample(self, msg: Dict[str, Any]) -> None: msg["fn"] = msg["fn"].expand( torch.broadcast_shapes(msg["fn"].batch_shape, mask.shape) ) + + +@pyro.poutine.block() +@pyro.validation_enabled(False) +@torch.no_grad() +def guess_max_plate_nesting( + model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs +) -> int: + """ + Guesses the maximum plate nesting level by running `pyro.infer.Trace_ELBO` + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + :type guide: Callable[P, Any] + :return: maximum plate nesting level + :rtype: int + """ + elbo = pyro.infer.Trace_ELBO() + elbo._guess_max_plate_nesting(model, guide, args, kwargs) + return elbo.max_plate_nesting diff --git a/chirho/robust/handlers/predictive.py b/chirho/observational/handlers/predictive.py similarity index 58% rename from chirho/robust/handlers/predictive.py rename to chirho/observational/handlers/predictive.py index f73c38bfd..7fd03383b 100644 --- a/chirho/robust/handlers/predictive.py +++ b/chirho/observational/handlers/predictive.py @@ -1,18 +1,101 @@ -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, Mapping, Optional, TypeVar import pyro import torch from typing_extensions import ParamSpec from chirho.indexed.handlers import IndexPlatesMessenger -from chirho.robust.internals.nmc import BatchedLatents -from chirho.robust.internals.utils import bind_leftmost_dim -from chirho.robust.ops import Point +from chirho.indexed.ops import indices_of +from chirho.observational.handlers.condition import Observations +from chirho.observational.internals import ( + bind_leftmost_dim, + site_is_delta, + unbind_leftmost_dim, +) +from chirho.observational.ops import Observation P = ParamSpec("P") S = TypeVar("S") T = TypeVar("T") +Point = Mapping[str, Observation[T]] + + +class BatchedLatents(pyro.poutine.messenger.Messenger): + """ + Effect handler that adds a fresh batch dimension to all latent ``sample`` sites. + Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery + in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension + based on the ``name`` argument to ``BatchedLatents`` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param int num_particles: Number of particles to use for parallelization. + :param str name: Name of the fresh batch dimension. + """ + + num_particles: int + name: str + + def __init__(self, num_particles: int, *, name: str = "__particles_mc"): + assert num_particles > 0 + assert len(name) > 0 + self.num_particles = num_particles + self.name = name + super().__init__() + + def _pyro_sample(self, msg: dict) -> None: + if ( + self.num_particles > 1 + and msg["value"] is None + and not pyro.poutine.util.site_is_factor(msg) + and not pyro.poutine.util.site_is_subsample(msg) + and not site_is_delta(msg) + and self.name not in indices_of(msg["fn"]) + ): + msg["fn"] = unbind_leftmost_dim( + msg["fn"].expand((1,) + msg["fn"].batch_shape), + self.name, + size=self.num_particles, + ) + + +class BatchedObservations(Generic[T], Observations[T]): + """ + Effect handler that takes a dictionary of observation values for ``sample`` sites + that are assumed to be batched along their leftmost dimension, adds a fresh named + dimension using the machinery in ``chirho.indexed``, and reshapes the observation + values so that the new ``chirho.observational.observe`` sites are batched along + the fresh named dimension. + + Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary + of values whose leftmost dimension is a batch dimension over independent samples. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param Point[T] data: Dictionary of observation values. + :param str name: Name of the fresh batch dimension. + """ + + name: str + + def __init__(self, data: Point[T], *, name: str = "__particles_data"): + assert len(name) > 0 + self.name = name + super().__init__(data) + + def _pyro_observe(self, msg: dict) -> None: + super()._pyro_observe(msg) + if msg["kwargs"]["name"] in self.data: + rv, obs = msg["args"] + event_dim = ( + len(rv.event_shape) + if hasattr(rv, "event_shape") + else msg["kwargs"].get("event_dim", 0) + ) + batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim) + msg["args"] = (rv, batch_obs) + class PredictiveModel(Generic[P, T], torch.nn.Module): """ @@ -76,7 +159,7 @@ class PredictiveFunctional(Generic[P, T], torch.nn.Module): the returned values are batched along their leftmost positional dimension. Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` - when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct + when :class:`~chirho.observational.handlers.predictive.PredictiveModel` is used to construct the ``model`` argument and infer the ``sample`` sites whose values should be returned, and uses :class:`~BatchedLatents` to parallelize over samples from the model. diff --git a/chirho/observational/internals.py b/chirho/observational/internals.py index d61a13b71..bf483ec5f 100644 --- a/chirho/observational/internals.py +++ b/chirho/observational/internals.py @@ -1,13 +1,18 @@ from __future__ import annotations +import functools from typing import Mapping, Optional, TypeVar import pyro import pyro.distributions import torch +from typing_extensions import ParamSpec +from chirho.indexed.handlers import add_indices +from chirho.indexed.ops import IndexSet, get_index_plates, indices_of from chirho.observational.ops import AtomicObservation, observe +P = ParamSpec("P") K = TypeVar("K") T = TypeVar("T") @@ -67,3 +72,87 @@ class ObserveNameMessenger(pyro.poutine.messenger.Messenger): def _pyro_observe(self, msg): if "name" not in msg["kwargs"]: msg["kwargs"]["name"] = msg["name"] + + +def site_is_delta(msg: dict) -> bool: + d = msg["fn"] + while hasattr(d, "base_dist"): + d = d.base_dist + return isinstance(d, pyro.distributions.Delta) + + +@functools.singledispatch +def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs): + """ + Helper function to move the leftmost dimension of a ``torch.Tensor`` + or ``pyro.distributions.Distribution`` or other batched value + into a fresh named dimension using the machinery in ``chirho.indexed`` , + allocating a new dimension with the given name if necessary + via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param v: Batched value. + :param name: Name of the fresh dimension. + :param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` . + """ + raise NotImplementedError + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_tensor( + v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0 +) -> torch.Tensor: + size = max(size, v.shape[0]) + v = v.expand((size,) + v.shape[1:]) + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.shape + while new_dim - event_dim < -len(v.shape): + v = v[None] + if v.shape[0] == 1 and orig_shape[0] != 1: + v = torch.transpose(v, -len(orig_shape), new_dim - event_dim) + return v + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_distribution( + v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs +) -> pyro.distributions.Distribution: + size = max(size, v.batch_shape[0]) + if v.batch_shape[0] != 1: + raise NotImplementedError("Cannot freely reshape distribution") + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.batch_shape + + new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:] + return v.expand(new_shape) + + +@functools.singledispatch +def bind_leftmost_dim(v, name: str, **kwargs): + """ + Helper function to move a named dimension managed by ``chirho.indexed`` + into a new unnamed dimension to the left of all named dimensions in the value. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + """ + raise NotImplementedError + + +@bind_leftmost_dim.register +def _bind_leftmost_dim_tensor( + v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs +) -> torch.Tensor: + if name not in indices_of(v, event_dim=event_dim): + return v + return torch.transpose( + v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim + ) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 29447c736..4ea1d51f0 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -276,7 +276,7 @@ def linearize( import pyro.distributions as dist import torch - from chirho.robust.handlers.predictive import PredictiveModel + from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import linearize pyro.settings.set(module_local_params=True) diff --git a/chirho/robust/internals/nmc.py b/chirho/robust/internals/nmc.py index d12468432..1b8b93e76 100644 --- a/chirho/robust/internals/nmc.py +++ b/chirho/robust/internals/nmc.py @@ -1,21 +1,15 @@ import collections import math import typing -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar import pyro import torch from typing_extensions import ParamSpec from chirho.indexed.handlers import IndexPlatesMessenger -from chirho.indexed.ops import get_index_plates, indices_of -from chirho.observational.handlers.condition import Observations -from chirho.robust.internals.utils import ( - bind_leftmost_dim, - get_importance_traces, - site_is_delta, - unbind_leftmost_dim, -) +from chirho.indexed.ops import get_index_plates +from chirho.observational.handlers.predictive import BatchedLatents, BatchedObservations from chirho.robust.ops import Point pyro.settings.set(module_local_params=True) @@ -26,80 +20,43 @@ T = TypeVar("T") -class BatchedLatents(pyro.poutine.messenger.Messenger): +def get_importance_traces( + model: Callable[P, Any], + guide: Optional[Callable[P, Any]] = None, +) -> Callable[P, Tuple[pyro.poutine.Trace, pyro.poutine.Trace]]: """ - Effect handler that adds a fresh batch dimension to all latent ``sample`` sites. - Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery - in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension - based on the ``name`` argument to ``BatchedLatents`` . - - .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - - :param int num_particles: Number of particles to use for parallelization. - :param str name: Name of the fresh batch dimension. + Thin functional wrapper around :func:`~pyro.infer.enum.get_importance_trace` + that cleans up the original interface to avoid unnecessary arguments + and efficiently supports using the prior in a model as a default guide. + + :param model: Model to run. + :param guide: Guide to run. If ``None``, use the prior in ``model`` as a guide. + :returns: A function that takes the same arguments as ``model`` and ``guide`` and returns + a tuple of importance traces ``(model_trace, guide_trace)``. """ - num_particles: int - name: str - - def __init__(self, num_particles: int, *, name: str = "__particles_mc"): - assert num_particles > 0 - assert len(name) > 0 - self.num_particles = num_particles - self.name = name - super().__init__() - - def _pyro_sample(self, msg: dict) -> None: - if ( - self.num_particles > 1 - and msg["value"] is None - and not pyro.poutine.util.site_is_factor(msg) - and not pyro.poutine.util.site_is_subsample(msg) - and not site_is_delta(msg) - and self.name not in indices_of(msg["fn"]) - ): - msg["fn"] = unbind_leftmost_dim( - msg["fn"].expand((1,) + msg["fn"].batch_shape), - self.name, - size=self.num_particles, + def _fn( + *args: P.args, **kwargs: P.kwargs + ) -> Tuple[pyro.poutine.Trace, pyro.poutine.Trace]: + if guide is not None: + model_trace, guide_trace = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, guide, args, kwargs + ) + return model_trace, guide_trace + else: # use prior as default guide, but don't run model twice + model_trace, _ = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, lambda *_, **__: None, args, kwargs ) + guide_trace = model_trace.copy() + for name, node in list(guide_trace.nodes.items()): + if node["type"] != "sample": + del model_trace.nodes[name] + elif pyro.poutine.util.site_is_factor(node) or node["is_observed"]: + del guide_trace.nodes[name] + return model_trace, guide_trace -class BatchedObservations(Generic[T], Observations[T]): - """ - Effect handler that takes a dictionary of observation values for ``sample`` sites - that are assumed to be batched along their leftmost dimension, adds a fresh named - dimension using the machinery in ``chirho.indexed``, and reshapes the observation - values so that the new ``chirho.observational.observe`` sites are batched along - the fresh named dimension. - - Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary - of values whose leftmost dimension is a batch dimension over independent samples. - - .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - - :param Point[T] data: Dictionary of observation values. - :param str name: Name of the fresh batch dimension. - """ - - name: str - - def __init__(self, data: Point[T], *, name: str = "__particles_data"): - assert len(name) > 0 - self.name = name - super().__init__(data) - - def _pyro_observe(self, msg: dict) -> None: - super()._pyro_observe(msg) - if msg["kwargs"]["name"] in self.data: - rv, obs = msg["args"] - event_dim = ( - len(rv.event_shape) - if hasattr(rv, "event_shape") - else msg["kwargs"].get("event_dim", 0) - ) - batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim) - msg["args"] = (rv, batch_obs) + return _fn class BatchedNMCLogMarginalLikelihood(Generic[P, T], torch.nn.Module): @@ -205,7 +162,11 @@ def forward( # move data plate dimension to the left for name in reversed(plate_name_to_dim.keys()): - log_weights = bind_leftmost_dim(log_weights, name) + log_weights = torch.transpose( + log_weights[None], + -len(log_weights.shape) - 1, + plate_name_to_dim[name].dim, + ) # pack log_weights by squeezing out rightmost dimensions for _ in range(len(log_weights.shape) - len(plate_name_to_dim)): diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index e8934fb7d..39c40d0fc 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,8 +1,6 @@ import contextlib -import functools -import math from math import prod -from typing import Any, Callable, List, Mapping, Optional, Tuple, TypeVar +from typing import Callable, List, Mapping, Optional, Tuple, TypeVar import pyro import torch @@ -16,9 +14,6 @@ ) from typing_extensions import Concatenate, ParamSpec -from chirho.indexed.handlers import add_indices -from chirho.indexed.ops import IndexSet, get_index_plates, indices_of - P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") @@ -153,7 +148,6 @@ def recurse_to_flattened_sub_tspec( for flat_jac_output_subtree in recurse_to_flattened_sub_tspec( pytree=jac, sub_tspec=param_tspec ): - flat_sub_out: List[torch.Tensor] = [] # Then map that subtree (with tree structure matching that of params) onto the params and batched_vector. @@ -224,27 +218,6 @@ def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: return param_dict, mod_func -@pyro.poutine.block() -@pyro.validation_enabled(False) -@torch.no_grad() -def guess_max_plate_nesting( - model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs -) -> int: - """ - Guesses the maximum plate nesting level by running `pyro.infer.Trace_ELBO` - - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - :type guide: Callable[P, Any] - :return: maximum plate nesting level - :rtype: int - """ - elbo = pyro.infer.Trace_ELBO() - elbo._guess_max_plate_nesting(model, guide, args, kwargs) - return elbo.max_plate_nesting - - @contextlib.contextmanager def reset_rng_state(rng_state: T): """ @@ -255,126 +228,3 @@ def reset_rng_state(rng_state: T): yield pyro.util.set_rng_state(rng_state) finally: pyro.util.set_rng_state(prev_rng_state) - - -@functools.singledispatch -def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs): - """ - Helper function to move the leftmost dimension of a ``torch.Tensor`` - or ``pyro.distributions.Distribution`` or other batched value - into a fresh named dimension using the machinery in ``chirho.indexed`` , - allocating a new dimension with the given name if necessary - via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - - .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - - :param v: Batched value. - :param name: Name of the fresh dimension. - :param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` . - """ - raise NotImplementedError - - -@unbind_leftmost_dim.register -def _unbind_leftmost_dim_tensor( - v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0 -) -> torch.Tensor: - size = max(size, v.shape[0]) - v = v.expand((size,) + v.shape[1:]) - - if name not in get_index_plates(): - add_indices(IndexSet(**{name: set(range(size))})) - - new_dim: int = get_index_plates()[name].dim - orig_shape = v.shape - while new_dim - event_dim < -len(v.shape): - v = v[None] - if v.shape[0] == 1 and orig_shape[0] != 1: - v = torch.transpose(v, -len(orig_shape), new_dim - event_dim) - return v - - -@unbind_leftmost_dim.register -def _unbind_leftmost_dim_distribution( - v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs -) -> pyro.distributions.Distribution: - size = max(size, v.batch_shape[0]) - if v.batch_shape[0] != 1: - raise NotImplementedError("Cannot freely reshape distribution") - - if name not in get_index_plates(): - add_indices(IndexSet(**{name: set(range(size))})) - - new_dim: int = get_index_plates()[name].dim - orig_shape = v.batch_shape - - new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:] - return v.expand(new_shape) - - -@functools.singledispatch -def bind_leftmost_dim(v, name: str, **kwargs): - """ - Helper function to move a named dimension managed by ``chirho.indexed`` - into a new unnamed dimension to the left of all named dimensions in the value. - - .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - """ - raise NotImplementedError - - -@bind_leftmost_dim.register -def _bind_leftmost_dim_tensor( - v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs -) -> torch.Tensor: - if name not in indices_of(v, event_dim=event_dim): - return v - return torch.transpose( - v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim - ) - - -def get_importance_traces( - model: Callable[P, Any], - guide: Optional[Callable[P, Any]] = None, -) -> Callable[P, Tuple[pyro.poutine.Trace, pyro.poutine.Trace]]: - """ - Thin functional wrapper around :func:`~pyro.infer.enum.get_importance_trace` - that cleans up the original interface to avoid unnecessary arguments - and efficiently supports using the prior in a model as a default guide. - - :param model: Model to run. - :param guide: Guide to run. If ``None``, use the prior in ``model`` as a guide. - :returns: A function that takes the same arguments as ``model`` and ``guide`` and returns - a tuple of importance traces ``(model_trace, guide_trace)``. - """ - - def _fn( - *args: P.args, **kwargs: P.kwargs - ) -> Tuple[pyro.poutine.Trace, pyro.poutine.Trace]: - if guide is not None: - model_trace, guide_trace = pyro.infer.enum.get_importance_trace( - "flat", math.inf, model, guide, args, kwargs - ) - return model_trace, guide_trace - else: # use prior as default guide, but don't run model twice - model_trace, _ = pyro.infer.enum.get_importance_trace( - "flat", math.inf, model, lambda *_, **__: None, args, kwargs - ) - - guide_trace = model_trace.copy() - for name, node in list(guide_trace.nodes.items()): - if node["type"] != "sample": - del model_trace.nodes[name] - elif pyro.poutine.util.site_is_factor(node) or node["is_observed"]: - del guide_trace.nodes[name] - return model_trace, guide_trace - - return _fn - - -def site_is_delta(msg: dict) -> bool: - d = msg["fn"] - while hasattr(d, "base_dist"): - d = d.base_dist - return isinstance(d, pyro.distributions.Delta) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 6416b07fb..6b82243c1 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -41,7 +41,7 @@ def influence_fn( import pyro.distributions as dist import torch - from chirho.robust.handlers.predictive import PredictiveModel + from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator from chirho.robust.ops import influence_fn diff --git a/tests/robust/test_handlers/test_MC_EIF.py b/tests/robust/test_handlers/test_MC_EIF.py index e550df1cd..e8a05b771 100644 --- a/tests/robust/test_handlers/test_MC_EIF.py +++ b/tests/robust/test_handlers/test_MC_EIF.py @@ -7,8 +7,11 @@ import torch from typing_extensions import ParamSpec +from chirho.observational.handlers.predictive import ( + PredictiveFunctional, + PredictiveModel, +) from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator -from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from chirho.robust.ops import influence_fn from ..robust_fixtures import SimpleGuide, SimpleModel diff --git a/tests/robust/test_handlers/test_estimators.py b/tests/robust/test_handlers/test_estimators.py index 8d20949d6..9ed671799 100644 --- a/tests/robust/test_handlers/test_estimators.py +++ b/tests/robust/test_handlers/test_estimators.py @@ -6,11 +6,14 @@ import torch from typing_extensions import ParamSpec +from chirho.observational.handlers.predictive import ( + PredictiveFunctional, + PredictiveModel, +) from chirho.robust.handlers.estimators import ( MonteCarloInfluenceEstimator, one_step_corrected_estimator, ) -from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from ..robust_fixtures import SimpleGuide, SimpleModel diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index ade682726..32c957d5a 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -7,16 +7,16 @@ from chirho.indexed.handlers import IndexPlatesMessenger from chirho.indexed.ops import indices_of -from chirho.robust.handlers.predictive import PredictiveModel +from chirho.observational.handlers.predictive import ( + BatchedLatents, + BatchedObservations, + PredictiveModel, +) from chirho.robust.internals.linearize import ( conjugate_gradient_solve, make_empirical_fisher_vp, ) -from chirho.robust.internals.nmc import ( - BatchedLatents, - BatchedNMCLogMarginalLikelihood, - BatchedObservations, -) +from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood from chirho.robust.internals.utils import make_functional_call, reset_rng_state from .robust_fixtures import SimpleGuide, SimpleModel diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 5a887efe0..145ac9f79 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -8,7 +8,7 @@ from pyro.infer.predictive import Predictive from typing_extensions import ParamSpec -from chirho.robust.handlers.predictive import PredictiveModel +from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import ( conjugate_gradient_solve, linearize, diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py index 34d5e4d02..ff04ca44a 100644 --- a/tests/robust/test_performance.py +++ b/tests/robust/test_performance.py @@ -9,12 +9,12 @@ import torch from typing_extensions import ParamSpec -from chirho.indexed.handlers import DependentMaskMessenger +from chirho.indexed.handlers import DependentMaskMessenger, guess_max_plate_nesting from chirho.observational.handlers import condition -from chirho.robust.handlers.predictive import PredictiveModel +from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import make_empirical_fisher_vp from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood -from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call +from chirho.robust.internals.utils import make_functional_call from chirho.robust.ops import Point from .robust_fixtures import SimpleGuide, SimpleModel