Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enforcement for np.sort and np.argsort #918

Merged
merged 59 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
9a81558
set mergesort as default and disable unstable kinds
yuema137 Oct 23, 2024
339277f
add unittest
yuema137 Oct 23, 2024
b127ee9
formatting
yuema137 Oct 23, 2024
cda8c43
formatting
yuema137 Oct 23, 2024
fb1419d
change name to sort_enforcement
yuema137 Oct 23, 2024
c1e8246
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
05c594c
break long error messages
yuema137 Oct 23, 2024
cdeb036
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 23, 2024
fda2638
keep the original sorting in numpy
yuema137 Oct 25, 2024
94f9159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
1989b73
reemove unused import
yuema137 Oct 30, 2024
d0aa707
always use stablesort
yuema137 Oct 31, 2024
348f8e3
add numba-supported version of stableargsort
yuema137 Oct 31, 2024
c776020
use better naming for stablesort
yuema137 Oct 31, 2024
939777a
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
b609e71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8f23ef5
use jitable to allow both regular function and numba-decorated functi…
yuema137 Oct 31, 2024
3347799
remove redundant numba_sort
yuema137 Oct 31, 2024
f8ac388
explicitly import stablesort from strax for numba decorated functions
yuema137 Oct 31, 2024
2cf9753
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
bfc89e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8a7c00d
consistent import style within one module
yuema137 Oct 31, 2024
2fae5ce
remove unused import
yuema137 Oct 31, 2024
4b933e0
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
ab9ff71
add sorting error
yuema137 Oct 31, 2024
66f4d2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
8f759c6
disable numba support for stable_sort
yuema137 Oct 31, 2024
478a61f
consistent import style for stable sort
yuema137 Oct 31, 2024
92a7807
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Oct 31, 2024
0be548f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
3b236d3
add kwargs
yuema137 Nov 1, 2024
074d597
merge master
yuema137 Nov 1, 2024
b9e203d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
024989a
modify docstring for stable_sort
yuema137 Nov 1, 2024
6e9491f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
66839ae
remove kwargs
yuema137 Nov 13, 2024
edfc4a6
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
4c11074
update variable name
yuema137 Nov 13, 2024
0168bc8
update test_sort with hypothesis
yuema137 Nov 13, 2024
413c8c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
898a8cf
rewrite hithest_density_region to decoupld stable_sort from numba part
yuema137 Nov 13, 2024
f942bc6
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
a9006f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
507a6cd
remove unused import
yuema137 Nov 13, 2024
93e4b5b
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
cf9d276
break long lines
yuema137 Nov 13, 2024
ddb242d
Merge branch 'master' into set_default_as_mergesort
yuema137 Nov 13, 2024
4ad97ae
remove numba decorator for the main function
yuema137 Nov 13, 2024
6d0b394
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
c311544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
586d7df
fix typo
yuema137 Nov 13, 2024
f4aefe8
rewrite hitlets to use non-numba HDR region
yuema137 Nov 13, 2024
e32d49d
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
b0fb1ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
3d9cc73
format hitlets.py
yuema137 Nov 13, 2024
13cc0d7
unify growing_result import to fix mypy error
yuema137 Nov 13, 2024
5a3beab
Merge branch 'set_default_as_mergesort' of github.com:AxFoundation/st…
yuema137 Nov 13, 2024
48d4e93
remove redundant space
yuema137 Nov 14, 2024
248a2a8
Remove unnecessary indent
dachengx Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions strax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Glue the package together
# See https://www.youtube.com/watch?v=0oTh1CXRaQ0 if this confuses you
# The order of subpackes is not invariant, since we use strax.xxx inside strax
from .sort_enforcement import *
from .utils import *
from .chunk import *
from .dtypes import *
Expand Down
11 changes: 6 additions & 5 deletions strax/processing/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# for these fundamental functions, we throw warnings each time they are called

