Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use functional interface for dynamical systems #341

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,32 @@
simulate_point,
simulate_trajectory,
)
from chirho.dynamical.ops import InPlaceDynamics, State, get_keys
from chirho.dynamical.ops import Dynamics, State, get_keys
from chirho.indexed.ops import IndexSet, gather, get_index_plates

S = TypeVar("S")
T = TypeVar("T")


def _deriv(
dynamics: InPlaceDynamics[torch.Tensor],
dynamics: Dynamics[torch.Tensor],
var_order: Tuple[str, ...],
time: torch.Tensor,
state: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
ddt: State[torch.Tensor] = State()
env: State[torch.Tensor] = State()
for var, value in zip(var_order, state):
setattr(env, var, value)

assert "t" not in get_keys(env), "variable name t is reserved for time"
env.t = time

dynamics.diff(ddt, env)
ddt: State[torch.Tensor] = dynamics(env)
return tuple(getattr(ddt, var, torch.tensor(0.0)) for var in var_order)


def _torchdiffeq_ode_simulate_inner(
dynamics: InPlaceDynamics[torch.Tensor],
dynamics: Dynamics[torch.Tensor],
initial_state: State[torch.Tensor],
timespan,
**odeint_kwargs,
Expand Down Expand Up @@ -108,7 +107,7 @@ def _batched_odeint(
@simulate_point.register(TorchDiffEq)
def torchdiffeq_ode_simulate(
solver: TorchDiffEq,
dynamics: InPlaceDynamics[torch.Tensor],
dynamics: Dynamics[torch.Tensor],
initial_state: State[torch.Tensor],
start_time: torch.Tensor,
end_time: torch.Tensor,
Expand All @@ -132,7 +131,7 @@ def torchdiffeq_ode_simulate(
@simulate_trajectory.register(TorchDiffEq)
def torchdiffeq_ode_simulate_trajectory(
solver: TorchDiffEq,
dynamics: InPlaceDynamics[torch.Tensor],
dynamics: Dynamics[torch.Tensor],
initial_state: State[torch.Tensor],
timespan: torch.Tensor,
) -> State[torch.Tensor]:
Expand All @@ -144,7 +143,7 @@ def torchdiffeq_ode_simulate_trajectory(
@get_next_interruptions_dynamic.register(TorchDiffEq)
def torchdiffeq_get_next_interruptions_dynamic(
solver: TorchDiffEq,
dynamics: InPlaceDynamics[torch.Tensor],
dynamics: Dynamics[torch.Tensor],
start_state: State[torch.Tensor],
start_time: torch.Tensor,
next_static_interruption: StaticInterruption,
Expand Down
16 changes: 8 additions & 8 deletions chirho/dynamical/internals/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyro
import torch

from chirho.dynamical.ops import InPlaceDynamics, State, simulate
from chirho.dynamical.ops import Dynamics, State, simulate

if typing.TYPE_CHECKING:
from chirho.dynamical.handlers.interruption import (
Expand Down Expand Up @@ -42,7 +42,7 @@ def get_solver() -> Solver:
@functools.singledispatch
def simulate_point(
solver: Solver,
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
initial_state: State[T],
start_time: R,
end_time: R,
Expand All @@ -59,7 +59,7 @@ def simulate_point(
@functools.singledispatch
def simulate_trajectory(
solver: Solver,
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
initial_state: State[T],
timespan: R,
**kwargs,
Expand All @@ -76,7 +76,7 @@ def simulate_trajectory(
@pyro.poutine.runtime.effectful(type="simulate_to_interruption")
def simulate_to_interruption(
solver: Solver,
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
start_state: State[T],
start_time: R,
end_time: R,
Expand Down Expand Up @@ -115,8 +115,8 @@ def simulate_to_interruption(

@pyro.poutine.runtime.effectful(type="apply_interruptions")
def apply_interruptions(
dynamics: InPlaceDynamics[T], start_state: State[T]
) -> Tuple[InPlaceDynamics[T], State[T]]:
dynamics: Dynamics[T], start_state: State[T]
) -> Tuple[Dynamics[T], State[T]]:
"""
Apply the effects of an interruption to a dynamical system.
"""
Expand All @@ -126,7 +126,7 @@ def apply_interruptions(

def get_next_interruptions(
solver: Solver,
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
start_state: State[T],
start_time: R,
end_time: R,
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_next_interruptions(
@functools.singledispatch
def get_next_interruptions_dynamic(
solver: Solver,
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
start_state: State[T],
start_time: R,
next_static_interruption: StaticInterruption,
Expand Down
9 changes: 3 additions & 6 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numbers
import typing
from typing import FrozenSet, Generic, Optional, Protocol, TypeVar, Union
from typing import Callable, FrozenSet, Generic, Optional, TypeVar, Union

import pyro
import torch
Expand Down Expand Up @@ -36,15 +36,12 @@ def get_keys(state: State[T]) -> FrozenSet[str]:
return frozenset(state.__dict__["_values"].keys())


@typing.runtime_checkable
class InPlaceDynamics(Protocol[S]):
def diff(self, __dstate: State[S], __state: State[S]) -> None:
...
Dynamics = Callable[[State[T]], State[T]]


@pyro.poutine.runtime.effectful(type="simulate")
def simulate(
dynamics: InPlaceDynamics[T],
dynamics: Dynamics[T],
initial_state: State[T],
start_time: R,
end_time: R,
Expand Down
9 changes: 7 additions & 2 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from chirho.dynamical.ops import State, get_keys

pyro.settings.set(module_local_params=True)

T = TypeVar("T")


class UnifiedFixtureDynamics:
class UnifiedFixtureDynamics(pyro.nn.PyroModule):
def __init__(self, beta=None, gamma=None):
super().__init__()

Expand All @@ -21,21 +23,24 @@ def __init__(self, beta=None, gamma=None):
if self.gamma is None:
self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)

def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
def forward(self, X: State[torch.Tensor]):
dX: State[torch.Tensor] = State()
beta = self.beta * (
1.0 + 0.1 * torch.sin(0.1 * X.t)
) # beta oscilates slowly in time.

dX.S = -beta * X.S * X.I
dX.I = beta * X.S * X.I - self.gamma * X.I # noqa
dX.R = self.gamma * X.I
return dX

def _unit_measurement_error(self, name: str, x: torch.Tensor):
if x.ndim == 0:
return pyro.sample(name, Normal(x, 1))
else:
return pyro.sample(name, Normal(x, 1).to_event(1))

@pyro.nn.pyro_method
def observation(self, X: State[torch.Tensor]):
self._unit_measurement_error("S_obs", X.S)
self._unit_measurement_error("I_obs", X.I)
Expand Down
20 changes: 9 additions & 11 deletions tests/dynamical/test_dynamic_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
LogTrajectory,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import InPlaceDynamics, State, get_keys, simulate
from chirho.dynamical.ops import State, get_keys, simulate
from chirho.indexed.ops import IndexSet, gather, indices_of, union

from .dynamical_fixtures import UnifiedFixtureDynamics
Expand Down Expand Up @@ -403,16 +403,14 @@ def test_split_twinworld_dynamic_matches_output(


def test_grad_of_dynamic_intervention_event_f_params():
class Model(InPlaceDynamics):
def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
dX.x = tt(1.0)
dX.z = X.dz
dX.dz = tt(0.0) # also a constant, this gets set by interventions.
dX.param = tt(
0.0
) # this is a constant event function parameter, so no change.

model = Model()
def model(X: State[torch.Tensor]):
dX = State()
dX.x = tt(1.0)
dX.z = X.dz
dX.dz = tt(0.0) # also a constant, this gets set by interventions.
dX.param = tt(0.0) # this is a constant event function parameter, so no change.
return dX

param = torch.nn.Parameter(tt(5.0))
# Param has to be part of the state in order to take gradients with respect to it.
s0 = State(x=tt(0.0), z=tt(0.0), dz=tt(0.0), param=param)
Expand Down
4 changes: 2 additions & 2 deletions tests/dynamical/test_static_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ class RandBetaUnifiedFixtureDynamics(UnifiedFixtureDynamics):
def beta(self):
return pyro.distributions.Beta(1, 1)

def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
super().diff(dX, X)
def forward(self, X: State[torch.Tensor]):
assert torch.allclose(self.beta, self.beta)
return super().forward(X)

model = RandBetaUnifiedFixtureDynamics()

Expand Down
Loading