Skip to content

Commit

Permalink
Get model_name from model class on loading checkpoint (#221)
Browse files Browse the repository at this point in the history
* bonus docs fix

* Get model_name from model class

* add tests

* formatting, linting, add format make command

* makefile docs

* when training from ckpt, model has name now

* fix test cleaning

* add save_dir as tmp_path for tests

* return copy of hparams from cache so new ref each time

* format

Co-authored-by: Emily Miller <ejm714@gmail.com>
  • Loading branch information
pjbull and ejm714 authored Aug 26, 2022
1 parent e30c2ae commit 17d291b
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 50 deletions.
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ clean-pyc: ## remove Python file artifacts
find . -name '__pycache__' -exec rm -fr {} +

clean-test: ## remove test and coverage artifacts
find . -name ".DS_Store" -type f -delete # breaks tests on MacOS
rm -fr .tox/
rm -f .coverage
rm -f coverage.xml
Expand All @@ -55,13 +56,17 @@ dist: clean ## builds source and wheel package
python setup.py bdist_wheel
ls -l dist

## Lint using flake8
## Format using black
format:
black zamba tests

## Lint using flake8 + black
lint:
flake8 zamba tests
black --check zamba tests

## Generate assets and run tests
tests:
tests: clean-test
pytest tests -vv

## Run the tests that are just for densepose
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/yaml-config.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ train_config:

predict_config:
model_name: time_distributed
data_directoty: example_vids/
data_dir: example_vids/
# other training parameters, eg. batch_size
```

Expand All @@ -34,7 +34,7 @@ video_loader_config:

predict_config:
model_name: time_distributed
data_directoty: example_vids/
data_dir: example_vids/
```
## Required arguments
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
class DummyZambaVideoClassificationLightningModule(ZambaVideoClassificationLightningModule):
"""A dummy model whose linear weights start out as all zeros."""

_default_model_name = "dummy_model" # used to look up default configuration for checkpoints

def __init__(
self,
num_frames: int,
Expand Down
124 changes: 88 additions & 36 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,30 @@ def test_train_data_dir_only():

def test_train_data_dir_and_labels(tmp_path, labels_relative_path, labels_absolute_path):
# correct data dir
config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path)
config = TrainConfig(
data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path, save_dir=tmp_path / "my_model"
)
assert config.data_dir is not None
assert config.labels is not None

# data dir ignored if absolute path provided in filepath
config = TrainConfig(data_dir=tmp_path, labels=labels_absolute_path)
config = TrainConfig(
data_dir=tmp_path, labels=labels_absolute_path, save_dir=tmp_path / "my_model"
)
assert config.data_dir is not None
assert config.labels is not None
assert not config.labels.filepath.str.startswith(str(tmp_path)).any()

# incorrect data dir with relative filepaths
with pytest.raises(ValidationError) as error:
TrainConfig(data_dir=ASSETS_DIR, labels=labels_relative_path)
TrainConfig(
data_dir=ASSETS_DIR, labels=labels_relative_path, save_dir=tmp_path / "my_model"
)
assert "None of the video filepaths exist" in error.value.errors()[0]["msg"]


def test_train_labels_only(labels_absolute_path):
config = TrainConfig(labels=labels_absolute_path)
def test_train_labels_only(labels_absolute_path, tmp_path):
config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model")
assert config.labels is not None


Expand Down Expand Up @@ -89,7 +95,7 @@ def test_filepath_column(tmp_path, labels_absolute_path):

# train: labels
with pytest.raises(ValidationError) as error:
TrainConfig(labels=tmp_path / "bad_filepath_column.csv")
TrainConfig(labels=tmp_path / "bad_filepath_column.csv", save_dir=tmp_path / "my_model")
assert "must contain `filepath` and `label` columns" in error.value.errors()[0]["msg"]


Expand All @@ -98,7 +104,7 @@ def test_label_column(tmp_path, labels_absolute_path):
tmp_path / "bad_label_column.csv"
)
with pytest.raises(ValidationError) as error:
TrainConfig(labels=tmp_path / "bad_label_column.csv")
TrainConfig(labels=tmp_path / "bad_label_column.csv", save_dir=tmp_path / "my_model")
assert "must contain `filepath` and `label` columns" in error.value.errors()[0]["msg"]


