Skip to content

Commit

Permalink
cab
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 18, 2024
1 parent 1b8e08b commit aab5a79
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type:
training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
training_frame_annotator = f"{training_frame_name}{i}"
training_frame_annotator = f"{training_frame_name}_{i}"
skeleton_instances_list = []
for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
Expand Down Expand Up @@ -202,9 +202,9 @@ def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignor
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
for video in videos:
image_series = ImageSeries(
name=f"video_{i}",
name=name_generator("video"),
description="Video file",
unit="NA",
format="external",
Expand All @@ -224,7 +224,8 @@ def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
filename: The path to the SLEAP package.
labels: The SLEAP Labels object.
"""
assert filename.endswith(".pkg.slp")
if not filename.endswith(".pkg.slp"):
raise ValueError("The filename must end with '.pkg.slp'.")

path = filename.split(".slp")[0]
save_path = Path(path + ".nwb_images")
Expand Down Expand Up @@ -502,8 +503,19 @@ def append_nwb_training(
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
pose_training = labels_to_pose_training(labels)

for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
raise NotImplementedError


Expand Down

0 comments on commit aab5a79

Please sign in to comment.