From 31a2291f46508b5f49adba874d2e7e76d463c887 Mon Sep 17 00:00:00 2001 From: dachengx Date: Wed, 18 Sep 2024 09:08:21 -0500 Subject: [PATCH] Add `spectrum_axis` configuration for `SpectrumTemplateSource` --- ...nbinned_wimp_statistical_model_template_source_test.yaml | 1 + alea/template_source.py | 6 +++--- alea/utils.py | 6 ++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/alea/examples/configs/unbinned_wimp_statistical_model_template_source_test.yaml b/alea/examples/configs/unbinned_wimp_statistical_model_template_source_test.yaml index 5c3dce8..b7862b0 100644 --- a/alea/examples/configs/unbinned_wimp_statistical_model_template_source_test.yaml +++ b/alea/examples/configs/unbinned_wimp_statistical_model_template_source_test.yaml @@ -126,4 +126,5 @@ likelihood_config: - signal_efficiency template_filename: wimp50gev_template.ii.h5 spectrum_name: test_cs1_spectrum.json + spectrum_axis: 1 efficiency_name: signal_efficiency diff --git a/alea/template_source.py b/alea/template_source.py index e8eef18..b502358 100644 --- a/alea/template_source.py +++ b/alea/template_source.py @@ -373,14 +373,14 @@ def build_histogram(self): if "spectrum_name" not in self.config: raise ValueError("spectrum_name not in config") + if "spectrum_axis" not in self.config: + raise ValueError("spectrum_axis not in config") spectrum = self._get_json_spectrum( self.config["spectrum_name"].format(**self.format_named_parameters) ) - # Perform scaling, the first axis is assumed to be reweighted # The spectrum is assumed to be probability density (in per the unit of first axis). - axis = 0 - # h = h.normalize(axis=axis) + axis = self.config["spectrum_axis"] bin_edges = h.bin_edges[axis] bin_centers = h.bin_centers(axis=axis) slices = [None] * h.histogram.ndim diff --git a/alea/utils.py b/alea/utils.py index efdcb65..ed076a3 100644 --- a/alea/utils.py +++ b/alea/utils.py @@ -206,6 +206,12 @@ def dump_yaml(file_name: str, data: dict): yaml.safe_dump(data, file) +def dump_json(file_name: str, data: dict): + """Dump data to a json file.""" + with open(file_name, "w") as file: + json.dump(data, file, indent=4) + + def _get_abspath(file_name): """Get the abspath of the file.