diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 2d397964e..decad1946 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -201,6 +201,7 @@ class CommandContext: def from_labels(cls, labels: Labels) -> "CommandContext": """Creates a command context for use independently of GUI app.""" state = GuiState() + state["labels"] = labels app = FakeApp(labels) return cls(state=state, app=app) @@ -1364,7 +1365,11 @@ def ask(context: CommandContext, params: dict) -> bool: def export_dataset_gui( - labels: Labels, filename: str, all_labeled: bool = False, suggested: bool = False + labels: Labels, + filename: str, + all_labeled: bool = False, + suggested: bool = False, + verbose: bool = True, ) -> str: """Export dataset with image data and display progress GUI dialog. @@ -1372,12 +1377,15 @@ def export_dataset_gui( labels: `sleap.Labels` dataset to export. filename: Output filename. Should end in `.pkg.slp`. all_labeled: If `True`, export all labeled frames, including frames with no user - instances. - suggested: If `True`, include image data for suggested frames. + instances. Defaults to `False`. + suggested: If `True`, include image data for suggested frames. Defaults to + `False`. + verbose: If `True`, display progress dialog. Defaults to `True`. """ - win = QtWidgets.QProgressDialog( - "Exporting dataset with frame images...", "Cancel", 0, 1 - ) + if verbose: + win = QtWidgets.QProgressDialog( + "Exporting dataset with frame images...", "Cancel", 0, 1 + ) def update_progress(n, n_total): if win.wasCanceled(): @@ -1398,15 +1406,16 @@ def update_progress(n, n_total): save_frame_data=True, all_labeled=all_labeled, suggested=suggested, - progress_callback=update_progress, + progress_callback=update_progress if verbose else None, ) - if win.wasCanceled(): - # Delete output if saving was canceled. - os.remove(filename) - return "canceled" + if verbose: + if win.wasCanceled(): + # Delete output if saving was canceled. + os.remove(filename) + return "canceled" - win.hide() + win.hide() return filename @@ -1422,6 +1431,7 @@ def do_action(cls, context: CommandContext, params: dict): filename=params["filename"], all_labeled=cls.all_labeled, suggested=cls.suggested, + verbose=params.get("verbose", True), ) @staticmethod diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1ba320054..652c931e9 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -40,6 +40,7 @@ import itertools import os from collections.abc import MutableSequence +from pathlib import Path from typing import ( Callable, List, @@ -2219,7 +2220,12 @@ def from_deepposekit( ) def save_frame_data_imgstore( - self, output_dir: str = "./", format: str = "png", all_labels: bool = False + self, + output_dir: str = "./", + format: str = "png", + all_labeled: bool = False, + suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> List[ImgStoreVideo]: """Write images for labeled frames from all videos to imgstore datasets. @@ -2232,28 +2238,55 @@ def save_frame_data_imgstore( Use "png" for lossless, "jpg" for lossy. Other imgstore formats will probably work as well but have not been tested. - all_labels: Include any labeled frames, not just the frames + all_labeled: Include any labeled frames, not just the frames we'll use for training (i.e., those with `Instance` objects ). + suggested: Include suggested frames even if they do not have instances. + Useful for inference after training. Defaults to `False`. + progress_callback: If provided, this function will be called to report the + progress of the frame data saving. This function should be a callable + of the form: `fn(n, n_total)` where `n` is the number of frames saved so + far and `n_total` is the total number of frames that will be saved. This + is called after each video is processed. If the function has a return + value and it returns `False`, saving will be canceled and the output + deleted. Returns: A list of :class:`ImgStoreVideo` objects with the stored frames. """ + + # Lets gather all the suggestions by video + suggestion_frames_by_video = {video: [] for video in self.videos} + if suggested: + for suggestion in self.suggestions: + suggestion_frames_by_video[suggestion.video].append( + suggestion.frame_idx + ) + # For each label imgstore_vids = [] - for v_idx, v in enumerate(self.videos): - frame_nums = [ - lf.frame_idx - for lf in self.labeled_frames - if v == lf.video and (all_labels or lf.has_user_instances) - ] + total_vids = len(self.videos) + for v_idx, video in enumerate(self.videos): + lfs_v = self.find(video) + frame_nums = { + lf.frame_idx for lf in lfs_v if all_labeled or lf.has_user_instances + } + + if suggested: + frame_nums.update(suggestion_frames_by_video[video]) # Join with "/" instead of os.path.join() since we want # path to work on Windows and Posix systems - frames_filename = output_dir + f"/frame_data_vid{v_idx}" - vid = v.to_imgstore( - path=frames_filename, frame_numbers=frame_nums, format=format + frames_fn = Path(output_dir, f"frame_data_vid{v_idx}") + vid = video.to_imgstore( + path=frames_fn.as_posix(), frame_numbers=frame_nums, format=format ) + if progress_callback is not None: + # Notify update callback. + ret = progress_callback(v_idx, total_vids) + if ret == False: + vid.close() + return [] # Close the video for now vid.close() @@ -2296,23 +2329,30 @@ def save_frame_data_hdf5( Returns: A list of :class:`HDF5Video` objects with the stored frames. """ + + # Lets gather all the suggestions by video + suggestion_frames_by_video = {video: [] for video in self.videos} + if suggested: + for suggestion in self.suggestions: + suggestion_frames_by_video[suggestion.video].append( + suggestion.frame_idx + ) + # Build list of frames to save. vids = [] frame_idxs = [] for video in self.videos: lfs_v = self.find(video) - frame_nums = [ + frame_nums = { lf.frame_idx for lf in lfs_v if all_labeled or (user_labeled and lf.has_user_instances) - ] + } + if suggested: - frame_nums += [ - suggestion.frame_idx - for suggestion in self.suggestions - if suggestion.video == video - ] - frame_nums = sorted(list(set(frame_nums))) + frame_nums.update(suggestion_frames_by_video[video]) + + frame_nums = sorted(list(frame_nums)) vids.append(video) frame_idxs.append(frame_nums) diff --git a/sleap/io/format/dispatch.py b/sleap/io/format/dispatch.py index e4803a87d..43f879627 100644 --- a/sleap/io/format/dispatch.py +++ b/sleap/io/format/dispatch.py @@ -5,6 +5,7 @@ """ import attr +from pathlib import Path from typing import List, Optional, Tuple, Union from sleap.io.format.adaptor import Adaptor, SleapObjectType @@ -77,7 +78,9 @@ def write(self, filename: str, source_object: object, *args, **kwargs): if adaptor.can_write_filename(filename): return adaptor.write(filename, source_object, *args, **kwargs) - raise TypeError("No file format adaptor could write this file.") + raise TypeError( + f"No file format adaptor could write this file: {Path(filename).name}." + ) def write_safely(self, *args, **kwargs) -> Optional[BaseException]: """Wrapper for writing file without throwing exception.""" diff --git a/sleap/io/format/labels_json.py b/sleap/io/format/labels_json.py index 50fa7d18d..f284731a6 100644 --- a/sleap/io/format/labels_json.py +++ b/sleap/io/format/labels_json.py @@ -241,9 +241,11 @@ def write( compress: Optional[bool] = None, save_frame_data: bool = False, frame_data_format: str = "png", + all_labeled: bool = False, + suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ): - """ - Save a Labels instance to a JSON format. + """Save a Labels instance to a JSON format. Args: filename: The filename to save the data to. @@ -276,6 +278,11 @@ def write( Note: 'h264/mkv' and 'avc1/mp4' require separate installation of these codecs on your system. They are excluded from SLEAP because of their GPL license. + all_labeled: Whether to save all frames or just the labeled frames to use in + training. + suggested: Whether to save the suggested labels along with the training + labels. + progress_callback: A function that will be called with the current progress. Returns: None @@ -299,7 +306,11 @@ def write( # of the videos. We will only include the labeled frames though. We will # then replace each video with this new video new_videos = labels.save_frame_data_imgstore( - output_dir=tmp_dir, format=frame_data_format + output_dir=tmp_dir, + format=frame_data_format, + all_labeled=all_labeled, + suggested=suggested, + progress_callback=progress_callback, ) # Make video paths relative diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 6f1ed7cd3..fa3ff3d9c 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,16 +1,17 @@ -from pathlib import PurePath, Path +import pytest import shutil import sys -from typing import List +import time -import pytest -from qtpy.QtWidgets import QComboBox +from pathlib import PurePath, Path +from typing import List -from sleap import Skeleton, Track +from sleap import Skeleton, Track, PredictedInstance from sleap.gui.commands import ( CommandContext, - ImportDeepLabCutFolder, ExportAnalysisFile, + ExportDatasetWithImages, + ImportDeepLabCutFolder, RemoveVideo, ReplaceVideo, OpenSkeleton, @@ -826,3 +827,80 @@ def load_and_assert_changes(new_video_path: Path): load_and_assert_changes(search_path) finally: # Move video back to original location - for ease of re-testing shutil.move(new_video_path, expected_video_path) + + +@pytest.mark.parametrize("export_extension", [".json.zip", ".slp"]) +def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir): + def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False): + """Assert that the loaded labels are similar to the original.""" + + # Load the labels, but first copy file to a location (which pytest can and will + # keep in memory, but won't affect our re-use of the original file name) + filename_for_pytest_to_hoard: Path = path_to_pkg.with_name( + f"pytest_labels_{time.perf_counter_ns()}{export_extension}" + ) + shutil.copyfile(path_to_pkg.as_posix(), filename_for_pytest_to_hoard.as_posix()) + labels_reload: Labels = Labels.load_file( + filename_for_pytest_to_hoard.as_posix() + ) + + assert len(labels_reload.labeled_frames) == len(centered_pair_labels) + assert len(labels_reload.videos) == len(centered_pair_labels.videos) + assert len(labels_reload.suggestions) == len(centered_pair_labels.suggestions) + assert len(labels_reload.tracks) == len(centered_pair_labels.tracks) + assert len(labels_reload.skeletons) == len(centered_pair_labels.skeletons) + assert ( + len( + set(labels_reload.skeleton.node_names) + - set(centered_pair_labels.skeleton.node_names) + ) + == 0 + ) + num_images = len(labels_reload) + if sugg: + num_images += len(lfs_sugg) + if not pred: + num_images -= len(lfs_pred) + assert labels_reload.video.num_frames == num_images + + # Set-up CommandContext + path_to_pkg = Path(tmpdir, "test_exportLabelsPackage.ext") + path_to_pkg = path_to_pkg.with_suffix(export_extension) + + def no_gui_ask(cls, context, params): + """No GUI version of `ExportDatasetWithImages.ask`.""" + params["filename"] = path_to_pkg.as_posix() + params["verbose"] = False + return True + + ExportDatasetWithImages.ask = no_gui_ask + + # Remove frames we want to use for suggestions and predictions + lfs_sugg = [centered_pair_labels[idx] for idx in [-1, -2]] + lfs_pred = [centered_pair_labels[idx] for idx in [-3, -4]] + centered_pair_labels.remove_frames(lfs_sugg) + + # Add suggestions + for lf in lfs_sugg: + centered_pair_labels.add_suggestion(centered_pair_labels.video, lf.frame_idx) + + # Add predictions and remove user instances from those frames + for lf in lfs_pred: + predicted_inst = PredictedInstance.from_instance(lf.instances[0], score=0.5) + centered_pair_labels.add_instance(lf, predicted_inst) + for inst in lf.user_instances: + centered_pair_labels.remove_instance(lf, inst) + context = CommandContext.from_labels(centered_pair_labels) + + # Case 1: Export user-labeled frames with image data into a single SLP file. + context.exportUserLabelsPackage() + assert path_to_pkg.exists() + assert_loaded_package_similar(path_to_pkg) + + # Case 2: Export user-labeled frames and suggested frames with image data. + context.exportTrainingPackage() + assert_loaded_package_similar(path_to_pkg, sugg=True) + + # Case 3: Export all frames and suggested frames with image data. + context.exportFullPackage() + assert_loaded_package_similar(path_to_pkg, sugg=True, pred=True)