Skip to content

Add SaveEnsembleMixin #1046

Merged
merged 6 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
-
- Add `SaveModelPipelineMixin`, add `load`, add saving for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036))
- Add `SaveModelPipelineMixin`, add `load`, add saving and loading for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036))
- Add `SaveMixin` to models and transforms ([#1007](https://github.com/tinkoff-ai/etna/pull/1007))
- Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988))
- Add `experimental` module with `TimeSeriesBinaryClassifier` and `PredictabilityAnalyzer` ([#985](https://github.com/tinkoff-ai/etna/pull/985))
- Inference track results: add `predict` method to pipelines, teach some models to work with context, change hierarchy of base models, update notebook examples ([#979](https://github.com/tinkoff-ai/etna/pull/979))
- Add `get_ruptures_regularization` into `experimental` module ([#1001](https://github.com/tinkoff-ai/etna/pull/1001))
-
- Add `SaveEnsembleMixin`, add saving and loading for `VotingEnsemble`, `StackingEnsemble` and `DirectEnsemble` ([#1046](https://github.com/tinkoff-ai/etna/pull/1046))
### Changed
-
- Change returned model in get_model of BATSModel, TBATSModel ([#987](https://github.com/tinkoff-ai/etna/pull/987))
Expand Down
6 changes: 4 additions & 2 deletions etna/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from hydra_slayer import get_factory


def load(path: pathlib.Path) -> Any:
def load(path: pathlib.Path, **kwargs: Any) -> Any:
"""Load saved object by path.

Parameters
----------
path:
Path to load object from.
kwargs:
Parameters for loading specific for the loaded object.

Returns
-------
Expand All @@ -33,7 +35,7 @@ def load(path: pathlib.Path) -> Any:

# create object for that class
object_class = get_factory(object_class_name)
loaded_object = object_class.load(path=path)
loaded_object = object_class.load(path=path, **kwargs)

return loaded_object

Expand Down
2 changes: 1 addition & 1 deletion etna/ensembles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from etna.ensembles.base import EnsembleMixin
from etna.ensembles.direct_ensemble import DirectEnsemble
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.stacking_ensemble import StackingEnsemble
from etna.ensembles.voting_ensemble import VotingEnsemble
5 changes: 3 additions & 2 deletions etna/ensembles/direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from joblib import delayed

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.pipeline.base import BasePipeline


class DirectEnsemble(EnsembleMixin, BasePipeline):
class DirectEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.

Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons.
Expand Down
76 changes: 73 additions & 3 deletions etna/ensembles/base.py → etna/ensembles/mixins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import pathlib
import tempfile
import zipfile
from copy import deepcopy
from typing import Any
from typing import List
from typing import Optional

import pandas as pd

from etna.core import SaveMixin
from etna.core import load
from etna.datasets import TSDataset
from etna.loggers import tslogger
from etna.pipeline.base import BasePipeline
Expand Down Expand Up @@ -56,6 +61,22 @@ def _predict_pipeline(
tslogger.log(msg=f"Prediction is done with {pipeline}.")
return prediction


class SaveEnsembleMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines.

It saves object to the zip archive with 3 entities:

* metadata.json: contains library version and class name.

* object.pkl: pickled without pipelines and ts.

* pipelines: folder with saved pipelines.
"""

pipelines: List[BasePipeline]
ts: Optional[TSDataset]

def save(self, path: pathlib.Path):
"""Save the object.

Expand All @@ -64,15 +85,64 @@ def save(self, path: pathlib.Path):
path:
Path to save object to.
"""
raise NotImplementedError()
pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")

# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

# save transforms separately
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = 8
for i, pipeline in enumerate(pipelines):
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
archive.write(pipeline_save_path, f"pipelines/{save_name}")

@classmethod
def load(cls, path: pathlib.Path) -> Any:
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:
"""Load an object.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
raise NotImplementedError()
obj = super().load(path=path)
obj.ts = deepcopy(ts)

with zipfile.ZipFile(path, "r") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

archive.extractall(temp_dir)

# load pipelines
pipelines_dir = temp_dir / "pipelines"
pipelines = []
for path in sorted(pipelines_dir.iterdir()):
pipelines.append(load(path, ts=ts))

obj.pipelines = pipelines

return obj
5 changes: 3 additions & 2 deletions etna/ensembles/stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from typing_extensions import Literal

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.pipeline.base import BasePipeline


class StackingEnsemble(EnsembleMixin, BasePipeline):
class StackingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""StackingEnsemble is a pipeline that forecast future using the metamodel to combine the forecasts of the base models.

Examples
Expand Down
5 changes: 3 additions & 2 deletions etna/ensembles/voting_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

from etna.analysis.feature_relevance.relevance_table import TreeBasedRegressor
from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.pipeline.base import BasePipeline


class VotingEnsemble(EnsembleMixin, BasePipeline):
class VotingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""VotingEnsemble is a pipeline that forecast future values with weighted averaging of it's pipelines forecasts.

Examples
Expand Down
9 changes: 7 additions & 2 deletions etna/pipeline/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _predict(
class SaveModelPipelineMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for pipelines with model inside.

It saves object to the zip archive with 2 files:
It saves object to the zip archive with 4 entities:

* metadata.json: contains library version and class name.

Expand All @@ -118,6 +118,7 @@ def save(self, path: pathlib.Path):
model = self.model
transforms = self.transforms
ts = self.ts

try:
# extract attributes we can't easily save
delattr(self, "model")
Expand All @@ -143,7 +144,7 @@ def save(self, path: pathlib.Path):
# save transforms separately
transforms_dir = temp_dir / "transforms"
transforms_dir.mkdir()
num_digits = len(str(len(transforms) - 1))
num_digits = 8
for i, transform in enumerate(transforms):
save_name = f"{i:0{num_digits}d}.zip"
transform_save_path = transforms_dir / save_name
Expand Down Expand Up @@ -189,4 +190,8 @@ def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:

obj.transforms = transforms

# set transforms in ts
if obj.ts is not None:
obj.ts.transforms = transforms

return obj
19 changes: 19 additions & 0 deletions tests/test_core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pathlib
import tempfile

import pandas as pd
import pytest

from etna.core import load
from etna.models import NaiveModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform


Expand All @@ -21,6 +24,22 @@ def test_load_ok():
transform.save(save_path)

new_transform = load(save_path)

assert type(new_transform) == type(transform)
for attribute in ["in_column", "value", "inplace"]:
assert getattr(new_transform, attribute) == getattr(transform, attribute)


def test_load_ok_with_params(example_tsds):
pipeline = Pipeline(model=NaiveModel(), horizon=7)
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
save_path = temp_dir / "pipeline.zip"
pipeline.fit(ts=example_tsds)
pipeline.save(save_path)

new_pipeline = load(save_path, ts=example_tsds)

assert new_pipeline.ts is not None
assert type(new_pipeline) == type(pipeline)
pd.testing.assert_frame_equal(new_pipeline.ts.to_pandas(), example_tsds.to_pandas())
40 changes: 22 additions & 18 deletions tests/test_ensembles/test_direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from etna.ensembles import DirectEnsemble
from etna.models import NaiveModel
from etna.pipeline import Pipeline
from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original


@pytest.fixture
def direct_ensemble_pipeline() -> DirectEnsemble:
ensemble = DirectEnsemble(
pipelines=[
Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1),
Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2),
]
)
return ensemble


@pytest.fixture
Expand Down Expand Up @@ -36,32 +48,24 @@ def test_get_horizon_raise_error_on_same_horizons():
_ = DirectEnsemble(pipelines=[Mock(horizon=1), Mock(horizon=1)])


def test_forecast(simple_ts_train, simple_ts_forecast):
ensemble = DirectEnsemble(
pipelines=[
Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1),
Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2),
]
)
ensemble.fit(simple_ts_train)
forecast = ensemble.forecast()
def test_forecast(direct_ensemble_pipeline, simple_ts_train, simple_ts_forecast):
direct_ensemble_pipeline.fit(simple_ts_train)
forecast = direct_ensemble_pipeline.forecast()
pd.testing.assert_frame_equal(forecast.to_pandas(), simple_ts_forecast.to_pandas())


def test_predict(simple_ts_train):
ensemble = DirectEnsemble(
pipelines=[
Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1),
Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2),
]
)
def test_predict(direct_ensemble_pipeline, simple_ts_train):
smallest_pipeline = Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1)
ensemble.fit(simple_ts_train)
direct_ensemble_pipeline.fit(simple_ts_train)
smallest_pipeline.fit(simple_ts_train)
prediction = ensemble.predict(
prediction = direct_ensemble_pipeline.predict(
ts=simple_ts_train, start_timestamp=simple_ts_train.index[1], end_timestamp=simple_ts_train.index[2]
)
expected_prediction = smallest_pipeline.predict(
ts=simple_ts_train, start_timestamp=simple_ts_train.index[1], end_timestamp=simple_ts_train.index[2]
)
pd.testing.assert_frame_equal(prediction.to_pandas(), expected_prediction.to_pandas())


def test_save_load(direct_ensemble_pipeline, example_tsds):
assert_pipeline_equals_loaded_original(pipeline=direct_ensemble_pipeline, ts=example_tsds)
24 changes: 0 additions & 24 deletions tests/test_ensembles/test_ensemble_mixin.py

This file was deleted.

Loading