diff --git a/tests/conftest.py b/tests/conftest.py index 0cac803e..01ec6b5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ def __init__( backbone = torch.nn.Linear(num_frames, num_hidden) torch.nn.init.ones_(backbone.weight) else: - backbone = self.load_from_checkpoint(finetune_from).backbone + backbone = self.from_disk(finetune_from).backbone for param in backbone.parameters(): param.requires_grad = False diff --git a/zamba/models/efficientnet_models.py b/zamba/models/efficientnet_models.py index f36c2021..17314bbd 100644 --- a/zamba/models/efficientnet_models.py +++ b/zamba/models/efficientnet_models.py @@ -17,7 +17,10 @@ class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule): ) def __init__( - self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs + self, + num_frames=16, + finetune_from: Optional[Union[os.PathLike, str]] = None, + **kwargs, ): super().__init__(**kwargs) @@ -25,7 +28,7 @@ def __init__( efficientnet = timm.create_model("efficientnetv2_rw_m", pretrained=True) efficientnet.classifier = nn.Identity() else: - efficientnet = self.load_from_checkpoint(finetune_from).base.module + efficientnet = self.from_disk(finetune_from).base.module # freeze base layers for param in efficientnet.parameters(): diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index a49387f7..7bf79d48 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -62,7 +62,7 @@ def instantiate_model( Only used if labels is not None. model_name (ModelEnum, optional): Model name used to look up default hparams used for that model. Only relevant if training from scratch. - use_default_model_labels(bool, optional): Whether to output the full set of default model labels rather than + use_default_model_labels (bool, optional): Whether to output the full set of default model labels rather than just the species in the labels file. Only used if labels is not None. Returns: @@ -78,9 +78,8 @@ def instantiate_model( # predicting if labels is None: - # predict; load from checkpoint uses associated hparams logger.info("Loading from checkpoint.") - model = model_class.load_from_checkpoint(checkpoint_path=checkpoint) + model = model_class.from_disk(path=checkpoint, **hparams) return model # get species from labels file @@ -110,10 +109,8 @@ def instantiate_model( return resume_training( scheduler_config=scheduler_config, hparams=hparams, - species=species, model_class=model_class, checkpoint=checkpoint, - labels=labels, ) else: @@ -157,10 +154,8 @@ def replace_head(scheduler_config, hparams, species, model_class, checkpoint): def resume_training( scheduler_config, hparams, - species, model_class, checkpoint, - labels, ): # resume training; add additional species columns to labels file if needed logger.info( @@ -170,7 +165,7 @@ def resume_training( if scheduler_config != "default": hparams.update(scheduler_config.dict()) - model = model_class.load_from_checkpoint(checkpoint_path=checkpoint, **hparams) + model = model_class.from_disk(path=checkpoint, **hparams) log_schedulers(model) return model diff --git a/zamba/models/slowfast_models.py b/zamba/models/slowfast_models.py index fed19fa1..58e12501 100644 --- a/zamba/models/slowfast_models.py +++ b/zamba/models/slowfast_models.py @@ -55,7 +55,7 @@ def __init__( if finetune_from is None: self.initialize_from_torchub() else: - model = self.load_from_checkpoint(finetune_from) + model = self.from_disk(finetune_from) self._backbone_output_dim = model.head.proj.in_features self.backbone = model.backbone self.base = model.base diff --git a/zamba/pytorch_lightning/utils.py b/zamba/pytorch_lightning/utils.py index d42a444c..041f63e9 100644 --- a/zamba/pytorch_lightning/utils.py +++ b/zamba/pytorch_lightning/utils.py @@ -303,5 +303,6 @@ def to_disk(self, path: os.PathLike): torch.save(checkpoint, path) @classmethod - def from_disk(cls, path: os.PathLike): - return cls.load_from_checkpoint(path) + def from_disk(cls, path: os.PathLike, **kwargs): + # note: we always load models onto CPU; moving to GPU is handled by `devices` in pl.Trainer + return cls.load_from_checkpoint(path, map_location="cpu", **kwargs)