Skip to content

Commit

Permalink
Rename update parameters to from_responsibilities
Browse files Browse the repository at this point in the history
  • Loading branch information
adonath committed Dec 18, 2024
1 parent e0e4cf8 commit 99578f6
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions gmmx/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def log_prob(self, x: jax.Array, means: jax.Array) -> jax.Array:
)

@classmethod
def update_parameters(
def from_responsibilities(
cls,
x: jax.Array,
means: jax.Array,
Expand Down Expand Up @@ -393,7 +393,7 @@ def from_squeezed(
return cls(weights=weights, means=means, covariances=covariances) # type: ignore [arg-type]

@classmethod
def update_parameters(
def from_responsibilities(
cls,
x: jax.Array,
resp: jax.Array,
Expand All @@ -420,14 +420,19 @@ def update_parameters(
"""
nk = jnp.sum(resp, axis=Axis.batch, keepdims=True)
means = jnp.matmul(resp.T, x.T.mT).T / nk
covariances = COVARIANCE[covariance_type].update_parameters(
covariances = COVARIANCE[covariance_type].from_responsibilities(
x=x, means=means, resp=resp, nk=nk, reg_covar=reg_covar
)
return cls(weights=nk / nk.sum(), means=means, covariances=covariances)

@classmethod
def from_k_means(
cls, x: jax.Array, n_components: int, reg_covar: float = 1e-6, **kwargs
cls,
x: jax.Array,
n_components: int,
reg_covar: float = 1e-6,
covariance_type: CovarianceType = CovarianceType.full,
**kwargs,
) -> None:
"""Init from k-means clustering
Expand All @@ -439,6 +444,8 @@ def from_k_means(
Number of components
reg_covar : float, optional
Regularization for the covariance matrix, by default 1e6
covariance_type : str, optional
Covariance type, by default "full"
**kwargs : dict
Additional arguments passed to `~sklearn.cluster.KMeans`
Expand All @@ -460,7 +467,9 @@ def from_k_means(

xp = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar))
resp = jnp.expand_dims(resp, axis=(Axis.features, Axis.features_covar))
return cls.update_parameters(xp, resp, reg_covar=reg_covar)
return cls.from_responsibilities(
xp, resp, reg_covar=reg_covar, covariance_type=covariance_type
)

@property
def n_features(self) -> int:
Expand Down

0 comments on commit 99578f6

Please sign in to comment.