Skip to content

Commit

Permalink
Add further type annotations to gmhmms
Browse files Browse the repository at this point in the history
- Add annotations to the *GaussianMixtureHMMEmssions classes
  • Loading branch information
gileshd committed Sep 20, 2024
1 parent da492c3 commit 38e7912
Showing 1 changed file with 81 additions and 41 deletions.
122 changes: 81 additions & 41 deletions dynamax/hidden_markov_model/models/gmm_hmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
Expand All @@ -10,16 +11,15 @@
from dynamax.utils.distributions import NormalInverseWishart
from dynamax.utils.distributions import nig_posterior_update
from dynamax.utils.distributions import niw_posterior_update
from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar
from typing import NamedTuple, Optional, Tuple, Union


# Types
class ParamsGaussianMixtureHMMEmissions(NamedTuple):
weights: Union[Float[Array, "state_dim num_components"], ParameterProperties]
means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties]
Expand Down Expand Up @@ -48,14 +48,14 @@ class ParamsDiagonalGaussianMixtureHMM(NamedTuple):
class GaussianMixtureHMMEmissions(HMMEmissions):

def __init__(self,
num_states,
num_components,
emission_dim,
emission_weights_concentration=1.1,
emission_prior_mean=0.,
emission_prior_mean_concentration=1e-4,
emission_prior_extra_df=1e-4,
emission_prior_scale=0.1):
num_states: int,
num_components: int,
emission_dim: int,
emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1,
emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.,
emission_prior_mean_concentration: Scalar=1e-4,
emission_prior_extra_df: Scalar=1e-4,
emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=0.1):
self.num_states = num_states
self.num_components = num_components
self.emission_dim = emission_dim
Expand All @@ -69,12 +69,14 @@ def __init__(self,
def emission_shape(self):
return (self.emission_dim,)

def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_means=None,
emission_covariances=None,
emissions=None):
def initialize(self,
key: Array=jr.PRNGKey(0),
method: str="prior",
emission_weights: Optional[Float[Array, "num_states num_components"]]=None,
emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emission_covariances: Optional[Float[Array, "num_states num_components emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[ParamsGaussianMixtureHMMEmissions, ParamsGaussianMixtureHMMEmissions]:
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -111,21 +113,30 @@ def initialize(self, key=jr.PRNGKey(0),
covs=ParameterProperties(constrainer=RealToPSDBijector()))
return params, props

def distribution(self, params, state, inputs=None):
def distribution(self,
params: ParamsGaussianMixtureHMMEmissions,
state: int,
inputs: Optional[Array] = None
) -> tfd.Distribution:
return tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=params.weights[state]),
components_distribution=tfd.MultivariateNormalFullCovariance(
loc=params.means[state], covariance_matrix=params.covs[state]))

def log_prior(self, params):
def log_prior(self, params:ParamsGaussianMixtureHMMEmissions) -> Float[Array, ""]:
lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob(
params.weights).sum()
lp += NormalInverseWishart(self.emission_prior_mean, self.emission_prior_mean_concentration,
self.emission_prior_df, self.emission_prior_scale).log_prob(
(params.covs, params.means)).sum()
return lp

def collect_suff_stats(self, params, posterior, emissions, inputs=None):
def collect_suff_stats(self,
params: ParamsGaussianMixtureHMMEmissions,
posterior: HMMPosterior,
emissions: Float[Array, "num_timesteps emission_dim"],
inputs: Optional[Array] = None
) -> Dict[str, Float[Array, "..."]]:
def prob_fn(x):
logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalFullCovariance(
loc=mus, covariance_matrix=sigmas).log_prob(x) + jnp.log(weights))(
Expand All @@ -141,10 +152,20 @@ def prob_fn(x):
N = weights.sum(axis=0)
return dict(N=N, Sx=Sx, SxxT=SxxT)

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(
self,
params: ParamsGaussianMixtureHMMEmissions,
props: ParamsGaussianMixtureHMMEmissions
) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsGaussianMixtureHMMEmissions,
props: ParamsGaussianMixtureHMMEmissions,
batch_stats: Dict[str, Float[Array, "..."]],
m_step_state: Any
) -> Tuple[ParamsGaussianMixtureHMMEmissions, Any]:
assert props.weights.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
assert props.means.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
assert props.covs.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters"
Expand Down Expand Up @@ -268,14 +289,14 @@ def initialize(self,

class DiagonalGaussianMixtureHMMEmissions(HMMEmissions):
def __init__(self,
num_states,
num_components,
emission_dim,
emission_weights_concentration=1.1,
emission_prior_mean=0.,
emission_prior_mean_concentration=1e-4,
emission_prior_shape=1.,
emission_prior_scale=1.):
num_states: int,
num_components: int,
emission_dim: int,
emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1,
emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.,
emission_prior_mean_concentration: Scalar=1e-4,
emission_prior_shape: Scalar=1.,
emission_prior_scale: Union[Scalar, Float[Array, " emission_dim"]]=1.):
self.num_states = num_states
self.num_components = num_components
self.emission_dim = emission_dim
Expand All @@ -288,15 +309,17 @@ def __init__(self,
self.emission_prior_scale = emission_prior_scale

@property
def emission_shape(self):
def emission_shape(self) -> Tuple[int]:
return (self.emission_dim,)

def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_means=None,
emission_scale_diags=None,
emissions=None):
def initialize(self,
key: Array=jr.PRNGKey(0),
method: str="prior",
emission_weights: Optional[Float[Array, "num_states num_components"]]=None,
emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emission_scale_diags: Optional[Float[Array, "num_states num_components emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[ParamsDiagonalGaussianMixtureHMMEmissions, ParamsDiagonalGaussianMixtureHMMEmissions]:
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -333,14 +356,18 @@ def initialize(self, key=jr.PRNGKey(0),
scale_diags=ParameterProperties(constrainer=tfb.Softplus()))
return params, props

def distribution(self, params, state, inputs=None):
def distribution(self,
params: ParamsDiagonalGaussianMixtureHMMEmissions,
state: int,
inputs: Optional[Array] = None
) -> tfd.Distribution:
return tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=params.weights[state]),
components_distribution=tfd.MultivariateNormalDiag(
loc=params.means[state],
scale_diag=params.scale_diags[state]))

def log_prior(self, params):
def log_prior(self, params: ParamsDiagonalGaussianMixtureHMMEmissions) -> Float[Array, ""]:
lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob(
params.weights).sum()
lp += NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_concentration,
Expand All @@ -349,7 +376,12 @@ def log_prior(self, params):
return lp

# Expectation-maximization (EM) code
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
def collect_suff_stats(self,
params: ParamsDiagonalGaussianMixtureHMMEmissions,
posterior: HMMPosterior,
emissions: Float[Array, "num_timesteps emission_dim"],
inputs: Optional[Array] = None
) -> Dict[str, Float[Array, "..."]]:
# Evaluate the posterior probability of each discrete class
def prob_fn(x):
logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalDiag(
Expand All @@ -367,10 +399,18 @@ def prob_fn(x):
N = weights.sum(axis=0)
return dict(N=N, Sx=Sx, Sxsq=Sxsq)

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self,
params: ParamsDiagonalGaussianMixtureHMMEmissions,
props: ParamsDiagonalGaussianMixtureHMMEmissions
) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(self,
params: ParamsDiagonalGaussianMixtureHMMEmissions,
props: ParamsDiagonalGaussianMixtureHMMEmissions,
batch_stats: Dict[str, Float[Array, "..."]],
m_step_state: None
) -> Tuple[ParamsDiagonalGaussianMixtureHMMEmissions, None]:
assert props.weights.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
assert props.means.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
assert props.scale_diags.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters"
Expand Down

0 comments on commit 38e7912

Please sign in to comment.