diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index fd9a35d6db..c44cd74e2a 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -202,10 +202,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/dlinear.py b/darts/models/forecasting/dlinear.py index c2add55dbe..47a27d3e9d 100644 --- a/darts/models/forecasting/dlinear.py +++ b/darts/models/forecasting/dlinear.py @@ -289,10 +289,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 67ae8d58d9..04b684cfc6 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -1532,7 +1532,7 @@ def model_params(self) -> dict: @classmethod def _default_save_path(cls) -> str: - return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" + return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}" def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> None: """ @@ -1555,8 +1555,8 @@ def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> Non ---------- path Path or file handle under which to save the model at its current state. If no path is specified, the model - is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pkl"``. - E.g., ``"RegressionModel_2020-01-01_12:00:00.pkl"``. + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pkl"``. + E.g., ``"RegressionModel_2020-01-01_12_00_00.pkl"``. pkl_kwargs Keyword arguments passed to `pickle.dump()` """ diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index ecf65c3035..c31648cbc8 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -622,10 +622,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index a46dc94e55..fa5a438fc2 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -558,10 +558,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/nlinear.py b/darts/models/forecasting/nlinear.py index c759ba3649..0dbec335d5 100644 --- a/darts/models/forecasting/nlinear.py +++ b/darts/models/forecasting/nlinear.py @@ -248,10 +248,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index d2ca6afc62..0eef89aca4 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -30,6 +30,7 @@ def __init__( self, input_chunk_length: int, output_chunk_length: int, + train_sample_shape: Optional[Tuple] = None, loss_fn: nn.modules.loss._Loss = nn.MSELoss(), torch_metrics: Optional[ Union[torchmetrics.Metric, torchmetrics.MetricCollection] @@ -59,6 +60,9 @@ def __init__( Number of input past time steps per chunk. output_chunk_length Number of output time steps per chunk. + train_sample_shape + Shape of the model's input, used to instantiate model without calling ``fit_from_dataset`` and + perform sanity check on new training/inference datasets used for re-training or prediction. loss_fn PyTorch loss function used for training. This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. @@ -101,6 +105,9 @@ def __init__( # by default models are deterministic (i.e. not probabilistic) self.likelihood = likelihood + # saved in checkpoint to be able to instantiate a model without calling fit_from_dataset + self.train_sample_shape = train_sample_shape + # persist optimiser and LR scheduler parameters self.optimizer_cls = optimizer_cls self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs @@ -383,11 +390,17 @@ def _produce_predict_output(self, x: Tuple): def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # we must save the dtype for correct parameter precision at loading time checkpoint["model_dtype"] = self.dtype + # we must save the shape of the input to be able to instanciate the model without calling fit_from_dataset + checkpoint["train_sample_shape"] = self.train_sample_shape def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision # before parameters are loaded by PyTorch-Lightning dtype = checkpoint["model_dtype"] + self.to_dtype(dtype) + + def to_dtype(self, dtype): + """Cast module precision (float32 by default) to another precision.""" if dtype == torch.float16: self.half() if dtype == torch.float32: diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index 794a6c82ba..5b3b3e39bb 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -281,10 +281,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index 15b0fcfdb8..cee1a87c15 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -323,10 +323,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index 2e8df7f390..faca1bfdd7 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -763,10 +763,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index dd8a957e7d..b361d4f998 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -154,10 +154,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. @@ -271,7 +271,7 @@ def __init__( # get model name and work dir if model_name is None: - current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S.%f") + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S") model_name = current_time + "_torch_model_run_" + str(os.getpid()) self.model_name = model_name @@ -415,6 +415,10 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None: "calling `super.__init__(...)`. Do this with `self._extract_pl_module_params(**self.model_params).`", ) + self.pl_module_params["train_sample_shape"] = [ + variate.shape if variate is not None else None + for variate in self.train_sample + ] # the tensors have shape (chunk_length, nr_dimensions) self.model = self._create_model(self.train_sample) self._module_name = self.model.__class__.__name__ @@ -860,7 +864,7 @@ def fit_from_dataset( same_dims, "The dimensionality of the series in the training set do not match the dimensionality" " of the series the model has previously been trained on. " - "Model input/output dimensions = {}, provided input/ouptput dimensions = {}".format( + "Model input/output dimensions = {}, provided input/output dimensions = {}".format( tuple( s.shape[1] if s is not None else None for s in self.train_sample ), @@ -1261,8 +1265,10 @@ def save(self, path: Optional[str] = None) -> None: Parameters ---------- path - Path under which to save the model at its current state. If no path is specified, the model is automatically - saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pt"``. E.g., ``"RNNModel_2020-01-01_12:00:00.pt"``. + Path under which to save the model at its current state. Please avoid path starting with "last-" or + "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. + E.g., ``"RNNModel_2020-01-01_12_00_00.pt"``. """ if path is None: # default path @@ -1273,9 +1279,21 @@ def save(self, path: Optional[str] = None) -> None: torch.save(self, f_out) # save the LightningModule checkpoint + path_ptl_ckpt = path + ".ckpt" if self.trainer is not None: - path_ptl_ckpt = path + ".ckpt" self.trainer.save_checkpoint(path_ptl_ckpt) + # TODO: keep track of PyTorch Lightning to see if they implement model checkpoint saving + # without having to call fit/predict/validate/test before + # try to recover original automatic PL checkpoint + elif self.load_ckpt_path: + if os.path.exists(self.load_ckpt_path): + shutil.copy(self.load_ckpt_path, path_ptl_ckpt) + else: + logger.warning( + f"Model was not trained since the last loading and attempt to retrieve PyTorch " + f"Lightning checkpoint {self.load_ckpt_path} was unsuccessful: model was saved " + f"without its weights." + ) @staticmethod def load(path: str, **kwargs) -> "TorchForecastingModel": @@ -1325,6 +1343,12 @@ def load(path: str, **kwargs) -> "TorchForecastingModel": path_ptl_ckpt = path + ".ckpt" if os.path.exists(path_ptl_ckpt): model.model = model._load_from_checkpoint(path_ptl_ckpt, **kwargs) + else: + model._fit_called = False + logger.warning( + f"Model was loaded without weights since no PyTorch LightningModule checkpoint ('.ckpt') could be " + f"found at {path_ptl_ckpt}. Please call `fit()` before calling `predict()`." + ) return model @staticmethod @@ -1372,7 +1396,7 @@ def load_from_checkpoint( Parameters ---------- model_name - The name of the model (used to retrieve the checkpoints folder's name). + The name of the model, used to retrieve the checkpoints folder's name. work_dir Working directory (containing the checkpoints folder). Defaults to current working directory. file_name @@ -1409,7 +1433,7 @@ def load_from_checkpoint( model = TorchForecastingModel.load(base_model_path, **kwargs) # load PyTorch LightningModule from checkpoint - # if file_name is None, find most recent file in savepath that is a checkpoint + # if file_name is None, find the path of the best or most recent checkpoint in savepath if file_name is None: file_name = _get_checkpoint_fname(work_dir, model_name, best=best) @@ -1417,6 +1441,8 @@ def load_from_checkpoint( logger.info(f"loading {file_name}") model.model = model._load_from_checkpoint(file_path, **kwargs) + # restore _fit_called attribute, set to False in load() if no .ckpt is found/provided + model._fit_called = True model.load_ckpt_path = file_path return model @@ -1429,6 +1455,152 @@ def _load_from_checkpoint(self, file_path, **kwargs): pl_module_cls = getattr(sys.modules[self._module_path], self._module_name) return pl_module_cls.load_from_checkpoint(file_path, **kwargs) + def load_weights_from_checkpoint( + self, + model_name: str = None, + work_dir: str = None, + file_name: str = None, + best: bool = True, + strict: bool = True, + **kwargs, + ): + """ + Load only the weights from automatically saved checkpoints under '{work_dir}/darts_logs/{model_name}/ + checkpoints/'. This method is used for models that were created with ``save_checkpoints=True`` and + that need to be re-trained or fine-tuned with different optimizer or learning rate scheduler. However, + it can also be used to load weights for inference. + + To resume an interrupted training, please consider using :meth:`load_from_checkpoint() + ` which also reload the trainer, optimizer and + learning rate scheduler states. + + For manually saved model, consider using :meth:`load() ` or + :meth:`load_weights() ` instead. + + Parameters + ---------- + model_name + The name of the model, used to retrieve the checkpoints folder's name. Default: ``self.model_name``. + work_dir + Working directory (containing the checkpoints folder). Defaults to current working directory. + file_name + The name of the checkpoint file. If not specified, use the most recent one. + best + If set, will retrieve the best model (according to validation loss) instead of the most recent one. Only + is ignored when ``file_name`` is given. Default: ``True``. + strict + If set, strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict(). + Default: ``True``. + For more information, read the `official documentation `_. + **kwargs + Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a + different device than the one from which it was saved. + For more information, read the `official documentation `_. + """ + raise_if( + "weights_only" in kwargs.keys() and kwargs["weights_only"], + "Passing `weights_only=True` to `torch.load` will disrupt this" + " method sanity checks.", + logger, + ) + + # use the name of the model being loaded with the saved weights + if model_name is None: + model_name = self.model_name + + if work_dir is None: + work_dir = os.path.join(os.getcwd(), DEFAULT_DARTS_FOLDER) + + # load PyTorch LightningModule from checkpoint + # if file_name is None, find the path of the best or most recent checkpoint in savepath + if file_name is None: + file_name = _get_checkpoint_fname(work_dir, model_name, best=best) + + # checkpoints generated by PL, prefix is defined in TorchForecastingModel __init__() + if file_name[:5] == "last-" or file_name[:5] == "best-": + checkpoint_dir = _get_checkpoint_folder(work_dir, model_name) + # manual save + else: + checkpoint_dir = os.getcwd() + + ckpt_path = os.path.join(checkpoint_dir, file_name) + ckpt = torch.load(ckpt_path, **kwargs) + ckpt_hyper_params = ckpt["hyper_parameters"] + + # verify that the arguments passed to the constructor match those of the checkpoint + for param_key, param_value in self.model_params.items(): + # TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers + if param_key in ckpt_hyper_params.keys() and param_value is not None: + # some parameters must be converted + if isinstance(ckpt_hyper_params[param_key], list) and not isinstance( + param_value, list + ): + param_value = [param_value] * len(ckpt_hyper_params[param_key]) + + raise_if( + param_value != ckpt_hyper_params[param_key], + f"The values of the hyper parameter {param_key} should be identical between " + f"the instantiated model ({param_value}) and the loaded checkpoint " + f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.", + logger, + ) + + # indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported + raise_if_not( + "train_sample_shape" in ckpt.keys(), + "The provided checkpoint was generated with darts release <= 0.23.1" + " and it is missing the 'train_sample_shape' key. This value must" + " be computed from the `model.train_sample` attribute and manually" + " added to the checkpoint prior to loading.", + logger, + ) + + # pl_forecasting module saves the train_sample shape, must recreate one + mock_train_sample = [ + np.zeros(sample_shape) if sample_shape else None + for sample_shape in ckpt["train_sample_shape"] + ] + self.train_sample = tuple(mock_train_sample) + + # instanciate the model without having to call `fit_from_dataset` + self._init_model() + # cast model precision to correct type + self.model.to_dtype(ckpt["model_dtype"]) + # load only the weights from the state dict + self.model.load_state_dict(ckpt["state_dict"], strict=strict) + # update the fit_called attribute to allow for direct inference + self._fit_called = True + + def load_weights(self, path: str, **kwargs): + """ + Loads the weights from a manually saved model (saved with :meth:`save() `). + + Parameters + ---------- + path + Path from which to load the model's weights. If no path was specified when saving the model, the + automatically generated path ending with ".pt" has to be provided. + **kwargs + Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a + different device than the one from which it was saved. + For more information, read the `official documentation `_. + + """ + path_ptl_ckpt = path + ".ckpt" + raise_if_not( + os.path.exists(path_ptl_ckpt), + f"Could not find PyTorch LightningModule checkpoint {path_ptl_ckpt}.", + logger, + ) + + self.load_weights_from_checkpoint( + file_name=path_ptl_ckpt, + **kwargs, + ) + def to_cpu(self): """Updates the PyTorch Lightning Trainer parameters to move the model to CPU the next time :fun:`fit()` or :func:`predict()` is called. diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 7bf9065a29..29addb790e 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -411,10 +411,10 @@ def __init__( Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, - defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., - ``"2021-06-14_09:53:32_torch_model_run_44607"``. + ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index a2867e7ec1..7f29c5f683 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -7,6 +7,7 @@ from darts import TimeSeries from darts.logging import get_logger +from darts.metrics import mape from darts.tests.base_test_class import DartsBaseTestClass logger = get_logger(__name__) @@ -87,6 +88,7 @@ def test_suppress_automatic_save(self, patch_save_model): def test_manual_save_and_load(self): """validate manual save with automatic save files by comparing output between the two""" + model_dir = os.path.join(self.temp_work_dir) manual_name = "test_save_manual" auto_name = "test_save_automatic" model_manual_save = RNNModel( @@ -110,11 +112,27 @@ def test_manual_save_and_load(self): random_state=42, ) + # save model without training + no_training_ckpt = "no_training.pth.tar" + no_training_ckpt_path = os.path.join(model_dir, no_training_ckpt) + model_manual_save.save(no_training_ckpt_path) + # check that model object file was created + self.assertTrue(os.path.exists(no_training_ckpt_path)) + # check that the PyTorch Ligthning ckpt does not exist + self.assertFalse(os.path.exists(no_training_ckpt_path + ".ckpt")) + # informative exception about `fit()` not called + with self.assertRaises( + ValueError, + msg="The model must be fit before calling predict(). " + "For global models, if predict() is called without specifying a series, " + "the model must have been fit on a single training series.", + ): + no_train_model = RNNModel.load(no_training_ckpt_path) + no_train_model.predict(n=4) + model_manual_save.fit(self.series, epochs=1) model_auto_save.fit(self.series, epochs=1) - model_dir = os.path.join(self.temp_work_dir) - # check that file was not created with manual save self.assertFalse( os.path.exists(os.path.join(model_dir, manual_name, "checkpoints")) @@ -164,6 +182,37 @@ def test_manual_save_and_load(self): model_manual_save.predict(n=4), model_auto_save1.predict(n=4) ) + # save() model directly after load_from_checkpoint() + checkpoint_file_name_2 = "checkpoint_1.pth.tar" + checkpoint_file_name_cpkt_2 = checkpoint_file_name_2 + ".ckpt" + + model_path_manual_2 = os.path.join( + checkpoint_path_manual, checkpoint_file_name_2 + ) + model_path_manual_ckpt_2 = os.path.join( + checkpoint_path_manual, checkpoint_file_name_cpkt_2 + ) + model_auto_save2 = RNNModel.load_from_checkpoint( + model_name=auto_name, + work_dir=self.temp_work_dir, + best=False, + map_location="cpu", + ) + # save model directly after loading, model has no trainer + model_auto_save2.save(model_path_manual_2) + + # assert original .ckpt checkpoint was correctly copied + self.assertTrue(os.path.exists(model_path_manual_ckpt_2)) + + model_chained_load_save = RNNModel.load( + model_path_manual_2, map_location="cpu" + ) + + # compare chained load_from_checkpoint() save() with manual save + self.assertEqual( + model_chained_load_save.predict(n=4), model_manual_save.predict(n=4) + ) + def test_create_instance_new_model_no_name_set(self): RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir) # no exception is raised @@ -290,6 +339,135 @@ def test_train_from_10_n_epochs_20_fit_15_epochs(self): model1.fit(self.series, epochs=15) self.assertEqual(15, model1.epochs_trained) + def test_load_weights_from_checkpoint(self): + ts_training = self.series[:90] + ts_test = self.series[90:] + original_model_name = "original" + retrained_model_name = "retrained" + # original model, checkpoints are saved + model = RNNModel( + 12, + "RNN", + 5, + 1, + n_epochs=5, + work_dir=self.temp_work_dir, + save_checkpoints=True, + model_name=original_model_name, + random_state=1, + ) + model.fit(ts_training) + original_preds = model.predict(10) + original_mape = mape(original_preds, ts_test) + + # load last checkpoint of original model, train it for 2 additional epochs + model_rt = RNNModel( + 12, + "RNN", + 5, + 1, + n_epochs=5, + work_dir=self.temp_work_dir, + model_name=retrained_model_name, + random_state=1, + ) + model_rt.load_weights_from_checkpoint( + model_name=original_model_name, work_dir=self.temp_work_dir, best=False + ) + + # must indicate series otherwise self.training_series must be saved in checkpoint + loaded_preds = model_rt.predict(10, ts_training) + # save/load checkpoint should produce identical predictions + self.assertEqual(original_preds, loaded_preds) + + model_rt.fit(ts_training) + retrained_preds = model_rt.predict(10) + retrained_mape = mape(retrained_preds, ts_test) + self.assertTrue( + retrained_mape < original_mape, + f"Retrained model has a greater error (mape) than the original model, " + f"respectively {retrained_mape} and {original_mape}", + ) + + # raise Exception when trying to load ckpt weights in different architecture + with self.assertRaises(ValueError): + model_rt = RNNModel( + 12, + "RNN", + 10, # loaded model has only 5 hidden_layers + 5, + ) + model_rt.load_weights_from_checkpoint( + model_name=original_model_name, + work_dir=self.temp_work_dir, + best=False, + ) + + # raise Exception when trying to pass `weights_only`=True to `torch.load()` + with self.assertRaises(ValueError): + model_rt = RNNModel( + 12, + "RNN", + 5, + 5, + ) + model_rt.load_weights_from_checkpoint( + model_name=original_model_name, + work_dir=self.temp_work_dir, + best=False, + weights_only=True, + ) + + def test_load_weights(self): + ts_training = self.series[:90] + ts_test = self.series[90:] + original_model_name = "original" + retrained_model_name = "retrained" + # original model, checkpoints are saved + model = RNNModel( + 12, + "RNN", + 5, + 1, + n_epochs=5, + work_dir=self.temp_work_dir, + save_checkpoints=False, + model_name=original_model_name, + random_state=1, + ) + model.fit(ts_training) + path_manual_save = os.path.join(self.temp_work_dir, "RNN_manual_save.pt") + model.save(path_manual_save) + original_preds = model.predict(10) + original_mape = mape(original_preds, ts_test) + + # load last checkpoint of original model, train it for 2 additional epochs + model_rt = RNNModel( + 12, + "RNN", + 5, + 1, + n_epochs=5, + work_dir=self.temp_work_dir, + model_name=retrained_model_name, + random_state=1, + ) + model_rt.load_weights(path=path_manual_save) + + # must indicate series otherwise self.training_series must be saved in checkpoint + loaded_preds = model_rt.predict(10, ts_training) + # save/load checkpoint should produce identical predictions + self.assertEqual(original_preds, loaded_preds) + + model_rt.fit(ts_training) + retrained_preds = model_rt.predict(10) + retrained_mape = mape(retrained_preds, ts_test) + self.assertTrue( + retrained_mape < original_mape, + f"Retrained model has a greater mape error than the original model, " + f"respectively {retrained_mape} and {original_mape}", + ) + def test_optimizers(self): optimizers = [