Skip to content

Commit

Permalink
Fix labels export for json (#1410)
Browse files Browse the repository at this point in the history
* wip: fix labels export for json

* Add test for json.zip labels pkg

* Add test for .slp labels pkg

* Make linter happy
  • Loading branch information
roomrys authored Jul 27, 2023
1 parent 90c012d commit b2ad203
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 41 deletions.
34 changes: 22 additions & 12 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class CommandContext:
def from_labels(cls, labels: Labels) -> "CommandContext":
"""Creates a command context for use independently of GUI app."""
state = GuiState()
state["labels"] = labels
app = FakeApp(labels)
return cls(state=state, app=app)

Expand Down Expand Up @@ -1364,20 +1365,27 @@ def ask(context: CommandContext, params: dict) -> bool:


def export_dataset_gui(
labels: Labels, filename: str, all_labeled: bool = False, suggested: bool = False
labels: Labels,
filename: str,
all_labeled: bool = False,
suggested: bool = False,
verbose: bool = True,
) -> str:
"""Export dataset with image data and display progress GUI dialog.
Args:
labels: `sleap.Labels` dataset to export.
filename: Output filename. Should end in `.pkg.slp`.
all_labeled: If `True`, export all labeled frames, including frames with no user
instances.
suggested: If `True`, include image data for suggested frames.
instances. Defaults to `False`.
suggested: If `True`, include image data for suggested frames. Defaults to
`False`.
verbose: If `True`, display progress dialog. Defaults to `True`.
"""
win = QtWidgets.QProgressDialog(
"Exporting dataset with frame images...", "Cancel", 0, 1
)
if verbose:
win = QtWidgets.QProgressDialog(
"Exporting dataset with frame images...", "Cancel", 0, 1
)

def update_progress(n, n_total):
if win.wasCanceled():
Expand All @@ -1398,15 +1406,16 @@ def update_progress(n, n_total):
save_frame_data=True,
all_labeled=all_labeled,
suggested=suggested,
progress_callback=update_progress,
progress_callback=update_progress if verbose else None,
)

if win.wasCanceled():
# Delete output if saving was canceled.
os.remove(filename)
return "canceled"
if verbose:
if win.wasCanceled():
# Delete output if saving was canceled.
os.remove(filename)
return "canceled"

win.hide()
win.hide()

return filename

Expand All @@ -1422,6 +1431,7 @@ def do_action(cls, context: CommandContext, params: dict):
filename=params["filename"],
all_labeled=cls.all_labeled,
suggested=cls.suggested,
verbose=params.get("verbose", True),
)

