diff --git a/docs/guides/cli.md b/docs/guides/cli.md index c29270299..ab62f3130 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -36,8 +36,8 @@ optional arguments: ```none usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS] - [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] - [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] + [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] + [--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] [--suffix SUFFIX] training_job_path [labels_path] @@ -68,6 +68,8 @@ optional arguments: --save_viz Enable saving of prediction visualizations to the run folder if not already specified in the training job config. + --keep_viz Keep prediction visualization images in the run + folder after training if --save_viz is enabled. --zmq Enable ZMQ logging (for GUI) if not already specified in the training job config. --run_name RUN_NAME Run name to use when saving file, overrides other run diff --git a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb index b0211bbca..4e26cb286 100644 --- a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb +++ b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb @@ -335,7 +335,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", @@ -727,7 +727,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index be9e272c7..c730fa9c4 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -286,6 +286,11 @@ training: type: bool default: true +- name: _keep_viz + label: Keep Prediction Visualization Images After Training + type: bool + default: false + - name: _predict_frames label: Predict On type: list diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index a2e84788c..7569607a0 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -1,4 +1,5 @@ """Run training/inference in background process via CLI.""" + import abc import attr import os @@ -500,9 +501,11 @@ def write_pipeline_files( "data_path": os.path.basename(data_path), "models": [Path(p).as_posix() for p in new_cfg_filenames], "output_path": prediction_output_path, - "type": "labels" - if type(item_for_inference) == DatasetItemForInference - else "video", + "type": ( + "labels" + if type(item_for_inference) == DatasetItemForInference + else "video" + ), "only_suggested_frames": only_suggested_frames, "tracking": tracking_args, } @@ -544,6 +547,7 @@ def run_learning_pipeline( """ save_viz = inference_params.get("_save_viz", False) + keep_viz = inference_params.get("_keep_viz", False) if "movenet" in inference_params["_pipeline"]: trained_job_paths = [inference_params["_pipeline"]] @@ -557,6 +561,7 @@ def run_learning_pipeline( inference_params=inference_params, gui=True, save_viz=save_viz, + keep_viz=keep_viz, ) # Check that all the models were trained @@ -585,6 +590,7 @@ def run_gui_training( inference_params: Dict[str, Any], gui: bool = True, save_viz: bool = False, + keep_viz: bool = False, ) -> Dict[Text, Text]: """ Runs training for each training job. @@ -594,6 +600,7 @@ def run_gui_training( config_info_list: List of ConfigFileInfo with configs for training. gui: Whether to show gui windows and process gui events. save_viz: Whether to save visualizations from training. + keep_viz: Whether to keep prediction visualization images after training. Returns: Dictionary, keys are head name, values are path to trained config. @@ -683,6 +690,7 @@ def waiting(): video_paths=video_path_list, waiting_callback=waiting, save_viz=save_viz, + keep_viz=keep_viz, ) if ret == "success": @@ -825,6 +833,7 @@ def train_subprocess( video_paths: Optional[List[Text]] = None, waiting_callback: Optional[Callable] = None, save_viz: bool = False, + keep_viz: bool = False, ): """Runs training inside subprocess.""" run_path = job_config.outputs.run_path @@ -853,6 +862,8 @@ def train_subprocess( if save_viz: cli_args.append("--save_viz") + if keep_viz: + cli_args.append("--keep_viz") # Use cli arg since cli ignores setting in config if job_config.outputs.tensorboard.write_logs: diff --git a/sleap/nn/config/outputs.py b/sleap/nn/config/outputs.py index ffb0d76e4..ccb6077b1 100644 --- a/sleap/nn/config/outputs.py +++ b/sleap/nn/config/outputs.py @@ -151,8 +151,8 @@ class OutputsConfig: save_visualizations: If True, will render and save visualizations of the model predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the split is one of "train", "validation", "test". - delete_viz_images: If True, delete the saved visualizations after training - completes. This is useful to reduce the model folder size if you do not need + keep_viz_images: If True, keep the saved visualization images after training + completes. This is useful unchecked to reduce the model folder size if you do not need to keep the visualization images. zip_outputs: If True, compress the run folder to a zip file. This will be named "{run_folder}.zip". @@ -170,7 +170,7 @@ class OutputsConfig: runs_folder: Text = "models" tags: List[Text] = attr.ib(factory=list) save_visualizations: bool = True - delete_viz_images: bool = True + keep_viz_images: bool = False zip_outputs: bool = False log_to_csv: bool = True checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 6a64e43b6..9e4245b88 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -946,7 +946,7 @@ def train(self): if self.config.outputs.save_outputs: if ( self.config.outputs.save_visualizations - and self.config.outputs.delete_viz_images + and not self.config.outputs.keep_viz_images ): self.cleanup() @@ -997,7 +997,7 @@ def cleanup(self): def package(self): """Package model folder into a zip file for portability.""" - if self.config.outputs.delete_viz_images: + if not self.config.outputs.keep_viz_images: self.cleanup() logger.info(f"Packaging results to: {self.run_path}.zip") shutil.make_archive( @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None): "already specified in the training job config." ), ) + parser.add_argument( + "--keep_viz", + action="store_true", + help=( + "Keep prediction visualization images in the run folder after training when " + "--save_viz is enabled." + ), + ) parser.add_argument( "--zmq", action="store_true", @@ -1949,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None): if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix job_config.outputs.save_visualizations |= args.save_viz + job_config.outputs.keep_viz_images = args.keep_viz if args.labels_path == "": args.labels_path = None args.video_paths = args.video_paths.split(",") diff --git a/sleap/training_profiles/baseline.centroid.json b/sleap/training_profiles/baseline.centroid.json index 933989ecf..3a54db25c 100755 --- a/sleap/training_profiles/baseline.centroid.json +++ b/sleap/training_profiles/baseline.centroid.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.bottomup.json b/sleap/training_profiles/baseline_large_rf.bottomup.json index ea45c9b25..18fb3104f 100644 --- a/sleap/training_profiles/baseline_large_rf.bottomup.json +++ b/sleap/training_profiles/baseline_large_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.single.json b/sleap/training_profiles/baseline_large_rf.single.json index 75e97b1a6..3feeccd69 100644 --- a/sleap/training_profiles/baseline_large_rf.single.json +++ b/sleap/training_profiles/baseline_large_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.topdown.json b/sleap/training_profiles/baseline_large_rf.topdown.json index 9b17f6832..38e96594b 100644 --- a/sleap/training_profiles/baseline_large_rf.topdown.json +++ b/sleap/training_profiles/baseline_large_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.bottomup.json b/sleap/training_profiles/baseline_medium_rf.bottomup.json index 1cc35330a..61b08515c 100644 --- a/sleap/training_profiles/baseline_medium_rf.bottomup.json +++ b/sleap/training_profiles/baseline_medium_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.single.json b/sleap/training_profiles/baseline_medium_rf.single.json index 579f6c8c3..0951bc761 100644 --- a/sleap/training_profiles/baseline_medium_rf.single.json +++ b/sleap/training_profiles/baseline_medium_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.topdown.json b/sleap/training_profiles/baseline_medium_rf.topdown.json index 9e3a0bde5..9eccb76c1 100755 --- a/sleap/training_profiles/baseline_medium_rf.topdown.json +++ b/sleap/training_profiles/baseline_medium_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.bottomup.json b/sleap/training_profiles/pretrained.bottomup.json index 3e4f3935f..57b7398b5 100644 --- a/sleap/training_profiles/pretrained.bottomup.json +++ b/sleap/training_profiles/pretrained.bottomup.json @@ -122,6 +122,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.centroid.json b/sleap/training_profiles/pretrained.centroid.json index a5df5e48a..74c43d3e2 100644 --- a/sleap/training_profiles/pretrained.centroid.json +++ b/sleap/training_profiles/pretrained.centroid.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.single.json b/sleap/training_profiles/pretrained.single.json index 7ca907007..615f0de4d 100644 --- a/sleap/training_profiles/pretrained.single.json +++ b/sleap/training_profiles/pretrained.single.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.topdown.json b/sleap/training_profiles/pretrained.topdown.json index aeeaebbd8..be0d97de8 100644 --- a/sleap/training_profiles/pretrained.topdown.json +++ b/sleap/training_profiles/pretrained.topdown.json @@ -114,6 +114,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json index 7e52d1703..2ae0e925c 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json @@ -128,6 +128,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json index bcb2f26d5..7b6f817aa 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json @@ -191,6 +191,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json index 045890b21..5d8081628 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json @@ -141,7 +141,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json index 070e9d3c0..9591e5b52 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json @@ -208,7 +208,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json index 8e39fea3f..68e4f894e 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json @@ -127,6 +127,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json index d1fb718ba..e3bfbc5f8 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json @@ -192,6 +192,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json index 739d8e3e7..f4914aae4 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json @@ -119,6 +119,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json index 7b6782a68..e747f6862 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json @@ -179,6 +179,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json index 41d8ac8c3..977654b2e 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json @@ -118,6 +118,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/training_config.json b/tests/data/models/minimal_instance.UNet.centroid/training_config.json index 2d2280a31..02e9683e1 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/training_config.json @@ -175,6 +175,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json index cb2e4f353..f2bb907fa 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json @@ -120,6 +120,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json index 66901c9f0..dffecc1d9 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json @@ -180,6 +180,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/gui/test_dialogs.py b/tests/gui/test_dialogs.py index 4455550fb..611a73c85 100644 --- a/tests/gui/test_dialogs.py +++ b/tests/gui/test_dialogs.py @@ -1,6 +1,5 @@ """Module to test the dialogs of the GUI (contained in sleap/gui/dialogs).""" - import os from pathlib import Path diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index b6696e819..72db17bb5 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -123,34 +123,61 @@ def test_train_load_single_instance( assert (w == w2).all() -def test_train_single_instance(min_labels_robot, cfg): +def test_train_single_instance(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=False ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = True + cfg.outputs.keep_viz_images = True + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) + assert viz_path.exists() -def test_train_single_instance_with_offset(min_labels_robot, cfg): +def test_train_single_instance_with_offset(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=True ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = False + cfg.outputs.keep_viz_images = False + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[1].shape) == (None, 320, 560, 4) + assert not viz_path.exists() + def test_train_centroids(training_labels, cfg): cfg.model.heads.centroid = CentroidsHeadConfig( @@ -360,3 +387,26 @@ def test_resume_training_cli( trainer = sleap_train(cli_args) assert trainer.config.model.base_checkpoint == base_checkpoint_path + + +@pytest.mark.parametrize("keep_viz_cli", ["", "--keep_viz"]) +def test_keep_viz_cli( + keep_viz_cli, + min_single_instance_robot_model_path: str, + tmp_path: str, +): + """Test training CLI for --keep_viz option.""" + cfg_dir = min_single_instance_robot_model_path + cfg = TrainingJobConfig.load_json(str(Path(cfg_dir, "training_config.json"))) + + # Save training config to tmp folder + cfg_path = str(Path(tmp_path, "training_config.json")) + cfg.save_json(cfg_path) + + cli_args = [cfg_path, keep_viz_cli] + trainer = sleap_train(cli_args) + + # Check that --keep_viz is set correctly + assert trainer.config.outputs.keep_viz_images == ( + True if keep_viz_cli == "--keep_viz" else False + )