Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upsampling option to PositionGroup #1008

Merged
merged 9 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

```python
from spyglass.common.common_behav import PositionIntervalMap
from spyglass.decoding.v1.core import PositionGroup

PositionIntervalMap.alter()
PositionGroup.alter()
```

### Infrastructure
Expand Down Expand Up @@ -42,7 +44,9 @@ PositionIntervalMap.alter()
- Remove redundant calls to tables in `populate_all_common` #870
- Improve logging clarity in `populate_all_common` #870
- `PositionIntervalMap` now inserts null entries for missing intervals #870
- Decoding: Default values for classes on `ImportError` #966
- Decoding:
- Default values for classes on `ImportError` #966
- Add option to upsample data rate in `PositionGroup` #1008
- Position
- Allow dlc without pre-existing tracking data #973, #975
- Raise `KeyError` for missing input parameters across helper funcs #966
Expand Down
5 changes: 3 additions & 2 deletions notebooks/41_Decoding_Clusterless.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@
"\n",
"We use the the `PositionOutput` table to figure out the `merge_id` associated with `nwb_file_name` to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions.\n",
"\n",
"Note that the position data sampling frequency is what determines the time step of the decoding. In this case, the position data sampling frequency is 30 Hz, so the time step of the decoding will be 1/30 seconds. In practice, you will want to use a smaller time step such as 500 Hz. This will allow you to decode at a finer time scale. To do this, you will want to interpolate the position data to a higher sampling frequency as shown in the [position trodes notebook](./20_Position_Trodes.ipynb).\n",
"Note that we can use the `upsample_rate` parameter to define the rate to which position data will be upsampled to to for decoding in Hz. This is useful if we want to decode at a finer time scale than the position data sampling frequency. In practice, a value of 500Hz is used in many analyses. Skipping or providing a null value for this parameter will default to using the position sampling rate.\n",
"\n",
"You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`."
]
Expand Down Expand Up @@ -981,6 +981,7 @@
" nwb_file_name=nwb_copy_file_name,\n",
" group_name=\"test_group\",\n",
" keys=[{\"pos_merge_id\": merge_id} for merge_id in position_merge_ids],\n",
" upsample_rate=500,\n",
")\n",
"\n",
"PositionGroup & {\n",
Expand Down Expand Up @@ -2956,7 +2957,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.-1"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion notebooks/py_scripts/41_Decoding_Clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
#
# We use the the `PositionOutput` table to figure out the `merge_id` associated with `nwb_file_name` to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions.
#
# Note that the position data sampling frequency is what determines the time step of the decoding. In this case, the position data sampling frequency is 30 Hz, so the time step of the decoding will be 1/30 seconds. In practice, you will want to use a smaller time step such as 500 Hz. This will allow you to decode at a finer time scale. To do this, you will want to interpolate the position data to a higher sampling frequency as shown in the [position trodes notebook](./20_Position_Trodes.ipynb).
# Note that we can use the `upsample_rate` parameter to define the rate to which position data will be upsampled to to for decoding in Hz. This is useful if we want to decode at a finer time scale than the position data sampling frequency. In practice, a value of 500Hz is used in many analyses. Skipping or providing a null value for this parameter will default to using the position sampling rate.
#
# You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`.

Expand Down Expand Up @@ -181,6 +181,7 @@
nwb_file_name=nwb_copy_file_name,
group_name="test_group",
keys=[{"pos_merge_id": merge_id} for merge_id in position_merge_ids],
upsample_rate=500,
)

PositionGroup & {
Expand Down
19 changes: 4 additions & 15 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,22 +385,11 @@ def fetch_position_info(key):
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
}
position_variable_names = (PositionGroup & position_group_key).fetch1(
"position_variables"
)

position_info = []
for pos_merge_id in (PositionGroup.Position & position_group_key).fetch(
"pos_merge_id"
):
position_info.append(
(PositionOutput & {"merge_id": pos_merge_id}).fetch1_dataframe()
)

min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)
position_info = (
pd.concat(position_info, axis=0).loc[min_time:max_time].dropna()
)
position_info, position_variable_names = (
PositionGroup & position_group_key
).fetch_position_info(min_time=min_time, max_time=max_time)

return position_info, position_variable_names

Expand Down Expand Up @@ -441,7 +430,7 @@ def fetch_linear_position_info(key):
axis=1,
)
.loc[min_time:max_time]
.dropna()
.dropna(subset=position_variable_names)
)