import strax
from strax import stable_sort, stable_argsort
import numba
from numba.typed import List
import numpy as np
Expand Down Expand Up @@ -37,23 +38,23 @@ def sort_by_time(x):
# Faster sorting:
x = _sort_by_time_and_channel(x, channel, channel.max() + 1)
elif "channel" in x.dtype.names:
x = np.sort(x, order=("time", "channel"))
x = stable_sort(x, order=("time", "channel"))
else:
x = np.sort(x, order=("time",))
x = stable_sort(x, order=("time",))
return x


@numba.jit(nopython=True, nogil=True, cache=True)
def _sort_by_time_and_channel(x, channel, max_channel_plus_one, sort_kind="mergesort"):
"""Assumes you have no more than 10k channels, and records don't span more than 11 days.

(5-10x) faster than np.sort(order=...), as np.sort looks at all fields
(5-10x) faster than strax.stable_sort(order=...), as strax.stable_sort looks at all fields

"""
# I couldn't get fast argsort on multiple keys to work in numba
# So, let's make a single key...
sort_key = (x["time"] - x["time"].min()) * max_channel_plus_one + channel
sort_i = np.argsort(sort_key, kind=sort_kind)
sort_i = stable_argsort(sort_key, kind=sort_kind)
return x[sort_i]


Expand Down Expand Up @@ -426,7 +427,7 @@ def _touching_windows(
thing_start, thing_end, container_start, container_end, window=0, endtime_sort_kind="mergesort"
):
n = len(thing_start)
container_end_argsort = np.argsort(container_end, kind=endtime_sort_kind)
container_end_argsort = stable_argsort(container_end, kind=endtime_sort_kind)

# we search twice, first for the beginning of the interval, then for the end
left_i = right_i = 0
Expand Down
103 changes: 59 additions & 44 deletions strax/processing/hitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def concat_overlapping_hits(hits, extensions, pmt_channels, start, end):
return hits


@strax.utils.growing_result(strax.hit_dtype, chunk_size=int(1e4))
@strax.growing_result(strax.hit_dtype, chunk_size=int(1e4))
@numba.njit(nogil=True, cache=True)
def _concat_overlapping_hits(
hits,
Expand Down Expand Up @@ -499,23 +499,56 @@ def _conditional_entropy(hitlets, template, flat=False, square_data=False):
return res


@export
@numba.njit(cache=True)
def _compute_simple_edges(interval_indices, dt):
"""Compute edges without fractional edges using numba."""
left = interval_indices[0, 0] * dt
right = interval_indices[1, np.argmax(interval_indices[1, :])] * dt
return left, right


@export
@numba.njit(cache=True)
def _compute_fractional_edges(interval_indices, data, area_fraction_amplitude, dt):
"""Compute edges with fractional consideration using numba."""
left = interval_indices[0, 0]
right = interval_indices[1, np.argmax(interval_indices[1, :])] - 1

left_amp = data[left]
right_amp = data[right]

next_left_amp = 0
if (left - 1) >= 0:
next_left_amp = data[left - 1]
next_right_amp = 0
if (right + 1) < len(data):
next_right_amp = data[right + 1]

fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp)
fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp)

left_edge = (left + 0.5 - fl) * dt
right_edge = (right + 0.5 + fr) * dt
return left_edge, right_edge


