diff --git a/chirho/dynamical/handlers/trajectory.py b/chirho/dynamical/handlers/trajectory.py index ae7bae5cd..7775288d1 100644 --- a/chirho/dynamical/handlers/trajectory.py +++ b/chirho/dynamical/handlers/trajectory.py @@ -4,7 +4,7 @@ import pyro import torch -from chirho.dynamical.internals._utils import append +from chirho.dynamical.internals._utils import _trajectory_to_state, append from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory from chirho.dynamical.ops import Trajectory from chirho.indexed.ops import IndexSet, gather, get_index_plates @@ -67,4 +67,4 @@ def _pyro_post_simulate(self, msg) -> None: final_idx = IndexSet(**{idx_name: {len(timespan) - 1}}) final_state = gather(trajectory, final_idx, name_to_dim=name_to_dim) - msg["value"] = final_state.to_state() + msg["value"] = _trajectory_to_state(final_state) diff --git a/chirho/dynamical/internals/_utils.py b/chirho/dynamical/internals/_utils.py index da9842f01..47dd17411 100644 --- a/chirho/dynamical/internals/_utils.py +++ b/chirho/dynamical/internals/_utils.py @@ -80,3 +80,7 @@ def _append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor: @functools.lru_cache def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]: return tuple(sorted(varnames)) + + +def _trajectory_to_state(traj: Trajectory[T]) -> State[T]: + return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys}) diff --git a/chirho/dynamical/internals/backends/torchdiffeq.py b/chirho/dynamical/internals/backends/torchdiffeq.py index f9e9939d6..52f0b4be9 100644 --- a/chirho/dynamical/internals/backends/torchdiffeq.py +++ b/chirho/dynamical/internals/backends/torchdiffeq.py @@ -10,7 +10,7 @@ StaticInterruption, ) from chirho.dynamical.handlers.solver import TorchDiffEq -from chirho.dynamical.internals._utils import _var_order +from chirho.dynamical.internals._utils import _trajectory_to_state, _var_order from chirho.dynamical.internals.solver import ( get_next_interruptions_dynamic, simulate_point, @@ -126,7 +126,7 @@ def torchdiffeq_ode_simulate( final_idx = IndexSet(**{idx_name: {len(timespan) - 1}}) final_state_traj = gather(trajectory, final_idx, name_to_dim=name_to_dim) - final_state = final_state_traj.to_state() + final_state = _trajectory_to_state(final_state_traj) return final_state diff --git a/chirho/dynamical/ops.py b/chirho/dynamical/ops.py index f0dbbcf2a..d30c18965 100644 --- a/chirho/dynamical/ops.py +++ b/chirho/dynamical/ops.py @@ -38,18 +38,8 @@ def __getattr__(self, __name: str) -> T: raise AttributeError(f"{__name} not in {self.__dict__['_values']}") -class _Sliceable(Protocol[T_co]): - def squeeze(self, dim: int) -> "_Sliceable[T_co]": - ... - - -class Trajectory(Generic[T], State[_Sliceable[T]]): - def to_state(self) -> State[T]: - ret: State[T] = State( - # TODO support event_dim > 0 - **{k: getattr(self, k).squeeze(-1) for k in self.keys} - ) - return ret +class Trajectory(Generic[T], State[T]): + pass @typing.runtime_checkable