Skip to content

Commit

Permalink
Add SaveNNMixin for fixing saving/loading of NNs
Browse files Browse the repository at this point in the history
  • Loading branch information
d.a.bunin committed Nov 24, 2022
1 parent 40efb0c commit 7d2cf1b
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 111 deletions.
85 changes: 52 additions & 33 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import sys
import warnings
import zipfile
from enum import Enum
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -118,7 +119,7 @@ def get_etna_version() -> Tuple[int, int, int]:


class SaveMixin:
"""Basic implementation of AbstractSaveable abstract class.
"""Basic implementation of ``AbstractSaveable`` abstract class.
It saves object to the zip archive with 2 files:
Expand All @@ -127,6 +128,21 @@ class SaveMixin:
* object.pkl: pickled object.
"""

def _save_metadata(self, archive: zipfile.ZipFile):
full_class_name = f"{inspect.getmodule(self).__name__}.{self.__class__.__name__}" # type: ignore
metadata = {
"etna_version": get_etna_version(),
"class": full_class_name,
}
metadata_str = json.dumps(metadata, indent=2, sort_keys=True)
metadata_bytes = metadata_str.encode("utf-8")
with archive.open("metadata.json", "w") as output_file:
output_file.write(metadata_bytes)

def _save_state(self, archive: zipfile.ZipFile):
with archive.open("object.pkl", "w") as output_file:
pickle.dump(self, output_file)

def save(self, path: pathlib.Path):
"""Save the object.
Expand All @@ -135,19 +151,36 @@ def save(self, path: pathlib.Path):
path:
Path to save object to.
"""
with ZipFile(path, "w") as zip_file:
full_class_name = f"{inspect.getmodule(self).__name__}.{self.__class__.__name__}" # type: ignore
metadata = {
"etna_version": get_etna_version(),
"class": full_class_name,
}
metadata_str = json.dumps(metadata, indent=2, sort_keys=True)
metadata_bytes = metadata_str.encode("utf-8")
with zip_file.open("metadata.json", "w") as output_file:
output_file.write(metadata_bytes)

with zip_file.open("object.pkl", "w") as output_file:
pickle.dump(self, output_file)
with ZipFile(path, "w") as archive:
self._save_metadata(archive)
self._save_state(archive)

@classmethod
def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]:
with archive.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
return metadata

@classmethod
def _validate_metadata(cls, metadata: Dict[str, Any]):
current_etna_version = get_etna_version()
saved_etna_version = tuple(metadata["etna_version"])

# if major version is different give a warning
if current_etna_version[0] != saved_etna_version[0] or current_etna_version[:2] < saved_etna_version[:2]:
current_etna_version_str = ".".join([str(x) for x in current_etna_version])
saved_etna_version_str = ".".join([str(x) for x in saved_etna_version])
warnings.warn(
f"The object was saved under etna version {saved_etna_version_str} "
f"but running version is {current_etna_version_str}, this can cause problems with compatibility!"
)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Any:
with archive.open("object.pkl", "r") as input_file:
return pickle.load(input_file)

@classmethod
def load(cls, path: pathlib.Path) -> Any:
Expand All @@ -158,22 +191,8 @@ def load(cls, path: pathlib.Path) -> Any:
path:
Path to load object from.
"""
with ZipFile(path, "r") as zip_file:
with zip_file.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
current_etna_version = get_etna_version()
saved_etna_version = tuple(metadata["etna_version"])

# if major version is different give a warning
if current_etna_version[0] != saved_etna_version[0] or current_etna_version[:2] < saved_etna_version[:2]:
current_etna_version_str = ".".join([str(x) for x in current_etna_version])
saved_etna_version_str = ".".join([str(x) for x in saved_etna_version])
warnings.warn(
f"The object was saved under etna version {saved_etna_version_str} "
f"but running version is {current_etna_version_str}, this can cause problems with compatibility!"
)

with zip_file.open("object.pkl", "r") as input_file:
return pickle.load(input_file)
with ZipFile(path, "r") as archive:
metadata = cls._load_metadata(archive)
cls._validate_metadata(metadata)
obj = cls._load_state(archive)
return obj
4 changes: 3 additions & 1 deletion etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from etna.datasets.tsdataset import TSDataset
from etna.loggers import tslogger
from etna.models.decorators import log_decorator
from etna.models.mixins import SaveNNMixin