@export
def highest_density_region_width(
data, fractions_desired, dt=1, fractionl_edges=False, _buffer_size=100
data, fractions_desired, dt=1, fractional_edges=False, _buffer_size=100
):
"""Function which computes the left and right edge based on the outer most sample for the
highest density region of a signal.

Defines a 100% fraction as the sum over all positive samples in a waveform.
Args:
data: Data of a signal, e.g. hitlet or peak including zero length encoding
fractions_desired: Area fractions for which HDR should be computed
dt: Sample length in ns
fractional_edges: If true computes width as fractional time
_buffer_size: Maximal number of allowed intervals

:param data: Data of a signal, e.g. hitlet or peak including zero length encoding.
:param fractions_desired: Area fractions for which the highest density region should be
computed.
:param dt: Sample length in ns.
:param fractionl_edges: If true computes width as fractional time depending on the covered area
between the current and next sample.
:param _buffer_size: Maximal number of allowed intervals. If signal exceeds number e.g. due to
noise width computation is skipped.
Returns:
np.ndarray: Array of shape (len(fractions_desired), 2) containing left and right edges

"""
res = np.zeros((len(fractions_desired), 2), dtype=np.float32)
Expand All @@ -525,49 +558,31 @@ def highest_density_region_width(
res[:] = np.nan
return res

inter, amps = strax.highest_density_region(
# Use the pure-python implementation for HDR computation
intervals, amps = strax.highest_density_region(
data,
fractions_desired,
only_upper_part=True,
_buffer_size=_buffer_size,
)

for index_area_fraction, (interval_indicies, area_fraction_amplitude) in enumerate(
zip(inter, amps)
# Deal with each area fraction separately
for index_area_fraction, (interval_indices, area_fraction_amplitude) in enumerate(
zip(intervals, amps)
):
if np.all(interval_indicies[:] == -1):
if np.all(interval_indices[:] == -1):
res[index_area_fraction, :] = np.nan
continue

if not fractionl_edges:
res[index_area_fraction, 0] = interval_indicies[0, 0] * dt
res[index_area_fraction, 1] = (
interval_indicies[1, np.argmax(interval_indicies[1, :])] * dt
)
if not fractional_edges:
left, right = _compute_simple_edges(interval_indices, dt)
res[index_area_fraction, 0] = left
res[index_area_fraction, 1] = right
else:
left = interval_indicies[0, 0]
# -1 since value corresponds to outer edge:
right = interval_indicies[1, np.argmax(interval_indicies[1, :])] - 1

# Get amplitudes of outer most samples
# and amplitudes of adjacent samples (if any)
left_amp = data[left]
right_amp = data[right]

next_left_amp = 0
if (left - 1) >= 0:
next_left_amp = data[left - 1]
next_right_amp = 0
if (right + 1) < len(data):
next_right_amp = data[right + 1]

# Compute fractions and new left and right edges, the case
# left_amp == next_left_amp cannot occure by the definition
# of the highest density region.
fl = (left_amp - area_fraction_amplitude) / (left_amp - next_left_amp)
fr = (right_amp - area_fraction_amplitude) / (right_amp - next_right_amp)

res[index_area_fraction, 0] = (left + 0.5 - fl) * dt
res[index_area_fraction, 1] = (right + 0.5 + fr) * dt
left, right = _compute_fractional_edges(
interval_indices, data, area_fraction_amplitude, dt
)
res[index_area_fraction, 0] = left
res[index_area_fraction, 1] = right

return res
5 changes: 2 additions & 3 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import numba

import strax
from strax import utils
from strax.dtypes import peak_dtype, DIGITAL_SUM_WAVEFORM_CHANNEL
from strax.dtypes import DIGITAL_SUM_WAVEFORM_CHANNEL

export, __all__ = strax.exporter()


@export
@utils.growing_result(dtype=peak_dtype(), chunk_size=int(1e4))
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True, cache=True)
def find_peaks(
hits,
Expand Down
2 changes: 1 addition & 1 deletion strax/processing/peak_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compute_widths(peaks, select_peaks_indices=None):
desired_fr = np.concatenate([0.5 - desired_widths / 2, 0.5 + desired_widths / 2])

# We lose the 50% fraction with this operation, let's add it back
desired_fr = np.sort(np.unique(np.append(desired_fr, [0.5])))
desired_fr = strax.stable_sort(np.unique(np.append(desired_fr, [0.5])))

fr_times = index_of_fraction(peaks[select_peaks_indices], desired_fr)
fr_times *= peaks["dt"][select_peaks_indices].reshape(-1, 1)
Expand Down
Loading