Skip to content

Commit

Permalink
Update type annotations in hmm inference code.
Browse files Browse the repository at this point in the history
Major changes:
- Replace `jaxtyping.Int` with `dynamax.typing.IntScalar` or `int`
  - this reflects when integer scalar arrays are accepted
  - `jaxtyping.[Dtype]` cannot be used directly for type checking
    instead they must be used as part of an array.
- Fix the shape of `transition_matrix`:
  - if transition_matrix has a leading timestep axis it should be of
    length T-1 not of length T.
- Add annotation indicating that `transition_matrix` is an optional argument
- Raise ValueError when neither `transition_matrix` or `transition_fn`
  provided.
  • Loading branch information
gileshd committed Sep 24, 2024
1 parent 357a906 commit 5b4f00a
Showing 1 changed file with 47 additions and 37 deletions.
84 changes: 47 additions & 37 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
from typing import Callable, Optional, Tuple, Union, NamedTuple
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar
from dynamax.types import IntScalar, Scalar

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

def get_trans_mat(transition_matrix, transition_fn, t):
def get_trans_mat(
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]],
t: IntScalar
) -> Float[Array, "num_states num_states"]:
if transition_fn is not None:
return transition_fn(t)
else:
if transition_matrix.ndim == 3: # (T,K,K)
elif transition_matrix is not None:
if transition_matrix.ndim == 3: # (T-1,K,K)
return transition_matrix[t]
else:
return transition_matrix
else:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

class HMMPosteriorFiltered(NamedTuple):
r"""Simple wrapper for properties of an HMM filtering posterior.
Expand Down Expand Up @@ -49,8 +56,8 @@ class HMMPosterior(NamedTuple):
predicted_probs: Float[Array, "num_timesteps num_states"]
smoothed_probs: Float[Array, "num_timesteps num_states"]
initial_probs: Float[Array, " num_states"]
trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]] = None
trans_probs: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None


def _normalize(u: Array, axis=0, eps=1e-15):
Expand Down Expand Up @@ -96,10 +103,10 @@ def _predict(probs, A):
@partial(jit, static_argnames=["transition_fn"])
def hmm_filter(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> HMMPosteriorFiltered:
r"""Forwards filtering
Expand Down Expand Up @@ -143,8 +150,8 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]:
Expand Down Expand Up @@ -190,10 +197,10 @@ def _step(carry, t):
@partial(jit, static_argnames=["transition_fn"])
def hmm_two_filter_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand Down Expand Up @@ -244,10 +251,10 @@ def hmm_two_filter_smoother(
@partial(jit, static_argnames=["transition_fn"])
def hmm_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand Down Expand Up @@ -324,11 +331,11 @@ def _step(carry, args):
@partial(jit, static_argnames=["transition_fn", "window_size"])
def hmm_fixed_lag_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
window_size: Int,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
window_size: int,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> HMMPosterior:
r"""Compute the smoothed state probabilities using the fixed-lag smoother.
Expand Down Expand Up @@ -438,10 +445,10 @@ def compute_posterior(filtered_probs, beta):
@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_mode(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, " num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.
Expand Down Expand Up @@ -486,10 +493,10 @@ def _forward_pass(state, best_next_state):
def hmm_posterior_sample(
key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Tuple[Scalar, Int[Array, " num_timesteps"]]:
r"""Sample a latent sequence from the posterior.
Expand Down Expand Up @@ -542,6 +549,7 @@ def _compute_sum_transition_probs(
transition_matrix: Float[Array, "num_states num_states"],
hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.
Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand Down Expand Up @@ -578,11 +586,13 @@ def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]):


def _compute_all_transition_probs(
transition_matrix: Float[Array, "num_timesteps num_states num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Float[Array, "num_timesteps num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.
Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand All @@ -600,14 +610,12 @@ def _compute_probs(t):
return transition_probs


# TODO: Consider alternative annotation for return type:
# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims.
# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd).
# TODO: This is a candidate for @overloading.
def compute_transition_probs(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]:
r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$.
Expand All @@ -620,8 +628,10 @@ def compute_transition_probs(
Returns:
array of smoothed transition probabilities.
"""
reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2
if reduce_sum:
if transition_matrix is None and transition_fn is None:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

if transition_matrix is not None and transition_matrix.ndim == 2:
return _compute_sum_transition_probs(transition_matrix, hmm_posterior)
else:
return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn)

0 comments on commit 5b4f00a

Please sign in to comment.