Expand All @@ -111,7 +117,7 @@ def test_extra_column(tmp_path, labels_absolute_path):
index=False,
)
# this column is not one hot encoded
config = TrainConfig(labels=tmp_path / "extra_species_col.csv")
config = TrainConfig(labels=tmp_path / "extra_species_col.csv", save_dir=tmp_path / "my_model")
assert list(config.labels.columns) == [
"filepath",
"split",
Expand All @@ -138,7 +144,9 @@ def test_one_video_does_not_exist(tmp_path, labels_absolute_path, caplog):
# one fewer file than in original list since bad file is skipped
assert len(config.filepaths) == (len(files_df) - 1)

config = TrainConfig(labels=tmp_path / "labels_with_fake_video.csv")
config = TrainConfig(
labels=tmp_path / "labels_with_fake_video.csv", save_dir=tmp_path / "my_model"
)
assert "Skipping 1 file(s) that could not be found" in caplog.text
assert len(config.labels) == (len(files_df) - 1)

Expand All @@ -159,7 +167,9 @@ def test_videos_cannot_be_loaded(tmp_path, labels_absolute_path, caplog):
assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
assert len(config.filepaths) == (len(files_df) - 2)

config = TrainConfig(labels=tmp_path / "labels_with_non_loadable_videos.csv")
config = TrainConfig(
labels=tmp_path / "labels_with_non_loadable_videos.csv", save_dir=tmp_path / "my_model"
)
assert "Skipping 2 file(s) that could not be loaded with ffmpeg" in caplog.text
assert len(config.labels) == (len(files_df) - 2)

Expand All @@ -183,43 +193,43 @@ def test_early_stopping_mode():
assert "Provided mode max is incorrect for val_loss monitor." == error.value.errors()[0]["msg"]


def test_labels_with_all_null_species(labels_absolute_path):
def test_labels_with_all_null_species(labels_absolute_path, tmp_path):
labels = pd.read_csv(labels_absolute_path)
labels["label"] = np.nan
with pytest.raises(ValueError) as error:
TrainConfig(labels=labels)
TrainConfig(labels=labels, save_dir=tmp_path / "my_model")
assert "Species cannot be null for all videos." == error.value.errors()[0]["msg"]


def test_labels_with_partially_null_species(labels_absolute_path, caplog):
def test_labels_with_partially_null_species(labels_absolute_path, caplog, tmp_path):
labels = pd.read_csv(labels_absolute_path)
labels.loc[0, "label"] = np.nan
TrainConfig(labels=labels)
TrainConfig(labels=labels, save_dir=tmp_path / "my_model")
assert "Found 1 filepath(s) with no label. Will skip." in caplog.text


def test_labels_with_all_null_split(labels_absolute_path, caplog):
def test_labels_with_all_null_split(labels_absolute_path, caplog, tmp_path):
labels = pd.read_csv(labels_absolute_path)
labels["split"] = np.nan
TrainConfig(labels=labels)
TrainConfig(labels=labels, save_dir=tmp_path / "my_model")
assert "Split column is entirely null. Will generate splits automatically" in caplog.text


def test_labels_with_partially_null_split(labels_absolute_path):
def test_labels_with_partially_null_split(labels_absolute_path, tmp_path):
labels = pd.read_csv(labels_absolute_path)
labels.loc[0, "split"] = np.nan
with pytest.raises(ValueError) as error:
TrainConfig(labels=labels)
TrainConfig(labels=labels, save_dir=tmp_path / "my_model")
assert (
"Found 1 row(s) with null `split`. Fill in these rows with either `train`, `val`, or `holdout`"
) in error.value.errors()[0]["msg"]


def test_labels_with_invalid_split(labels_absolute_path):
def test_labels_with_invalid_split(labels_absolute_path, tmp_path):
labels = pd.read_csv(labels_absolute_path)
labels.loc[0, "split"] = "test"
with pytest.raises(ValueError) as error:
TrainConfig(labels=labels)
TrainConfig(labels=labels, save_dir=tmp_path / "my_model")
assert (
"Found the following invalid values for `split`: {'test'}. `split` can only contain `train`, `val`, or `holdout.`"
) == error.value.errors()[0]["msg"]
Expand Down Expand Up @@ -262,14 +272,24 @@ def test_labels_split_proportions(labels_no_splits, tmp_path):
assert config.labels.split.value_counts().to_dict() == {"a": 13, "b": 6}


def test_from_scratch(labels_absolute_path):
config = TrainConfig(labels=labels_absolute_path, from_scratch=True, checkpoint=None)
def test_from_scratch(labels_absolute_path, tmp_path):
config = TrainConfig(
labels=labels_absolute_path,
from_scratch=True,
checkpoint=None,
save_dir=tmp_path / "my_model",
)
assert config.model_name == "time_distributed"
assert config.from_scratch
assert config.checkpoint is None

with pytest.raises(ValueError) as error:
TrainConfig(labels=labels_absolute_path, from_scratch=True, model_name=None)
TrainConfig(
labels=labels_absolute_path,
from_scratch=True,
model_name=None,
save_dir=tmp_path / "my_model",
)
assert "If from_scratch=True, model_name cannot be None." == error.value.errors()[0]["msg"]


Expand Down Expand Up @@ -297,11 +317,11 @@ def test_predict_filepaths_with_duplicates(labels_absolute_path, tmp_path, caplo


def test_model_cache_dir(labels_absolute_path, tmp_path):
config = TrainConfig(labels=labels_absolute_path)
config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model")
assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba"

os.environ["MODEL_CACHE_DIR"] = str(tmp_path)
config = TrainConfig(labels=labels_absolute_path)
config = TrainConfig(labels=labels_absolute_path, save_dir=tmp_path / "my_model")
assert config.model_cache_dir == tmp_path

config = PredictConfig(filepaths=labels_absolute_path, model_cache_dir=tmp_path / "my_cache")
Expand Down Expand Up @@ -365,23 +385,32 @@ def test_predict_save(labels_absolute_path, tmp_path, dummy_trained_model_checkp
assert config.save_dir == save_dir


def test_validate_scheduler(labels_absolute_path):
def test_validate_scheduler(labels_absolute_path, tmp_path):
# None gets transformed into SchedulerConfig
config = TrainConfig(
labels=labels_absolute_path, scheduler_config=None, skip_load_validation=True
labels=labels_absolute_path,
scheduler_config=None,
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
assert config.scheduler_config == SchedulerConfig(scheduler=None, scheduler_params=None)

# default is valid
config = TrainConfig(
labels=labels_absolute_path, scheduler_config="default", skip_load_validation=True
labels=labels_absolute_path,
scheduler_config="default",
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
assert config.scheduler_config == "default"

# other strings are not
with pytest.raises(ValueError) as error:
TrainConfig(
labels=labels_absolute_path, scheduler_config="StepLR", skip_load_validation=True
labels=labels_absolute_path,
scheduler_config="StepLR",
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
assert (
"Scheduler can either be 'default', None, or a SchedulerConfig."
Expand All @@ -393,27 +422,40 @@ def test_validate_scheduler(labels_absolute_path):
labels=labels_absolute_path,
scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params={"gamma": 0.2}),
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
assert config.scheduler_config == SchedulerConfig(
scheduler="StepLR", scheduler_params={"gamma": 0.2}
)


def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog):
def test_dry_run_and_skip_load_validation(labels_absolute_path, caplog, tmp_path):
# check dry_run is True sets skip_load_validation to True
config = TrainConfig(labels=labels_absolute_path, dry_run=True, skip_load_validation=False)
config = TrainConfig(
labels=labels_absolute_path,
dry_run=True,
skip_load_validation=False,
save_dir=tmp_path / "my_model",
)
assert config.skip_load_validation
assert "Turning off video loading check since dry_run=True." in caplog.text

# if dry run is False, skip_load_validation is unchanged
config = TrainConfig(labels=labels_absolute_path, dry_run=False, skip_load_validation=False)
config = TrainConfig(
labels=labels_absolute_path,
dry_run=False,
skip_load_validation=False,
save_dir=tmp_path / "my_model",
)
assert not config.skip_load_validation


def test_default_video_loader_config(labels_absolute_path):
def test_default_video_loader_config(labels_absolute_path, tmp_path):
# if no video loader is specified, use default for model
config = ModelConfig(
train_config=TrainConfig(labels=labels_absolute_path, skip_load_validation=True),
train_config=TrainConfig(
labels=labels_absolute_path, skip_load_validation=True, save_dir=tmp_path / "my_model"
),
video_loader_config=None,
)
assert config.video_loader_config is not None
Expand All @@ -425,10 +467,20 @@ def test_default_video_loader_config(labels_absolute_path):
assert config.video_loader_config is not None


def test_checkpoint_sets_model_to_none(labels_absolute_path, dummy_trained_model_checkpoint):
def test_checkpoint_sets_model_to_default(
labels_absolute_path, dummy_trained_model_checkpoint, tmp_path
):
config = TrainConfig(
labels=labels_absolute_path,
checkpoint=dummy_trained_model_checkpoint,
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
assert config.model_name == "dummy_model"

config = PredictConfig(
filepaths=labels_absolute_path,
checkpoint=dummy_trained_model_checkpoint,
skip_load_validation=True,
)
assert config.model_name is None
assert config.model_name == "dummy_model"
14 changes: 12 additions & 2 deletions tests/test_instantiate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def test_head_replaced_for_new_species(dummy_trained_model_checkpoint, tmp_path)

@pytest.mark.parametrize("model", ["time_distributed", "slowfast", "european"])
def test_finetune_new_labels(labels_absolute_path, model, tmp_path):
config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True)
config = TrainConfig(
labels=labels_absolute_path,
model_name=model,
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
model = instantiate_model(
checkpoint=config.checkpoint,
weight_download_region=config.weight_download_region,
Expand All @@ -164,7 +169,12 @@ def test_finetune_new_labels(labels_absolute_path, model, tmp_path):

@pytest.mark.parametrize("model", ["time_distributed", "slowfast", "european"])
def test_resume_subset_labels(labels_absolute_path, model, tmp_path):
config = TrainConfig(labels=labels_absolute_path, model_name=model, skip_load_validation=True)
config = TrainConfig(
labels=labels_absolute_path,
model_name=model,
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
model = instantiate_model(
checkpoint=config.checkpoint,
weight_download_region=config.weight_download_region,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,11 @@ def test_train_save_dir_overwrite(

assert not any([f.name.startswith("version_") for f in config.save_dir.iterdir()])

# when training from checkpoint, model_name is None so get PTL default ckpt name
for f in [
"train_configuration.yaml",
"test_metrics.json",
"val_metrics.json",
"epoch=0-step=11.ckpt",
"dummy_model.ckpt",
]:
assert (config.save_dir / f).exists()

Expand Down
Loading

0 comments on commit 17d291b

Please sign in to comment.