From 17d291b3a6074e85d7ad8b980ef7837d0fd00821 Mon Sep 17 00:00:00 2001 From: Peter Bull Date: Thu, 25 Aug 2022 17:04:31 -0700 Subject: [PATCH] Get model_name from model class on loading checkpoint (#221) * bonus docs fix * Get model_name from model class * add tests * formatting, linting, add format make command * makefile docs * when training from ckpt, model has name now * fix test cleaning * add save_dir as tmp_path for tests * return copy of hparams from cache so new ref each time * format Co-authored-by: Emily Miller --- Makefile | 9 +- docs/docs/yaml-config.md | 4 +- tests/conftest.py | 2 + tests/test_config.py | 124 ++++++++++++++++++++-------- tests/test_instantiate_model.py | 14 +++- tests/test_model_manager.py | 3 +- zamba/models/config.py | 8 +- zamba/models/efficientnet_models.py | 4 + zamba/models/model_manager.py | 5 +- zamba/models/slowfast_models.py | 2 + zamba/models/utils.py | 12 +++ 11 files changed, 137 insertions(+), 50 deletions(-) diff --git a/Makefile b/Makefile index a3ce0f41..940e311c 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,7 @@ clean-pyc: ## remove Python file artifacts find . -name '__pycache__' -exec rm -fr {} + clean-test: ## remove test and coverage artifacts + find . -name ".DS_Store" -type f -delete # breaks tests on MacOS rm -fr .tox/ rm -f .coverage rm -f coverage.xml @@ -55,13 +56,17 @@ dist: clean ## builds source and wheel package python setup.py bdist_wheel ls -l dist -## Lint using flake8 +## Format using black +format: + black zamba tests + +## Lint using flake8 + black lint: flake8 zamba tests black --check zamba tests ## Generate assets and run tests -tests: +tests: clean-test pytest tests -vv ## Run the tests that are just for densepose diff --git a/docs/docs/yaml-config.md b/docs/docs/yaml-config.md index 6baffabc..044b573b 100644 --- a/docs/docs/yaml-config.md +++ b/docs/docs/yaml-config.md @@ -20,7 +20,7 @@ train_config: predict_config: model_name: time_distributed - data_directoty: example_vids/ + data_dir: example_vids/ # other training parameters, eg. batch_size ``` @@ -34,7 +34,7 @@ video_loader_config: predict_config: model_name: time_distributed - data_directoty: example_vids/ + data_dir: example_vids/ ``` ## Required arguments diff --git a/tests/conftest.py b/tests/conftest.py index 8549bc21..60211de9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ class DummyZambaVideoClassificationLightningModule(ZambaVideoClassificationLightningModule): """A dummy model whose linear weights start out as all zeros.""" + _default_model_name = "dummy_model" # used to look up default configuration for checkpoints + def __init__( self, num_frames: int, diff --git a/tests/test_config.py b/tests/test_config.py index e97134a9..0fc0991b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -29,24 +29,30 @@ def test_train_data_dir_only(): def test_train_data_dir_and_labels(tmp_path, labels_relative_path, labels_absolute_path): # correct data dir - config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path) + config = TrainConfig( + data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path, save_dir=tmp_path / "my_model" + ) assert config.data_dir is not None assert config.labels is not None # data dir ignored if absolute path provided in filepath - config = TrainConfig(data_dir=tmp_path, labels=labels_absolute_path) + config = TrainConfig( + data_dir=tmp_path, labels=labels_absolute_path, save_dir=tmp_path / "my_model" + ) assert config.data_dir is not None assert config.labels is not None assert not config.labels.filepath.str.startswith(str(tmp_path)).any() # incorrect data dir with relative filepaths with pytest.raises(ValidationError) as error: - TrainConfig(data_dir=ASSETS_DIR, labels=labels_relative_path) + TrainConfig( + data_dir=ASSETS_DIR, labels=labels_relative_path, save_dir=tmp_path / "my_model" + ) assert "None of the video filepaths exist" in error.value.errors()[0]["msg"] -def test_train_labels_only(labels_absolute_path): - config = TrainConfig(labels=labels_absolute_path) +def test_train_labels_only(labels_absolute_path, tmp_path): + config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model") assert config.labels is not None @@ -89,7 +95,7 @@ def test_filepath_column(tmp_path, labels_absolute_path): # train: labels with pytest.raises(ValidationError) as error: - TrainConfig(labels=tmp_path / "bad_filepath_column.csv") + TrainConfig(labels=tmp_path / "bad_filepath_column.csv", save_dir=tmp_path / "my_model") assert "must contain `filepath` and `label` columns" in error.value.errors()[0]["msg"] @@ -98,7 +104,7 @@ def test_label_column(tmp_path, labels_absolute_path): tmp_path / "bad_label_column.csv" ) with pytest.raises(ValidationError) as error: - TrainConfig(labels=tmp_path / "bad_label_column.csv") + TrainConfig(labels=tmp_path / "bad_label_column.csv", save_dir=tmp_path / "my_model") assert "must contain `filepath` and `label` columns" in error.value.errors()[0]["msg"] @@ -111,7 +117,7 @@ def test_extra_column(tmp_path, labels_absolute_path): index=False, ) # this column is not one hot encoded - config = TrainConfig(labels=tmp_path / "extra_species_col.csv") + config = TrainConfig(labels=tmp_path / "extra_species_col.csv", save_dir=tmp_path / "my_model") assert list(config.labels.columns) == [ "filepath", "split", @@ -138,7 +144,9 @@ def test_one_video_does_not_exist(tmp_path, labels_absolute_path, caplog): # one fewer file than in original list since bad file is skipped assert len(config.filepaths) == (len(files_df) - 1) - config = TrainConfig(labels=tmp_path / "labels_with_fake_video.csv") + config = TrainConfig( + labels=tmp_path / "labels_with_fake_video.csv", save_dir=tmp_path / "my_model" + ) assert "Skipping 1 file(s) that could not be found" in caplog.text assert len(config.labels) == (len(files_df) - 1) @@ -159,7 +167,9 @@ def test_videos_cannot_be_loaded(tmp_path, labels_absolute_path, caplog): assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text assert len(config.filepaths) == (len(files_df) - 2) - config = TrainConfig(labels=tmp_path / "labels_with_non_loadable_videos.csv") + config = TrainConfig( + labels=tmp_path / "labels_with_non_loadable_videos.csv", save_dir=tmp_path / "my_model" + ) assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text assert len(config.labels) == (len(files_df) - 2) @@ -183,43 +193,43 @@ def test_early_stopping_mode(): assert "Provided mode max is incorrect for val_loss monitor." == error.value.errors()[0]["msg"] -def test_labels_with_all_null_species(labels_absolute_path): +def test_labels_with_all_null_species(labels_absolute_path, tmp_path): labels = pd.read_csv(labels_absolute_path) labels["label"] = np.nan with pytest.raises(ValueError) as error: - TrainConfig(labels=labels) + TrainConfig(labels=labels, save_dir=tmp_path / "my_model") assert "Species cannot be null for all videos." == error.value.errors()[0]["msg"] -def test_labels_with_partially_null_species(labels_absolute_path, caplog): +def test_labels_with_partially_null_species(labels_absolute_path, caplog, tmp_path): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "label"] = np.nan - TrainConfig(labels=labels) + TrainConfig(labels=labels, save_dir=tmp_path / "my_model") assert "Found 1 filepath(s) with no label. Will skip." in caplog.text -def test_labels_with_all_null_split(labels_absolute_path, caplog): +def test_labels_with_all_null_split(labels_absolute_path, caplog, tmp_path): labels = pd.read_csv(labels_absolute_path) labels["split"] = np.nan - TrainConfig(labels=labels) + TrainConfig(labels=labels, save_dir=tmp_path / "my_model") assert "Split column is entirely null. Will generate splits automatically" in caplog.text -def test_labels_with_partially_null_split(labels_absolute_path): +def test_labels_with_partially_null_split(labels_absolute_path, tmp_path): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "split"] = np.nan with pytest.raises(ValueError) as error: - TrainConfig(labels=labels) + TrainConfig(labels=labels, save_dir=tmp_path / "my_model") assert ( "Found 1 row(s) with null `split`. Fill in these rows with either `train`, `val`, or `holdout`" ) in error.value.errors()[0]["msg"] -def test_labels_with_invalid_split(labels_absolute_path): +def test_labels_with_invalid_split(labels_absolute_path, tmp_path): labels = pd.read_csv(labels_absolute_path) labels.loc[0, "split"] = "test" with pytest.raises(ValueError) as error: - TrainConfig(labels=labels) + TrainConfig(labels=labels, save_dir=tmp_path / "my_model") assert ( "Found the following invalid values for `split`: {'test'}. `split` can only contain `train`, `val`, or `holdout.`" ) == error.value.errors()[0]["msg"] @@ -262,14 +272,24 @@ def test_labels_split_proportions(labels_no_splits, tmp_path): assert config.labels.split.value_counts().to_dict() == {"a": 13, "b": 6} -def test_from_scratch(labels_absolute_path): - config = TrainConfig(labels=labels_absolute_path, from_scratch=True, checkpoint=None) +def test_from_scratch(labels_absolute_path, tmp_path): + config = TrainConfig( + labels=labels_absolute_path, + from_scratch=True, + checkpoint=None, + save_dir=tmp_path / "my_model", + ) assert config.model_name == "time_distributed" assert config.from_scratch assert config.checkpoint is None with pytest.raises(ValueError) as error: - TrainConfig(labels=labels_absolute_path, from_scratch=True, model_name=None) + TrainConfig( + labels=labels_absolute_path, + from_scratch=True, + model_name=None, + save_dir=tmp_path / "my_model", + ) assert "If from_scratch=True, model_name cannot be None." == error.value.errors()[0]["msg"] @@ -297,11 +317,11 @@ def test_predict_filepaths_with_duplicates(labels_absolute_path, tmp_path, caplo def test_model_cache_dir(labels_absolute_path, tmp_path): - config = TrainConfig(labels=labels_absolute_path) + config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model") assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba" os.environ["MODEL_CACHE_DIR"] = str(tmp_path) - config = TrainConfig(labels=labels_absolute_path) + config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model") assert config.model_cache_dir == tmp_path config = PredictConfig(filepaths=labels_absolute_path, model_cache_dir=tmp_path / "my_cache") @@ -365,23 +385,32 @@ def test_predict_save(labels_absolute_path, tmp_path, dummy_trained_model_checkp assert config.save_dir == save_dir -def test_validate_scheduler(labels_absolute_path): +def test_validate_scheduler(labels_absolute_path, tmp_path): # None gets transformed into SchedulerConfig config = TrainConfig( - labels=labels_absolute_path, scheduler_config=None, skip_load_validation=True + labels=labels_absolute_path, + scheduler_config=None, + skip_load_validation=True, + save_dir=tmp_path / "my_model", ) assert config.scheduler_config == SchedulerConfig(scheduler=None, scheduler_params=None) # default is valid config = TrainConfig( - labels=labels_absolute_path, scheduler_config="default", skip_load_validation=True + labels=labels_absolute_path, + scheduler_config="default", + skip_load_validation=True, + save_dir=tmp_path / "my_model", ) assert config.scheduler_config == "default" # other strings are not with pytest.raises(ValueError) as error: TrainConfig( - labels=labels_absolute_path, scheduler_config="StepLR", skip_load_validation=True + labels=labels_absolute_path, + scheduler_config="StepLR", + skip_load_validation=True, + save_dir=tmp_path / "my_model", ) assert ( "Scheduler can either be 'default', None, or a SchedulerConfig." @@ -393,27 +422,40 @@ def test_validate_scheduler(labels_absolute_path): labels=labels_absolute_path, scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params={"gamma": 0.2}), skip_load_validation=True, + save_dir=tmp_path / "my_model", ) assert config.scheduler_config == SchedulerConfig( scheduler="StepLR", scheduler_params={"gamma": 0.2} ) -def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog): +def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog, tmp_path): # check dry_run is True sets skip_load_validation to True - config = TrainConfig(labels=labels_absolute_path, dry_run=True, skip_load_validation=False) + config = TrainConfig( + labels=labels_absolute_path, + dry_run=True, + skip_load_validation=False, + save_dir=tmp_path / "my_model", + ) assert config.skip_load_validation assert "Turning off video loading check since dry_run=True." in caplog.text # if dry run is False, skip_load_validation is unchanged - config = TrainConfig(labels=labels_absolute_path, dry_run=False, skip_load_validation=False) + config = TrainConfig( + labels=labels_absolute_path, + dry_run=False, + skip_load_validation=False, + save_dir=tmp_path / "my_model", + ) assert not config.skip_load_validation -def test_default_video_loader_config(labels_absolute_path): +def test_default_video_loader_config(labels_absolute_path, tmp_path): # if no video loader is specified, use default for model config = ModelConfig( - train_config=TrainConfig(labels=labels_absolute_path, skip_load_validation=True), + train_config=TrainConfig( + labels=labels_absolute_path, skip_load_validation=True, save_dir=tmp_path / "my_model" + ), video_loader_config=None, ) assert config.video_loader_config is not None @@ -425,10 +467,20 @@ def test_default_video_loader_config(labels_absolute_path): assert config.video_loader_config is not None -def test_checkpoint_sets_model_to_none(labels_absolute_path, dummy_trained_model_checkpoint): +def test_checkpoint_sets_model_to_default( + labels_absolute_path, dummy_trained_model_checkpoint, tmp_path +): config = TrainConfig( labels=labels_absolute_path, checkpoint=dummy_trained_model_checkpoint, skip_load_validation=True, + save_dir=tmp_path / "my_model", + ) + assert config.model_name == "dummy_model" + + config = PredictConfig( + filepaths=labels_absolute_path, + checkpoint=dummy_trained_model_checkpoint, + skip_load_validation=True, ) - assert config.model_name is None + assert config.model_name == "dummy_model" diff --git a/tests/test_instantiate_model.py b/tests/test_instantiate_model.py index 710078ab..cf1cb88b 100644 --- a/tests/test_instantiate_model.py +++ b/tests/test_instantiate_model.py @@ -151,7 +151,12 @@ def test_head_replaced_for_new_species(dummy_trained_model_checkpoint, tmp_path) @pytest.mark.parametrize("model", ["time_distributed", "slowfast", "european"]) def test_finetune_new_labels(labels_absolute_path, model, tmp_path): - config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) + config = TrainConfig( + labels=labels_absolute_path, + model_name=model, + skip_load_validation=True, + save_dir=tmp_path / "my_model", + ) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, @@ -164,7 +169,12 @@ def test_finetune_new_labels(labels_absolute_path, model, tmp_path): @pytest.mark.parametrize("model", ["time_distributed", "slowfast", "european"]) def test_resume_subset_labels(labels_absolute_path, model, tmp_path): - config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True) + config = TrainConfig( + labels=labels_absolute_path, + model_name=model, + skip_load_validation=True, + save_dir=tmp_path / "my_model", + ) model = instantiate_model( checkpoint=config.checkpoint, weight_download_region=config.weight_download_region, diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index d0ef4b83..497a4d30 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -169,12 +169,11 @@ def test_train_save_dir_overwrite( assert not any([f.name.startswith("version_") for f in config.save_dir.iterdir()]) - # when training from checkpoint, model_name is None so get PTL default ckpt name for f in [ "train_configuration.yaml", "test_metrics.json", "val_metrics.json", - "epoch=0-step=11.ckpt", + "dummy_model.ckpt", ]: assert (config.save_dir / f).exists() diff --git a/zamba/models/config.py b/zamba/models/config.py index 251895a7..426a126e 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -18,7 +18,8 @@ from zamba.data.metadata import create_site_specific_splits from zamba.data.video import VideoLoaderConfig from zamba.exceptions import ZambaFfmpegException -from zamba.models.utils import RegionEnum, get_model_checkpoint_filename +from zamba.models.registry import available_models +from zamba.models.utils import RegionEnum, get_checkpoint_hparams, get_model_checkpoint_filename from zamba.pytorch.transforms import zamba_image_model_transforms, slowfast_transforms from zamba.settings import SPLIT_SEED, VIDEO_SUFFIXES @@ -164,8 +165,9 @@ def validate_model_name_and_checkpoint(cls, values): # checkpoint supercedes model elif checkpoint is not None and model_name is not None: logger.info(f"Using checkpoint file: {checkpoint}.") - # set model name to None so proper model class is retrieved from ckpt up upon instantiation - values["model_name"] = None + # get model name from checkpoint so it can be used for the video loader config + hparams = get_checkpoint_hparams(checkpoint) + values["model_name"] = available_models[hparams["model_class"]]._default_model_name elif checkpoint is None and model_name is not None: if not values.get("from_scratch"): diff --git a/zamba/models/efficientnet_models.py b/zamba/models/efficientnet_models.py index e16de391..d2d7b1d1 100644 --- a/zamba/models/efficientnet_models.py +++ b/zamba/models/efficientnet_models.py @@ -12,6 +12,10 @@ @register_model class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule): + _default_model_name = ( + "time_distributed" # used to look up default configuration for checkpoints + ) + def __init__( self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs ): diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index fd0041e6..0ee3d6be 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -13,7 +13,6 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin -import torch from zamba import MODELS_DIRECTORY from zamba.data.video import VideoLoaderConfig @@ -27,7 +26,7 @@ RegionEnum, ) from zamba.models.registry import available_models -from zamba.models.utils import download_weights +from zamba.models.utils import download_weights, get_checkpoint_hparams from zamba.pytorch.finetuning import BackboneFinetuning from zamba.pytorch_lightning.utils import ZambaDataModule, ZambaVideoClassificationLightningModule @@ -87,7 +86,7 @@ def instantiate_model( destination_dir=model_cache_dir, ) - hparams = torch.load(checkpoint, map_location=torch.device("cpu"))["hyper_parameters"] + hparams = get_checkpoint_hparams(checkpoint) model_class = available_models[hparams["model_class"]] diff --git a/zamba/models/slowfast_models.py b/zamba/models/slowfast_models.py index 254fe46c..fed19fa1 100644 --- a/zamba/models/slowfast_models.py +++ b/zamba/models/slowfast_models.py @@ -22,6 +22,8 @@ class SlowFast(ZambaVideoClassificationLightningModule): _backbone_output_dim (int): Dimensionality of the backbone output (and head input). """ + _default_model_name = "slowfast" # used to look up default configuration for checkpoints + def __init__( self, backbone_mode: str = "train", diff --git a/zamba/models/utils.py b/zamba/models/utils.py index e6f2ff39..2570f515 100644 --- a/zamba/models/utils.py +++ b/zamba/models/utils.py @@ -1,9 +1,12 @@ +import copy from enum import Enum +from functools import lru_cache import os from pathlib import Path from typing import Union from cloudpathlib import S3Client, S3Path +import torch import yaml from zamba import MODELS_DIRECTORY @@ -45,3 +48,12 @@ def get_model_checkpoint_filename(model_name): with config_file.open() as f: config_dict = yaml.safe_load(f) return Path(config_dict["public_checkpoint"]) + + +def get_checkpoint_hparams(checkpoint): + return copy.deepcopy(_cached_hparams(checkpoint)) + + +@lru_cache() +def _cached_hparams(checkpoint): + return torch.load(checkpoint, map_location=torch.device("cpu"))["hyper_parameters"]