Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent sorting for BlueiceExtendedModel #149

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions alea/models/blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def get_expectation_values(self, per_likelihood_term=False, **kwargs) -> dict:
mus = self.data_generators[ll_index].mus
for n, mu in zip(self.likelihood_list[ll_index].source_name_list, mus):
ret[ll_name][n] = mu
# sort by source name
ret[ll_name] = dict(sorted(ret[ll_name].items(), key=lambda item: item[0]))
if not per_likelihood_term:
# sum over sources with same names of all likelihood terms
ret = {
Expand Down Expand Up @@ -254,6 +256,9 @@ def get_source_histograms(self, likelihood_name: str, expected_events=False, **k
for hist in source_histograms.values():
hist.histogram /= hist.bin_volumes()

# sort the source_histograms by source name
source_histograms = dict(sorted(source_histograms.items(), key=lambda item: item[0]))

return source_histograms

def _process_blueice_config(self, config, template_folder_list):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,19 @@ def test_get_source_histograms(self):
# check that invalid likelihood names fail
with self.assertRaises(ValueError):
model.get_source_histograms("alea_iacta_est")

def test_sorted_returns(self):
"""Test if sources are sorted in the same way for all return dicts."""
for model in self.models:
mus_per_ll = model.get_expectation_values(per_likelihood_term=True)
mus = model.get_expectation_values()
hist_per_ll = {}
for ll_name in model.likelihood_names[:-1]:
hist_per_ll[ll_name] = model.get_source_histograms(ll_name)
# check that keys are the same for each SR
for ll_name in model.likelihood_names[:-1]:
self.assertEqual(mus_per_ll[ll_name].keys(), hist_per_ll[ll_name].keys())
# check that global keys are the same
all_keys = {v for d in mus_per_ll.values() for v in d.keys()}
all_keys = sorted(all_keys)
self.assertEqual(all_keys, sorted(mus.keys()))
Loading