Skip to content

Fixes from inference track #1096

Merged
merged 9 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from etna.models.nn.mlp import MLPModel
from etna.models.nn.rnn import RNNModel
from etna.models.nn.tft import TFTModel
from etna.models.nn.utils import PytorchForecastingDatasetBuilder
14 changes: 10 additions & 4 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def __init__(
if loss is None:
loss = NormalDistributionLoss()

if (encoder_length is None or decoder_length is None) and dataset_builder is not None:

if dataset_builder is not None:
self.encoder_length = dataset_builder.max_encoder_length
self.decoder_length = dataset_builder.max_prediction_length
self.dataset_builder = dataset_builder
elif (encoder_length is not None and decoder_length is not None) and dataset_builder is None:
elif encoder_length is not None and decoder_length is not None:
self.encoder_length = encoder_length
self.decoder_length = decoder_length
self.dataset_builder = PytorchForecastingDatasetBuilder(
Expand Down Expand Up @@ -199,7 +198,11 @@ def forecast(

@log_decorator
def predict(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make predictions.

Expand All @@ -210,6 +213,9 @@ def predict(
----------
ts:
Dataset with features
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Expand Down
15 changes: 11 additions & 4 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ def __init__(
super().__init__()
if loss is None:
loss = QuantileLoss()
if (encoder_length is None or decoder_length is None) and dataset_builder is not None:

if dataset_builder is not None:
self.encoder_length = dataset_builder.max_encoder_length
self.decoder_length = dataset_builder.max_prediction_length
self.dataset_builder = dataset_builder
elif (encoder_length is not None and decoder_length is not None) and dataset_builder is None:
elif encoder_length is not None and decoder_length is not None:
self.encoder_length = encoder_length
self.decoder_length = decoder_length
self.dataset_builder = PytorchForecastingDatasetBuilder(
Expand All @@ -107,6 +106,7 @@ def __init__(
)
else:
raise ValueError("You should provide either dataset_builder or encoder_length and decoder_length")

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.lr = lr
Expand Down Expand Up @@ -227,7 +227,11 @@ def forecast(

@log_decorator
def predict(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make predictions.

Expand All @@ -238,6 +242,9 @@ def predict(
----------
ts:
Dataset with features
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Expand Down
28 changes: 27 additions & 1 deletion etna/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_train_dataset(self, ts: TSDataset) -> TimeSeriesDataSet:
return pf_dataset

def create_inference_dataset(self, ts: TSDataset) -> TimeSeriesDataSet:
"""Create train dataset.
"""Create inference dataset.

Parameters
----------
Expand Down Expand Up @@ -236,7 +236,33 @@ def fit(self, ts: TSDataset):
raise ValueError("Trainer or model is None")
return self

def _get_first_prediction_timestamp(self, ts: TSDataset, horizon: int) -> pd.Timestamp:
return ts.index[-horizon]

def _is_in_sample_prediction(self, ts: TSDataset, horizon: int) -> bool:
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
return first_prediction_timestamp <= self._last_train_timestamp

def _is_prediction_with_gap(self, ts: TSDataset, horizon: int) -> bool:
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
first_timestamp_after_train = pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
return first_prediction_timestamp > first_timestamp_after_train

def _make_target_prediction(self, ts: TSDataset, horizon: int) -> Tuple[TSDataset, DataLoader]:
if self._is_in_sample_prediction(ts=ts, horizon=horizon):
raise NotImplementedError(
"It is not possible to make in-sample predictions with DeepAR model! "
"In-sample predictions aren't supported by current implementation."
)
elif self._is_prediction_with_gap(ts=ts, horizon=horizon):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
raise NotImplementedError(
"You can only forecast from the next point after the last one in the training dataset: "
f"last train timestamp: {self._last_train_timestamp}, first prediction timestamp is {first_prediction_timestamp}"
)
else:
pass

if len(ts.df) != horizon + self.encoder_length:
raise ValueError("Length of dataset must be equal to horizon + max_encoder_length")

Expand Down
44 changes: 44 additions & 0 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
from copy import deepcopy
from typing import Dict
from typing import List
Expand Down Expand Up @@ -154,3 +155,46 @@ def _forecast_prediction_interval(
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

return predictions

def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
fit_ts = self._fit_ts

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you did it via try/except?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your alternative? I wanted to make action inside finally block even if there is some exception. That is why I did it like this.

# extract attributes we can't easily save
delattr(self, "_fit_ts")

# save the remaining part
super().save(path=path)
finally:
self._fit_ts = fit_ts

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> "HierarchicalPipeline":
"""Load an object.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj._fit_ts = deepcopy(ts)
if ts is not None:
obj.ts = obj.reconciliator.aggregate(ts=ts)
else:
obj.ts = None
return obj
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset

collect_ignore = ["test_models/test_inference/"]


@pytest.fixture(autouse=True)
def random_seed():
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions tests/test_models/test_inference/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _test_prediction_in_sample_full(ts, model, transforms, method_name):

# forecasting
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.transform(transforms)
prediction_size = len(forecast_ts.index)
forecast_ts = make_prediction(model=model, ts=forecast_ts, prediction_size=prediction_size, method_name=method_name)

Expand All @@ -56,7 +56,7 @@ def _test_prediction_in_sample_suffix(ts, model, transforms, method_name, num_sk

# forecasting
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.transform(transforms)
forecast_ts.df = forecast_ts.df.iloc[(num_skip_points - model.context_size) :]
prediction_size = len(forecast_ts.index) - num_skip_points
forecast_ts = make_prediction(model=model, ts=forecast_ts, prediction_size=prediction_size, method_name=method_name)
Expand Down
Loading