Skip to content

Commit

Permalink
Consistent sorting for BlueiceExtendedModel (#149)
Browse files Browse the repository at this point in the history
* sort returns of BEM

* make a test

---------

Co-authored-by: Dacheng Xu <dx2227@columbia.edu>
  • Loading branch information
hammannr and dachengx authored Mar 20, 2024
1 parent 08280b8 commit 34f98f0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions alea/models/blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def get_expectation_values(self, per_likelihood_term=False, **kwargs) -> dict:
mus = self.data_generators[ll_index].mus
for n, mu in zip(self.likelihood_list[ll_index].source_name_list, mus):
ret[ll_name][n] = mu
# sort by source name
ret[ll_name] = dict(sorted(ret[ll_name].items(), key=lambda item: item[0]))
if not per_likelihood_term:
# sum over sources with same names of all likelihood terms
ret = {
Expand Down Expand Up @@ -254,6 +256,9 @@ def get_source_histograms(self, likelihood_name: str, expected_events=False, **k
for hist in source_histograms.values():
hist.histogram /= hist.bin_volumes()

# sort the source_histograms by source name
source_histograms = dict(sorted(source_histograms.items(), key=lambda item: item[0]))

return source_histograms

def _process_blueice_config(self, config, template_folder_list):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,19 @@ def test_get_source_histograms(self):
# check that invalid likelihood names fail
with self.assertRaises(ValueError):
model.get_source_histograms("alea_iacta_est")

def test_sorted_returns(self):
"""Test if sources are sorted in the same way for all return dicts."""
for model in self.models:
mus_per_ll = model.get_expectation_values(per_likelihood_term=True)
mus = model.get_expectation_values()
hist_per_ll = {}
for ll_name in model.likelihood_names[:-1]:
hist_per_ll[ll_name] = model.get_source_histograms(ll_name)
# check that keys are the same for each SR
for ll_name in model.likelihood_names[:-1]:
self.assertEqual(mus_per_ll[ll_name].keys(), hist_per_ll[ll_name].keys())
# check that global keys are the same
all_keys = {v for d in mus_per_ll.values() for v in d.keys()}
all_keys = sorted(all_keys)
self.assertEqual(all_keys, sorted(mus.keys()))

0 comments on commit 34f98f0

Please sign in to comment.