@staticmethod
Expand Down
78 changes: 59 additions & 19 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import itertools
import os
from collections.abc import MutableSequence
from pathlib import Path
from typing import (
Callable,
List,
Expand Down Expand Up @@ -2219,7 +2220,12 @@ def from_deepposekit(
)

def save_frame_data_imgstore(
self, output_dir: str = "./", format: str = "png", all_labels: bool = False
self,
output_dir: str = "./",
format: str = "png",
all_labeled: bool = False,
suggested: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> List[ImgStoreVideo]:
"""Write images for labeled frames from all videos to imgstore datasets.
Expand All @@ -2232,28 +2238,55 @@ def save_frame_data_imgstore(
Use "png" for lossless, "jpg" for lossy.
Other imgstore formats will probably work as well but
have not been tested.
all_labels: Include any labeled frames, not just the frames
all_labeled: Include any labeled frames, not just the frames
we'll use for training (i.e., those with `Instance` objects ).
suggested: Include suggested frames even if they do not have instances.
Useful for inference after training. Defaults to `False`.
progress_callback: If provided, this function will be called to report the
progress of the frame data saving. This function should be a callable
of the form: `fn(n, n_total)` where `n` is the number of frames saved so
far and `n_total` is the total number of frames that will be saved. This
is called after each video is processed. If the function has a return
value and it returns `False`, saving will be canceled and the output
deleted.
Returns:
A list of :class:`ImgStoreVideo` objects with the stored
frames.
"""

# Lets gather all the suggestions by video
suggestion_frames_by_video = {video: [] for video in self.videos}
if suggested:
for suggestion in self.suggestions:
suggestion_frames_by_video[suggestion.video].append(
suggestion.frame_idx
)

# For each label
imgstore_vids = []
for v_idx, v in enumerate(self.videos):
frame_nums = [
lf.frame_idx
for lf in self.labeled_frames
if v == lf.video and (all_labels or lf.has_user_instances)
]
total_vids = len(self.videos)
for v_idx, video in enumerate(self.videos):
lfs_v = self.find(video)
frame_nums = {
lf.frame_idx for lf in lfs_v if all_labeled or lf.has_user_instances
}

if suggested:
frame_nums.update(suggestion_frames_by_video[video])

# Join with "/" instead of os.path.join() since we want
# path to work on Windows and Posix systems
frames_filename = output_dir + f"/frame_data_vid{v_idx}"
vid = v.to_imgstore(
path=frames_filename, frame_numbers=frame_nums, format=format
frames_fn = Path(output_dir, f"frame_data_vid{v_idx}")
vid = video.to_imgstore(
path=frames_fn.as_posix(), frame_numbers=frame_nums, format=format
)
if progress_callback is not None:
# Notify update callback.
ret = progress_callback(v_idx, total_vids)
if ret == False:
vid.close()
return []

# Close the video for now
vid.close()
Expand Down Expand Up @@ -2296,23 +2329,30 @@ def save_frame_data_hdf5(
Returns:
A list of :class:`HDF5Video` objects with the stored frames.
"""

# Lets gather all the suggestions by video
suggestion_frames_by_video = {video: [] for video in self.videos}
if suggested:
for suggestion in self.suggestions:
suggestion_frames_by_video[suggestion.video].append(
suggestion.frame_idx
)

# Build list of frames to save.
vids = []
frame_idxs = []
for video in self.videos:
lfs_v = self.find(video)
frame_nums = [
frame_nums = {
lf.frame_idx
for lf in lfs_v
if all_labeled or (user_labeled and lf.has_user_instances)
]
}

if suggested:
frame_nums += [
suggestion.frame_idx
for suggestion in self.suggestions
if suggestion.video == video
]
frame_nums = sorted(list(set(frame_nums)))
frame_nums.update(suggestion_frames_by_video[video])

frame_nums = sorted(list(frame_nums))
vids.append(video)
frame_idxs.append(frame_nums)

Expand Down
5 changes: 4 additions & 1 deletion sleap/io/format/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import attr
from pathlib import Path
from typing import List, Optional, Tuple, Union

from sleap.io.format.adaptor import Adaptor, SleapObjectType
Expand Down Expand Up @@ -77,7 +78,9 @@ def write(self, filename: str, source_object: object, *args, **kwargs):
if adaptor.can_write_filename(filename):
return adaptor.write(filename, source_object, *args, **kwargs)

raise TypeError("No file format adaptor could write this file.")
raise TypeError(
f"No file format adaptor could write this file: {Path(filename).name}."
)

def write_safely(self, *args, **kwargs) -> Optional[BaseException]:
"""Wrapper for writing file without throwing exception."""
Expand Down
17 changes: 14 additions & 3 deletions sleap/io/format/labels_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,11 @@ def write(
compress: Optional[bool] = None,
save_frame_data: bool = False,
frame_data_format: str = "png",
all_labeled: bool = False,
suggested: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
):
"""
Save a Labels instance to a JSON format.
"""Save a Labels instance to a JSON format.
Args:
filename: The filename to save the data to.
Expand Down Expand Up @@ -276,6 +278,11 @@ def write(
Note: 'h264/mkv' and 'avc1/mp4' require separate installation
of these codecs on your system. They are excluded from SLEAP
because of their GPL license.
all_labeled: Whether to save all frames or just the labeled frames to use in
training.
suggested: Whether to save the suggested labels along with the training
labels.
progress_callback: A function that will be called with the current progress.
Returns:
None
Expand All @@ -299,7 +306,11 @@ def write(
# of the videos. We will only include the labeled frames though. We will
# then replace each video with this new video
new_videos = labels.save_frame_data_imgstore(
output_dir=tmp_dir, format=frame_data_format
output_dir=tmp_dir,
format=frame_data_format,
all_labeled=all_labeled,
suggested=suggested,
progress_callback=progress_callback,
)

# Make video paths relative
Expand Down
90 changes: 84 additions & 6 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from pathlib import PurePath, Path
import pytest
import shutil
import sys
from typing import List
import time

import pytest
from qtpy.QtWidgets import QComboBox
from pathlib import PurePath, Path
from typing import List

from sleap import Skeleton, Track
from sleap import Skeleton, Track, PredictedInstance
from sleap.gui.commands import (
CommandContext,
ImportDeepLabCutFolder,
ExportAnalysisFile,
ExportDatasetWithImages,
ImportDeepLabCutFolder,
RemoveVideo,
ReplaceVideo,
OpenSkeleton,
Expand Down Expand Up @@ -826,3 +827,80 @@ def load_and_assert_changes(new_video_path: Path):
load_and_assert_changes(search_path)
finally: # Move video back to original location - for ease of re-testing
shutil.move(new_video_path, expected_video_path)


@pytest.mark.parametrize("export_extension", [".json.zip", ".slp"])
def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir):
def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False):
"""Assert that the loaded labels are similar to the original."""

# Load the labels, but first copy file to a location (which pytest can and will
# keep in memory, but won't affect our re-use of the original file name)
filename_for_pytest_to_hoard: Path = path_to_pkg.with_name(
f"pytest_labels_{time.perf_counter_ns()}{export_extension}"
)
shutil.copyfile(path_to_pkg.as_posix(), filename_for_pytest_to_hoard.as_posix())
labels_reload: Labels = Labels.load_file(
filename_for_pytest_to_hoard.as_posix()
)

assert len(labels_reload.labeled_frames) == len(centered_pair_labels)
assert len(labels_reload.videos) == len(centered_pair_labels.videos)
assert len(labels_reload.suggestions) == len(centered_pair_labels.suggestions)
assert len(labels_reload.tracks) == len(centered_pair_labels.tracks)
assert len(labels_reload.skeletons) == len(centered_pair_labels.skeletons)
assert (
len(
set(labels_reload.skeleton.node_names)
- set(centered_pair_labels.skeleton.node_names)
)
== 0
)
num_images = len(labels_reload)
if sugg:
num_images += len(lfs_sugg)
if not pred:
num_images -= len(lfs_pred)
assert labels_reload.video.num_frames == num_images

# Set-up CommandContext
path_to_pkg = Path(tmpdir, "test_exportLabelsPackage.ext")
path_to_pkg = path_to_pkg.with_suffix(export_extension)

def no_gui_ask(cls, context, params):
"""No GUI version of `ExportDatasetWithImages.ask`."""
params["filename"] = path_to_pkg.as_posix()
params["verbose"] = False
return True

ExportDatasetWithImages.ask = no_gui_ask

# Remove frames we want to use for suggestions and predictions
lfs_sugg = [centered_pair_labels[idx] for idx in [-1, -2]]
lfs_pred = [centered_pair_labels[idx] for idx in [-3, -4]]
centered_pair_labels.remove_frames(lfs_sugg)

# Add suggestions
for lf in lfs_sugg:
centered_pair_labels.add_suggestion(centered_pair_labels.video, lf.frame_idx)

# Add predictions and remove user instances from those frames
for lf in lfs_pred:
predicted_inst = PredictedInstance.from_instance(lf.instances[0], score=0.5)
centered_pair_labels.add_instance(lf, predicted_inst)
for inst in lf.user_instances:
centered_pair_labels.remove_instance(lf, inst)
context = CommandContext.from_labels(centered_pair_labels)

# Case 1: Export user-labeled frames with image data into a single SLP file.
context.exportUserLabelsPackage()
assert path_to_pkg.exists()
assert_loaded_package_similar(path_to_pkg)

# Case 2: Export user-labeled frames and suggested frames with image data.
context.exportTrainingPackage()
assert_loaded_package_similar(path_to_pkg, sugg=True)

# Case 3: Export all frames and suggested frames with image data.
context.exportFullPackage()
assert_loaded_package_similar(path_to_pkg, sugg=True, pred=True)

0 comments on commit b2ad203

Please sign in to comment.