Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interpolation of nans in decoding position #1033

Merged
merged 12 commits into from
Aug 6, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PositionGroup.alter()
- Decoding:
- Default values for classes on `ImportError` #966
- Add option to upsample data rate in `PositionGroup` #1008
- Avoid interpolating over large `nan` intervals in position #1033
- Position
- Allow dlc without pre-existing tracking data #973, #975
- Raise `KeyError` for missing input parameters across helper funcs #966
Expand Down
15 changes: 7 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def make(self, key):
position_info.index <= interval_end,
)
] = True
is_training[
position_info[position_variable_names].isna().values.max(axis=1)
edeno marked this conversation as resolved.
Show resolved Hide resolved
] = False
if "is_training" not in decoding_kwargs:
decoding_kwargs["is_training"] = is_training

Expand Down Expand Up @@ -427,14 +430,10 @@ def fetch_linear_position_info(key):

min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

return (
pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
Expand Down
27 changes: 22 additions & 5 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,7 @@ def fetch_position_info(
min_time = min([df.index.min() for df in position_info])
if max_time is None:
max_time = max([df.index.max() for df in position_info])
position_info = (
pd.concat(position_info, axis=0)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
position_info = pd.concat(position_info, axis=0).loc[min_time:max_time]

return position_info, position_variable_names

Expand Down Expand Up @@ -239,10 +235,31 @@ def _upsample(
np.unique(np.concatenate((position_df.index, new_time))),
name="time",
)

# Find NaN intervals
nan_intervals = {}
for column in position_df.columns:
edeno marked this conversation as resolved.
Show resolved Hide resolved
is_nan = position_df[column].isna().values.astype(int)
edeno marked this conversation as resolved.
Show resolved Hide resolved
edeno marked this conversation as resolved.
Show resolved Hide resolved
st = np.where(np.diff(is_nan) == 1)[0] + 1
en = np.where(np.diff(is_nan) == -1)[0]
if is_nan[0]:
st = np.insert(st, 0, 0)
if is_nan[-1]:
en = np.append(en, len(is_nan) - 1)
st = position_df.index[st].values
en = position_df.index[en].values
nan_intervals[column] = list(zip(st, en))

# upsample and interpolate
position_df = (
position_df.reindex(index=new_index)
.interpolate(method=upsampling_interpolation_method)
.reindex(index=new_time)
)

# Fill NaN intervals
for column, intervals in nan_intervals.items():
for st, en in intervals:
position_df[column][st:en] = np.nan

return position_df
17 changes: 8 additions & 9 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def make(self, key):
position_info.index <= interval_end,
)
] = True
is_training[
position_info[position_variable_names].isna().values.max(axis=1)
] = False

if "is_training" not in decoding_kwargs:
decoding_kwargs["is_training"] = is_training

Expand Down Expand Up @@ -388,15 +392,10 @@ def fetch_linear_position_info(key):
edge_spacing=environment.edge_spacing,
)
min_time, max_time = SortedSpikesDecodingV1._get_interval_range(key)

return (
pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(
Expand Down
Loading