diff --git a/tests/test_utils.py b/tests/test_utils.py index 47ed1ec..96c16d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,6 +23,18 @@ def setUp(self): self.cluster_idx_list = [0, 1] self.cluster_idx_np = np.array([0, 1]) + def test_valid_input_marginals(self): + result = generate_contin_table_with_clustered_AE( + row_marginal=self.contin_table_np.sum(axis=1), + column_marginal=self.contin_table_np.sum(axis=0), + signal_mat=self.signal_mat, + cluster_idx=self.cluster_idx_list, + n=5, + rho=0.5, + ) + self.assertEqual(len(result), 5) + self.assertTrue(all(isinstance(table, np.ndarray) for table in result)) + def test_valid_input_dataframe(self): result = generate_contin_table_with_clustered_AE( row_marginal=None,