From e9a28b9a5e21a47182c99c87d422ff2faf3f87d0 Mon Sep 17 00:00:00 2001 From: Emily Miller Date: Fri, 10 Feb 2023 16:40:12 -0800 Subject: [PATCH] Support recent versions of black and lightning (#260) * use most recent black release * add needed attributes to dictionary for v1.9 --- tests/test_cli.py | 1 - tests/test_load_video_frames.py | 1 - zamba/cli.py | 1 - zamba/models/config.py | 4 ---- zamba/models/depth_estimation/depth_manager.py | 8 ++++---- zamba/models/efficientnet_models.py | 1 - zamba/models/model_manager.py | 2 -- zamba/pytorch_lightning/utils.py | 9 +++++++++ 8 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index c7dd62ec..180a8202 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -64,7 +64,6 @@ def test_shared_cli_options(mocker, minimum_valid_train, minimum_valid_predict): mocker.patch("zamba.cli.ModelManager.predict", pred_mock) for command in [minimum_valid_train, minimum_valid_predict]: - # check default model is time distributed one result = runner.invoke(app, command) assert result.exit_code == 0 diff --git a/tests/test_load_video_frames.py b/tests/test_load_video_frames.py index 91a29392..92fdf94a 100644 --- a/tests/test_load_video_frames.py +++ b/tests/test_load_video_frames.py @@ -65,7 +65,6 @@ def assert_megadetector_total_or_none(original_video_metadata, video_shape, **kw def assert_no_frames_or_correct_shape(original_video_metadata, video_shape, **kwargs): - return (video_shape["frames"] == 0) or ( (video_shape["height"] == kwargs["frame_selection_height"]) and (video_shape["width"] == kwargs["frame_selection_width"]) diff --git a/zamba/cli.py b/zamba/cli.py index 90fa0733..f8cf1ae9 100644 --- a/zamba/cli.py +++ b/zamba/cli.py @@ -600,5 +600,4 @@ def depth( if __name__ == "__main__": - app() diff --git a/zamba/models/config.py b/zamba/models/config.py index e23f7c78..0c3ba4e4 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -148,7 +148,6 @@ def check_files_exist_and_load( bad_load = [] if not skip_load_validation: - logger.info( "Checking that all videos can be loaded. If you're very confident all your videos can be loaded, you can skip this with `skip_load_validation`, but it's not recommended." ) @@ -503,7 +502,6 @@ def validate_filepaths_and_labels(cls, values): # validate split column has no partial nulls or invalid values if "split" in labels.columns: - # if split is entirely null, warn, drop column, and generate splits automatically if labels.split.isnull().all(): logger.warning( @@ -559,7 +557,6 @@ def validate_provided_species_and_use_default_model_labels(cls, values): ) if not provided_species.issubset(model_species): - # if labels are not a subset, user cannot set use_default_model_labels to True if values["use_default_model_labels"]: raise ValueError( @@ -677,7 +674,6 @@ def make_split(labels, values): species_df = labels[labels[c] > 0] if len(species_df): - # within each species, seed splits by putting one video in each set and then allocate videos based on split proportions labels.loc[species_df.index, "split"] = expected_splits + random.choices( list(values["split_proportions"].keys()), diff --git a/zamba/models/depth_estimation/depth_manager.py b/zamba/models/depth_estimation/depth_manager.py index 142208fd..2774e40c 100644 --- a/zamba/models/depth_estimation/depth_manager.py +++ b/zamba/models/depth_estimation/depth_manager.py @@ -36,7 +36,6 @@ def depth_transforms(size): class DepthDataset(torch.utils.data.Dataset): def __init__(self, filepaths): - # these are hardcoded because they depend on the trained model weights used for inference self.height = 270 self.width = 480 @@ -55,7 +54,6 @@ def __init__(self, filepaths): logger.info(f"Running object detection on {len(filepaths)} videos.") for video_filepath in tqdm(filepaths): - # get video array at 1 fps, use full size for detecting objects logger.debug(f"Loading video: {video_filepath}") try: @@ -73,7 +71,6 @@ def __init__(self, filepaths): # iterate over frames for frame_idx, (detections, scores) in enumerate(detections_per_frame): - # if anything is detected in the frame, save out relevant frames if len(detections) > 0: logger.debug(f"{len(detections)} detection(s) found at second {frame_idx}.") @@ -234,7 +231,10 @@ def predict(self, filepaths): for d, vid, t in zip(distance.cpu().numpy(), filepath, time): predictions.append((vid, t, d)) - predictions = pd.DataFrame(predictions, columns=["filepath", "time", "distance"],).round( + predictions = pd.DataFrame( + predictions, + columns=["filepath", "time", "distance"], + ).round( {"distance": 1} ) # round to useful number of decimal places diff --git a/zamba/models/efficientnet_models.py b/zamba/models/efficientnet_models.py index d2d7b1d1..f36c2021 100644 --- a/zamba/models/efficientnet_models.py +++ b/zamba/models/efficientnet_models.py @@ -19,7 +19,6 @@ class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule): def __init__( self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs ): - super().__init__(**kwargs) if finetune_from is None: diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index 8f9f100c..12b8d680 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -390,7 +390,6 @@ def predict_model( } if predict_config.save is not False: - config_path = predict_config.save_dir / "predict_configuration.yaml" logger.info(f"Writing out full configuration to {config_path}.") with config_path.open("w") as fp: @@ -415,7 +414,6 @@ def predict_model( df = df.round(5) if predict_config.save is not False: - preds_path = predict_config.save_dir / "zamba_predictions.csv" logger.info(f"Saving out predictions to {preds_path}.") with preds_path.open("w") as fp: diff --git a/zamba/pytorch_lightning/utils.py b/zamba/pytorch_lightning/utils.py index 40473dca..be397b76 100644 --- a/zamba/pytorch_lightning/utils.py +++ b/zamba/pytorch_lightning/utils.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import pytorch_lightning as pl from pytorch_lightning import LightningDataModule, LightningModule from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score import torch @@ -273,9 +274,17 @@ def configure_optimizers(self): } def to_disk(self, path: os.PathLike): + """Save out model weights to a checkpoint file on disk. + + Note: this does not include callbacks, optimizer_states, or lr_schedulers. + To include those, use `Trainer.save_checkpoint()` instead. + """ + checkpoint = { "state_dict": self.state_dict(), "hyper_parameters": self.hparams, + "global_step": self.global_step, + "pytorch-lightning_version": pl.__version__, } torch.save(checkpoint, path)