diff --git a/gmmx/fit.py b/gmmx/fit.py index 70daea1..c17357f 100644 --- a/gmmx/fit.py +++ b/gmmx/fit.py @@ -100,7 +100,12 @@ def m_step( Updated Gaussian mixture model instance. """ x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar)) - return gmm.from_responsibilities(x, jnp.exp(log_resp), reg_covar=self.reg_covar) + return gmm.from_responsibilities( + x, + jnp.exp(log_resp), + reg_covar=self.reg_covar, + covariance_type=gmm.covariances.type, + ) @jax.jit def fit(self, x: jax.Array, gmm: GaussianMixtureModelJax) -> EMFitterResult: diff --git a/gmmx/gmm.py b/gmmx/gmm.py index 23e4497..83a4502 100644 --- a/gmmx/gmm.py +++ b/gmmx/gmm.py @@ -51,7 +51,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Union +from typing import Any, ClassVar, Union import jax import numpy as np @@ -109,6 +109,7 @@ class FullCovariances: """ values: jax.Array + type: ClassVar[CovarianceType] = CovarianceType.full def __post_init__(self) -> None: check_shape(self.values, (1, None, None, None)) @@ -273,7 +274,7 @@ def precisions_cholesky(self) -> jax.Array: COVARIANCE: dict[CovarianceType, Any] = { - CovarianceType.full: FullCovariances, + FullCovariances.type: FullCovariances, } # keep this mapping separate, as names in sklearn might change