diff --git a/strax/__init__.py b/strax/__init__.py index 7e65f880..fde0f7cc 100644 --- a/strax/__init__.py +++ b/strax/__init__.py @@ -4,6 +4,7 @@ # Glue the package together # See https://www.youtube.com/watch?v=0oTh1CXRaQ0 if this confuses you # The order of subpackes is not invariant, since we use strax.xxx inside strax +from .sort_enforcement import * from .utils import * from .chunk import * from .dtypes import * diff --git a/strax/processing/general.py b/strax/processing/general.py index 461bb04d..8aa912ad 100644 --- a/strax/processing/general.py +++ b/strax/processing/general.py @@ -4,6 +4,7 @@ # for these fundamental functions, we throw warnings each time they are called import strax +from strax import stable_sort, stable_argsort import numba from numba.typed import List import numpy as np @@ -37,9 +38,9 @@ def sort_by_time(x): # Faster sorting: x = _sort_by_time_and_channel(x, channel, channel.max() + 1) elif "channel" in x.dtype.names: - x = np.sort(x, order=("time", "channel")) + x = stable_sort(x, order=("time", "channel")) else: - x = np.sort(x, order=("time",)) + x = stable_sort(x, order=("time",)) return x @@ -47,13 +48,13 @@ def sort_by_time(x): def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind="mergesort"): """Assumes you have no more than 10k channels, and records don't span more than 11 days. - (5-10x) faster than np.sort(order=...), as np.sort looks at all fields + (5-10x) faster than strax.stable_sort(order=...), as strax.stable_sort looks at all fields """ # I couldn't get fast argsort on multiple keys to work in numba # So, let's make a single key... sort_key = (x["time"] - x["time"].min()) * max_channel_plus_one + channel - sort_i = np.argsort(sort_key, kind=sort_kind) + sort_i = stable_argsort(sort_key, kind=sort_kind) return x[sort_i] @@ -426,7 +427,7 @@ def _touching_windows( thing_start, thing_end, container_start, container_end, window=0, endtime_sort_kind="mergesort" ): n = len(thing_start) - container_end_argsort = np.argsort(container_end, kind=endtime_sort_kind) + container_end_argsort = stable_argsort(container_end, kind=endtime_sort_kind) # we search twice, first for the beginning of the interval, then for the end left_i = right_i = 0 diff --git a/strax/processing/hitlets.py b/strax/processing/hitlets.py index 763300a7..b8dfa03d 100644 --- a/strax/processing/hitlets.py +++ b/strax/processing/hitlets.py @@ -93,7 +93,7 @@ def concat_overlapping_hits(hits, extensions, pmt_channels, start, end): return hits -@strax.utils.growing_result(strax.hit_dtype, chunk_size=int(1e4)) +@strax.growing_result(strax.hit_dtype, chunk_size=int(1e4)) @numba.njit(nogil=True, cache=True) def _concat_overlapping_hits( hits, @@ -499,23 +499,56 @@ def _conditional_entropy(hitlets, template, flat=False, square_data=False): return res +@export +@numba.njit(cache=True) +def _compute_simple_edges(interval_indices, dt): + """Compute edges without fractional edges using numba.""" + left = interval_indices[0, 0] * dt + right = interval_indices[1, np.argmax(interval_indices[1, :])] * dt + return left, right + + +@export @numba.njit(cache=True) +def _compute_fractional_edges(interval_indices, data, area_fraction_amplitude, dt): + """Compute edges with fractional consideration using numba.""" + left = interval_indices[0, 0] + right = interval_indices[1, np.argmax(interval_indices[1, :])] - 1 + + left_amp = data[left] + right_amp = data[right] + + next_left_amp = 0 + if (left - 1) >= 0: + next_left_amp = data[left - 1] + next_right_amp = 0 + if (right + 1) < len(data): + next_right_amp = data[right + 1] + + fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp) + fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp) + + left_edge = (left + 0.5 - fl) * dt + right_edge = (right + 0.5 + fr) * dt + return left_edge, right_edge + + +@export def highest_density_region_width( - data, fractions_desired, dt=1, fractionl_edges=False, _buffer_size=100 + data, fractions_desired, dt=1, fractional_edges=False, _buffer_size=100 ): """Function which computes the left and right edge based on the outer most sample for the highest density region of a signal. - Defines a 100% fraction as the sum over all positive samples in a waveform. + Args: + data: Data of a signal, e.g. hitlet or peak including zero length encoding + fractions_desired: Area fractions for which HDR should be computed + dt: Sample length in ns + fractional_edges: If true computes width as fractional time + _buffer_size: Maximal number of allowed intervals - :param data: Data of a signal, e.g. hitlet or peak including zero length encoding. - :param fractions_desired: Area fractions for which the highest density region should be - computed. - :param dt: Sample length in ns. - :param fractionl_edges: If true computes width as fractional time depending on the covered area - between the current and next sample. - :param _buffer_size: Maximal number of allowed intervals. If signal exceeds number e.g. due to - noise width computation is skipped. + Returns: + np.ndarray: Array of shape (len(fractions_desired), 2) containing left and right edges """ res = np.zeros((len(fractions_desired), 2), dtype=np.float32) @@ -525,49 +558,31 @@ def highest_density_region_width( res[:] = np.nan return res - inter, amps = strax.highest_density_region( + # Use the pure-python implementation for HDR computation + intervals, amps = strax.highest_density_region( data, fractions_desired, only_upper_part=True, _buffer_size=_buffer_size, ) - for index_area_fraction, (interval_indicies, area_fraction_amplitude) in enumerate( - zip(inter, amps) + # Deal with each area fraction separately + for index_area_fraction, (interval_indices, area_fraction_amplitude) in enumerate( + zip(intervals, amps) ): - if np.all(interval_indicies[:] == -1): + if np.all(interval_indices[:] == -1): res[index_area_fraction, :] = np.nan continue - if not fractionl_edges: - res[index_area_fraction, 0] = interval_indicies[0, 0] * dt - res[index_area_fraction, 1] = ( - interval_indicies[1, np.argmax(interval_indicies[1, :])] * dt - ) + if not fractional_edges: + left, right = _compute_simple_edges(interval_indices, dt) + res[index_area_fraction, 0] = left + res[index_area_fraction, 1] = right else: - left = interval_indicies[0, 0] - # -1 since value corresponds to outer edge: - right = interval_indicies[1, np.argmax(interval_indicies[1, :])] - 1 - - # Get amplitudes of outer most samples - # and amplitudes of adjacent samples (if any) - left_amp = data[left] - right_amp = data[right] - - next_left_amp = 0 - if (left - 1) >= 0: - next_left_amp = data[left - 1] - next_right_amp = 0 - if (right + 1) < len(data): - next_right_amp = data[right + 1] - - # Compute fractions and new left and right edges, the case - # left_amp == next_left_amp cannot occure by the definition - # of the highest density region. - fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp) - fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp) - - res[index_area_fraction, 0] = (left + 0.5 - fl) * dt - res[index_area_fraction, 1] = (right + 0.5 + fr) * dt + left, right = _compute_fractional_edges( + interval_indices, data, area_fraction_amplitude, dt + ) + res[index_area_fraction, 0] = left + res[index_area_fraction, 1] = right return res diff --git a/strax/processing/peak_building.py b/strax/processing/peak_building.py index 9a2a9966..abf07929 100644 --- a/strax/processing/peak_building.py +++ b/strax/processing/peak_building.py @@ -2,14 +2,13 @@ import numba import strax -from strax import utils -from strax.dtypes import peak_dtype, DIGITAL_SUM_WAVEFORM_CHANNEL +from strax.dtypes import DIGITAL_SUM_WAVEFORM_CHANNEL export, __all__ = strax.exporter() @export -@utils.growing_result(dtype=peak_dtype(), chunk_size=int(1e4)) +@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True, cache=True) def find_peaks( hits, diff --git a/strax/processing/peak_properties.py b/strax/processing/peak_properties.py index 76cb2693..7bc814ba 100644 --- a/strax/processing/peak_properties.py +++ b/strax/processing/peak_properties.py @@ -96,7 +96,7 @@ def compute_widths(peaks, select_peaks_indices=None): desired_fr = np.concatenate([0.5 - desired_widths / 2, 0.5 + desired_widths / 2]) # We lose the 50% fraction with this operation, let's add it back - desired_fr = np.sort(np.unique(np.append(desired_fr, [0.5]))) + desired_fr = strax.stable_sort(np.unique(np.append(desired_fr, [0.5]))) fr_times = index_of_fraction(peaks[select_peaks_indices], desired_fr) fr_times *= peaks["dt"][select_peaks_indices].reshape(-1, 1) diff --git a/strax/processing/statistics.py b/strax/processing/statistics.py index 963cace5..25d81483 100644 --- a/strax/processing/statistics.py +++ b/strax/processing/statistics.py @@ -1,36 +1,18 @@ import numpy as np import numba + import strax +from strax.sort_enforcement import stable_argsort, stable_sort export, __all__ = strax.exporter() @export -@numba.njit(cache=True) -def highest_density_region(data, fractions_desired, only_upper_part=False, _buffer_size=10): - """Computes for a given sampled distribution the highest density region of the desired - fractions. Does not assume anything on the normalisation of the data. - - :param data: Sampled distribution - :param fractions_desired: numpy.array Area/probability for which - the hdr should be computed. - :param _buffer_size: Size of the result buffer. The size is - equivalent to the maximal number of allowed intervals. - :param only_upper_part: Boolean, if true only computes - area/probability between maximum and current height. - :return: two arrays: The first one stores the start and inclusive - endindex of the highest density region. The second array holds - the amplitude for which the desired fraction was reached. - Note: - Also goes by the name highest posterior density. Please note, - that the right edge corresponds to the right side of the sample. - Hence the corresponding index is -= 1. - - """ - fi = 0 # number of fractions seen - # Buffer for the result if we find more then _buffer_size edges the function fails. - # User can then manually increase the buffer if needed. +@numba.jit(nopython=True, nogil=True, cache=True) +def _compute_hdr_core(data, fractions_desired, only_upper_part=False, _buffer_size=10): + """Core computation for highest density region initialization.""" + fi = 0 res = np.zeros((len(fractions_desired), 2, _buffer_size), dtype=np.int32) res_amp = np.zeros(len(fractions_desired), dtype=np.float32) @@ -41,90 +23,153 @@ def highest_density_region(data, fractions_desired, only_upper_part=False, _buff "with a total probability of less-equal 0." ) - # Need an index which sorted by amplitude - max_to_min = np.argsort(data, kind="mergesort")[::-1] + max_to_min = stable_argsort(data)[::-1] + return max_to_min, area_tot, res, res_amp, fi + + +@export +@numba.jit(nopython=True, nogil=True, cache=True) +def _process_intervals_numba(ind, gaps, fi, res, g0, _buffer_size): + """Process intervals using numba. + + Args: + ind: Sorted indices + gaps: Gap indices + fi: Current fraction index + res: Result buffer + g0: Start index + _buffer_size: Maximum number of intervals + + Returns: + tuple: (fi + 1, res) Updated fraction index and result buffer + + """ + if len(gaps) > _buffer_size: + res[fi, 0, :] = -1 + res[fi, 1, :] = -1 + return fi + 1, res + + g_ind = -1 + for g_ind, g in enumerate(gaps): + interval = ind[g0:g] + res[fi, 0, g_ind] = interval[0] + res[fi, 1, g_ind] = interval[-1] + 1 + g0 = g + + interval = ind[g0:] + res[fi, 0, g_ind + 1] = interval[0] + res[fi, 1, g_ind + 1] = interval[-1] + 1 + return fi + 1, res + + +@export +@numba.jit(nopython=True, nogil=True, cache=True) +def _compute_fraction_seen(data, max_to_min, j, lowest_sample_seen, area_tot, only_upper_part): + """Compute fraction seen (numba-compilable part). + + Args: + data: Input distribution + max_to_min: Sorted indices from max to min + j: Current index + lowest_sample_seen: Current lowest sample + area_tot: Total area + only_upper_part: If True, only compute area between max and current height + + Returns: + tuple: (fraction_seen, sorted_data_max_to_j, actual_lowest) + + """ + lowest_sample_seen *= int(only_upper_part) + sorted_data_max_to_j = data[max_to_min[:j]] + return ( + np.sum(sorted_data_max_to_j - lowest_sample_seen) / area_tot, + sorted_data_max_to_j, + lowest_sample_seen, + ) + + +@export +@numba.jit(nopython=True, nogil=True, cache=True) +def _compute_true_height(sorted_data_sum, j, g, lowest_sample_seen): + """Compute true height (numba-compilable part). + + Args: + sorted_data_sum: Sum of sorted data + j: Current index + g: Fraction ratio + lowest_sample_seen: Current lowest sample + + Returns: + float: True height value + + """ + return (1 - g) * sorted_data_sum / j + g * lowest_sample_seen + + +@export +def highest_density_region(data, fractions_desired, only_upper_part=False, _buffer_size=10): + """Compute highest density region for a given sampled distribution. + + This function splits only the stable sort operation into Python, keeping all other + computations numba-accelerated for maximum performance. + + Args: + data: Sampled distribution + fractions_desired: Area/probability for which HDR should be computed + only_upper_part: If True, only compute area between max and current height + _buffer_size: Size of result buffer (max number of allowed intervals) + + Returns: + tuple: (res, res_amp) where res contains interval indices and res_amp contains + amplitudes for desired fractions + + """ + # Initialize using numba + max_to_min, area_tot, res, res_amp, fi = _compute_hdr_core( + data, fractions_desired, only_upper_part, _buffer_size + ) lowest_sample_seen = np.inf for j in range(1, len(data)): - # Loop over indices compute fractions from max to min if lowest_sample_seen == data[max_to_min[j]]: - # We saw this sample height already, so no need to repeat continue lowest_sample_seen = data[max_to_min[j]] - lowest_sample_seen *= int(only_upper_part) - sorted_data_max_to_j = data[max_to_min[:j]] - fraction_seen = np.sum(sorted_data_max_to_j - lowest_sample_seen) / area_tot - # Check if this height step exceeded at least one of the desired - # fractions + # Compute fraction seen (numba) + fraction_seen, sorted_data_max_to_j, actual_lowest = _compute_fraction_seen( + data, max_to_min, j, lowest_sample_seen, area_tot, only_upper_part + ) + m = fractions_desired[fi:] <= fraction_seen if not np.any(m): - # If we do not exceed go to the next sample. continue for fraction_desired in fractions_desired[fi : fi + np.sum(m)]: - # Since we loop always to the height of the next highest sample - # it might happen that we overshoot the desired fraction. Similar - # to the area deciles algorithm we have now to figure out at which - # height we actually reached the desired fraction and store the - # corresponding height: g = fraction_desired / fraction_seen - - # The following gives the true height, to get here one has to - # solve for h: - # 1. fraction_seen = sum_{i=0}^j (y_i - y_j) / a_total - # 2. fraction_desired = sum_{i=0}^j (y_i - h) / a_total - # 3. g = fraction_desired/fraction_seen - # j == number of seen samples - # n == number of total samples in distribution - true_height = (1 - g) * np.sum(sorted_data_max_to_j) / j + g * lowest_sample_seen + # Compute true height (numba) + true_height = _compute_true_height(np.sum(sorted_data_max_to_j), j, g, actual_lowest) res_amp[fi] = true_height - # Find gaps and get edges of hdr intervals: - ind = np.sort(max_to_min[:j]) - gaps = np.arange(1, len(ind) + 1) + # Only stable_sort in Python mode + with numba.objmode(ind="int64[:]"): + ind = stable_sort(max_to_min[:j]) - g0 = 0 - g_ind = -1 + # Rest stays in numba mode + gaps = np.arange(1, len(ind) + 1) diff = ind[1:] - ind[:-1] gaps = gaps[:-1][diff > 1] - if len(gaps) > _buffer_size: - # This signal has more boundaries than the buffer can hold - # hence set all entries to -1 instead. - res[fi, 0, :] = -1 - res[fi, 1, :] = -1 - fi += 1 - else: - for g_ind, g in enumerate(gaps): - # Loop over all gaps and get outer edges: - interval = ind[g0:g] - res[fi, 0, g_ind] = interval[0] - res[fi, 1, g_ind] = interval[-1] + 1 - g0 = g - - # Now we have to do the last interval: - interval = ind[g0:] - res[fi, 0, g_ind + 1] = interval[0] - res[fi, 1, g_ind + 1] = interval[-1] + 1 - fi += 1 - - if fi == (len(fractions_desired)): - # Found all fractions so we are done + # Process intervals with numba + fi, res = _process_intervals_numba(ind, gaps, fi, res, 0, _buffer_size) + + if fi == len(fractions_desired): return res, res_amp - # If we end up here this might be due to an offset - # of the distribution with respect to zero. In that case it can - # happen that we do not find all desired fractions. - # Hence we have to enforce to compute the last step from the last - # lowest hight we have seen to zero. - # Left and right edge is by definition 0 and len(data): + # Handle remaining fractions (in numba) res[fi:, 0, 0] = 0 res[fi:, 1, 0] = len(data) - # Now we have to compute the heights for the fractions we have not - # seen yet, since lowest_sample_seen == 0 and j == len(data) - # the formula above reduces to: for ind, fraction_desired in enumerate(fractions_desired[fi:]): res_amp[fi + ind] = (1 - fraction_desired) * np.sum(data) / len(data) + return res, res_amp diff --git a/strax/run_selection.py b/strax/run_selection.py index ee9dee4c..638bf517 100644 --- a/strax/run_selection.py +++ b/strax/run_selection.py @@ -10,6 +10,7 @@ import pytz import datetime import strax +from strax import stable_argsort # use tqdm as loaded in utils (from tqdm.notebook when in a juypyter env) tqdm = strax.utils.tqdm @@ -441,7 +442,7 @@ def define_run( run_md["comments"] = [{"comment": comment} for comment in comments] # Make sure subruns are sorted in time - sort_index = np.argsort(starts) + sort_index = stable_argsort(starts) data = {keys[i]: data[keys[i]] for i in sort_index} # Superrun names must start with an underscore diff --git a/strax/sort_enforcement.py b/strax/sort_enforcement.py new file mode 100644 index 00000000..84bda23d --- /dev/null +++ b/strax/sort_enforcement.py @@ -0,0 +1,46 @@ +import numpy as np +from numba.extending import register_jitable + +# Define error message as a constant +UNSTABLE_SORT_MESSAGE = ( + "quicksort and heapsort are not allowed due to non-deterministic behavior.\n" + "Please use mergesort for deterministic sorting behavior." +) + + +# Define custom exception for sorting errors +class SortingError(Exception): + pass + + +def stable_sort(arr, kind="mergesort", **kwargs): + """Stable sort function using mergesort, w/o numba optimization. + + Args: + arr: numpy array to sort + kind: sorting algorithm to use (only 'mergesort' is allowed) + + Returns: + Sorted array using mergesort algorithm + + """ + if kind != "mergesort": + raise SortingError(UNSTABLE_SORT_MESSAGE) + return np.sort(arr, kind="mergesort", **kwargs) + + +@register_jitable +def stable_argsort(arr, kind="mergesort"): + """Numba-optimized stable argsort function using mergesort. + + Args: + arr: numpy array to sort + kind: sorting algorithm to use (only 'mergesort' is allowed) + + Returns: + Indices that would sort the array using mergesort algorithm + + """ + if kind != "mergesort": + raise SortingError(UNSTABLE_SORT_MESSAGE) + return np.argsort(arr, kind="mergesort") diff --git a/strax/utils.py b/strax/utils.py index f532a46f..f824a7ec 100644 --- a/strax/utils.py +++ b/strax/utils.py @@ -11,6 +11,7 @@ import typing as ty from hashlib import sha1 import strax +from strax import stable_argsort, stable_sort import numexpr import dill import numba @@ -523,7 +524,7 @@ def multi_run( # This will autocast all run ids to Unicode fixed-width run_id_numpy = np.array(run_ids) - run_id_numpy = np.sort(run_id_numpy) + run_id_numpy = stable_sort(run_id_numpy) _is_superrun = np.any([r.startswith("_") for r in run_id_numpy]) # Get from kwargs whether output should contain a run_id field. @@ -612,7 +613,7 @@ def multi_run( pbar.close() return None - final_result = [final_result[ind] for ind in np.argsort(run_id_output)] + final_result = [final_result[ind] for ind in stable_argsort(run_id_output)] pbar.close() if ignore_errors and len(failures): log.warning( diff --git a/tests/test_general_processing.py b/tests/test_general_processing.py index ab130aa4..5a141389 100644 --- a/tests/test_general_processing.py +++ b/tests/test_general_processing.py @@ -111,7 +111,7 @@ def test_get_empty_container_ids(full_container_ids): :return: """ - full_container_ids = np.sort(full_container_ids) + full_container_ids = strax.stable_sort(full_container_ids) if len(full_container_ids): n_containers = np.max(full_container_ids) @@ -155,7 +155,7 @@ def test_split(things, split_indices): :param split_indices: Indices at which things should be split. """ - split_indices = np.sort(split_indices) + split_indices = strax.stable_sort(split_indices) split_things = strax.processing.general._split(things, split_indices) split_things_np = np.split(things, split_indices) @@ -312,18 +312,18 @@ def test_sort_by_time(time, channel): dummy_array2["time"] = time res1 = strax.sort_by_time(dummy_array) - res2 = np.sort(dummy_array, order="time") + res2 = strax.stable_sort(dummy_array, order="time") assert np.all(res1 == res2) res1 = strax.sort_by_time(dummy_array2) - res2 = np.sort(dummy_array2, order="time") + res2 = strax.stable_sort(dummy_array2, order="time") assert np.all(res1 == res2) # Test again with random channels dummy_array3 = dummy_array2.copy() dummy_array3["channel"] = channel res1 = strax.sort_by_time(dummy_array3) - res2 = np.sort(dummy_array3, order=("time", "channel")) + res2 = strax.stable_sort(dummy_array3, order=("time", "channel")) assert np.all(res1 == res2) # Create some large time difference that would cause @@ -332,7 +332,7 @@ def test_sort_by_time(time, channel): dummy_array3["time"][0] = np.iinfo(np.int64).min // 2 + 1 dummy_array3["time"][-1] = np.iinfo(np.int64).max // 2 - 1 res1 = strax.sort_by_time(dummy_array3) - res2 = np.sort(dummy_array3, order=("time", "channel")) + res2 = strax.stable_sort(dummy_array3, order=("time", "channel")) assert np.all(res1 == res2) _test_sort_by_time_peaks(time) @@ -347,7 +347,7 @@ def _test_sort_by_time_peaks(time): dummy_array["channel"] = -1 res1 = strax.sort_by_time(dummy_array) - res2 = np.sort(dummy_array, order="time") + res2 = strax.stable_sort(dummy_array, order="time") assert np.all(res1 == res2) diff --git a/tests/test_hitlet.py b/tests/test_hitlet.py index a3f29ac4..5321e50a 100644 --- a/tests/test_hitlet.py +++ b/tests/test_hitlet.py @@ -290,14 +290,14 @@ def test_highest_density_region_width(): # Check that negative data does not raise: res = strax.processing.hitlets.highest_density_region_width( - np.array([0, -1, -2]), np.array([0.5]), fractionl_edges=True + np.array([0, -1, -2]), np.array([0.5]), fractional_edges=True ) assert np.all(np.isnan(res)), "For empty data HDR is not defined, should return np.nan!" def _test_highest_density_region_width(distribution, truth_dict): res = strax.processing.hitlets.highest_density_region_width( - distribution, np.array(list(truth_dict.keys())), fractionl_edges=True + distribution, np.array(list(truth_dict.keys())), fractional_edges=True ) for ind, (fraction, truth) in enumerate(truth_dict.items()): @@ -390,7 +390,7 @@ def test_conditional_entropy(data, size_template_and_ind_max_template): """ hitlet = np.zeros(1, dtype=strax.hitlet_with_data_dtype(n_samples=10)) - ind_max_template, size_template = np.sort(size_template_and_ind_max_template) + ind_max_template, size_template = strax.stable_sort(size_template_and_ind_max_template) # Make dummy hitlet: data = data.astype(np.float32) diff --git a/tests/test_mailbox.py b/tests/test_mailbox.py index 5ef96b2a..0224ec5b 100644 --- a/tests/test_mailbox.py +++ b/tests/test_mailbox.py @@ -35,7 +35,7 @@ def mailbox_tester( numbers = np.arange(len(messages)) if expected_result is None: messages = np.asarray(messages) - expected_result = messages[np.argsort(numbers)] + expected_result = messages[strax.stable_argsort(numbers)] mb = strax.Mailbox(max_messages=max_messages, timeout=timeout, lazy=lazy) diff --git a/tests/test_peak_processing.py b/tests/test_peak_processing.py index 2040f593..149df172 100644 --- a/tests/test_peak_processing.py +++ b/tests/test_peak_processing.py @@ -41,7 +41,7 @@ def test_find_peaks(hits, min_channels, min_area): ends = peaks["time"] + peaks["length"] * peaks["dt"] assert np.all(ends[:-1] + gap_threshold <= starts[1:]) - assert np.all(starts == np.sort(starts)), "Not sorted" + assert np.all(starts == strax.stable_sort(starts)), "Not sorted" assert np.all(peaks["time"] < strax.endtime(peaks)), "Non+ peak length" diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..4b04005d --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,95 @@ +import unittest +import numpy as np +import warnings +from hypothesis import given, strategies +from hypothesis.extra.numpy import arrays, integer_dtypes +from strax.sort_enforcement import SortingError, stable_sort, stable_argsort + + +class TestSortEnforcement(unittest.TestCase): + @given(arrays(integer_dtypes(), strategies.integers(1, 100))) + def test_explicit_stable_sort(self, arr): + """Test explicit stable_sort function with generated arrays.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + sorted_arr = stable_sort(arr) + np.testing.assert_array_equal(sorted_arr, np.sort(arr, kind="mergesort")) + # Verify the array is actually sorted + self.assertTrue(np.all(sorted_arr[:-1] <= sorted_arr[1:])) + + @given(arrays(integer_dtypes(), strategies.integers(1, 100))) + def test_explicit_stable_argsort(self, arr): + """Test explicit stable_argsort function with generated arrays.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + sorted_indices = stable_argsort(arr) + np.testing.assert_array_equal(sorted_indices, np.argsort(arr, kind="mergesort")) + # Verify the indices actually sort the array + sorted_arr = arr[sorted_indices] + self.assertTrue(np.all(sorted_arr[:-1] <= sorted_arr[1:])) + + @given( + arrays(integer_dtypes(), strategies.integers(1, 100)), + strategies.sampled_from(["quicksort", "heapsort"]), + ) + def test_wrapped_quicksort_rejection(self, arr, sort_kind): + """Test that quicksort and heapsort raise errors in wrapped functions.""" + with self.assertRaises(SortingError): + stable_sort(arr, kind=sort_kind) + with self.assertRaises(SortingError): + stable_argsort(arr, kind=sort_kind) + + @given(arrays(integer_dtypes(), strategies.integers(1, 100))) + def test_original_numpy_unaffected(self, arr): + """Test that original numpy sort functions still work with quicksort.""" + try: + quicksort_result = np.sort(arr, kind="quicksort") + self.assertTrue(np.all(quicksort_result[:-1] <= quicksort_result[1:])) + + quicksort_indices = np.argsort(arr, kind="quicksort") + sorted_arr = arr[quicksort_indices] + self.assertTrue(np.all(sorted_arr[:-1] <= sorted_arr[1:])) + except Exception as e: + self.fail(f"numpy sort with quicksort raised an unexpected exception: {e}") + + @given( + strategies.lists( + strategies.tuples( + strategies.integers(1, 10), # num field + strategies.text(min_size=1, max_size=1), # letter field + ), + min_size=1, + max_size=100, + ) + ) + def test_sort_stability(self, data): + """Test that wrapped sorting is stable using generated structured arrays.""" + # Convert list of tuples to structured array + arr = np.array(data, dtype=[("num", int), ("letter", "U1")]) + + # First sort by letter to establish initial order + arr_by_letter = stable_sort(arr, order="letter") + # Then sort by number - if sort is stable, items with same number + # should maintain their relative order from the letter sort + final_sort = stable_sort(arr_by_letter, order="num") + + # Verify sorting works correctly + for i in range(len(final_sort) - 1): + # Check primary sort key (number) + self.assertTrue( + final_sort[i]["num"] <= final_sort[i + 1]["num"], + f"Primary sort failed: {final_sort[i]} should come before {final_sort[i + 1]}", + ) + + # If numbers are equal, check that letter order is preserved + if final_sort[i]["num"] == final_sort[i + 1]["num"]: + self.assertTrue( + final_sort[i]["letter"] <= final_sort[i + 1]["letter"], + f"Stability violated: for equal numbers {final_sort[i]['num']}, " + f"letter {final_sort[i]['letter']} should come " + f"before or equal to {final_sort[i + 1]['letter']}", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2)