Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interpolation of nans in decoding position #1033

Merged
merged 12 commits into from
Aug 6, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ PositionGroup.alter()

- Default values for classes on `ImportError` #966
- Add option to upsample data rate in `PositionGroup` #1008
- Avoid interpolating over large `nan` intervals in position #1033

- Position

Expand Down
15 changes: 7 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def make(self, key):
position_info.index <= interval_end,
)
] = True
is_training[
position_info[position_variable_names].isna().values.max(axis=1)
edeno marked this conversation as resolved.
Show resolved Hide resolved
] = False
if "is_training" not in decoding_kwargs:
decoding_kwargs["is_training"] = is_training

Expand Down Expand Up @@ -426,14 +429,10 @@ def fetch_linear_position_info(key):

min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

return (
pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
Expand Down
34 changes: 29 additions & 5 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def fetch_position_info(
PositionOutput & {"merge_id": pos_merge_id}
).fetch1_dataframe(),
upsampling_sampling_rate=upsample_rate,
position_variable_names=position_variable_names,
)
)
else:
Expand All @@ -189,11 +190,7 @@ def fetch_position_info(
min_time = min([df.index.min() for df in position_info])
if max_time is None:
max_time = max([df.index.max() for df in position_info])
position_info = (
pd.concat(position_info, axis=0)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
position_info = pd.concat(position_info, axis=0).loc[min_time:max_time]

return position_info, position_variable_names

Expand All @@ -202,6 +199,7 @@ def _upsample(
position_df: pd.DataFrame,
upsampling_sampling_rate: float,
upsampling_interpolation_method: str = "linear",
position_variable_names: list[str] = None,
) -> pd.DataFrame:
"""upsample position data to a fixed sampling rate

Expand All @@ -213,6 +211,9 @@ def _upsample(
sampling rate to upsample to
upsampling_interpolation_method : str, optional
pandas method for interpolation, by default "linear"
position_variable_names : list[str], optional
names of position variables of focus, for which nan values will not be
interpolated, by default None includes all columns

Returns
-------
Expand All @@ -239,10 +240,33 @@ def _upsample(
np.unique(np.concatenate((position_df.index, new_time))),
name="time",
)

# Find NaN intervals
nan_intervals = {}
if position_variable_names is None:
position_variable_names = position_df.columns
for column in position_variable_names:
is_nan = position_df[column].isna().to_numpy().astype(int)
st = np.where(np.diff(is_nan) == 1)[0] + 1
en = np.where(np.diff(is_nan) == -1)[0]
if is_nan[0]:
st = np.insert(st, 0, 0)
if is_nan[-1]:
en = np.append(en, len(is_nan) - 1)
st = position_df.index[st].to_numpy()
en = position_df.index[en].to_numpy()
nan_intervals[column] = list(zip(st, en))

# upsample and interpolate
position_df = (
position_df.reindex(index=new_index)
.interpolate(method=upsampling_interpolation_method)
.reindex(index=new_time)
)

# Fill NaN intervals
for column, intervals in nan_intervals.items():
for st, en in intervals:
position_df[column][st:en] = np.nan

return position_df
17 changes: 8 additions & 9 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def make(self, key):
position_info.index <= interval_end,
)
] = True
is_training[
position_info[position_variable_names].isna().values.max(axis=1)
] = False

if "is_training" not in decoding_kwargs:
decoding_kwargs["is_training"] = is_training

Expand Down Expand Up @@ -387,15 +391,10 @@ def fetch_linear_position_info(key):
edge_spacing=environment.edge_spacing,
)
min_time, max_time = SortedSpikesDecodingV1._get_interval_range(key)

return (
pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(
Expand Down
Loading