Skip to content

Commit

Permalink
Fix inference tests on new segments for DeepARModel and TFTModel (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman committed Feb 14, 2023
1 parent 38bf668 commit ccfff31
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Fixed
-
- Fix inference tests on new segments for `DeepARModel` and `TFTModel` ([#1109](https://github.com/tinkoff-ai/etna/pull/1109))
- Fix `MeanSegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1104](https://github.com/tinkoff-ai/etna/pull/1104))
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
Expand Down
14 changes: 4 additions & 10 deletions tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pandas.util.testing import assert_frame_equal
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data import NaNLabelEncoder
from typing_extensions import get_args

from etna.datasets import TSDataset
Expand Down Expand Up @@ -879,15 +880,6 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
),
],
)
def test_forecast_new_segments(self, model, transforms, example_tsds):
self._test_forecast_new_segments(example_tsds, model, transforms, train_segments=["segment_1"])

@to_be_fixed(raises=KeyError, match="Unknown category")
@pytest.mark.parametrize(
"model, transforms",
[
(
DeepARModel(max_epochs=1, learning_rate=[0.01]),
[
Expand All @@ -896,6 +888,7 @@ def test_forecast_new_segments(self, model, transforms, example_tsds):
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)},
target_normalizer=GroupNormalizer(groups=["segment"]),
)
],
Expand All @@ -909,14 +902,15 @@ def test_forecast_new_segments(self, model, transforms, example_tsds):
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)},
static_categoricals=["segment"],
target_normalizer=None,
)
],
),
],
)
def test_forecast_new_segments_failed_encoding_error(self, model, transforms, example_tsds):
def test_forecast_new_segments(self, model, transforms, example_tsds):
self._test_forecast_new_segments(example_tsds, model, transforms, train_segments=["segment_1"])

@to_be_fixed(raises=NotImplementedError, match="Per-segment models can't make predictions on new segments")
Expand Down
14 changes: 4 additions & 10 deletions tests/test_models/test_inference/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pandas.util.testing import assert_frame_equal
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data import NaNLabelEncoder

from etna.datasets import TSDataset
from etna.models import AutoARIMAModel
Expand Down Expand Up @@ -782,7 +783,7 @@ def _test_predict_new_segments(self, ts, model, transforms, train_segments, num_
def test_predict_new_segments(self, model, transforms, example_tsds):
self._test_predict_new_segments(example_tsds, model, transforms, train_segments=["segment_1"])

@to_be_fixed(raises=KeyError, match="Unknown category")
@to_be_fixed(raises=NotImplementedError, match="Method predict isn't currently implemented")
@pytest.mark.parametrize(
"model, transforms",
[
Expand All @@ -794,6 +795,7 @@ def test_predict_new_segments(self, model, transforms, example_tsds):
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)},
target_normalizer=GroupNormalizer(groups=["segment"]),
)
],
Expand All @@ -807,20 +809,12 @@ def test_predict_new_segments(self, model, transforms, example_tsds):
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
categorical_encoders={"segment": NaNLabelEncoder(add_nan=True, warn=False)},
static_categoricals=["segment"],
target_normalizer=None,
)
],
),
],
)
def test_predict_new_segments_failed_encoding_error(self, model, transforms, example_tsds):
self._test_predict_new_segments(example_tsds, model, transforms, train_segments=["segment_1"])

@to_be_fixed(raises=NotImplementedError, match="Method predict isn't currently implemented")
@pytest.mark.parametrize(
"model, transforms",
[
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
Expand Down

0 comments on commit ccfff31

Please sign in to comment.