Skip to content

Commit

Permalink
Fix bug with hardcoded frequency in PytorchForecastingTransform (#107)
Browse files Browse the repository at this point in the history
* Fix bug with hardcoded frequency in PytorchForecastingTransformer, add notes about position sensitivity of transform, write tests
  • Loading branch information
Mr-Geekman authored Sep 28, 2021
1 parent 1073870 commit 262fa68
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102))
- Fix bug with hardcoded frequency in PytorchForecastingTransform ([#107](https://github.com/tinkoff-ai/etna-ts/pull/107))

## [1.1.1] - 2021-09-23
### Fixed
Expand Down
19 changes: 16 additions & 3 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import log_decorator
from etna.transforms import PytorchForecastingTransform


class DeepARModel(Model):
Expand Down Expand Up @@ -89,6 +90,16 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> DeepAR:
dropout=self.dropout,
)

@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
)

@log_decorator
def fit(self, ts: TSDataset) -> "DeepARModel":
"""
Expand All @@ -103,7 +114,8 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
-------
DeepARModel
"""
self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train)
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

self.trainer = pl.Trainer(
logger=tslogger.pl_loggers,
Expand All @@ -113,7 +125,7 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
gradient_clip_val=self.gradient_clip_val,
)

train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)
train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)

Expand All @@ -134,7 +146,8 @@ def forecast(self, ts: TSDataset) -> TSDataset:
TSDataset
TSDataset with predictions.
"""
prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader(
pf_transform = self._get_pf_transform(ts)
prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader(
train=False, batch_size=self.batch_size * 2
)

Expand Down
21 changes: 17 additions & 4 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import log_decorator
from etna.transforms import PytorchForecastingTransform


class TFTModel(Model):
Expand Down Expand Up @@ -97,6 +98,16 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> TemporalFusionTransfor
hidden_continuous_size=self.hidden_continuous_size,
)

@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
)

@log_decorator
def fit(self, ts: TSDataset) -> "TFTModel":
"""
Expand All @@ -111,7 +122,8 @@ def fit(self, ts: TSDataset) -> "TFTModel":
-------
TFTModel
"""
self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train)
pf_transform = self._get_pf_transform(ts)
self.model = self._from_dataset(pf_transform.pf_dataset_train)

self.trainer = pl.Trainer(
logger=tslogger.pl_loggers,
Expand All @@ -121,14 +133,14 @@ def fit(self, ts: TSDataset) -> "TFTModel":
gradient_clip_val=self.gradient_clip_val,
)

train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)
train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)

return self

@log_decorator
def forecast(self, ts: TSDataset) -> pd.DataFrame:
def forecast(self, ts: TSDataset) -> TSDataset:
"""
Predict future.
Expand All @@ -142,7 +154,8 @@ def forecast(self, ts: TSDataset) -> pd.DataFrame:
TSDataset
TSDataset with predictions.
"""
prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader(
pf_transform = self._get_pf_transform(ts)
prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader(
train=False, batch_size=self.batch_size * 2
)

Expand Down
55 changes: 34 additions & 21 deletions etna/transforms/pytorch_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(
):
"""Parameters for TimeSeriesDataSet object.
Notes
-----
This transform should be added at the very end of `transforms` parameter.
Reference
---------
https://github.com/jdb78/pytorch-forecasting/blob/v0.8.5/pytorch_forecasting/data/timeseries.py#L117
Expand Down Expand Up @@ -76,6 +80,14 @@ def __init__(
self.lags = lags
self.scalers = scalers

@staticmethod
def _calculate_freq_unit(freq: str) -> pd.Timedelta:
"""Calculate frequency unit by its string representation."""
if freq[0].isdigit():
return pd.Timedelta(freq)
else:
return pd.Timedelta(1, unit=freq)

def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform":
"""
Fit TimeSeriesDataSet.
Expand All @@ -89,22 +101,22 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform":
-------
PytorchForecastingTransform
"""
ts = TSDataset(df, "1d")
self.freq = ts.freq
ts = ts.to_pandas(flatten=True)
ts = ts.dropna()
self.min_timestamp = ts.timestamp.min()
self.freq = pd.infer_freq(df.index)
ts = TSDataset(df, self.freq)
df_flat = ts.to_pandas(flatten=True)
df_flat = df_flat.dropna()
self.min_timestamp = df_flat.timestamp.min()

if self.time_varying_known_categoricals:
for feature_name in self.time_varying_known_categoricals:
ts[feature_name] = ts[feature_name].astype(str)
df_flat[feature_name] = df_flat[feature_name].astype(str)

ts["time_idx"] = ts["timestamp"] - self.min_timestamp
ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq)
ts["time_idx"] = ts["time_idx"].astype(int)
freq_unit = self._calculate_freq_unit(self.freq)
df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit
df_flat["time_idx"] = df_flat["time_idx"].astype(int)

