diff --git a/chirho/dynamical/handlers/event_loop.py b/chirho/dynamical/handlers/event_loop.py index d87845ce2..c698b5507 100644 --- a/chirho/dynamical/handlers/event_loop.py +++ b/chirho/dynamical/handlers/event_loop.py @@ -7,6 +7,7 @@ from chirho.dynamical.handlers.interruption import Interruption from chirho.dynamical.internals.solver import ( apply_interruptions, + get_next_interruptions, get_solver, simulate_to_interruption, ) @@ -29,13 +30,18 @@ def _pyro_simulate(self, msg) -> None: with pyro.poutine.messenger.block_messengers( lambda m: m is self or (isinstance(m, Interruption) and m.used) ): - state, terminal_interruptions, start_time = simulate_to_interruption( + terminal_interruptions, interruption_time = get_next_interruptions( + solver, dynamics, state, start_time, end_time + ) + + state = simulate_to_interruption( solver, dynamics, state, start_time, - end_time, + interruption_time, ) + start_time = interruption_time for h in terminal_interruptions: h.used = True @@ -46,6 +52,4 @@ def _pyro_simulate(self, msg) -> None: dynamics, state = apply_interruptions(dynamics, state) msg["value"] = state - msg["stop"] = True msg["done"] = True - msg["in_SEL"] = True diff --git a/chirho/dynamical/handlers/interruption.py b/chirho/dynamical/handlers/interruption.py index dd807510e..a467ca0e0 100644 --- a/chirho/dynamical/handlers/interruption.py +++ b/chirho/dynamical/handlers/interruption.py @@ -7,7 +7,6 @@ from chirho.dynamical.handlers.trajectory import LogTrajectory from chirho.dynamical.ops import State -from chirho.indexed.ops import get_index_plates, indices_of from chirho.interventional.ops import Intervention, intervene from chirho.observational.ops import Observation, observe @@ -23,7 +22,7 @@ def __enter__(self): self.used = False return super().__enter__() - def _pyro_simulate_to_interruption(self, msg) -> None: + def _pyro_get_next_interruptions(self, msg) -> None: raise NotImplementedError("shouldn't be here!") @@ -34,7 +33,7 @@ def __init__(self, time: R): self.time = torch.as_tensor(time) # TODO enforce this where it is needed super().__init__() - def _pyro_simulate_to_interruption(self, msg) -> None: + def _pyro_get_next_interruptions(self, msg) -> None: _, _, _, start_time, end_time = msg["args"] if start_time < self.time < end_time: @@ -67,7 +66,7 @@ def __init__(self, event_f: Callable[[R, State[T]], R]): self.event_f = event_f super().__init__() - def _pyro_simulate_to_interruption(self, msg) -> None: + def _pyro_get_next_interruptions(self, msg) -> None: msg["kwargs"].setdefault("dynamic_interruptions", []).append(self) @@ -157,21 +156,4 @@ def __init__( super().__init__(times, eps=eps) def _pyro_post_simulate(self, msg) -> None: - super()._pyro_post_simulate(msg) - - # This checks whether the simulate has already redirected in a InterruptionEventLoop. - # If so, we don't want to run the observation again. - if msg.setdefault("in_SEL", False): - return - - # TODO remove this redundant check by fixing semantics of LogTrajectory and simulate - name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()} - name_to_dim["__time"] = -1 - len_traj = ( - 0 - if len(self.trajectory.keys()) == 0 - else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"]) - ) - - if len_traj == len(self.times): - msg["value"] = observe(self.trajectory, self.observation) + self.trajectory = observe(self.trajectory, self.observation) diff --git a/chirho/dynamical/handlers/trajectory.py b/chirho/dynamical/handlers/trajectory.py index 888167b9a..cf3e0aab1 100644 --- a/chirho/dynamical/handlers/trajectory.py +++ b/chirho/dynamical/handlers/trajectory.py @@ -1,11 +1,10 @@ -import typing from typing import Generic, TypeVar import pyro import torch from chirho.dynamical.internals._utils import _squeeze_time_dim, append -from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory +from chirho.dynamical.internals.solver import simulate_trajectory from chirho.dynamical.ops import State from chirho.indexed.ops import IndexSet, gather, get_index_plates @@ -30,16 +29,12 @@ def __enter__(self) -> "LogTrajectory[T]": self.trajectory: State[T] = State() return super().__enter__() - def _pyro_simulate(self, msg) -> None: + def _pyro_simulate_to_interruption(self, msg) -> None: msg["done"] = True - def _pyro_post_simulate(self, msg) -> None: + def _pyro_post_simulate_to_interruption(self, msg) -> None: # Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times - dynamics, initial_state, start_time, end_time = msg["args"] - if msg["kwargs"].get("solver", None) is not None: - solver = typing.cast(Solver, msg["kwargs"]["solver"]) - else: - solver = get_solver() + solver, dynamics, initial_state, start_time, end_time = msg["args"] filtered_timespan = self.times[ (self.times >= start_time) & (self.times <= end_time) diff --git a/chirho/dynamical/internals/solver.py b/chirho/dynamical/internals/solver.py index 3efd03019..1649a3f15 100644 --- a/chirho/dynamical/internals/solver.py +++ b/chirho/dynamical/internals/solver.py @@ -8,7 +8,7 @@ import pyro import torch -from chirho.dynamical.ops import Dynamics, State, simulate +from chirho.dynamical.ops import Dynamics, State if typing.TYPE_CHECKING: from chirho.dynamical.handlers.interruption import ( @@ -80,37 +80,14 @@ def simulate_to_interruption( start_state: State[T], start_time: R, end_time: R, - *, - next_static_interruption: Optional[StaticInterruption] = None, - dynamic_interruptions: List[DynamicInterruption] = [], **kwargs, -) -> Tuple[State[T], Tuple[Interruption, ...], R]: +) -> State[T]: """ - Simulate a dynamical system until the next interruption. This will be either one of the passed - dynamic interruptions, the next static interruption, or the end time, whichever comes first. + Simulate a dynamical system until the next interruption. - :returns: the final state, a collection of interruptions that ended the simulation - (this will usually just be a single interruption), and the time the interruption occurred. + :returns: the final state """ - - interruptions, interruption_time = get_next_interruptions( - solver, - dynamics, - start_state, - start_time, - end_time, - next_static_interruption=next_static_interruption, - dynamic_interruptions=dynamic_interruptions, - **kwargs, - ) - # TODO: consider memoizing results of `get_next_interruptions` to avoid recomputing - # the solver in the dynamic setting. The interactions are a bit tricky here though, as we couldn't be in - # a LogTrajectory context. - event_state = simulate( - dynamics, start_state, start_time, interruption_time, solver=solver - ) - - return event_state, interruptions, interruption_time + return simulate_point(solver, dynamics, start_state, start_time, end_time, **kwargs) @pyro.poutine.runtime.effectful(type="apply_interruptions") @@ -124,6 +101,7 @@ def apply_interruptions( return dynamics, start_state +@pyro.poutine.runtime.effectful(type="get_next_interruptions") def get_next_interruptions( solver: Solver, dynamics: Dynamics[T], diff --git a/chirho/dynamical/ops.py b/chirho/dynamical/ops.py index f484e8368..2cbe85d1f 100644 --- a/chirho/dynamical/ops.py +++ b/chirho/dynamical/ops.py @@ -30,9 +30,13 @@ def simulate( """ Simulate a dynamical system. """ - from chirho.dynamical.internals.solver import Solver, get_solver, simulate_point + from chirho.dynamical.internals.solver import ( + Solver, + get_solver, + simulate_to_interruption, + ) solver_: Solver = get_solver() if solver is None else typing.cast(Solver, solver) - return simulate_point( + return simulate_to_interruption( solver_, dynamics, initial_state, start_time, end_time, **kwargs )