Skip to content

Commit

Permalink
Use numpy and strax native dtypes, not "<i8" or "<f4" (#1502)
Browse files Browse the repository at this point in the history
* Use numpy native dtypes, not `"<i8"` or `"<f4"`

* Use `np.bool_` inside numba decorated function

* Debug
  • Loading branch information
dachengx authored Dec 18, 2024
1 parent 7fc3dbd commit 4d2c89e
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 47 deletions.
4 changes: 2 additions & 2 deletions straxen/analyses/bokeh_waveform_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def event_display_interactive(
raise ValueError("Found an event without peaks this should not had have happened.")

# Select main/alt S1/S2s based on time and endtime in event:
m_other_peaks = np.ones(len(peaks), dtype=np.bool_) # To select non-event peaks
m_other_peaks = np.ones(len(peaks), dtype=bool) # To select non-event peaks
endtime = strax.endtime(peaks)

signal = {}
Expand Down Expand Up @@ -570,7 +570,7 @@ def plot_pmt_array(
# Plotting PMTs:
pmts = straxen.pmt_positions()
if plot_all_pmts:
mask_pmts = np.zeros(len(pmts), dtype=np.bool_)
mask_pmts = np.zeros(len(pmts), dtype=bool)
else:
mask_pmts = to_pe == 0
pmts_on = pmts[~mask_pmts]
Expand Down
2 changes: 1 addition & 1 deletion straxen/analyses/waveform_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def plot_records_matrix(
# labels in the case of strings.
# Make a dict that converts the label to an int
int_labels = {h: i for i, h in enumerate(set(ylabs))}
mask = np.ones(len(ylabs), dtype=np.bool_)
mask = np.ones(len(ylabs), dtype=bool)
# If the label (int) is different wrt. its neighbour, show it
mask[1:] = np.abs(np.diff([int_labels[y] for y in ylabs])) > 0
# Only label the selection
Expand Down
28 changes: 12 additions & 16 deletions straxen/plugins/afterpulses/afterpulse_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,45 +412,41 @@ def dtype_afterpulses():
- The afterpulse datatype
"""
dtype_ap = [
(("Channel/PMT number", "channel"), "<i2"),
(("Time resolution in ns", "dt"), "<i2"),
(("Start time of the interval (ns since unix epoch)", "time"), "<i8"),
(("Length of the interval in samples", "length"), "<i4"),
(("Integral in ADC x samples", "area"), "<i4"),
(("Pulse area in PE", "area_pe"), "<f4"),
(("Sample index in which hit starts", "left"), "<i2"),
dtype_ap = strax.interval_dtype + [
(("Integral in ADC x samples", "area"), np.int32),
(("Pulse area in PE", "area_pe"), np.float32),
(("Sample index in which hit starts", "left"), np.int16),
(
(
"Sample index in which hit area succeeds 10% of total area",
"sample_10pc_area",
),
"<i2",
np.int16,
),
(
(
"Sample index in which hit area succeeds 50% of total area",
"sample_50pc_area",
),
"<i2",
np.int16,
),
(("Sample index of hit maximum", "max"), "<i2"),
(("Sample index of hit maximum", "max"), np.int16),
(
(
"Index of first sample in record just beyond hit (exclusive bound)",
"right",
),
"<i2",
np.int16,
),
(("Height of hit in ADC counts", "height"), "<i4"),
(("Height of hit in PE", "height_pe"), "<f4"),
(("Delay of hit w.r.t. LED hit in same WF, in samples", "tdelay"), "<i2"),
(("Height of hit in ADC counts", "height"), np.int32),
(("Height of hit in PE", "height_pe"), np.float32),
(("Delay of hit w.r.t. LED hit in same WF, in samples", "tdelay"), np.int16),
(
(
"Internal (temporary) index of fragment in which hit was found",
"record_i",
),
"<i4",
np.int32,
),
(
("Index of sample in record where integration starts", "left_integration"),
Expand Down
8 changes: 3 additions & 5 deletions straxen/plugins/events/event_basics_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def infer_dtype(self):
# Basic event properties
self._set_posrec_save()
self._set_dtype_requirements()
dtype = []
dtype += strax.time_fields
dtype += [
dtype = strax.time_fields + [
("n_peaks", np.int32, "Number of peaks in the event"),
("drift_time", np.float32, "Drift time between main S1 and S2 in ns"),
("event_number", np.int64, "Event number in this dataset"),
Expand Down Expand Up @@ -128,8 +126,8 @@ def _set_dtype_requirements(self):
("range_90p_area", np.float32, "width, 90% area [ns]"),
("rise_time", np.float32, "time between 10% and 50% area quantiles [ns]"),
("area_fraction_top", np.float32, "fraction of area seen by the top PMT array"),
("tight_coincidence", np.int16, "Channel within tight range of mean"),
("n_saturated_channels", np.int16, "Total number of saturated channels"),
("tight_coincidence", np.int16, "channel within tight range of mean"),
("n_saturated_channels", np.int16, "total number of saturated channels"),
)

def setup(self):
Expand Down
30 changes: 13 additions & 17 deletions straxen/plugins/led_cal/led_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"""

from immutabledict import immutabledict
import strax
import straxen
import numba
import numpy as np
import scipy.stats as sps
import strax
import straxen

# This makes sure shorthands for only the necessary functions
# are made available under straxen.[...]
Expand Down Expand Up @@ -162,15 +162,11 @@ class LEDCalibration(strax.Plugin):
help=("Minimum hit amplitude in numbers of baseline_rms above baseline. "),
)

dtype = [
dtype = strax.interval_dtype + [
(("Area averaged in integration windows", "area"), np.float32),
(("Area averaged in noise integration windows", "area_noise"), np.float32),
(("Amplitude in LED window", "amplitude_led"), np.float32),
(("Amplitude in off LED window", "amplitude_noise"), np.float32),
(("Channel", "channel"), np.int16),
(("Start time of the interval (ns since unix epoch)", "time"), np.int64),
(("Time resolution in ns", "dt"), np.int16),
(("Length of the interval in samples", "length"), np.int32),
(("Whether there was a hit found in the record", "triggered"), bool),
(("Sample index of the hit that defines the window position", "hit_position"), np.uint8),
(("Window used for integration", "integration_window"), np.uint8, (2,)),
Expand Down Expand Up @@ -281,28 +277,28 @@ def get_records(raw_records, baseline_window, led_cal_record_length):

record_length_padded = np.shape(raw_records.dtype["data"])[0]

_dtype = [
(("Start time since unix epoch [ns]", "time"), "<i8"),
(("Length of the interval in samples", "length"), "<i4"),
(("Width of one sample [ns]", "dt"), "<i2"),
(("Channel/PMT number", "channel"), "<i2"),
_dtype = strax.interval_dtype + [
(
(
"Length of pulse to which the record belongs (without zero-padding)",
"pulse_length",
),
"<i4",
np.int32,
),
(("Fragment number in the pulse", "record_i"), "<i2"),
(("Fragment number in the pulse", "record_i"), np.int16),
(
("Baseline in ADC counts. data = int(baseline) - data_orig", "baseline"),
"f4",
np.float32,
),
(
("Baseline RMS in ADC counts. data = baseline - data_orig", "baseline_rms"),
"f4",
np.float32,
),
(
("Waveform data in raw ADC counts with 0 padding", "data"),
np.float32,
(record_length_padded,),
),
(("Waveform data in raw ADC counts with 0 padding", "data"), "f4", (record_length_padded,)),
]

records = np.zeros(len(raw_records), dtype=_dtype)
Expand Down
6 changes: 2 additions & 4 deletions straxen/plugins/peaks/peak_basics_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ class PeakBasicsVanilla(strax.Plugin):
)

def infer_dtype(self):
dtype = [
(("Start time of the peak (ns since unix epoch)", "time"), np.int64),
(("End time of the peak (ns since unix epoch)", "endtime"), np.int64),
(("Weighted center time of the peak (ns since unix epoch)", "center_time"), np.int64),
dtype = strax.time_fields + [
(("Weighted center time of the peak [ns]", "center_time"), np.int64),
(("Peak integral in PE", "area"), np.float32),
(("Number of hits contributing at least one sample to the peak", "n_hits"), np.int32),
(("Number of PMTs contributing to the peak", "n_channels"), np.int16),
Expand Down
4 changes: 2 additions & 2 deletions straxen/plugins/raw_records_coin_nv/nveto_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def pulse_in_interval(raw_records, record_links, start_times, end_times):
"""
nrr = len(raw_records)
result = np.zeros(nrr, np.bool_)
result = np.zeros(nrr, bool)

last_interval_seen = 0
for ind, rr in enumerate(raw_records):
Expand Down Expand Up @@ -430,7 +430,7 @@ def _coincidence(rr, nfold=4, resolving_time=300):
"""
# 1. estimate time difference between fragments:
start_times = rr["time"]
mask = np.zeros(len(start_times), dtype=np.bool_)
mask = np.zeros(len(start_times), dtype=bool)
t_diff = np.diff(start_times, prepend=start_times[0])

# 2. Now we have to check if n-events are within resolving time:
Expand Down

0 comments on commit 4d2c89e

Please sign in to comment.