Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes for the DLC module #967

Merged
merged 16 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Upcoming

* Fixed bug with empty `NWBInterface` out of `DeeplabcutInterface` conversion; sped up `DeeplabcutInterface` conversion when timestamps are specified. [PR #967](https://github.com/catalystneuro/neuroconv/pull/967)

## v0.5.0 (July 17, 2024)

Expand Down
1 change: 1 addition & 0 deletions requirements-testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ ndx-miniscope
spikeinterface[qualitymetrics]>=0.100.0
zarr<2.18.0 # Error with Blosc (read-only during decode) in numcodecs on May 7; check later if resolved
pytest-xdist
ndx-pose
vigji marked this conversation as resolved.
Show resolved Hide resolved
vigji marked this conversation as resolved.
Show resolved Hide resolved
73 changes: 46 additions & 27 deletions src/neuroconv/datainterfaces/behavior/deeplabcut/_dlc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib
import os
import pickle
import warnings
from pathlib import Path
Expand All @@ -8,6 +7,7 @@
import numpy as np
import pandas as pd
import yaml
from ndx_pose import PoseEstimation, PoseEstimationSeries
from pynwb import NWBFile
from ruamel.yaml import YAML

Expand All @@ -20,11 +20,11 @@ def _read_config(config_file_path):
"""
ruamelFile = YAML()
path = Path(config_file_path)
if os.path.exists(path):
if path.exists():
try:
with open(path, "r") as f:
cfg = ruamelFile.load(f)
curr_dir = os.path.dirname(config_file_path)
curr_dir = config_file_path.parent
if cfg["project_path"] != curr_dir:
cfg["project_path"] = curr_dir
except Exception as err:
Expand Down Expand Up @@ -58,14 +58,11 @@ def _get_movie_timestamps(movie_file, VARIABILITYBOUND=1000, infer_timestamps=Tr
n_frames = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
fps = reader.get(cv2.CAP_PROP_FPS)

warnings.warn("Inferring timestamps from video. This might take a while (to speed up, set timestamps)")
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(n_frames):
_ = reader.read()
timestamps.append(reader.get(cv2.CAP_PROP_POS_MSEC))

for _ in range(len(reader)):
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
_ = reader.read()
timestamps.append(reader.get(cv2.CAP_PROP_POS_MSEC))

timestamps = np.array(timestamps) / 1000 # Convert to seconds

if np.nanvar(np.diff(timestamps)) < 1.0 / fps * 1.0 / VARIABILITYBOUND:
Expand Down Expand Up @@ -110,39 +107,51 @@ def _infer_nan_timestamps(timestamps):
return timestamps


def _ensure_individuals_in_header(df, dummy_name):
def _ensure_individuals_in_header(df, individual_name):
if "individuals" not in df.columns.names:
# Single animal project -> add individual row to
# the header of single animal projects.
temp = pd.concat({dummy_name: df}, names=["individuals"], axis=1)
temp = pd.concat({individual_name: df}, names=["individuals"], axis=1)
df = temp.reorder_levels(["scorer", "individuals", "bodyparts", "coords"], axis=1)
return df


def _get_pes_args(config_file, h5file, individual_name, infer_timestamps=True):
if "DLC" not in h5file or not h5file.endswith(".h5"):
def _get_pes_args(
*,
config_file: Path,
h5file: Path,
individual_name: str,
timestamps_available: bool = False,
infer_timestamps: bool = True,
):
config_file = Path(config_file)
h5file = Path(h5file)

if "DLC" not in h5file.name or not h5file.suffix == ".h5":
raise IOError("The file passed in is not a DeepLabCut h5 data file.")

cfg = _read_config(config_file)

vidname, scorer = os.path.split(h5file)[-1].split("DLC")
scorer = "DLC" + os.path.splitext(scorer)[0]
vidname, scorer = h5file.stem.split("DLC")
scorer = "DLC" + scorer
video = None

df = _ensure_individuals_in_header(pd.read_hdf(h5file), individual_name)

# Fetch the corresponding metadata pickle file
paf_graph = []
filename, _ = os.path.splitext(h5file)
filename = str(h5file.parent / h5file.stem)
for i, c in enumerate(filename[::-1]):
if c.isnumeric():
break
if i > 0:
filename = filename[:-i]
metadata_file = filename + "_meta.pickle"
if os.path.isfile(metadata_file):
metadata_file = Path(filename + "_meta.pickle")

if metadata_file.exists():
with open(metadata_file, "rb") as file:
metadata = pickle.load(file)

test_cfg = metadata["data"]["DLC-model-config file"]
paf_graph = test_cfg.get("partaffinityfield_graph", [])
if paf_graph:
Expand All @@ -157,13 +166,19 @@ def _get_pes_args(config_file, h5file, individual_name, infer_timestamps=True):
video = video_path, params["crop"]
break

# find timestamps only if required:
if timestamps_available:
timestamps = None
else:
if video is None:
timestamps = df.index.tolist() # setting timestamps to dummy TODO: extract timestamps in DLC?
else:
timestamps = _get_movie_timestamps(video[0], infer_timestamps=infer_timestamps)

if video is None:
warnings.warn(f"The video file corresponding to {h5file} could not be found...")
video = "fake_path", "0, 0, 0, 0"

timestamps = df.index.tolist() # setting timestamps to dummy TODO: extract timestamps in DLC?
else:
timestamps = _get_movie_timestamps(video[0], infer_timestamps=infer_timestamps)
return scorer, df, video, paf_graph, timestamps, cfg


Expand All @@ -178,13 +193,11 @@ def _write_pes_to_nwbfile(
exclude_nans,
pose_estimation_container_kwargs: Optional[dict] = None,
):
from ndx_pose import PoseEstimation, PoseEstimationSeries
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved

pose_estimation_container_kwargs = pose_estimation_container_kwargs or dict()

pose_estimation_series = []
for kpt, xyp in df_animal.groupby(level="bodyparts", axis=1, sort=False):
data = xyp.to_numpy()
for keypoint in df_animal.columns.get_level_values("bodyparts").unique():
data = df_animal.xs(keypoint, level="bodyparts", axis=1).to_numpy()

if exclude_nans:
# exclude_nans is inverse infer_timestamps. if not infer, there may be nans
Expand All @@ -194,8 +207,8 @@ def _write_pes_to_nwbfile(
timestamps_cleaned = timestamps

pes = PoseEstimationSeries(
name=f"{animal}_{kpt}",
description=f"Keypoint {kpt} from individual {animal}.",
name=f"{animal}_{keypoint}" if animal else keypoint,
description=f"Keypoint {keypoint} from individual {animal}.",
data=data[:, :2],
unit="pixels",
reference_frame="(0,0) corresponds to the bottom left corner of the video.",
Expand Down Expand Up @@ -269,11 +282,17 @@ def add_subject_to_nwbfile(
nwbfile : pynwb.NWBFile
nwbfile with pes written in the behavior module
"""
scorer, df, video, paf_graph, dlc_timestamps, _ = _get_pes_args(config_file, h5file, individual_name)
timestamps_available = timestamps is not None
scorer, df, video, paf_graph, dlc_timestamps, _ = _get_pes_args(
config_file=config_file,
h5file=h5file,
individual_name=individual_name,
timestamps_available=timestamps_available,
)
if timestamps is None:
timestamps = dlc_timestamps

df_animal = df.groupby(level="individuals", axis=1).get_group(individual_name)
df_animal = df.xs(individual_name, level="individuals", axis=1)

return _write_pes_to_nwbfile(
nwbfile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from pynwb.file import NWBFile

from ._dlc_utils import add_subject_to_nwbfile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this line here indirectly makes it exposed since a parent __init__.py somewhere along the chain will do a directed import from this public submodule

Meaning anyone working in a minimal installation, or an installation using extras that do not include DLC, will encounter import errors

Copy link
Contributor Author

@vigji vigji Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, apparently this is not the crucial change as long as imports in ._dlc_utils happen at the top of the file

from ....basetemporalalignmentinterface import BaseTemporalAlignmentInterface
from ....utils import FilePathType

Expand Down Expand Up @@ -103,7 +104,6 @@ def add_to_nwbfile(
metadata: dict
metadata info for constructing the nwb file (optional).
"""
from ._dlc_utils import add_subject_to_nwbfile

add_subject_to_nwbfile(
nwbfile=nwbfile,
Expand Down
Loading