diff --git a/gmmx/fit.py b/gmmx/fit.py index 6349e37..70daea1 100644 --- a/gmmx/fit.py +++ b/gmmx/fit.py @@ -100,7 +100,7 @@ def m_step( Updated Gaussian mixture model instance. """ x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar)) - return gmm.update_parameters(x, jnp.exp(log_resp), reg_covar=self.reg_covar) + return gmm.from_responsibilities(x, jnp.exp(log_resp), reg_covar=self.reg_covar) @jax.jit def fit(self, x: jax.Array, gmm: GaussianMixtureModelJax) -> EMFitterResult: