Skip to content

Commit

Permalink
Conditional Parameters (#186)
Browse files Browse the repository at this point in the history
* add ConditionalParameter class

* properly initialize and link the conditioning param

* make Parameters prepoerties work with conditional params

* add example config

* add all those properties to the ConditionalParam

* revert separation in Parameters properties

* fix parameter evaluation

* add some missing methods & run pre-commit hook

* add unittest

* some suggestions from pylint

* oops

* make Parameters __str__ more readable for ConditionalParameters
  • Loading branch information
hammannr authored Aug 7, 2024
1 parent 5c1ee1e commit a6da2f5
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
parameter_definition:
wimp_mass:
nominal_value: 50
fittable: false
description: WIMP mass in GeV/c^2

livetime_sr0:
nominal_value: 0.2
ptype: livetime
fittable: false
description: Livetime of SR0 in years

livetime_sr1:
nominal_value: 1.0
ptype: livetime
fittable: false
description: Livetime of SR1 in years

wimp_rate_multiplier:
nominal_value: 1.0
ptype: rate
fittable: true
fit_limits:
- 0
- null
parameter_interval_bounds:
- 0
- null

er_rate_multiplier:
nominal_value: 1.0
ptype: rate
uncertainty: 0.2
relative_uncertainty: true
fittable: true
fit_limits:
- 0
- null
fit_guess: 1.0

signal_efficiency:
conditioning_parameter_name: wimp_mass
nominal_value: 1.0
ptype: efficiency
uncertainty:
5: 0.3
10: 0.2
50: 0.15
100: 0.1
500: 0.1
relative_uncertainty: true
fittable: true
fit_limits:
- 0
- 10.
fit_guess: 1.0
description: Parameter to account for the uncertain signal expectation given a certain cross-section

er_band_shift:
nominal_value: 0
ptype: shape
uncertainty: 'stats.uniform(loc=-2, scale=4)'
# relative_uncertainty: false
fittable: true
blueice_anchors:
- -2
- -1
- 0
- 1
- 2
fit_limits:
- -2
- 2
description: ER band shape parameter (shifts the ER band up and down)

likelihood_config:
likelihood_weights: [1, 1, 1]
template_folder: null # will try to find the templates in alea
likelihood_terms:
# SR0
- name: sr0
default_source_class: alea.template_source.TemplateSource
likelihood_type: blueice.likelihood.UnbinnedLogLikelihood
analysis_space:
- cs1: np.linspace(0, 100, 51)
- cs2: np.geomspace(100, 100000, 51)
in_events_per_bin: true
livetime_parameter: livetime_sr0
slice_args: {}
sources:
- name: er
histname: er_template # TODO: implement a default histname based on the source name
parameters:
- er_rate_multiplier
- er_band_shift
named_parameters:
- er_band_shift
template_filename: er_template_{er_band_shift}.ii.h5
histogram_scale_factor: 1

- name: wimp
histname: wimp_template
parameters:
- wimp_mass
- wimp_rate_multiplier
- signal_efficiency
named_parameters:
- wimp_mass
template_filename: wimp{wimp_mass:d}gev_template.ii.h5
apply_efficiency: True
efficiency_name: signal_efficiency

# SR1
- name: sr1
default_source_class: alea.template_source.TemplateSource
likelihood_type: blueice.likelihood.UnbinnedLogLikelihood
analysis_space:
- cs1: np.linspace(0, 100, 51)
- cs2: np.geomspace(100, 100000, 51)
in_events_per_bin: true
livetime_parameter: livetime_sr1
slice_args: {}
sources:
- name: er
histname: er_template
parameters:
- er_rate_multiplier
- er_band_shift
named_parameters:
- er_band_shift
template_filename: er_template_{er_band_shift}.ii.h5
histogram_scale_factor: 2

- name: wimp
histname: wimp_template
parameters:
- wimp_mass
- wimp_rate_multiplier
- signal_efficiency
named_parameters:
- wimp_mass
template_filename: wimp{wimp_mass:d}gev_template.ii.h5
apply_efficiency: True
efficiency_name: signal_efficiency
167 changes: 160 additions & 7 deletions alea/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,134 @@ def _check_parameter_consistency(self):
)


class ConditionalParameter:
"""This class is used to define a parameter that depends on another parameter. It has the same
attributes as the Parameter class but each of them can be a dictionary with keys being the
values of the conditioning parameter and values being the corresponding values of the
conditional parameter. Calling the object with the conditioning parameter value as an argument
will return a corresponding Parameter object with the correct values.
Attributes:
name (str): The name of the parameter.
conditioning_parameter_name (str): The name of the conditioning parameter.
"""

def __init__(self, name: str, conditioning_parameter_name: str, **kwargs):
self.name = name
self.conditioning_name = conditioning_parameter_name
self.conditions_dict = self._unpack_conditions(kwargs)
self.conditioning_param = None

def __repr__(self) -> str:
parameter_str = ", ".join([f"{k}={v}" for k, v in self.__dict__.items() if v is not None])
_repr = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
_repr += f"({parameter_str})"
return _repr

@staticmethod
def _unpack_conditions(kwargs):
# 1) collect all condition keys and check for consistency
all_keys = set()
for value in kwargs.values():
if isinstance(value, dict):
if not all_keys:
all_keys = set(value.keys())
elif all_keys != set(value.keys()):
raise ValueError("Inconsistent condition keys across dictionaries.")

# 2) create the conditions dictionary
conditions_dict = {key: {} for key in all_keys}
for key, value in kwargs.items():
if isinstance(value, dict):
for condition_key, condition_value in value.items():
conditions_dict[condition_key][key] = condition_value
else:
for condition_key in all_keys:
conditions_dict[condition_key][key] = value

return conditions_dict

@property
def uncertainty(self) -> Any:
"""Return the uncertainty of the parameter (cominal condition)"""
return self().uncertainty

@property
def blueice_anchors(self) -> Any:
"""Return the blueice_anchors of the parameter (cominal condition)"""
return self().blueice_anchors

@property
def fit_guess(self) -> Optional[float]:
"""Return the initial guess for fitting the parameter (cominal condition)"""
return self().fit_guess

@property
def parameter_interval_bounds(self) -> Optional[Tuple[float, float]]:
"""Return the parameter_interval_bounds of the parameter (cominal condition)"""
return self().parameter_interval_bounds

@property
def nominal_value(self) -> Optional[float]:
"""Return the nominal value of the parameter (cominal condition)"""
return self().nominal_value

@property
def needs_reinit(self) -> bool:
"""Return True if the parameter needs re-initialization (for ptype ``needs_reinit``)."""
return self().needs_reinit

@property
def fittable(self) -> bool:
"""Return the fittable attribute of the parameter (cominal condition)"""
return self().fittable

@property
def ptype(self) -> Optional[str]:
"""Return the ptype of the parameter (cominal condition)"""
return self().ptype

@property
def relative_uncertainty(self) -> Optional[bool]:
"""Return the relative_uncertainty of the parameter (cominal condition)"""
return self().relative_uncertainty

@property
def fit_limits(self) -> Optional[Tuple[float, float]]:
"""Return the fit_limits of the parameter (cominal condition)"""
return self().fit_limits

def __eq__(self, other: object) -> bool:
"""Return True if all attributes are equal."""
if isinstance(other, ConditionalParameter):
return all(getattr(self, k) == getattr(other, k) for k in self.__dict__)
return False

def value_in_fit_limits(self, value: float) -> bool:
"""Returns True if value under cominal condition is within fit_limits."""
return self().value_in_fit_limits(value)

def __call__(self, **kwargs) -> Parameter:
if self.conditioning_name in kwargs:
cond_val = kwargs[self.conditioning_name]
elif self.conditioning_param is not None:
cond_val = self.conditioning_param.nominal_value
else:
err_msg = (
f"Conditioning parameter '{self.conditioning_name}' is missing. Can't fall back to "
"nominal value because conditioning parameter it is not set. "
)
raise ValueError(err_msg)
# check if the conditioning value is in the conditions dictionary
if cond_val not in self.conditions_dict:
raise ValueError(
f"Conditioning value '{cond_val}' not found in the conditions dictionary."
+ f"Available values are: {sorted(list(self.conditions_dict.keys()))}"
)
return Parameter(name=self.name, **self.conditions_dict[cond_val])


class Parameters:
"""Represents a collection of parameters.
Expand Down Expand Up @@ -246,9 +374,17 @@ def from_config(cls, config: Dict[str, dict]):
"""
parameters = cls()
parameter: Union[Parameter, ConditionalParameter]
for name, param_config in config.items():
parameter = Parameter(name=name, **param_config)
if "conditioning_parameter_name" in param_config:
parameter = ConditionalParameter(name, **param_config)
else:
parameter = Parameter(name=name, **param_config)
parameters.add_parameter(parameter)
# set conditioning parameters
for param in parameters:
if isinstance(param, ConditionalParameter):
param.conditioning_param = parameters[param.conditioning_name]
return parameters

@classmethod
Expand Down Expand Up @@ -279,7 +415,15 @@ def __str__(self) -> str:
"""Return an overview table of all parameters."""
par_list = []
for p in self:
par_dict = {}
if isinstance(p, ConditionalParameter):
par_dict = {
"conditioning_name": p.conditioning_name,
"conditions": sorted(p.conditions_dict.keys()),
}
# get nominal-condition parameter
p = p()
else:
par_dict = {}
for k, v in p.__dict__.items():
# replace hidden attributes with non-hidden properties
if k.startswith("_"):
Expand All @@ -295,7 +439,7 @@ def __str__(self) -> str:

return df.to_string()

def add_parameter(self, parameter: Parameter) -> None:
def add_parameter(self, parameter: Union[Parameter, ConditionalParameter]) -> None:
"""Adds a Parameter object to the Parameters collection.
Args:
Expand Down Expand Up @@ -355,7 +499,8 @@ def uncertainties(self) -> dict:
def with_uncertainty(self) -> "Parameters":
"""Return parameters with a not-NaN uncertainty.
The parameters are the same objects as in the original Parameters object, not a copy.
The parameters are the same objects as in the original Parameters object, not a copy. For
conditional parameters, the parameters under the nominal condition are returned.
"""
param_dict = {k: i for k, i in self.parameters.items() if i.uncertainty is not None}
Expand Down Expand Up @@ -391,6 +536,11 @@ def set_fit_guesses(self, **fit_guesses):
for name, value in fit_guesses.items():
self.parameters[name].fit_guess = value

def _evaluate_parameter(self, parameter: Parameter, **kwargs):
if isinstance(parameter, ConditionalParameter):
return parameter(**kwargs)
return parameter

def __call__(
self, return_fittable: Optional[bool] = False, **kwargs: Optional[Dict]
) -> Dict[str, float]:
Expand Down Expand Up @@ -419,6 +569,7 @@ def __call__(
raise ValueError(f"Parameter '{name}' not found.")

for name, param in self.parameters.items():
param = self._evaluate_parameter(param, **kwargs)
new_val = kwargs.get(name, None)
if param.needs_reinit and new_val != param.nominal_value and new_val is not None:
raise ValueError(
Expand Down Expand Up @@ -491,6 +642,8 @@ def values_in_fit_limits(self, **kwargs: Dict) -> bool:
bool: True if all values are within the fit limits.
"""
return all(
self.parameters[name].value_in_fit_limits(value) for name, value in kwargs.items()
)
for name, value in kwargs.items():
param = self._evaluate_parameter(self.parameters[name], **kwargs)
if not param.value_in_fit_limits(value):
return False
return True
Loading

0 comments on commit a6da2f5

Please sign in to comment.