Skip to content

Commit

Permalink
Merge pull request #229 from donghoon-shin/my_change
Browse files Browse the repository at this point in the history
For recordings with frame > (2 ** 31 - 1), data type of two variables in in prepare_spikesortingview_data need to change from np.int32 to np.int64.
  • Loading branch information
magland authored Jan 15, 2024
2 parents c3d164a + 90f8f70 commit 015612c
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions sortingview/SpikeSortingView/prepare_spikesortingview_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def prepare_spikesortingview_data(
channel_neighborhood_size: int,
bandpass_filter: bool = False,
) -> str:
# NOTE(DS): for data longer than 25hours with fs = 20000; num_frame is too large for int32
if recording.get_num_frames() > (2 ** 31 - 1):
int_type = np.int64
else:
int_type = np.int32
# int_type = np.int64 #NOTE(DS): to test in short recording

print(f"int_type: {int_type}")

if bandpass_filter:
recording = spre.bandpass_filter(recording)
unit_ids = np.array(sorting.get_unit_ids()).astype(np.int32)
Expand All @@ -27,13 +36,14 @@ def prepare_spikesortingview_data(
num_frames = recording.get_num_frames()
num_frames_per_segment = math.ceil(segment_duration_sec * sampling_frequency)
num_segments = math.ceil(num_frames / num_frames_per_segment)

with kcl.TemporaryDirectory() as tmpdir:
output_file_name = tmpdir + "/spikesortingview.h5"
with h5py.File(output_file_name, "w") as f:
f.create_dataset("unit_ids", data=unit_ids)
f.create_dataset("sampling_frequency", data=np.array([sampling_frequency]).astype(np.float32))
f.create_dataset("channel_ids", data=channel_ids)
f.create_dataset("num_frames", data=np.array([num_frames]).astype(np.int32))
f.create_dataset("num_frames", data=np.array([num_frames]).astype(int_type))
channel_locations = recording.get_channel_locations()
f.create_dataset("channel_locations", data=np.array(channel_locations))
f.create_dataset("num_segments", data=np.array([num_segments]).astype(np.int32))
Expand Down Expand Up @@ -65,7 +75,7 @@ def prepare_spikesortingview_data(
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame)
assert isinstance(spike_train, np.ndarray)
if len(spike_train) > 0:
values = traces_with_padding[spike_train.astype(np.int32) - start_frame_with_padding, :]
values = traces_with_padding[spike_train - start_frame_with_padding, :].astype(np.int32)
avg_value = np.mean(values, axis=0)
peak_channel_ind = np.argmax(np.abs(avg_value))
peak_channel_id = channel_ids[peak_channel_ind]
Expand Down Expand Up @@ -94,7 +104,7 @@ def prepare_spikesortingview_data(
start_frame_with_padding = max(start_frame - snippet_len[0], 0)
end_frame_with_padding = min(end_frame + snippet_len[1], num_frames)
traces_with_padding = recording.get_traces(start_frame=start_frame_with_padding, end_frame=end_frame_with_padding)
traces_sample = traces_with_padding[start_frame - start_frame_with_padding : start_frame - start_frame_with_padding + int(sampling_frequency * 1), :]
traces_sample = traces_with_padding[start_frame - start_frame_with_padding: start_frame - start_frame_with_padding + int(sampling_frequency * 1), :]
f.create_dataset(f"segment/{iseg}/traces_sample", data=traces_sample)
all_subsampled_spike_trains = []
for unit_id in unit_ids:
Expand All @@ -103,7 +113,7 @@ def prepare_spikesortingview_data(
peak_channel_id = fallback_unit_peak_channel_ids.get(str(unit_id), None)
if peak_channel_id is None:
raise Exception(f"Peak channel not found for unit {unit_id}. This is probably because no spikes were found in any segment for this unit.")
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame).astype(np.int32)
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame).astype(int_type)
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/spike_train", data=spike_train)
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
peak_channel_ind = channel_ids.tolist().index(peak_channel_id)
Expand All @@ -127,7 +137,7 @@ def prepare_spikesortingview_data(
channel_neighborhood = unit_channel_neighborhoods[str(unit_id)]
channel_neighborhood_indices = [channel_ids.tolist().index(ch_id) for ch_id in channel_neighborhood]
num = len(all_subsampled_spike_trains[ii])
spike_snippets = spike_snippets_concat[index : index + num, :, channel_neighborhood_indices]
spike_snippets = spike_snippets_concat[index: index + num, :, channel_neighborhood_indices]
index = index + num
f.create_dataset(f"segment/{iseg}/unit/{unit_id}/subsampled_spike_snippets", data=spike_snippets)
uri = kcl.store_file_local(output_file_name)
Expand Down Expand Up @@ -155,7 +165,7 @@ def subsample(x: np.ndarray, num: int):
if num >= len(x):
return x
stride = math.floor(len(x) / num)
return x[0 : stride * num : stride]
return x[0: stride * num: stride]


def extract_spike_snippets(*, traces: np.ndarray, times: np.ndarray, snippet_len: Tuple[int, int]):
Expand Down

0 comments on commit 015612c

Please sign in to comment.