Skip to content

Commit

Permalink
Add FullCovariances.type class var
Browse files Browse the repository at this point in the history
  • Loading branch information
adonath committed Dec 18, 2024
1 parent 614f510 commit 5f173a1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 6 additions & 1 deletion gmmx/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def m_step(
Updated Gaussian mixture model instance.
"""
x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar))
return gmm.from_responsibilities(x, jnp.exp(log_resp), reg_covar=self.reg_covar)
return gmm.from_responsibilities(
x,
jnp.exp(log_resp),
reg_covar=self.reg_covar,
covariance_type=gmm.covariances.type,
)

@jax.jit
def fit(self, x: jax.Array, gmm: GaussianMixtureModelJax) -> EMFitterResult:
Expand Down
5 changes: 3 additions & 2 deletions gmmx/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Union
from typing import Any, ClassVar, Union

import jax
import numpy as np
Expand Down Expand Up @@ -109,6 +109,7 @@ class FullCovariances:
"""

values: jax.Array
type: ClassVar[CovarianceType] = CovarianceType.full

def __post_init__(self) -> None:
check_shape(self.values, (1, None, None, None))
Expand Down Expand Up @@ -273,7 +274,7 @@ def precisions_cholesky(self) -> jax.Array:


COVARIANCE: dict[CovarianceType, Any] = {
CovarianceType.full: FullCovariances,
FullCovariances.type: FullCovariances,
}

# keep this mapping separate, as names in sklearn might change
Expand Down

0 comments on commit 5f173a1

Please sign in to comment.