From d53c613829816a38cb9d7ec06b5ed75221e911cf Mon Sep 17 00:00:00 2001 From: Toby Dixon Date: Wed, 8 May 2024 15:27:51 +0100 Subject: [PATCH] Cross talk correction code for `build_evt()` (#572) --------- Co-authored-by: ggmarshall Co-authored-by: Luigi Pertoldi --- CITATION.cff | 3 + pyproject.toml | 1 + src/pygama/evt/aggregators.py | 51 +-- src/pygama/evt/build_evt.py | 2 + src/pygama/evt/modules/geds.py | 180 +++++++++- src/pygama/evt/modules/xtalk.py | 362 +++++++++++++++++++++ src/pygama/hit/build_hit.py | 42 +++ src/pygama/math/distributions.py | 4 +- tests/evt/configs/basic-evt-config.yaml | 3 + tests/evt/configs/vov-test-evt-config.json | 3 + tests/evt/test_geds.py | 0 tests/evt/test_xtalk.py | 88 +++++ 12 files changed, 699 insertions(+), 40 deletions(-) create mode 100644 src/pygama/evt/modules/xtalk.py create mode 100644 tests/evt/test_geds.py create mode 100644 tests/evt/test_xtalk.py diff --git a/CITATION.cff b/CITATION.cff index 1a416f427..5f2abe390 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,6 +19,9 @@ authors: - family-names: Marshall given-names: George orcid: https://orcid.org/0000-0002-5470-5132 + - family-names: Dixon + given-names: Toby + orcid: https://orcid.org/0000-0001-8787-6336 - family-names: D'Andrea given-names: Valerio orcid: https://orcid.org/0000-0003-2037-4133 diff --git a/pyproject.toml b/pyproject.toml index f98d87b7d..4f38b3582 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "iminuit", "legend-daq2lh5>=1.2.1", "legend-pydataobj>=1.6", + "pylegendmeta>=0.9", "matplotlib", "numba!=0.53.*,!=0.54.*,!=0.57", "numpy>=1.21", diff --git a/src/pygama/evt/aggregators.py b/src/pygama/evt/aggregators.py index c9adee29b..33044be48 100644 --- a/src/pygama/evt/aggregators.py +++ b/src/pygama/evt/aggregators.py @@ -238,6 +238,7 @@ def evaluate_to_scalar( def evaluate_at_channel( datainfo, tcm, + channels, channels_skip, expr, field_list, @@ -253,6 +254,8 @@ def evaluate_at_channel( input and output LH5 datainfo with HDF5 groups where tables are found. tcm TCM data arrays in an object that can be accessed by attribute. + channels + list of channels to be included for evaluation. channels_skip list of channels to be skipped from evaluation and set to default value. expr @@ -281,7 +284,7 @@ def evaluate_at_channel( evt_ids_ch = np.searchsorted( tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) - if table_name not in channels_skip: + if (table_name in channels) and (table_name not in channels_skip): res = utils.get_data_at_channel( datainfo=datainfo, ch=table_name, @@ -307,6 +310,7 @@ def evaluate_at_channel_vov( expr, field_list, ch_comp, + channels, channels_skip, pars_dict=None, default_value=np.nan, @@ -326,6 +330,8 @@ def evaluate_at_channel_vov( list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. ch_comp array of "rawid"s at which the expression is evaluated. + channels + list of channels to be included for evaluation. channels_skip list of channels to be skipped from evaluation and set to default value. pars_dict @@ -335,20 +341,19 @@ def evaluate_at_channel_vov( """ f = utils.make_files_config(datainfo) - # blow up vov to aoesa - out = ak.Array([[] for _ in range(len(ch_comp))]) + ch_comp_channels = np.unique(ch_comp.flattened_data.nda).astype(int) - channels = np.unique(ch_comp.flattened_data.nda).astype(int) - ch_comp = ch_comp.view_as("ak") + out = np.full( + len(ch_comp.flattened_data.nda), default_value, dtype=type(default_value) + ) type_name = None - for ch in channels: + for ch in ch_comp_channels: table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch) - evt_ids_ch = np.searchsorted( tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) - if table_name not in channels_skip: + if (table_name in channels) and (table_name not in channels_skip): res = utils.get_data_at_channel( datainfo=datainfo, ch=table_name, @@ -357,23 +362,27 @@ def evaluate_at_channel_vov( field_list=field_list, pars_dict=pars_dict, ) - else: - idx_ch = tcm.idx[tcm.id == ch] - res = np.full(len(idx_ch), default_value) - - # see in which events the current channel is present - mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False) - cv = np.full(len(ch_comp), np.nan) - cv[evt_ids_ch] = res - cv[~mask] = np.nan - cv = ak.drop_none(ak.nan_to_none(ak.Array(cv)[:, None])) + new_evt_ids_ch = np.searchsorted( + ch_comp.cumulative_length, + np.where(ch_comp.flattened_data.nda == ch)[0], + "right", + ) + matches = np.isin(evt_ids_ch, new_evt_ids_ch) + out[ch_comp.flattened_data.nda == ch] = res[matches] - out = ak.concatenate((out, cv), axis=-1) + else: + length = len(np.where(ch_comp.flattened_data.nda == ch)[0]) + res = np.full(length, default_value) + out[ch_comp.flattened_data.nda == ch] = res - if ch == channels[0]: + if ch == ch_comp_channels[0]: + out = out.astype(res.dtype) type_name = res.dtype - return types.VectorOfVectors(ak.values_astype(out, type_name)) + return types.VectorOfVectors( + flattened_data=types.Array(out, dtype=type_name), + cumulative_length=ch_comp.cumulative_length, + ) def evaluate_to_aoesa( diff --git a/src/pygama/evt/build_evt.py b/src/pygama/evt/build_evt.py index 3620dd373..66be91bf6 100644 --- a/src/pygama/evt/build_evt.py +++ b/src/pygama/evt/build_evt.py @@ -497,6 +497,7 @@ def evaluate_expression( return aggregators.evaluate_at_channel( datainfo=datainfo, tcm=tcm, + channels=channels, channels_skip=channels_skip, expr=expr, field_list=field_list, @@ -512,6 +513,7 @@ def evaluate_expression( expr=expr, field_list=field_list, ch_comp=ch_comp, + channels=channels, channels_skip=channels_skip, pars_dict=pars_dict, default_value=default_value, diff --git a/src/pygama/evt/modules/geds.py b/src/pygama/evt/modules/geds.py index 131cdeea8..9efe43dc9 100644 --- a/src/pygama/evt/modules/geds.py +++ b/src/pygama/evt/modules/geds.py @@ -4,10 +4,12 @@ from collections.abc import Sequence +import awkward as ak import numpy as np from lgdo import lh5, types from .. import utils +from . import xtalk def apply_recovery_cut( @@ -26,7 +28,7 @@ def apply_recovery_cut( is_recovering = is_recovering | np.where( ( ((timestamps.nda - tstamp) < time_window) - & ((timestamps.nda - tstamp) > 0) + & ((timestamps.nda - tstamp) >= 0) ), True, False, @@ -41,33 +43,177 @@ def apply_xtalk_correction( tcm: utils.TCMData, table_names: Sequence[str], *, - energy_observable: types.VectorOfVectors, - rawids: types.VectorOfVectors, + return_mode: str, + uncal_energy_expr: str, + cal_energy_expr: str, + multiplicity_expr: str, + xtalk_threshold: float = None, + xtalk_matrix_filename: str = "", + xtalk_rawid_obj: str = "xtc/rawid_index", + xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative", + positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive", +) -> types.VectorOfVectors: + """Applies the cross-talk correction to the energy observable. + The format of `xtalk_matrix_filename` should be currently be a path to a lh5 file. + + The correction is applied using matrix algebra for all triggers above the threshold. + + Parameters + ---------- + datainfo, tcm, table_names + positional arguments automatically supplied by :func:`.build_evt`. + return_mode + string which can be either energy to return corrected energy or tcm_index + uncal_energy_expr + expression for the pulse parameter to be gathered for the uncalibrated energy (used for correction), + can be a combination of different fields. + cal_energy_expr + expression for the pulse parameter to be gathered for the calibrated energy, used for the xtalk threshold, + can be a combination of different fields. + xtalk_threshold + threshold used for xtalk correction, hits below this energy will not + be used to correct the other hits. + xtalk_matrix_filename + name of the file containing the xtalk matrices. + xtalk_matrix_obj + name of the lh5 object containing the xtalk matrix + positive_xtalk_matrix_obj + name of the lh5 object containing the positive polarity xtalk matrix + xtalk_rawids_obj + name of the lh5 object containing the name of the rawids + """ + + xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_obj, xtalk_matrix_filename, "np") + tcm_index_array = xtalk.build_tcm_index_array(tcm, datainfo, xtalk_matrix_rawids) + + energy_corr = xtalk.get_xtalk_correction( + tcm, + datainfo, + uncal_energy_expr, + cal_energy_expr, + xtalk_threshold, + xtalk_matrix_filename, + xtalk_rawid_obj, + xtalk_matrix_obj, + positive_xtalk_matrix_obj, + ) + + multiplicity_mask = xtalk.filter_hits( + datainfo, + tcm, + multiplicity_expr, + energy_corr, + xtalk_matrix_rawids, + ) + energy_corr = ak.from_regular(energy_corr) + multiplicity_mask = ak.from_regular(multiplicity_mask) + tcm_index_array = ak.from_regular(tcm_index_array) + + if return_mode == "energy": + return types.VectorOfVectors(energy_corr[multiplicity_mask]) + elif return_mode == "tcm_index": + return types.VectorOfVectors(tcm_index_array[multiplicity_mask]) + else: + raise ValueError(f"Unknown mode: {return_mode}") + + +def apply_xtalk_correction_and_calibrate( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + return_mode: str, + uncal_energy_expr: str, + cal_energy_expr: str, + cal_par_files: str | Sequence[str], + multiplicity_expr: str, xtalk_matrix_filename: str, + xtalk_threshold: float = None, + xtalk_rawid_obj: str = "xtc/rawid_index", + xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative", + positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive", + uncal_var: str = "dsp.cuspEmax", + recal_var: str = "hit.cuspEmax_ctc_cal", ) -> types.VectorOfVectors: """Applies the cross-talk correction to the energy observable. - The format of `xtalk_matrix_filename` should be... + The correction is applied using matrix algebra for all triggers above the + xalk threshold. Parameters ---------- datainfo, tcm, table_names positional arguments automatically supplied by :func:`.build_evt`. - energy_observable - array of energy values to correct, one event per row. The detector - identifier is stored in `rawids`, which has the same layout. - rawids - array of detector identifiers for each energy in `energy_observable`. + return_mode + string which can be either ``energy`` to return corrected energy or + ``tcm_index``. + uncal_energy_expr + expression for the pulse parameter to be gathered for the uncalibrated + energy (used for correction), can be a combination of different fields. + cal_energy_expr + expression for the pulse parameter to be gathered for the calibrated + energy, used for the xtalk threshold, can be a combination of different + fields. + cal_par_files + path to the generated hit tier par-files defining the calibration + curves. Used to recalibrate the data after xtalk correction. + multiplicity_expr: + expression defining the logic used to compute the event multiplicity. + xtalk_threshold + threshold used for xtalk correction, hits below this energy will not be + used to correct the other hits. xtalk_matrix_filename - name of the file containing the cross-talk matrices. + path to the file containing the xtalk matrices. + xtalk_matrix_obj + name of the lh5 object containing the xtalk matrix. + positive_xtalk_matrix_obj + name of the lh5 object containing the positive polarity xtalk matrix. + xtalk_matrix_rawids + name of the lh5 object containing the name of the rawids. + recal_var + name of the energy variable to use for re-calibration. """ - # read in xtalk matrices - lh5.read_as("", xtalk_matrix_filename, "ak") - # do the correction - energies_corr = ... + xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_obj, xtalk_matrix_filename, "np") + tcm_index_array = xtalk.build_tcm_index_array(tcm, datainfo, xtalk_matrix_rawids) - # return the result as LGDO - return types.VectorOfVectors( - energies_corr, attrs=utils.copy_lgdo_attrs(energy_observable) + energy_corr = xtalk.get_xtalk_correction( + tcm, + datainfo, + uncal_energy_expr, + cal_energy_expr, + xtalk_threshold, + xtalk_matrix_filename, + xtalk_rawid_obj, + xtalk_matrix_obj, + positive_xtalk_matrix_obj, + ) + + calibrated_corr = xtalk.calibrate_energy( + datainfo, + tcm, + energy_corr, + xtalk_matrix_rawids, + cal_par_files, + uncal_var, + recal_var, ) + + multiplicity_mask = xtalk.filter_hits( + datainfo, + tcm, + multiplicity_expr, + calibrated_corr, + xtalk_matrix_rawids, + ) + + calibrated_corr = ak.from_regular(calibrated_corr) + multiplicity_mask = ak.from_regular(multiplicity_mask) + tcm_index_array = ak.from_regular(tcm_index_array) + + if return_mode == "energy": + return types.VectorOfVectors(calibrated_corr[multiplicity_mask]) + elif return_mode == "tcm_index": + return types.VectorOfVectors(tcm_index_array[multiplicity_mask]) + else: + raise ValueError(f"Unknown mode: {return_mode}") diff --git a/src/pygama/evt/modules/xtalk.py b/src/pygama/evt/modules/xtalk.py new file mode 100644 index 000000000..85b9425d4 --- /dev/null +++ b/src/pygama/evt/modules/xtalk.py @@ -0,0 +1,362 @@ +""" +Module for cross talk correction of energies. +""" + +from __future__ import annotations + +import awkward as ak +import numpy as np +from legendmeta.catalog import Props +from lgdo import lh5, types +from numpy.typing import ArrayLike + +from pygama.hit.build_hit import _remove_uneeded_operations, _reorder_table_operations + +from .. import utils + + +def build_tcm_index_array( + tcm: utils.TCMData, datainfo: utils.DataInfo, rawids: np.ndarray +) -> np.ndarray: + """Builds a TCM index array for use in the event tier. + + Parameters + ---------- + datainfo + :class:`.DataInfo` object. + tcm + time-coincidence map object. + rawids + list of channel rawids from the cross talk matrix. + """ + + # initialise the output object + tcm_indexs_out = np.full((len(tcm.cumulative_length), len(rawids)), np.nan) + + # parse observables string. default to hit tier + for idx_chan, channel in enumerate(rawids): + + # get the event indexes + table_id = utils.get_tcm_id_by_pattern( + datainfo._asdict()["dsp"].table_fmt, f"ch{channel}" + ) + tcm_indexs = np.where(tcm.id == table_id)[0] + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) + tcm_indexs_out[evt_ids_ch, idx_chan] = tcm_indexs + + # transpose to return object where row is events and column rawid idx + return tcm_indexs_out + + +def gather_energy( + observable: str, tcm: utils.TCMData, datainfo: utils.DataInfo, rawids: ArrayLike +) -> ArrayLike: + """Prepares the array of energies for the cross talk correction. + + Parameters + ---------- + observable + expression for the pulse parameter to be gathered, can be a combination + of different fields. + datainfo + :class:`.DataInfo` object + tcm + time-coincidence map object. + rawids + list of channel rawids from the cross talk matrix. + """ + + # replace group. with group___ + for tier in datainfo._asdict(): + group = datainfo._asdict()[tier].group + observable = observable.replace(f"{group}.", f"{group}___") + + observable = observable.replace(".", "__") + + c = compile(observable, "gcc -O3 -ffast-math build_hit.py", "eval") + + tier_params = [] + for name in c.co_names: + if "___" in name: + tier, column = name.split("___") + group = datainfo._asdict()[tier].group + file = datainfo._asdict()[tier].file + if (name, file, group, column) not in tier_params: + tier_params.append((name, file, group, column)) + + # initialise the output object + energy_out = np.full((len(tcm.cumulative_length), len(rawids)), np.nan) + + for idx_chan, channel in enumerate(rawids): + tbl = types.Table() + hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) + + for name, file, group, column in tier_params: + try: + # read the energy data + data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx) + tbl.add_column(name, data) + except (lh5.exceptions.LH5DecodeError, KeyError): + tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan))) + + res = tbl.eval(observable) + energy_out[evt_ids_ch, idx_chan] = res.nda + + return energy_out + + +def filter_hits( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + filter_expr: str, + xtalk_corr_energy: ArrayLike, + rawids: ArrayLike, +) -> ArrayLike: + """Function that removes hits in an event below threshold. + + Parameters: + ----------- + datainfo, tcm + :class:`.DataInfo` and :class:`.TCMData` objects. + filter_expr + string containing the logic used to define which events are above + threshold. this string can also refer to the corrected energy as + `xtalk_corr_energy`. + xtalk_corr_energy + 2D numpy array of correct energy, the row corresponds to the event and + the column the rawid. + rawids + 1D array of the rawids corresponding to each column. + """ + + # find the fields in the string + mask = np.full_like(xtalk_corr_energy, False, dtype=bool) + + # replace group. with group___ + for tier in datainfo._asdict(): + group = datainfo._asdict()[tier].group + filter_expr = filter_expr.replace(f"{group}.", f"{group}___") + + c = compile(filter_expr, "toby dixon was here", "eval") + + tier_params = [] + for name in c.co_names: + if "___" in name: + tier, column = name.split("___") + group = datainfo._asdict()[tier].group + file = datainfo._asdict()[tier].file + if (name, file, group, column) not in tier_params: + tier_params.append((name, file, group, column)) + + for idx_chan, channel in enumerate(rawids): + tbl = types.Table() + + hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) + + for name, file, group, column in tier_params: + try: + # read the energy data + data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx) + + tbl.add_column(name, data) + except (lh5.exceptions.LH5DecodeError, KeyError): + tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan))) + + # add the corrected energy to the table + tbl.add_column( + "xtalk_corr_energy", types.Array(xtalk_corr_energy[evt_ids_ch, idx_chan]) + ) + res = tbl.eval(filter_expr) + mask[evt_ids_ch, idx_chan] = res.nda + + return mask + + +def xtalk_correct_energy_impl( + uncal_energy: ArrayLike, + cal_energy: ArrayLike, + xtalk_matrix: ArrayLike, + xtalk_threshold: float = None, +): + r"""Function to perform the actual xtalk correction of energy. + + 1. The energies are converted to a sparse format where each row corresponds + to a rawid + 2. All energy less than the threshold are set to 0 + 3. The correction is computed as: + + .. math:: + + E_{\text{cor},i} = -\times M_{i,j}E_{j} + + where $M_{i,j}$ is the cross talk matrix element where $i$ is response and + $j$ trigger channel. + + Parameters + ---------- + uncal_energy + 2D numpy array of the uncalibrated energies in each event, the row + corresponds to an event and the column the rawid. + cal_energy + 2D numpy array of the calibrated energies in each event, the row + corresponds to an event and the column the rawid. + xtalk_matrix + 2D numpy array of the cross talk correction matrix, the indices should + correspond to rawids (with same mapping as energies). + xtalk_threshold + threshold below which a hit is not used in xtalk correction. + + """ + # check input shapes and sizes + uncal_energy_no_nan = np.nan_to_num(uncal_energy, 0) + cal_energy_no_nan = np.nan_to_num(cal_energy, 0) + + if xtalk_threshold is not None: + uncal_energy_with_threshold = np.where( + cal_energy_no_nan < xtalk_threshold, 0, uncal_energy_no_nan + ) + else: + uncal_energy_with_threshold = uncal_energy_no_nan + energy_correction = -np.matmul(xtalk_matrix, uncal_energy_with_threshold.T).T + return uncal_energy_no_nan + energy_correction + + +def get_xtalk_correction( + tcm: utils.DataInfo, + datainfo: utils.DataInfo, + uncal_energy_expr: str, + cal_energy_expr: str, + xtalk_threshold: float = None, + xtalk_matrix_filename: str = "", + xtalk_rawid_obj: str = "xtc/rawid_index", + xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative", + positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive", +): + + # read lh5 files to numpy + xtalk_matrix_numpy = lh5.read_as(xtalk_matrix_obj, xtalk_matrix_filename, "np") + xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_obj, xtalk_matrix_filename, "np") + + positive_xtalk_matrix_numpy = lh5.read_as( + positive_xtalk_matrix_obj, xtalk_matrix_filename, "np" + ) + + # Combine positive and negative matrixs + # Now the matrix should have negative values corresponding to negative cross talk + # and positive values corresponding to positive cross talk . + # we also set nan to 0 and we transpose so that the row corresponds to response and column trigger + xtalk_matrix = np.nan_to_num( + np.where( + abs(xtalk_matrix_numpy) > abs(positive_xtalk_matrix_numpy), + xtalk_matrix_numpy, + positive_xtalk_matrix_numpy, + ), + 0, + ).T + + uncal_energy_array = gather_energy( + uncal_energy_expr, tcm, datainfo, xtalk_matrix_rawids + ) + cal_energy_array = gather_energy( + cal_energy_expr, tcm, datainfo, xtalk_matrix_rawids + ) + + energy_corr = xtalk_correct_energy_impl( + uncal_energy_array, cal_energy_array, xtalk_matrix, xtalk_threshold + ) + return energy_corr + + +def calibrate_energy( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + energy_corr: np.ndarray, + xtalk_matrix_rawids: np.ndarray, + par_files: str | list[str], + uncal_energy_var: str = None, + recal_energy_var: str = None, +): + """Function to recalibrate the energy after xtalk correction. + + Parameters + --------- + datainfo + :class:`.DataInfo` object. + tcm + :class:`.TCMData` object. + energy_corr + cross talk corrected (uncal) energies to be recalibrated. + par_files + path to the parameter files. + uncal_energy_var + name of the uncalibrated energy variable. + recal_energy_var + variable to be used for recalibration. + """ + + out_arr = np.full_like(energy_corr, np.nan) + par_dicts = Props.read_from(par_files) + pars = { + chan: chan_dict["pars"]["operations"] for chan, chan_dict in par_dicts.items() + } + + p = uncal_energy_var.split(".") + tier = p[0] if len(p) > 1 else "dsp" + + table_fmt = datainfo._asdict()[tier].table_fmt + file = datainfo._asdict()[tier].file + + for i, chan in enumerate(xtalk_matrix_rawids): + try: + cfg = pars[f"ch{chan}"] + cfg, chan_inputs = _remove_uneeded_operations( + _reorder_table_operations(cfg), recal_energy_var.split(".")[-1] + ) + + chan_inputs.remove(uncal_energy_var.split(".")[-1]) + + # get the event indices + table_id = utils.get_tcm_id_by_pattern(table_fmt, f"ch{chan}") + + hit_idx = ak.to_numpy(tcm.idx[tcm.id == table_id]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == table_id)[0], + "right", + ) + + # read the dsp data + outtbl_obj = lh5.read( + f"ch{chan}/dsp/", file, idx=hit_idx, field_mask=chan_inputs + ) + + # add the uncalibrated energy to the table + outtbl_obj.add_column( + uncal_energy_var.split(".")[-1], + types.Array(energy_corr[evt_ids_ch, i]), + ) + + for outname, info in cfg.items(): + outcol = outtbl_obj.eval( + info["expression"], info.get("parameters", None) + ) + outtbl_obj.add_column(outname, outcol) + out_arr[evt_ids_ch, i] = outtbl_obj[recal_energy_var.split(".")[-1]].nda + except KeyError: + out_arr[:, i] = np.nan + + return out_arr diff --git a/src/pygama/hit/build_hit.py b/src/pygama/hit/build_hit.py index bb8c4df5a..0cb8c5e9a 100644 --- a/src/pygama/hit/build_hit.py +++ b/src/pygama/hit/build_hit.py @@ -259,3 +259,45 @@ def _one_pass(config): return new else: current = new + + +def _get_dependencies(config, par, pars=None): + """ + Recursive func to iterate back through tree of input blocks for a given output block + """ + if pars is None: + pars = [] + par_op = config[par] + c = compile(par_op["expression"], "gcc -O3 -ffast-math build_hit.py", "eval") + for p in c.co_names: + if p in par_op["parameters"]: + pass + else: + pars.append(p) + if p in config: + pars = _get_dependencies(config, p, pars) + return pars + + +def _remove_uneeded_operations(config, outpars): + """ + Function that removes any operations not needed to generate outpars from the config dictionary + Returns the config without these blocks as well as a list of input keys from the dsp file + needed to generate outpars + """ + if not isinstance(outpars, list): + outpars = [outpars] + dependent_keys = [*outpars] + inkeys = [] + for par in outpars: + pars = _get_dependencies(config, par) + for p in pars: + if p in config and p not in dependent_keys: + dependent_keys.append(p) + elif p not in config and p not in inkeys: + inkeys.append(p) + + for key in list(config): + if key not in dependent_keys: + config.pop(key) + return config, inkeys diff --git a/src/pygama/math/distributions.py b/src/pygama/math/distributions.py index 4cff70881..74271234f 100644 --- a/src/pygama/math/distributions.py +++ b/src/pygama/math/distributions.py @@ -5,14 +5,14 @@ # nopycln: file +from pygama.math.functions.crystal_ball import crystal_ball # noqa: F401 from pygama.math.functions.crystal_ball import ( # noqa: F401 - crystal_ball, nb_crystal_ball_cdf, nb_crystal_ball_pdf, ) from pygama.math.functions.error_function import nb_erf, nb_erfc # noqa: F401 +from pygama.math.functions.exgauss import exgauss # noqa: F401 from pygama.math.functions.exgauss import ( # noqa: F401 - exgauss, nb_exgauss_cdf, nb_exgauss_pdf, nb_gauss_tail_approx, diff --git a/tests/evt/configs/basic-evt-config.yaml b/tests/evt/configs/basic-evt-config.yaml index bf229504e..08df66fff 100644 --- a/tests/evt/configs/basic-evt-config.yaml +++ b/tests/evt/configs/basic-evt-config.yaml @@ -72,14 +72,17 @@ operations: expression: hit.cuspEmax_ctc_cal initial: 0 is_usable_aoe: + channels: geds_on aggregation_mode: keep_at_ch:evt.energy_id expression: "True" initial: false aoe: + channels: geds_on aggregation_mode: keep_at_ch:evt.energy_id expression: hit.AoE_Classifier initial: np.nan is_aoe_rejected: + channels: geds_on aggregation_mode: keep_at_ch:evt.energy_id expression: ~(hit.AoE_Double_Sided_Cut) initial: false diff --git a/tests/evt/configs/vov-test-evt-config.json b/tests/evt/configs/vov-test-evt-config.json index 6de44075b..1003f40d5 100644 --- a/tests/evt/configs/vov-test-evt-config.json +++ b/tests/evt/configs/vov-test-evt-config.json @@ -55,10 +55,12 @@ "initial": 0 }, "aoe": { + "channels": "geds_on", "aggregation_mode": "keep_at_ch:evt.energy_id", "expression": "hit.AoE_Classifier" }, "aoe_idx": { + "channels": "geds_on", "aggregation_mode": "keep_at_idx:evt.energy_idx", "expression": "hit.AoE_Classifier" }, @@ -71,6 +73,7 @@ "dtype": "int16" }, "is_saturated": { + "channels": "geds_on", "aggregation_mode": "keep_at_ch:evt.energy_id", "expression": "hit.is_saturated" }, diff --git a/tests/evt/test_geds.py b/tests/evt/test_geds.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/evt/test_xtalk.py b/tests/evt/test_xtalk.py new file mode 100644 index 000000000..c1e4e741b --- /dev/null +++ b/tests/evt/test_xtalk.py @@ -0,0 +1,88 @@ +from pathlib import Path + +import numpy as np +import pytest +from lgdo import lh5 + +from pygama.evt import utils +from pygama.evt.modules import xtalk + +config_dir = Path(__file__).parent / "configs" + + +@pytest.fixture(scope="module") +def files_config(lgnd_test_data, tmptestdir): + tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" + outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" + + return { + "tcm": (lgnd_test_data.get_path(tcm_path), "hardware_tcm_1"), + "dsp": (lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), "dsp", "ch{}"), + "hit": (lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), "hit", "ch{}"), + "evt": (outfile, "evt"), + } + + +def test_xtalk_corrected_energy(lgnd_test_data, files_config): + + energy = np.array([[1, 2, 3], [4, 5, 6], [2, 0, 1], [0, 1, 0]]) + matrix = np.array([[0, 0, 1], [1, 0, 2], [0, 2, 0]]) + energy_corrected_zero_threshold = xtalk.xtalk_correct_energy_impl( + energy, energy, matrix, None + ) + + assert np.all( + energy_corrected_zero_threshold + == (energy - np.array([[3, 7, 4], [6, 16, 10], [1, 4, 0], [0, 0, 2]])) + ) + + # test with a 2.1 threshold + energy_corrected_two_threshold = xtalk.xtalk_correct_energy_impl( + energy, energy, matrix, 2.1 + ) + assert np.all( + energy_corrected_two_threshold + == (energy - np.array([[3, 6, 0], [6, 16, 10], [0, 0, 0], [0, 0, 0]])) + ) + + +def test_gather_energy(lgnd_test_data, files_config): + f = utils.make_files_config(files_config) + tcm = utils.TCMData( + id=lh5.read_as(f"/{f.tcm.group}/array_id", f.tcm.file, library="np"), + idx=lh5.read_as(f"/{f.tcm.group}/array_idx", f.tcm.file, library="np"), + cumulative_length=lh5.read_as( + f"/{f.tcm.group}/cumulative_length", f.tcm.file, library="np" + ), + ) + energy = xtalk.gather_energy( + "hit.cuspEmax_ctc_cal", tcm, f, np.array([1084803, 1084804]) + ) + n_rows = np.max(tcm.idx) + 1 + assert isinstance(energy, np.ndarray) + assert energy.ndim == 2 + assert np.shape(energy) == (n_rows, 2) + + +def test_filter_hits(lgnd_test_data, files_config): + f = utils.make_files_config(files_config) + tcm = utils.TCMData( + id=lh5.read_as(f"/{f.tcm.group}/array_id", f.tcm.file, library="np"), + idx=lh5.read_as(f"/{f.tcm.group}/array_idx", f.tcm.file, library="np"), + cumulative_length=lh5.read_as( + f"/{f.tcm.group}/cumulative_length", f.tcm.file, library="np" + ), + ) + n_rows = np.max(tcm.idx) + 1 + + filter = xtalk.filter_hits( + f, + tcm, + "hit.cuspEmax_ctc_cal>5", + np.zeros((n_rows, 2)), + np.array([1084803, 1084804]), + ) + + assert isinstance(filter, np.ndarray) + assert filter.ndim == 2 + assert np.shape(filter) == (n_rows, 2)