Skip to content

Commit

Permalink
Separate simulate from simulate_to_interruption (#356)
Browse files Browse the repository at this point in the history
* refactored simulate

* revert accidental change from rebase
  • Loading branch information
SamWitty authored Oct 20, 2023
1 parent db9801e commit 235d277
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 65 deletions.
12 changes: 8 additions & 4 deletions chirho/dynamical/handlers/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chirho.dynamical.handlers.interruption import Interruption
from chirho.dynamical.internals.solver import (
apply_interruptions,
get_next_interruptions,
get_solver,
simulate_to_interruption,
)
Expand All @@ -29,13 +30,18 @@ def _pyro_simulate(self, msg) -> None:
with pyro.poutine.messenger.block_messengers(
lambda m: m is self or (isinstance(m, Interruption) and m.used)
):
state, terminal_interruptions, start_time = simulate_to_interruption(
terminal_interruptions, interruption_time = get_next_interruptions(
solver, dynamics, state, start_time, end_time
)

state = simulate_to_interruption(
solver,
dynamics,
state,
start_time,
end_time,
interruption_time,
)
start_time = interruption_time
for h in terminal_interruptions:
h.used = True

Expand All @@ -46,6 +52,4 @@ def _pyro_simulate(self, msg) -> None:
dynamics, state = apply_interruptions(dynamics, state)

msg["value"] = state
msg["stop"] = True
msg["done"] = True
msg["in_SEL"] = True
26 changes: 4 additions & 22 deletions chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from chirho.dynamical.handlers.trajectory import LogTrajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import get_index_plates, indices_of
from chirho.interventional.ops import Intervention, intervene
from chirho.observational.ops import Observation, observe

Expand All @@ -23,7 +22,7 @@ def __enter__(self):
self.used = False
return super().__enter__()

def _pyro_simulate_to_interruption(self, msg) -> None:
def _pyro_get_next_interruptions(self, msg) -> None:
raise NotImplementedError("shouldn't be here!")


Expand All @@ -34,7 +33,7 @@ def __init__(self, time: R):
self.time = torch.as_tensor(time) # TODO enforce this where it is needed
super().__init__()

def _pyro_simulate_to_interruption(self, msg) -> None:
def _pyro_get_next_interruptions(self, msg) -> None:
_, _, _, start_time, end_time = msg["args"]

if start_time < self.time < end_time:
Expand Down Expand Up @@ -67,7 +66,7 @@ def __init__(self, event_f: Callable[[R, State[T]], R]):
self.event_f = event_f
super().__init__()

def _pyro_simulate_to_interruption(self, msg) -> None:
def _pyro_get_next_interruptions(self, msg) -> None:
msg["kwargs"].setdefault("dynamic_interruptions", []).append(self)


Expand Down Expand Up @@ -157,21 +156,4 @@ def __init__(
super().__init__(times, eps=eps)

def _pyro_post_simulate(self, msg) -> None:
super()._pyro_post_simulate(msg)

# This checks whether the simulate has already redirected in a InterruptionEventLoop.
# If so, we don't want to run the observation again.
if msg.setdefault("in_SEL", False):
return

# TODO remove this redundant check by fixing semantics of LogTrajectory and simulate
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim["__time"] = -1
len_traj = (
0
if len(self.trajectory.keys()) == 0
else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"])
)

if len_traj == len(self.times):
msg["value"] = observe(self.trajectory, self.observation)
self.trajectory = observe(self.trajectory, self.observation)
13 changes: 4 additions & 9 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import typing
from typing import Generic, TypeVar

import pyro
import torch

from chirho.dynamical.internals._utils import _squeeze_time_dim, append
from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory
from chirho.dynamical.internals.solver import simulate_trajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import IndexSet, gather, get_index_plates

Expand All @@ -30,16 +29,12 @@ def __enter__(self) -> "LogTrajectory[T]":
self.trajectory: State[T] = State()
return super().__enter__()

def _pyro_simulate(self, msg) -> None:
def _pyro_simulate_to_interruption(self, msg) -> None:
msg["done"] = True

def _pyro_post_simulate(self, msg) -> None:
def _pyro_post_simulate_to_interruption(self, msg) -> None:
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times
dynamics, initial_state, start_time, end_time = msg["args"]
if msg["kwargs"].get("solver", None) is not None:
solver = typing.cast(Solver, msg["kwargs"]["solver"])
else:
solver = get_solver()
solver, dynamics, initial_state, start_time, end_time = msg["args"]

filtered_timespan = self.times[
(self.times >= start_time) & (self.times <= end_time)
Expand Down
34 changes: 6 additions & 28 deletions chirho/dynamical/internals/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyro
import torch

from chirho.dynamical.ops import Dynamics, State, simulate
from chirho.dynamical.ops import Dynamics, State

if typing.TYPE_CHECKING:
from chirho.dynamical.handlers.interruption import (
Expand Down Expand Up @@ -80,37 +80,14 @@ def simulate_to_interruption(
start_state: State[T],
start_time: R,
end_time: R,
*,
next_static_interruption: Optional[StaticInterruption] = None,
dynamic_interruptions: List[DynamicInterruption] = [],
**kwargs,
) -> Tuple[State[T], Tuple[Interruption, ...], R]:
) -> State[T]:
"""
Simulate a dynamical system until the next interruption. This will be either one of the passed
dynamic interruptions, the next static interruption, or the end time, whichever comes first.
Simulate a dynamical system until the next interruption.
:returns: the final state, a collection of interruptions that ended the simulation
(this will usually just be a single interruption), and the time the interruption occurred.
:returns: the final state
"""

interruptions, interruption_time = get_next_interruptions(
solver,
dynamics,
start_state,
start_time,
end_time,
next_static_interruption=next_static_interruption,
dynamic_interruptions=dynamic_interruptions,
**kwargs,
)
# TODO: consider memoizing results of `get_next_interruptions` to avoid recomputing
# the solver in the dynamic setting. The interactions are a bit tricky here though, as we couldn't be in
# a LogTrajectory context.
event_state = simulate(
dynamics, start_state, start_time, interruption_time, solver=solver
)

return event_state, interruptions, interruption_time
return simulate_point(solver, dynamics, start_state, start_time, end_time, **kwargs)


@pyro.poutine.runtime.effectful(type="apply_interruptions")
Expand All @@ -124,6 +101,7 @@ def apply_interruptions(
return dynamics, start_state


@pyro.poutine.runtime.effectful(type="get_next_interruptions")
def get_next_interruptions(
solver: Solver,
dynamics: Dynamics[T],
Expand Down
8 changes: 6 additions & 2 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def simulate(
"""
Simulate a dynamical system.
"""
from chirho.dynamical.internals.solver import Solver, get_solver, simulate_point
from chirho.dynamical.internals.solver import (
Solver,
get_solver,
simulate_to_interruption,
)

solver_: Solver = get_solver() if solver is None else typing.cast(Solver, solver)
return simulate_point(
return simulate_to_interruption(
solver_, dynamics, initial_state, start_time, end_time, **kwargs
)

0 comments on commit 235d277

Please sign in to comment.