if SETTINGS.torch_required:
import torch
Expand All @@ -32,6 +33,7 @@
from unittest.mock import Mock

LightningModule = Mock # type: ignore
SaveNNMixin = Mock # type: ignore


class AbstractModel(SaveMixin, AbstractSaveable, ABC, BaseMixin):
Expand Down Expand Up @@ -430,7 +432,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
return loss


class DeepBaseModel(DeepBaseAbstractModel, NonPredictionIntervalContextRequiredAbstractModel):
class DeepBaseModel(DeepBaseAbstractModel, SaveNNMixin, NonPredictionIntervalContextRequiredAbstractModel):
"""Class for partially implemented interfaces for holding deep models."""

def __init__(
Expand Down
27 changes: 27 additions & 0 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import zipfile
from abc import ABC
from abc import abstractmethod
from copy import deepcopy
Expand All @@ -7,9 +8,11 @@
from typing import Optional
from typing import Sequence

import dill
import numpy as np
import pandas as pd

from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
from etna.models.decorators import log_decorator

Expand Down Expand Up @@ -441,3 +444,27 @@ def get_model(self) -> Any:
if not hasattr(self._base_model, "get_model"):
raise NotImplementedError(f"get_model method is not implemented for {self._base_model.__class__.__name__}")
return self._base_model.get_model()


class SaveNNMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` torch related classes.
It saves object to the zip archive with 2 files:
* metadata.json: contains library version and class name.
* object.pt: object saved by ``torch.save``.
"""

def _save_state(self, archive: zipfile.ZipFile):
import torch

with archive.open("object.pt", "w") as output_file:
torch.save(self, output_file, pickle_module=dill)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Any:
import torch

with archive.open("object.pt", "r") as input_file:
return torch.load(input_file, pickle_module=dill)
3 changes: 2 additions & 1 deletion etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from etna.loggers import tslogger
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.base import log_decorator
from etna.models.mixins import SaveNNMixin
from etna.models.nn.utils import _DeepCopyMixin
from etna.transforms import PytorchForecastingTransform

Expand All @@ -24,7 +25,7 @@
from pytorch_lightning import LightningModule


class DeepARModel(_DeepCopyMixin, PredictionIntervalContextIgnorantAbstractModel):
class DeepARModel(_DeepCopyMixin, SaveNNMixin, PredictionIntervalContextIgnorantAbstractModel):
"""Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`.
Notes
Expand Down
9 changes: 1 addition & 8 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,7 @@ def configure_optimizers(self):


class MLPModel(DeepBaseModel):
"""MLPModel.
Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably, because there is a native method :py:meth:`pytorch_lightning.Trainer.save_checkpoint`
that solves problems with using multiple devices.
"""
"""MLPModel."""

def __init__(
self,
Expand Down
9 changes: 1 addition & 8 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,7 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":


class RNNModel(DeepBaseModel):
"""RNN based model on LSTM cell.
Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably, because there is a native method :py:meth:`pytorch_lightning.Trainer.save_checkpoint`
that solves problems with using multiple devices.
"""
"""RNN based model on LSTM cell."""

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from etna.loggers import tslogger
from etna.models.base import PredictionIntervalContextIgnorantAbstractModel
from etna.models.base import log_decorator
from etna.models.mixins import SaveNNMixin
from etna.models.nn.utils import _DeepCopyMixin
from etna.transforms import PytorchForecastingTransform

Expand All @@ -25,7 +26,7 @@
from pytorch_lightning import LightningModule


class TFTModel(_DeepCopyMixin, PredictionIntervalContextIgnorantAbstractModel):
class TFTModel(_DeepCopyMixin, SaveNNMixin, PredictionIntervalContextIgnorantAbstractModel):
"""Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`.
Notes
Expand Down
90 changes: 43 additions & 47 deletions tests/test_core/test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import pathlib
import pickle
import tempfile
from unittest.mock import patch
from zipfile import ZipFile

Expand All @@ -22,63 +21,60 @@ def test_get_etna_version():
assert len(version) == 3


def test_save_mixin_save():
with tempfile.TemporaryDirectory() as dir_path_str:
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(dir_path_str)
path = dir_path.joinpath("dummy.zip")
def test_save_mixin_save(tmp_path):
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(tmp_path)
path = dir_path.joinpath("dummy.zip")

dummy.save(path)
dummy.save(path)

with ZipFile(path, "r") as zip_file:
files = zip_file.namelist()
assert sorted(files) == ["metadata.json", "object.pkl"]
with ZipFile(path, "r") as zip_file:
files = zip_file.namelist()
assert sorted(files) == ["metadata.json", "object.pkl"]

with zip_file.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
assert sorted(metadata.keys()) == ["class", "etna_version"]
assert metadata["class"] == "tests.test_core.test_mixins.Dummy"
with zip_file.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
assert sorted(metadata.keys()) == ["class", "etna_version"]
assert metadata["class"] == "tests.test_core.test_mixins.Dummy"

with zip_file.open("object.pkl", "r") as input_file:
loaded_dummy = pickle.load(input_file)
assert loaded_dummy.a == dummy.a
assert loaded_dummy.b == dummy.b
with zip_file.open("object.pkl", "r") as input_file:
loaded_dummy = pickle.load(input_file)
assert loaded_dummy.a == dummy.a
assert loaded_dummy.b == dummy.b


def test_save_mixin_load_ok(recwarn):
with tempfile.TemporaryDirectory() as dir_path_str:
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(dir_path_str)
path = dir_path.joinpath("dummy.zip")
def test_save_mixin_load_ok(recwarn, tmp_path):
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(tmp_path)
path = dir_path.joinpath("dummy.zip")

dummy.save(path)
loaded_dummy = Dummy.load(path)
dummy.save(path)
loaded_dummy = Dummy.load(path)

assert loaded_dummy.a == dummy.a
assert loaded_dummy.b == dummy.b
assert len(recwarn) == 0
assert loaded_dummy.a == dummy.a
assert loaded_dummy.b == dummy.b
assert len(recwarn) == 0


@pytest.mark.parametrize(
"save_version, load_version", [((1, 5, 0), (2, 5, 0)), ((2, 5, 0), (1, 5, 0)), ((1, 5, 0), (1, 3, 0))]
)
@patch("etna.core.mixins.get_etna_version")
def test_save_mixin_load_warning(get_version_mock, save_version, load_version):
with tempfile.TemporaryDirectory() as dir_path_str:
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(dir_path_str)
path = dir_path.joinpath("dummy.zip")

get_version_mock.return_value = save_version
dummy.save(path)

save_version_str = ".".join([str(x) for x in save_version])
load_version_str = ".".join([str(x) for x in load_version])
with pytest.warns(
UserWarning,
match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}",
):
get_version_mock.return_value = load_version
_ = Dummy.load(path)
def test_save_mixin_load_warning(get_version_mock, save_version, load_version, tmp_path):
dummy = Dummy(a=1, b=2)
dir_path = pathlib.Path(tmp_path)
path = dir_path.joinpath("dummy.zip")

get_version_mock.return_value = save_version
dummy.save(path)

save_version_str = ".".join([str(x) for x in save_version])
load_version_str = ".".join([str(x) for x in load_version])
with pytest.warns(
UserWarning,
match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}",
):
get_version_mock.return_value = load_version
_ = Dummy.load(path)
1 change: 0 additions & 1 deletion tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_prediction_interval_run_infuture(example_tsds):
assert (segment_slice["target_0.975"] - segment_slice["target"] >= 0).all()


@pytest.mark.xfail(reason="Should be fixed in inference-v2.0", strict=True)
def test_save_load(example_tsds):
horizon = 3
model = DeepARModel(max_epochs=2, learning_rate=[0.01], gpus=0, batch_size=64)
Expand Down
1 change: 0 additions & 1 deletion tests/test_models/nn/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_mlp_layers():
assert repr(model_) == repr(model.mlp)


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
def test_save_load(example_tsds):
horizon = 3
model = MLPModel(
Expand Down
1 change: 0 additions & 1 deletion tests/test_models/nn/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_context_size(encoder_length):
assert model.context_size == encoder_length


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
def test_save_load(example_tsds):
model = RNNModel(input_size=1, encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=2))
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)
Loading

0 comments on commit 7d2cf1b

Please sign in to comment.