pf_dataset = TimeSeriesDataSet(
ts,
df_flat,
time_idx="time_idx",
target="target",
group_ids=["segment"],
Expand Down Expand Up @@ -151,24 +163,25 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
We save TimeSeriesDataSet in instance to use it in the model.
It`s not right pattern of using Transforms and TSDataset.
"""
ts = TSDataset(df, "1d")
ts = ts.to_pandas(flatten=True)
ts = ts[ts.timestamp >= self.min_timestamp]
ts = ts.fillna(0)

ts["time_idx"] = ts["timestamp"] - self.min_timestamp
ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq)
ts["time_idx"] = ts["time_idx"].astype(int)
ts = TSDataset(df, self.freq)
df_flat = ts.to_pandas(flatten=True)
df_flat = df_flat[df_flat.timestamp >= self.min_timestamp]
df_flat = df_flat.fillna(0)

freq_unit = self._calculate_freq_unit(self.freq)
df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit
df_flat["time_idx"] = df_flat["time_idx"].astype(int)

if self.time_varying_known_categoricals:
for feature_name in self.time_varying_known_categoricals:
ts[feature_name] = ts[feature_name].astype(str)
df_flat[feature_name] = df_flat[feature_name].astype(str)

if inspect.stack()[1].function == "make_future":
pf_dataset_predict = TimeSeriesDataSet.from_parameters(
self.pf_dataset_params, ts, predict=True, stop_randomization=True
self.pf_dataset_params, df_flat, predict=True, stop_randomization=True
)
self.pf_dataset_predict = pf_dataset_predict
else:
pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, ts)
pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, df_flat)
self.pf_dataset_train = pf_dataset_train
return df
29 changes: 24 additions & 5 deletions tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,29 @@
from etna.datasets.tsdataset import TSDataset
from etna.metrics import MAE
from etna.models.nn import DeepARModel
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform


def test_fit_wrong_order_transform(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
add_const = AddConstTransform(in_column="target", value=1.0)
pft = PytorchForecastingTransform(
max_encoder_length=21,
max_prediction_length=8,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)

ts.fit_transform([pft, add_const])

model = DeepARModel(max_epochs=300, learning_rate=[0.1])
with pytest.raises(ValueError, match="add PytorchForecastingTransform"):
model.fit(ts)


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8, 21])
def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):
Expand All @@ -31,8 +50,8 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "1d")
ts_test = TSDataset(TSDataset.to_dataset(test), "1d")
ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False)
pft = PytorchForecastingTransform(
max_encoder_length=21,
Expand All @@ -45,10 +64,10 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):

ts_train.fit_transform([dft, pft])

tftmodel = DeepARModel(max_epochs=300, learning_rate=[0.1])
model = DeepARModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
tftmodel.fit(ts_train)
ts_pred = tftmodel.forecast(ts_pred)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")

Expand Down
31 changes: 26 additions & 5 deletions tests/test_models/nn/test_tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,31 @@
from etna.datasets.tsdataset import TSDataset
from etna.metrics import MAE
from etna.models.nn import TFTModel
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform


def test_fit_wrong_order_transform(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
add_const = AddConstTransform(in_column="target", value=1.0)
pft = PytorchForecastingTransform(
max_encoder_length=21,
min_encoder_length=21,
max_prediction_length=8,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
target_normalizer=None,
)

ts.fit_transform([pft, add_const])

model = TFTModel(max_epochs=300, learning_rate=[0.1])
with pytest.raises(ValueError, match="add PytorchForecastingTransform"):
model.fit(ts)


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8, 21])
def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):
Expand All @@ -31,8 +52,8 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "1d")
ts_test = TSDataset(TSDataset.to_dataset(test), "1d")
ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False)
pft = PytorchForecastingTransform(
max_encoder_length=21,
Expand All @@ -47,10 +68,10 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):

ts_train.fit_transform([dft, pft])

tftmodel = TFTModel(max_epochs=300, learning_rate=[0.1])
model = TFTModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
tftmodel.fit(ts_train)
ts_pred = tftmodel.forecast(ts_pred)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")
assert mae(ts_test, ts_pred) < 0.24
25 changes: 25 additions & 0 deletions tests/test_transforms/test_pytorch_forecasting_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from etna.transforms import PytorchForecastingTransform


@pytest.mark.parametrize("days_offset", [1, 2, 5, 10])
def test_time_idx(days_offset, example_tsds):
"""Check that PytorchForecastingTransform works with different frequencies correctly."""
df = example_tsds.to_pandas()
new_df = df.loc[df.index[::days_offset]]

transform = PytorchForecastingTransform(
max_encoder_length=3,
min_encoder_length=3,
max_prediction_length=3,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
)
transform.fit_transform(new_df)

time_idx = transform.pf_dataset_train.data["time"].tolist()
expected_len = new_df.shape[0]
expected_list = list(range(expected_len)) * len(example_tsds.segments)
assert time_idx == expected_list

0 comments on commit 262fa68

Please sign in to comment.