Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to get all sources names from all likelihoods #111

Merged
merged 7 commits into from
Nov 14, 2023
19 changes: 17 additions & 2 deletions alea/models/blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Dict, Callable, Optional, Union, cast
from copy import deepcopy
from pydoc import locate
import itertools

import numpy as np
import scipy.stats as stats
Expand Down Expand Up @@ -138,6 +139,21 @@ def get_source_name_list(self, likelihood_name: str) -> list:
ll_index = self.likelihood_names.index(likelihood_name)
return self.likelihood_list[ll_index].source_name_list

@property
def all_source_names(self) -> set:
"""Return a set of possible source names from all likelihood terms.

Args:
likelihood_name (str): Name of the likelihood.
Returns:
set: set of source names.

"""
source_names = set(
itertools.chain.from_iterable([ll.source_name_list for ll in self.likelihood_list[:-1]])
)
return source_names

@property
def likelihood_list(self) -> List:
"""Return a list of likelihood terms."""
Expand Down Expand Up @@ -200,10 +216,9 @@ def get_expectation_values(self, per_likelihood_term=False, **kwargs) -> dict:
ret[ll_name][n] = mu
if not per_likelihood_term:
# sum over sources with same names of all likelihood terms
all_source_names = {name for sublist in ret.values() for name in sublist}
ret = {
n: sum([ret[ll_name].get(n, 0.0) for ll_name in ret.keys()]) # type: ignore
for n in all_source_names
for n in self.all_source_names
}

return ret
Expand Down
9 changes: 9 additions & 0 deletions tests/test_blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def get_expectation_values(self):
expectation_values.append(this_expectation_dict)
return expectation_values

def test_all_source_names(self):
"""Test of the all_source_names method."""
for config, model in zip(self.configs, self.models):
_source_names = set()
for ll_t in config["likelihood_config"]["likelihood_terms"]:
_source_names.update([s["name"] for s in ll_t["sources"]])
source_names = model.all_source_names
self.assertEqual(source_names, _source_names)

def test_expectation_values(self):
"""Test of the expectation_values method."""

Expand Down
Loading