From a9006f1c83fa68abd28c5e470a405dd861e933c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 06:51:17 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strax/processing/statistics.py | 45 ++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/strax/processing/statistics.py b/strax/processing/statistics.py index 54e6122a..64c9ab5f 100644 --- a/strax/processing/statistics.py +++ b/strax/processing/statistics.py @@ -12,7 +12,9 @@ @register_jitable def _compute_hdr_core(data, fractions_desired, only_upper_part=False, _buffer_size=10): """Core computation for highest density region. + Returns the data needed for interval computation and the result arrays. + """ fi = 0 # number of fractions seen res = np.zeros((len(fractions_desired), 2, _buffer_size), dtype=np.int32) @@ -29,49 +31,54 @@ def _compute_hdr_core(data, fractions_desired, only_upper_part=False, _buffer_si max_to_min = stable_argsort(data)[::-1] return max_to_min, area_tot, res, res_amp, fi + @export def _process_hdr_intervals(ind, gaps, fi, res, g0, _buffer_size): """Process the intervals for highest density region. + This function handles the stable sorting part outside of numba. + """ if len(gaps) > _buffer_size: res[fi, 0, :] = -1 res[fi, 1, :] = -1 return fi + 1, res - + g_ind = -1 for g_ind, g in enumerate(gaps): interval = ind[g0:g] res[fi, 0, g_ind] = interval[0] res[fi, 1, g_ind] = interval[-1] + 1 g0 = g - + # Last interval interval = ind[g0:] res[fi, 0, g_ind + 1] = interval[0] res[fi, 1, g_ind + 1] = interval[-1] + 1 return fi + 1, res + @export @register_jitable def highest_density_region(data, fractions_desired, only_upper_part=False, _buffer_size=10): """Computes for a given sampled distribution the highest density region of the desired fractions. Does not assume anything on the normalisation of the data. - + :param data: Sampled distribution - :param fractions_desired: numpy.array Area/probability for which - the hdr should be computed. - :param _buffer_size: Size of the result buffer. The size is - equivalent to the maximal number of allowed intervals. - :param only_upper_part: Boolean, if true only computes - area/probability between maximum and current height. - :return: two arrays: The first one stores the start and inclusive - endindex of the highest density region. The second array holds - the amplitude for which the desired fraction was reached. + :param fractions_desired: numpy.array Area/probability for which the hdr should be computed. + :param _buffer_size: Size of the result buffer. The size is equivalent to the maximal number of + allowed intervals. + :param only_upper_part: Boolean, if true only computes area/probability between maximum and + current height. + :return: two arrays: The first one stores the start and inclusive endindex of the highest + density region. The second array holds the amplitude for which the desired fraction was + reached. + """ max_to_min, area_tot, res, res_amp, fi = _compute_hdr_core( - data, fractions_desired, only_upper_part, _buffer_size) - + data, fractions_desired, only_upper_part, _buffer_size + ) + lowest_sample_seen = np.inf for j in range(1, len(data)): if lowest_sample_seen == data[max_to_min[j]]: @@ -92,15 +99,15 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff res_amp[fi] = true_height # This part needs stable_sort - switch to object mode - with numba.objmode(ind='int64[:]'): + with numba.objmode(ind="int64[:]"): ind = stable_sort(max_to_min[:j]) - + gaps = np.arange(1, len(ind) + 1) diff = ind[1:] - ind[:-1] gaps = gaps[:-1][diff > 1] # Process intervals outside numba - with numba.objmode(fi='int64', res='int32[:, :, :]'): + with numba.objmode(fi="int64", res="int32[:, :, :]"): fi, res = _process_hdr_intervals(ind, gaps, fi, res, 0, _buffer_size) if fi == len(fractions_desired): @@ -112,5 +119,5 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff res[fi:, 1, 0] = len(data) for ind, fraction_desired in enumerate(fractions_desired[fi:]): res_amp[fi + ind] = (1 - fraction_desired) * np.sum(data) / len(data) - - return res, res_amp \ No newline at end of file + + return res, res_amp