Skip to content

Commit

Permalink
Support recent versions of black and lightning (#260)
Browse files Browse the repository at this point in the history
* use most recent black release

* add needed attributes to dictionary for v1.9
  • Loading branch information
ejm714 authored Feb 11, 2023
1 parent 56db26a commit e9a28b9
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 14 deletions.
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_shared_cli_options(mocker, minimum_valid_train, minimum_valid_predict):
mocker.patch("zamba.cli.ModelManager.predict", pred_mock)

for command in [minimum_valid_train, minimum_valid_predict]:

# check default model is time distributed one
result = runner.invoke(app, command)
assert result.exit_code == 0
Expand Down
1 change: 0 additions & 1 deletion tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def assert_megadetector_total_or_none(original_video_metadata, video_shape, **kw


def assert_no_frames_or_correct_shape(original_video_metadata, video_shape, **kwargs):

return (video_shape["frames"] == 0) or (
(video_shape["height"] == kwargs["frame_selection_height"])
and (video_shape["width"] == kwargs["frame_selection_width"])
Expand Down
1 change: 0 additions & 1 deletion zamba/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,5 +600,4 @@ def depth(


if __name__ == "__main__":

app()
4 changes: 0 additions & 4 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def check_files_exist_and_load(

bad_load = []
if not skip_load_validation:

logger.info(
"Checking that all videos can be loaded. If you're very confident all your videos can be loaded, you can skip this with `skip_load_validation`, but it's not recommended."
)
Expand Down Expand Up @@ -503,7 +502,6 @@ def validate_filepaths_and_labels(cls, values):

# validate split column has no partial nulls or invalid values
if "split" in labels.columns:

# if split is entirely null, warn, drop column, and generate splits automatically
if labels.split.isnull().all():
logger.warning(
Expand Down Expand Up @@ -559,7 +557,6 @@ def validate_provided_species_and_use_default_model_labels(cls, values):
)

if not provided_species.issubset(model_species):

# if labels are not a subset, user cannot set use_default_model_labels to True
if values["use_default_model_labels"]:
raise ValueError(
Expand Down Expand Up @@ -677,7 +674,6 @@ def make_split(labels, values):
species_df = labels[labels[c] > 0]

if len(species_df):

# within each species, seed splits by putting one video in each set and then allocate videos based on split proportions
labels.loc[species_df.index, "split"] = expected_splits + random.choices(
list(values["split_proportions"].keys()),
Expand Down
8 changes: 4 additions & 4 deletions zamba/models/depth_estimation/depth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def depth_transforms(size):

class DepthDataset(torch.utils.data.Dataset):
def __init__(self, filepaths):

# these are hardcoded because they depend on the trained model weights used for inference
self.height = 270
self.width = 480
Expand All @@ -55,7 +54,6 @@ def __init__(self, filepaths):

logger.info(f"Running object detection on {len(filepaths)} videos.")
for video_filepath in tqdm(filepaths):

# get video array at 1 fps, use full size for detecting objects
logger.debug(f"Loading video: {video_filepath}")
try:
Expand All @@ -73,7 +71,6 @@ def __init__(self, filepaths):

# iterate over frames
for frame_idx, (detections, scores) in enumerate(detections_per_frame):

# if anything is detected in the frame, save out relevant frames
if len(detections) > 0:
logger.debug(f"{len(detections)} detection(s) found at second {frame_idx}.")
Expand Down Expand Up @@ -234,7 +231,10 @@ def predict(self, filepaths):
for d, vid, t in zip(distance.cpu().numpy(), filepath, time):
predictions.append((vid, t, d))

predictions = pd.DataFrame(predictions, columns=["filepath", "time", "distance"],).round(
predictions = pd.DataFrame(
predictions,
columns=["filepath", "time", "distance"],
).round(
{"distance": 1}
) # round to useful number of decimal places

Expand Down
1 change: 0 additions & 1 deletion zamba/models/efficientnet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule):
def __init__(
self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs
):

super().__init__(**kwargs)

if finetune_from is None:
Expand Down
2 changes: 0 additions & 2 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def predict_model(
}

if predict_config.save is not False:

config_path = predict_config.save_dir / "predict_configuration.yaml"
logger.info(f"Writing out full configuration to {config_path}.")
with config_path.open("w") as fp:
Expand All @@ -415,7 +414,6 @@ def predict_model(
df = df.round(5)

if predict_config.save is not False:

preds_path = predict_config.save_dir / "zamba_predictions.csv"
logger.info(f"Saving out predictions to {preds_path}.")
with preds_path.open("w") as fp:
Expand Down
9 changes: 9 additions & 0 deletions zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule
from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score
import torch
Expand Down Expand Up @@ -273,9 +274,17 @@ def configure_optimizers(self):
}

def to_disk(self, path: os.PathLike):
"""Save out model weights to a checkpoint file on disk.
Note: this does not include callbacks, optimizer_states, or lr_schedulers.
To include those, use `Trainer.save_checkpoint()` instead.
"""

checkpoint = {
"state_dict": self.state_dict(),
"hyper_parameters": self.hparams,
"global_step": self.global_step,
"pytorch-lightning_version": pl.__version__,
}
torch.save(checkpoint, path)

Expand Down

0 comments on commit e9a28b9

Please sign in to comment.