From e5f8f9e36dc57dcc8d666faccf276040175049b4 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Tue, 17 Dec 2024 11:36:51 -0500 Subject: [PATCH] Further adapt tests --- tests/test_gmm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 82bf0ad..ac02e14 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -99,11 +99,11 @@ def test_fit(gmm_jax, gmm_jax_init): random_state = np.random.RandomState(827392) x, _ = gmm_jax.to_sklearn(random_state=random_state).sample(16_000) - fitter = EMFitter(tol=1e-6) + fitter = EMFitter(tol=1e-4) result = fitter.fit(x=x, gmm=gmm_jax_init) - assert int(result.n_iter) == 13 - assert_allclose(result.log_likelihood, -4.368584, rtol=1e-6) + assert int(result.n_iter) == 6 + assert_allclose(result.log_likelihood, -4.3686, rtol=1e-4) assert_allclose(result.log_likelihood_diff, 9.536743e-07, atol=fitter.tol) assert_allclose(result.gmm.weights_numpy, [0.2, 0.8], rtol=0.03)