Skip to content

Commit

Permalink
Adapt fit.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adonath committed Dec 18, 2024
1 parent 99578f6 commit 614f510
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gmmx/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def m_step(
Updated Gaussian mixture model instance.
"""
x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar))
return gmm.update_parameters(x, jnp.exp(log_resp), reg_covar=self.reg_covar)
return gmm.from_responsibilities(x, jnp.exp(log_resp), reg_covar=self.reg_covar)

@jax.jit
def fit(self, x: jax.Array, gmm: GaussianMixtureModelJax) -> EMFitterResult:
Expand Down

0 comments on commit 614f510

Please sign in to comment.