Skip to content

Fix problem with swapped forecast methods in HierarchicalPipeline #1259

Merged
merged 3 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `tsfresh` into optional dependencies, remove instruction about `pip install tsfresh` ([#1246](https://github.com/tinkoff-ai/etna/pull/1246))
- Fix `DeepARModel` and `TFTModel` to work with changed `prediction_size` ([#1251](https://github.com/tinkoff-ai/etna/pull/1251))
- Fix problems with flake8 B023 ([#1252](https://github.com/tinkoff-ai/etna/pull/1252))
- Fix problem with swapped forecast methods in HierarchicalPipeline ([#1259](https://github.com/tinkoff-ai/etna/pull/1259))

## [2.0.0] - 2023-04-11
### Added
Expand Down
25 changes: 12 additions & 13 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,25 +312,24 @@ def _forecast_prediction_interval(
self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
) -> TSDataset:
"""Add prediction intervals to the forecasts."""
# TODO: fix this: what if during backtest KeyboardInterrupt is raised
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

if self.ts is None:
raise ValueError("Pipeline is not fitted! Fit the Pipeline before calling forecast method.")

# TODO: rework intervals estimation for `BottomUpReconciliator`

with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
try:
# TODO: rework intervals estimation for `BottomUpReconciliator`

source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)

self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
return predictions

return predictions
finally:
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

def save(self, path: pathlib.Path):
"""Save the object.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_pipeline/test_hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,33 @@ def test_backtest_w_exog(product_level_constant_hierarchical_ts_with_exog, recon
np.testing.assert_allclose(metrics["MAE"], 0)


@pytest.mark.parametrize(
"reconciliator",
(
TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"),
TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"),
BottomUpReconciliator(target_level="total", source_level="market"),
),
)
def test_private_forecast_prediction_interval_no_swap_after_error(
product_level_constant_hierarchical_ts_with_exog, reconciliator
):
ts = product_level_constant_hierarchical_ts_with_exog
model = LinearPerSegmentModel()
pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1)
pipeline.backtest = Mock(side_effect=ValueError("Some error"))
forecast_method = pipeline.forecast
raw_forecast_method = pipeline.raw_forecast

pipeline.fit(ts=ts)
with pytest.raises(ValueError, match="Some error"):
_ = pipeline.forecast(prediction_interval=True, n_folds=1, quantiles=[0.025, 0.5, 0.975])

# check that methods aren't swapped
assert pipeline.forecast == forecast_method
assert pipeline.raw_forecast == raw_forecast_method


@pytest.mark.parametrize(
"reconciliator",
(
Expand Down