Skip to content

Forecast start_timestamp for CLI command #1265

Merged
merged 10 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Notebook `forecast_interpretation.ipynb` with forecast decomposition ([#1220](https://github.com/tinkoff-ai/etna/pull/1220))
- Exogenous variables shift transform `ExogShiftTransform`([#1254](https://github.com/tinkoff-ai/etna/pull/1254))
-
- Parameter `start_timestamp` to forecast CLI command ([#1265](https://github.com/tinkoff-ai/etna/pull/1265))
-
### Changed
- Set the default value of `final_model` to `LinearRegression(positive=True)` in the constructor of `StackingEnsemble` ([#1238](https://github.com/tinkoff-ai/etna/pull/1238))
Expand Down
19 changes: 19 additions & 0 deletions docs/source/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ Basic ``forecast`` usage:
[RAW_OUTPUT] by default we return only forecast without features [default: False]
[KNOWN_FUTURE] list of all known_future columns (regressor columns). If not specified then all exog_columns considered known_future [default: None]

**Forecast config parameters**

* :code:`prediction_interval` - whether to estimate prediction interval for forecast.
* :code:`quantiles` - levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval.
* :code:`n_folds` - number of folds to use in the backtest for prediction interval estimation. By default equals to 3.
* :code:`return_components` - whether to estimate forecast components
* :code:`start_timestamp` - timestamp with the starting point of forecast.
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved

Setting these parameters is optional.
Further information on arguments could be found in the documentation of :meth:`~etna.pipeline.pipeline.Pipeline.forecast` method.

**How to create config?**

Example of pipeline's config:
Expand All @@ -44,6 +55,14 @@ Example of forecast params config:
quantiles: [0.025, 0.975]
n_folds: 3

Parameter :code:`start_timestamp` could be set similarly:

.. code-block:: yaml

prediction_interval: true
quantiles: [0.025, 0.975]
start_timestamp: "2020-01-12"
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved

**How to prepare data?**

Example of dataset with data to forecast:
Expand Down
47 changes: 46 additions & 1 deletion etna/commands/forecast_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,45 @@
from typing_extensions import Literal

from etna.datasets import TSDataset
from etna.models.utils import determine_num_steps
from etna.pipeline import Pipeline

ADDITIONAL_FORECAST_PARAMETERS = {"start_timestamp"}


def get_forecast_call_params(forecast_params: Dict[str, Any]) -> Dict[str, Any]:
"""Select `forecast` arguments from params."""
return {k: v for k, v in forecast_params.items() if k not in ADDITIONAL_FORECAST_PARAMETERS}


def compute_horizon(horizon: int, forecast_params: Dict[str, Any], tsdataset: TSDataset) -> int:
"""Compute new pipeline horizon if `start_timestamp` presented in `forecast_params`."""
if "start_timestamp" in forecast_params:
freq = tsdataset.freq

forecast_start_timestamp = pd.Timestamp(forecast_params["start_timestamp"], freq=freq)
train_end_timestamp = tsdataset.index.max()

if forecast_start_timestamp <= train_end_timestamp:
raise ValueError("Parameter `start_timestamp` should greater than end of training dataset!")

delta = determine_num_steps(
start_timestamp=train_end_timestamp, end_timestamp=forecast_start_timestamp, freq=freq
)

horizon += delta - 1

return horizon


def filter_forecast(forecast_ts: TSDataset, forecast_params: Dict[str, Any]) -> TSDataset:
"""Filter out forecasts before `start_timestamp` if `start_timestamp` presented in `forecast_params`.."""
if "start_timestamp" in forecast_params:
forecast_start_timestamp = pd.Timestamp(forecast_params["start_timestamp"], freq=forecast_ts.freq)
forecast_ts.df = forecast_ts.df.loc[forecast_start_timestamp:, :]

return forecast_ts


def forecast(
config_path: Path = typer.Argument(..., help="path to yaml config with desired pipeline"),
Expand Down Expand Up @@ -83,9 +120,17 @@ def forecast(

tsdataset = TSDataset(df=df_timeseries, freq=freq, df_exog=df_exog, known_future=k_f)

horizon: int = pipeline_configs["horizon"] # type: ignore
horizon = compute_horizon(horizon=horizon, forecast_params=forecast_params, tsdataset=tsdataset)
pipeline_configs["horizon"] = horizon # type: ignore

forecast_call_args = get_forecast_call_params(forecast_params)

pipeline: Pipeline = hydra_slayer.get_from_params(**pipeline_configs)
pipeline.fit(tsdataset)
forecast = pipeline.forecast(**forecast_params)
forecast = pipeline.forecast(**forecast_call_args)

forecast = filter_forecast(forecast_ts=forecast, forecast_params=forecast_params)

flatten = forecast.to_pandas(flatten=True)
if raw_output:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_commands/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,19 @@ def base_forecast_omegaconf_path():
tmp.flush()
yield Path(tmp.name)
tmp.close()


@pytest.fixture
def start_timestamp_forecast_omegaconf_path():
tmp = NamedTemporaryFile("w")
tmp.write(
"""
prediction_interval: true
quantiles: [0.025, 0.975]
n_folds: 3
start_timestamp: "2021-09-10"
"""
)
tmp.flush()
yield Path(tmp.name)
tmp.close()
115 changes: 115 additions & 0 deletions tests/test_commands/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from subprocess import run
from tempfile import NamedTemporaryFile

import numpy as np
import pandas as pd
import pytest

from etna.commands.forecast_command import compute_horizon
from etna.commands.forecast_command import filter_forecast
from etna.commands.forecast_command import get_forecast_call_params
from etna.datasets import TSDataset


def test_dummy_run_with_exog(base_pipeline_yaml_path, base_timeseries_path, base_timeseries_exog_path):
tmp_output = NamedTemporaryFile("w")
Expand Down Expand Up @@ -100,3 +106,112 @@ def test_forecast_use_exog_correct(
pd.testing.assert_series_equal(
df_output["target"], pd.Series(data=[3.0, 3.0, 3.0], name="target"), check_less_precise=1
)


@pytest.fixture
def ms_tsds():
df = pd.DataFrame(
{
"timestamp": list(pd.date_range("2023-01-01", periods=4, freq="MS")) * 2,
"segment": ["A"] * 4 + ["B"] * 4,
"target": list(3 * np.arange(1, 5)) * 2,
}
)

df = TSDataset.to_dataset(df=df)
ts = TSDataset(df=df, freq="MS")
return ts


@pytest.fixture
def pipeline_dummy_config():
return {"horizon": 3}


@pytest.mark.parametrize(
"params,expected",
(
({"start_timestamp": "2021-09-10"}, {}),
(
{"prediction_interval": True, "n_folds": 3, "start_timestamp": "2021-09-10"},
{"prediction_interval": True, "n_folds": 3},
),
(
{"prediction_interval": True, "n_folds": 3, "quantiles": [0.025, 0.975]},
{"prediction_interval": True, "n_folds": 3, "quantiles": [0.025, 0.975]},
),
),
)
def test_get_forecast_call_params(params, expected):
result = get_forecast_call_params(forecast_params=params)
assert result == expected


@pytest.mark.parametrize("forecast_params", ({"start_timestamp": "2020-04-09"}, {"start_timestamp": "2019-04-10"}))
def test_compute_horizon_error(example_tsds, forecast_params, pipeline_dummy_config):
with pytest.raises(ValueError, match="Parameter `start_timestamp` should greater than end of training dataset!"):
compute_horizon(
horizon=pipeline_dummy_config["horizon"], forecast_params=forecast_params, tsdataset=example_tsds
)


@pytest.mark.parametrize(
"forecast_params,tsdataset_name,expected",
(
({"start_timestamp": "2020-04-10"}, "example_tsds", 3),
({"start_timestamp": "2020-04-12"}, "example_tsds", 5),
({"start_timestamp": "2020-02-01 02:00:00"}, "example_tsdf", 4),
({"start_timestamp": "2023-06-01"}, "ms_tsds", 4),
),
)
def test_compute_horizon(forecast_params, tsdataset_name, expected, request, pipeline_dummy_config):
tsdataset = request.getfixturevalue(tsdataset_name)
result = compute_horizon(
horizon=pipeline_dummy_config["horizon"], forecast_params=forecast_params, tsdataset=tsdataset
)
assert result == expected


@pytest.mark.parametrize(
"forecast_params,expected",
(
({"start_timestamp": "2020-04-06"}, "2020-04-06"),
({}, "2020-01-01"),
),
)
def test_filter_forecast(forecast_params, expected, example_tsds):
result = filter_forecast(forecast_ts=example_tsds, forecast_params=forecast_params)
assert result.df.index.min() == pd.Timestamp(expected)


@pytest.mark.parametrize(
"model_pipeline",
[
"elementary_linear_model_pipeline",
"elementary_boosting_model_pipeline",
],
)
def test_forecast_start_timestamp(
model_pipeline, base_timeseries_path, base_timeseries_exog_path, start_timestamp_forecast_omegaconf_path, request
):
tmp_output = NamedTemporaryFile("w")
tmp_output_path = Path(tmp_output.name)
model_pipeline = request.getfixturevalue(model_pipeline)

run(
[
"etna",
"forecast",
str(model_pipeline),
str(base_timeseries_path),
"D",
str(tmp_output_path),
str(base_timeseries_exog_path),
str(start_timestamp_forecast_omegaconf_path),
]
)
df_output = pd.read_csv(tmp_output_path)

assert len(df_output) == 3 * 2 # 3 predictions for 2 segments
assert df_output["timestamp"].min() == "2021-09-10" # start_timestamp
assert not np.any(df_output.isna().values)