@staticmethod
Expand Down
110 changes: 110 additions & 0 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datajoint as dj
import numpy as np
import pandas as pd
from non_local_detector import (
ContFragClusterlessClassifier,
ContFragSortedSpikesClassifier,
Expand Down Expand Up @@ -92,6 +94,7 @@ class PositionGroup(SpyglassMixin, dj.Manual):
position_group_name: varchar(80)
----
position_variables = NULL: longblob # list of position variables to decode
upsample_rate = NULL: float # upsampling rate for position data (Hz)
"""

class Position(SpyglassMixinPart):
Expand All @@ -106,6 +109,7 @@ def create_group(
group_name: str,
keys: list[dict],
position_variables: list[str] = ["position_x", "position_y"],
upsample_rate: float = np.nan,
):
group_key = {
"nwb_file_name": nwb_file_name,
Expand All @@ -115,6 +119,7 @@ def create_group(
{
**group_key,
"position_variables": position_variables,
"upsample_rate": upsample_rate,
},
skip_duplicates=True,
)
Expand All @@ -126,3 +131,108 @@ def create_group(
},
skip_duplicates=True,
)

def fetch_position_info(
self, key: dict = None, min_time: float = None, max_time: float = None
) -> tuple[pd.DataFrame, list[str]]:
"""fetch position information for decoding

Parameters
----------
key : dict, optional
restriction to a single entry in PositionGroup, by default None
min_time : float, optional
restrict position information to times greater than min_time, by default None
max_time : float, optional
restrict position information to times less than max_time, by default None

Returns
-------
tuple[pd.DataFrame, list[str]]
position information and names of position variables
"""
if key is None:
key = {}
key = (self & key).fetch1("KEY")
position_variable_names = (self & key).fetch1("position_variables")

position_info = []
upsample_rate = (self & key).fetch1("upsample_rate")
for pos_merge_id in (self.Position & key).fetch("pos_merge_id"):
if not np.isnan(upsample_rate):
position_info.append(
self._upsample(
(
PositionOutput & {"merge_id": pos_merge_id}
).fetch1_dataframe(),
upsampling_sampling_rate=upsample_rate,
)
)
else:
position_info.append(
(
PositionOutput & {"merge_id": pos_merge_id}
).fetch1_dataframe()
)

if min_time is None:
min_time = min([df.index.min() for df in position_info])
if max_time is None:
max_time = max([df.index.max() for df in position_info])
position_info = (
pd.concat(position_info, axis=0)
.loc[min_time:max_time]
.dropna(subset=position_variable_names)
)

return position_info, position_variable_names

@staticmethod
def _upsample(
position_df: pd.DataFrame,
upsampling_sampling_rate: float,
upsampling_interpolation_method: str = "linear",
) -> pd.DataFrame:
"""upsample position data to a fixed sampling rate

Parameters
----------
position_df : pd.DataFrame
dataframe containing position data
upsampling_sampling_rate : float
sampling rate to upsample to
upsampling_interpolation_method : str, optional
pandas method for interpolation, by default "linear"

Returns
-------
pd.DataFrame
upsampled position data
"""

upsampling_start_time = position_df.index[0]
upsampling_end_time = position_df.index[-1]

n_samples = (
int(
np.ceil(
(upsampling_end_time - upsampling_start_time)
* upsampling_sampling_rate
)
)
+ 1
)
new_time = np.linspace(
upsampling_start_time, upsampling_end_time, n_samples
)
new_index = pd.Index(
np.unique(np.concatenate((position_df.index, new_time))),
name="time",
)
position_df = (
position_df.reindex(index=new_index)
.interpolate(method=upsampling_interpolation_method)
.reindex(index=new_time)
)

return position_df
20 changes: 6 additions & 14 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,20 +349,12 @@ def fetch_position_info(key):
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
}
position_variable_names = (PositionGroup & position_group_key).fetch1(
"position_variables"
)

position_info = []
for pos_merge_id in (PositionGroup.Position & position_group_key).fetch(
"pos_merge_id"
):
position_info.append(
(PositionOutput & {"merge_id": pos_merge_id}).fetch1_dataframe()
)
min_time, max_time = SortedSpikesDecodingV1._get_interval_range(key)
position_info = (
pd.concat(position_info, axis=0).loc[min_time:max_time].dropna()
position_info, position_variable_names = (
PositionGroup & position_group_key
).fetch_position_info(
min_time=min_time,
max_time=max_time,
)

return position_info, position_variable_names
Expand Down Expand Up @@ -402,7 +394,7 @@ def fetch_linear_position_info(key):
axis=1,
)
.loc[min_time:max_time]
.dropna()
.dropna(subset=position_variable_names)
)

@staticmethod
Expand Down
Loading