From a6da2f583da8d8b77d8afb189c263375a21e4e24 Mon Sep 17 00:00:00 2001 From: Robert Hammann <53221264+hammannr@users.noreply.github.com> Date: Wed, 7 Aug 2024 08:12:27 +0200 Subject: [PATCH] Conditional Parameters (#186) * add ConditionalParameter class * properly initialize and link the conditioning param * make Parameters prepoerties work with conditional params * add example config * add all those properties to the ConditionalParam * revert separation in Parameters properties * fix parameter evaluation * add some missing methods & run pre-commit hook * add unittest * some suggestions from pylint * oops * make Parameters __str__ more readable for ConditionalParameters --- ...tical_model_mass_dependent_efficiency.yaml | 144 +++++++++++++++ alea/parameters.py | 167 +++++++++++++++++- tests/test_parameters.py | 66 +++++-- 3 files changed, 356 insertions(+), 21 deletions(-) create mode 100644 alea/examples/configs/unbinned_wimp_statistical_model_mass_dependent_efficiency.yaml diff --git a/alea/examples/configs/unbinned_wimp_statistical_model_mass_dependent_efficiency.yaml b/alea/examples/configs/unbinned_wimp_statistical_model_mass_dependent_efficiency.yaml new file mode 100644 index 00000000..2c72a7f0 --- /dev/null +++ b/alea/examples/configs/unbinned_wimp_statistical_model_mass_dependent_efficiency.yaml @@ -0,0 +1,144 @@ +parameter_definition: + wimp_mass: + nominal_value: 50 + fittable: false + description: WIMP mass in GeV/c^2 + + livetime_sr0: + nominal_value: 0.2 + ptype: livetime + fittable: false + description: Livetime of SR0 in years + + livetime_sr1: + nominal_value: 1.0 + ptype: livetime + fittable: false + description: Livetime of SR1 in years + + wimp_rate_multiplier: + nominal_value: 1.0 + ptype: rate + fittable: true + fit_limits: + - 0 + - null + parameter_interval_bounds: + - 0 + - null + + er_rate_multiplier: + nominal_value: 1.0 + ptype: rate + uncertainty: 0.2 + relative_uncertainty: true + fittable: true + fit_limits: + - 0 + - null + fit_guess: 1.0 + + signal_efficiency: + conditioning_parameter_name: wimp_mass + nominal_value: 1.0 + ptype: efficiency + uncertainty: + 5: 0.3 + 10: 0.2 + 50: 0.15 + 100: 0.1 + 500: 0.1 + relative_uncertainty: true + fittable: true + fit_limits: + - 0 + - 10. + fit_guess: 1.0 + description: Parameter to account for the uncertain signal expectation given a certain cross-section + + er_band_shift: + nominal_value: 0 + ptype: shape + uncertainty: 'stats.uniform(loc=-2, scale=4)' + # relative_uncertainty: false + fittable: true + blueice_anchors: + - -2 + - -1 + - 0 + - 1 + - 2 + fit_limits: + - -2 + - 2 + description: ER band shape parameter (shifts the ER band up and down) + +likelihood_config: + likelihood_weights: [1, 1, 1] + template_folder: null # will try to find the templates in alea + likelihood_terms: + # SR0 + - name: sr0 + default_source_class: alea.template_source.TemplateSource + likelihood_type: blueice.likelihood.UnbinnedLogLikelihood + analysis_space: + - cs1: np.linspace(0, 100, 51) + - cs2: np.geomspace(100, 100000, 51) + in_events_per_bin: true + livetime_parameter: livetime_sr0 + slice_args: {} + sources: + - name: er + histname: er_template # TODO: implement a default histname based on the source name + parameters: + - er_rate_multiplier + - er_band_shift + named_parameters: + - er_band_shift + template_filename: er_template_{er_band_shift}.ii.h5 + histogram_scale_factor: 1 + + - name: wimp + histname: wimp_template + parameters: + - wimp_mass + - wimp_rate_multiplier + - signal_efficiency + named_parameters: + - wimp_mass + template_filename: wimp{wimp_mass:d}gev_template.ii.h5 + apply_efficiency: True + efficiency_name: signal_efficiency + + # SR1 + - name: sr1 + default_source_class: alea.template_source.TemplateSource + likelihood_type: blueice.likelihood.UnbinnedLogLikelihood + analysis_space: + - cs1: np.linspace(0, 100, 51) + - cs2: np.geomspace(100, 100000, 51) + in_events_per_bin: true + livetime_parameter: livetime_sr1 + slice_args: {} + sources: + - name: er + histname: er_template + parameters: + - er_rate_multiplier + - er_band_shift + named_parameters: + - er_band_shift + template_filename: er_template_{er_band_shift}.ii.h5 + histogram_scale_factor: 2 + + - name: wimp + histname: wimp_template + parameters: + - wimp_mass + - wimp_rate_multiplier + - signal_efficiency + named_parameters: + - wimp_mass + template_filename: wimp{wimp_mass:d}gev_template.ii.h5 + apply_efficiency: True + efficiency_name: signal_efficiency diff --git a/alea/parameters.py b/alea/parameters.py index ac8b7dc3..ea18bc41 100644 --- a/alea/parameters.py +++ b/alea/parameters.py @@ -204,6 +204,134 @@ def _check_parameter_consistency(self): ) +class ConditionalParameter: + """This class is used to define a parameter that depends on another parameter. It has the same + attributes as the Parameter class but each of them can be a dictionary with keys being the + values of the conditioning parameter and values being the corresponding values of the + conditional parameter. Calling the object with the conditioning parameter value as an argument + will return a corresponding Parameter object with the correct values. + + Attributes: + name (str): The name of the parameter. + conditioning_parameter_name (str): The name of the conditioning parameter. + + """ + + def __init__(self, name: str, conditioning_parameter_name: str, **kwargs): + self.name = name + self.conditioning_name = conditioning_parameter_name + self.conditions_dict = self._unpack_conditions(kwargs) + self.conditioning_param = None + + def __repr__(self) -> str: + parameter_str = ", ".join([f"{k}={v}" for k, v in self.__dict__.items() if v is not None]) + _repr = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + _repr += f"({parameter_str})" + return _repr + + @staticmethod + def _unpack_conditions(kwargs): + # 1) collect all condition keys and check for consistency + all_keys = set() + for value in kwargs.values(): + if isinstance(value, dict): + if not all_keys: + all_keys = set(value.keys()) + elif all_keys != set(value.keys()): + raise ValueError("Inconsistent condition keys across dictionaries.") + + # 2) create the conditions dictionary + conditions_dict = {key: {} for key in all_keys} + for key, value in kwargs.items(): + if isinstance(value, dict): + for condition_key, condition_value in value.items(): + conditions_dict[condition_key][key] = condition_value + else: + for condition_key in all_keys: + conditions_dict[condition_key][key] = value + + return conditions_dict + + @property + def uncertainty(self) -> Any: + """Return the uncertainty of the parameter (cominal condition)""" + return self().uncertainty + + @property + def blueice_anchors(self) -> Any: + """Return the blueice_anchors of the parameter (cominal condition)""" + return self().blueice_anchors + + @property + def fit_guess(self) -> Optional[float]: + """Return the initial guess for fitting the parameter (cominal condition)""" + return self().fit_guess + + @property + def parameter_interval_bounds(self) -> Optional[Tuple[float, float]]: + """Return the parameter_interval_bounds of the parameter (cominal condition)""" + return self().parameter_interval_bounds + + @property + def nominal_value(self) -> Optional[float]: + """Return the nominal value of the parameter (cominal condition)""" + return self().nominal_value + + @property + def needs_reinit(self) -> bool: + """Return True if the parameter needs re-initialization (for ptype ``needs_reinit``).""" + return self().needs_reinit + + @property + def fittable(self) -> bool: + """Return the fittable attribute of the parameter (cominal condition)""" + return self().fittable + + @property + def ptype(self) -> Optional[str]: + """Return the ptype of the parameter (cominal condition)""" + return self().ptype + + @property + def relative_uncertainty(self) -> Optional[bool]: + """Return the relative_uncertainty of the parameter (cominal condition)""" + return self().relative_uncertainty + + @property + def fit_limits(self) -> Optional[Tuple[float, float]]: + """Return the fit_limits of the parameter (cominal condition)""" + return self().fit_limits + + def __eq__(self, other: object) -> bool: + """Return True if all attributes are equal.""" + if isinstance(other, ConditionalParameter): + return all(getattr(self, k) == getattr(other, k) for k in self.__dict__) + return False + + def value_in_fit_limits(self, value: float) -> bool: + """Returns True if value under cominal condition is within fit_limits.""" + return self().value_in_fit_limits(value) + + def __call__(self, **kwargs) -> Parameter: + if self.conditioning_name in kwargs: + cond_val = kwargs[self.conditioning_name] + elif self.conditioning_param is not None: + cond_val = self.conditioning_param.nominal_value + else: + err_msg = ( + f"Conditioning parameter '{self.conditioning_name}' is missing. Can't fall back to " + "nominal value because conditioning parameter it is not set. " + ) + raise ValueError(err_msg) + # check if the conditioning value is in the conditions dictionary + if cond_val not in self.conditions_dict: + raise ValueError( + f"Conditioning value '{cond_val}' not found in the conditions dictionary." + + f"Available values are: {sorted(list(self.conditions_dict.keys()))}" + ) + return Parameter(name=self.name, **self.conditions_dict[cond_val]) + + class Parameters: """Represents a collection of parameters. @@ -246,9 +374,17 @@ def from_config(cls, config: Dict[str, dict]): """ parameters = cls() + parameter: Union[Parameter, ConditionalParameter] for name, param_config in config.items(): - parameter = Parameter(name=name, **param_config) + if "conditioning_parameter_name" in param_config: + parameter = ConditionalParameter(name, **param_config) + else: + parameter = Parameter(name=name, **param_config) parameters.add_parameter(parameter) + # set conditioning parameters + for param in parameters: + if isinstance(param, ConditionalParameter): + param.conditioning_param = parameters[param.conditioning_name] return parameters @classmethod @@ -279,7 +415,15 @@ def __str__(self) -> str: """Return an overview table of all parameters.""" par_list = [] for p in self: - par_dict = {} + if isinstance(p, ConditionalParameter): + par_dict = { + "conditioning_name": p.conditioning_name, + "conditions": sorted(p.conditions_dict.keys()), + } + # get nominal-condition parameter + p = p() + else: + par_dict = {} for k, v in p.__dict__.items(): # replace hidden attributes with non-hidden properties if k.startswith("_"): @@ -295,7 +439,7 @@ def __str__(self) -> str: return df.to_string() - def add_parameter(self, parameter: Parameter) -> None: + def add_parameter(self, parameter: Union[Parameter, ConditionalParameter]) -> None: """Adds a Parameter object to the Parameters collection. Args: @@ -355,7 +499,8 @@ def uncertainties(self) -> dict: def with_uncertainty(self) -> "Parameters": """Return parameters with a not-NaN uncertainty. - The parameters are the same objects as in the original Parameters object, not a copy. + The parameters are the same objects as in the original Parameters object, not a copy. For + conditional parameters, the parameters under the nominal condition are returned. """ param_dict = {k: i for k, i in self.parameters.items() if i.uncertainty is not None} @@ -391,6 +536,11 @@ def set_fit_guesses(self, **fit_guesses): for name, value in fit_guesses.items(): self.parameters[name].fit_guess = value + def _evaluate_parameter(self, parameter: Parameter, **kwargs): + if isinstance(parameter, ConditionalParameter): + return parameter(**kwargs) + return parameter + def __call__( self, return_fittable: Optional[bool] = False, **kwargs: Optional[Dict] ) -> Dict[str, float]: @@ -419,6 +569,7 @@ def __call__( raise ValueError(f"Parameter '{name}' not found.") for name, param in self.parameters.items(): + param = self._evaluate_parameter(param, **kwargs) new_val = kwargs.get(name, None) if param.needs_reinit and new_val != param.nominal_value and new_val is not None: raise ValueError( @@ -491,6 +642,8 @@ def values_in_fit_limits(self, **kwargs: Dict) -> bool: bool: True if all values are within the fit limits. """ - return all( - self.parameters[name].value_in_fit_limits(value) for name, value in kwargs.items() - ) + for name, value in kwargs.items(): + param = self._evaluate_parameter(self.parameters[name], **kwargs) + if not param.value_in_fit_limits(value): + return False + return True diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 77b0b775..53ee205c 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -8,27 +8,65 @@ class TestParameters(TestCase): """Test of the Parameters class.""" - @classmethod - def setUp(cls): + def setUp(self): """Initialise the Parameters instance.""" - cls.config = load_yaml("unbinned_wimp_statistical_model.yaml") - cls.parameters = Parameters.from_config(cls.config["parameter_definition"]) + filenames = [ + "unbinned_wimp_statistical_model.yaml", + "unbinned_wimp_statistical_model_mass_dependent_efficiency.yaml", + ] + configs = [] + for fn in filenames: + configs.append(load_yaml(fn)["parameter_definition"]) + self.configs = configs + + parameters_list = [] + for config in self.configs: + parameters_list.append(Parameters.from_config(config)) + self.parameters_list = parameters_list def test_from_list(self): """Test of the from_list method.""" - only_name_parameters = Parameters.from_list(self.config["parameter_definition"].keys()) - # it is false because only names are assigned - self.assertFalse(only_name_parameters == self.parameters) + for config, parameters in zip(self.configs, self.parameters_list): + only_name_parameters = Parameters.from_list(config.keys()) + # it is false because only names are assigned + self.assertFalse(only_name_parameters == parameters) def test___repr__(self): """Test of the __repr__ method.""" - for p in self.parameters: - if not isinstance(repr(p), str): - raise ValueError("The __repr__ method does not return the correct string.") - if not isinstance(repr(self.parameters), str): - raise TypeError("The __repr__ method does not return a string.") + for parameters in self.parameters_list: + for p in parameters: + if not isinstance(repr(p), str): + raise ValueError("The __repr__ method does not return the correct string.") + if not isinstance(repr(parameters), str): + raise TypeError("The __repr__ method does not return a string.") def test_deep_copyable(self): """Test of whether Parameters instance can be deepcopied.""" - if deepcopy(self.parameters) != self.parameters: - raise ValueError("Parameters instance cannot be correctly deepcopied.") + for parameters in self.parameters_list: + if deepcopy(parameters) != parameters: + raise ValueError("Parameters instance cannot be correctly deepcopied.") + + def test_conditional_parameter(self): + """Test of the ConditionalParameter class.""" + config = self.configs[1] + parameters = self.parameters_list[1] + nominal_wimp_mass = config["wimp_mass"]["nominal_value"] + signal_eff_uncert_dict = config["signal_efficiency"]["uncertainty"] + + # Directly accessing the property should return the value + # under nominal conditions + val = parameters.signal_efficiency.uncertainty + expected_val = signal_eff_uncert_dict[nominal_wimp_mass] + self.assertEqual(val, expected_val) + + # Calling without kwargs should return the value + # under nominal conditions + val = parameters.signal_efficiency().uncertainty + expected_val = signal_eff_uncert_dict[nominal_wimp_mass] + self.assertEqual(val, expected_val) + + # Calling with kwargs should return the value under + # the specified conditions + for wimp_mass, expected_val in signal_eff_uncert_dict.items(): + val = parameters.signal_efficiency(wimp_mass=wimp_mass).uncertainty + self.assertEqual(val, expected_val)