Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Trajectory.to_state method #326

Merged
merged 10 commits into from
Oct 16, 2023
4 changes: 2 additions & 2 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pyro
import torch

from chirho.dynamical.internals._utils import append
from chirho.dynamical.internals._utils import _trajectory_to_state, append
from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory
from chirho.dynamical.ops import Trajectory
from chirho.indexed.ops import IndexSet, gather, get_index_plates
Expand Down Expand Up @@ -67,4 +67,4 @@ def _pyro_post_simulate(self, msg) -> None:

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
final_state = gather(trajectory, final_idx, name_to_dim=name_to_dim)
msg["value"] = final_state.to_state()
msg["value"] = _trajectory_to_state(final_state)
4 changes: 4 additions & 0 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ def _append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor:
@functools.lru_cache
def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]:
return tuple(sorted(varnames))


def _trajectory_to_state(traj: Trajectory[T]) -> State[T]:
return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys})
4 changes: 2 additions & 2 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StaticInterruption,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.internals._utils import _var_order
from chirho.dynamical.internals._utils import _trajectory_to_state, _var_order
from chirho.dynamical.internals.solver import (
get_next_interruptions_dynamic,
simulate_point,
Expand Down Expand Up @@ -126,7 +126,7 @@ def torchdiffeq_ode_simulate(

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
final_state_traj = gather(trajectory, final_idx, name_to_dim=name_to_dim)
final_state = final_state_traj.to_state()
final_state = _trajectory_to_state(final_state_traj)
return final_state


Expand Down
14 changes: 2 additions & 12 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,8 @@ def __getattr__(self, __name: str) -> T:
raise AttributeError(f"{__name} not in {self.__dict__['_values']}")


class _Sliceable(Protocol[T_co]):
def squeeze(self, dim: int) -> "_Sliceable[T_co]":
...


class Trajectory(Generic[T], State[_Sliceable[T]]):
def to_state(self) -> State[T]:
ret: State[T] = State(
# TODO support event_dim > 0
**{k: getattr(self, k).squeeze(-1) for k in self.keys}
)
return ret
class Trajectory(Generic[T], State[T]):
pass


@typing.runtime_checkable
Expand Down
Loading