Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PathLike objects to model save() and load() #1754

Merged
merged 10 commits into from
May 21, 2023
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
## [Unreleased](https://github.com/unit8co/darts/tree/master)
[Full Changelog](https://github.com/unit8co/darts/compare/0.24.0...master)

### For users of the library:

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:

Expand Down
31 changes: 25 additions & 6 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import copy
import datetime
import inspect
import io
import os
import pickle
import time
Expand Down Expand Up @@ -1781,7 +1782,9 @@ def model_params(self) -> dict:
def _default_save_path(cls) -> str:
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"

def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> None:
def save(
self, path: Optional[Union[str, os.PathLike, BinaryIO]] = None, **pkl_kwargs
) -> None:
"""
Saves the model under a given path or file handle.

Expand Down Expand Up @@ -1812,16 +1815,24 @@ def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> Non
# default path
path = self._default_save_path() + ".pkl"

if isinstance(path, str):
if isinstance(path, (str, os.PathLike)):
# save the whole object using pickle
with open(path, "wb") as handle:
pickle.dump(obj=self, file=handle, **pkl_kwargs)
else:
elif isinstance(path, io.BufferedWriter):
# save the whole object using pickle
pickle.dump(obj=self, file=path, **pkl_kwargs)
else:
raise_log(
ValueError(
"Argument 'path' has to be either 'str' or 'PathLike' (for a filepath) "
f"or 'BufferedWriter' (for an already opened file), but was '{path.__class__}'."
),
logger=logger,
)

@staticmethod
def load(path: Union[str, BinaryIO]) -> "ForecastingModel":
def load(path: Union[str, os.PathLike, BinaryIO]) -> "ForecastingModel":
"""
Loads the model from a given path or file handle.

Expand All @@ -1831,7 +1842,7 @@ def load(path: Union[str, BinaryIO]) -> "ForecastingModel":
Path or file handle from which to load the model.
"""

if isinstance(path, str):
if isinstance(path, (str, os.PathLike)):
raise_if_not(
os.path.exists(path),
f"The file {path} doesn't exist",
Expand All @@ -1840,8 +1851,16 @@ def load(path: Union[str, BinaryIO]) -> "ForecastingModel":

with open(path, "rb") as handle:
model = pickle.load(file=handle)
else:
elif isinstance(path, io.BufferedReader):
model = pickle.load(file=path)
else:
raise_log(
ValueError(
"Argument 'path' has to be either 'str' or 'PathLike' (for a filepath) "
f"or 'BufferedReader' (for an already opened file), but was '{path.__class__}'."
),
logger=logger,
)

return model

Expand Down
42 changes: 31 additions & 11 deletions darts/tests/models/forecasting/test_local_forecasting_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import pathlib
import shutil
import tempfile
from typing import Callable
Expand Down Expand Up @@ -107,7 +108,6 @@


class LocalForecastingModelsTestCase(DartsBaseTestClass):

# forecasting horizon used in runnability tests
forecasting_horizon = 5

Expand Down Expand Up @@ -154,8 +154,9 @@ def test_save_load_model(self):

for model in [ARIMA(1, 1, 1), LinearRegressionModel(lags=12)]:
model_path_str = type(model).__name__
model_path_file = model_path_str + "_file"
model_paths = [model_path_str, model_path_file]
model_path_pathlike = pathlib.Path(model_path_str + "_pathlike")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extending this test for invalid read/write path would be nice as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I put it in an extra test though to avoid inflating the original test too much.

model_path_binary = model_path_str + "_binary"
model_paths = [model_path_str, model_path_pathlike, model_path_binary]
full_model_paths = [
os.path.join(self.temp_work_dir, p) for p in model_paths
]
Expand All @@ -166,7 +167,8 @@ def test_save_load_model(self):
# test save
model.save()
model.save(model_path_str)
with open(model_path_file, "wb") as f:
model.save(model_path_pathlike)
with open(model_path_binary, "wb") as f:
model.save(f)

for p in full_model_paths:
Expand All @@ -180,13 +182,19 @@ def test_save_load_model(self):
if p.startswith(type(model).__name__)
]
)
== 3
== len(full_model_paths) + 1
)

# test load
loaded_model_str = type(model).load(model_path_str)
loaded_model_file = type(model).load(model_path_file)
loaded_models = [loaded_model_str, loaded_model_file]
loaded_model_pathlike = type(model).load(model_path_pathlike)
with open(model_path_binary, "rb") as f:
loaded_model_binary = type(model).load(f)
loaded_models = [
loaded_model_str,
loaded_model_pathlike,
loaded_model_binary,
]

for loaded_model in loaded_models:
self.assertEqual(
Expand All @@ -195,6 +203,22 @@ def test_save_load_model(self):

os.chdir(cwd)

def test_save_load_model_invalid_path(self):
# check if save and load methods raise an error when given an invalid path
model = ARIMA(1, 1, 1)
model.fit(self.ts_gaussian)

# Use a byte string as path (, which is not supported)
model_path_invalid = b"invalid_path"

# test save
with pytest.raises(ValueError):
model.save(model_path_invalid)

# test load
with pytest.raises(ValueError):
type(model).load(model_path_invalid)

def test_models_runnability(self):
for model, _ in models:
if not isinstance(model, RegressionModel):
Expand Down Expand Up @@ -366,7 +390,6 @@ def test_forecast_time_index(self):

@pytest.mark.slow
def test_statsmodels_future_models(self):

# same tests, but VARIMA requires to work on a multivariate target series
UNIVARIATE = "univariate"
MULTIVARIATE = "multivariate"
Expand Down Expand Up @@ -545,7 +568,6 @@ def test_backtest_retrain(
]

for model_cls, retrainable, multivariate, retrain, model_type in params:

if (
not isinstance(retrain, (int, bool, Callable))
or (isinstance(retrain, int) and retrain < 0)
Expand All @@ -556,7 +578,6 @@ def test_backtest_retrain(
_ = model_cls.historical_forecasts(series, retrain=retrain)

else:

if isinstance(retrain, Mock):
# resets patch_retrain_func call_count to 0
retrain.call_count = 0
Expand All @@ -569,7 +590,6 @@ def test_backtest_retrain(
with patch(
predict_method_to_patch, side_effect=series
) as patch_predict_method:

# Set _fit_called attribute to True, otherwise retrain function is never called
model_cls._fit_called = True

Expand Down