Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 17, 2024
1 parent 85d47ab commit 1b8e08b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def save_nwb(
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
else:
nwb.write_nwb_training(labels, filename)
nwb.write_nwb(labels, filename, None, None, True)

else:
if append and Path(filename).exists():
Expand Down
19 changes: 11 additions & 8 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: igno
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
Expand Down Expand Up @@ -211,7 +211,7 @@ def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignor
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)
Expand All @@ -233,7 +233,6 @@ def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
img_path = save_path / f"frame_{i}.png"
imwrite(img_path, labeled_frame.image)
img_paths.append(img_path)



def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
Expand Down Expand Up @@ -491,13 +490,15 @@ def append_nwb_data(
return nwbfile


def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return]
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
Expand All @@ -507,8 +508,10 @@ def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type:


def append_nwb(
labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None,
as_training: Optional[bool] = None
labels: Labels,
filename: str,
pose_estimation_metadata: Optional[dict] = None,
as_training: Optional[bool] = None,
):
"""Append a SLEAP `Labels` object to an existing NWB data file.
Expand Down

0 comments on commit 1b8e08b

Please sign in to comment.