diff --git a/tests/test_inference/test_initialization.py b/tests/test_inference/test_initialization.py index e8d58f5..d68ff8f 100644 --- a/tests/test_inference/test_initialization.py +++ b/tests/test_inference/test_initialization.py @@ -12,35 +12,6 @@ class InitTestCase(unittest.TestCase): def setUp(self) -> None: set_seed(42) - def test_baum_welch_cluster_init(self): - config = Config(n_nodes=4, n_states=5, n_cells=100, chain_length=200, wis_sample_size=100, debug=True) - data = simulate_full_dataset(config) - # get trees - fix_qt = qT(config, true_params={ - "tree": data['tree'] - }) - trees_sample, trees_weights = fix_qt.get_trees_sample() - - # get eps - fix_qeps = qEpsilonMulti(config, true_params={ - "eps": data['eps'] - }) - - # random init - qc_rand = qC(config).initialize(method='random') - rand_elbo = qc_rand.compute_elbo(trees_sample, trees_weights, fix_qeps) - - # Baum-Welch init - qc_bw = qC(config).initialize(method='bw-cluster', obs=data['obs'], clusters=data['z']) - bw_elbo = qc_bw.compute_elbo(trees_sample, trees_weights, fix_qeps) - - # print(bw_elbo, rand_elbo) - self.assertGreater(bw_elbo, rand_elbo) - - fix_qc = qC(config, true_params={ - "c": data['c'] - }) - # @unittest.skip("clonal init not working") def test_eps_init_from_data(self): config = Config(n_nodes=5, n_cells=300, chain_length=1000, step_size=.2, debug=True)