Skip to content

Commit

Permalink
[FEAT] StatsForecast MSTL forecasting model (#3338)
Browse files Browse the repository at this point in the history
* feat: add mstl model

* feat: update mstl usage example

* SDK : Checkout files from main

* SDK : update forecasting map

* feat: reformat forecast controller black

* fix: add pylint exp to preserve informative warning

* fix: merge formatting

Co-authored-by: Jeroen Bouma <jer.bouma@gmail.com>
Co-authored-by: Chavithra PARANA <chavithra@gmail.com>
Co-authored-by: martinb-bb <105685594+martinb-bb@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 14, 2022
1 parent d2b49b3 commit 37c943f
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 0 deletions.
59 changes: 59 additions & 0 deletions openbb_terminal/forecast/forecast_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
autoarima_view,
autoces_view,
autoets_view,
mstl_view,
rwd_view,
seasonalnaive_view,
expo_model,
Expand Down Expand Up @@ -119,6 +120,7 @@ class ForecastController(BaseController):
"autoarima",
"autoces",
"autoets",
"mstl",
"rwd",
"seasonalnaive",
"expo",
Expand Down Expand Up @@ -262,6 +264,7 @@ def update_runtime_choices(self):
"autoarima",
"autoces",
"autoets",
"mstl",
"rwd",
"seasonalnaive",
"expo",
Expand Down Expand Up @@ -346,6 +349,7 @@ def print_help(self):
mt.add_cmd("autoarima", self.files)
mt.add_cmd("autoces", self.files)
mt.add_cmd("autoets", self.files)
mt.add_cmd("mstl", self.files)
mt.add_cmd("rwd", self.files)
mt.add_cmd("seasonalnaive", self.files)
mt.add_cmd("expo", self.files)
Expand Down Expand Up @@ -1926,6 +1930,61 @@ def call_autoets(self, other_args: List[str]):
export_pred_raw=ns_parser.export_pred_raw,
)

# MSTL Model
@log_start_end(log=logger)
def call_mstl(self, other_args: List[str]):
"""Process mstl command"""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
add_help=False,
prog="mstl",
description="""
Perform Multiple Seasonalities and Trend using Loess (MSTL) forecast:
https://nixtla.github.io/statsforecast/examples/multipleseasonalities.html
""",
)
if other_args and "-" not in other_args[0][0]:
other_args.insert(0, "--target-dataset")

ns_parser = self.parse_known_args_and_warn(
parser,
other_args,
export_allowed=EXPORT_ONLY_FIGURES_ALLOWED,
target_dataset=True,
target_column=True,
n_days=True,
seasonal="A",
periods=True,
window=True,
residuals=True,
forecast_only=True,
start=True,
end=True,
naive=True,
export_pred_raw=True,
)
# TODO Convert this to multi series
if ns_parser:
if not helpers.check_parser_input(ns_parser, self.datasets):
return

mstl_view.display_mstl_forecast(
data=self.datasets[ns_parser.target_dataset],
dataset_name=ns_parser.target_dataset,
n_predict=ns_parser.n_days,
target_column=ns_parser.target_column,
seasonal_periods=ns_parser.seasonal_periods,
start_window=ns_parser.start_window,
forecast_horizon=ns_parser.n_days,
export=ns_parser.export,
residuals=ns_parser.residuals,
forecast_only=ns_parser.forecast_only,
start_date=ns_parser.s_start_date,
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
)

# RWD Model
@log_start_end(log=logger)
def call_rwd(self, other_args: List[str]):
Expand Down
155 changes: 155 additions & 0 deletions openbb_terminal/forecast/mstl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# pylint: disable=too-many-arguments
"""Multiple Seasonalities and Trend using Loess (MSTL) Model"""
__docformat__ = "numpy"

import logging
from typing import Any, Union, Optional, List, Tuple

import warnings
import numpy as np
import pandas as pd
from statsforecast.core import StatsForecast

from openbb_terminal.decorators import log_start_end
from openbb_terminal.rich_config import console
from openbb_terminal.forecast import helpers


warnings.simplefilter("ignore")

logger = logging.getLogger(__name__)

# pylint: disable=E1123


@log_start_end(log=logger)
def get_mstl_data(
data: Union[pd.Series, pd.DataFrame],
target_column: str = "close",
seasonal_periods: int = 7,
n_predict: int = 5,
start_window: float = 0.85,
forecast_horizon: int = 5,
) -> Tuple[list[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[float], Any]:

"""Performs MSTL forecasting
This is a wrapper around StatsForecast MSTL;
we refer to this link for the original and more complete documentation of the parameters.
https://nixtla.github.io/statsforecast/models.html#mstl
Parameters
----------
data : Union[pd.Series, np.ndarray]
Input data.
target_column (str, optional):
Target column to forecast. Defaults to "close".
seasonal_periods: int
Number of seasonal periods in a year (7 for daily data)
If not set, inferred from frequency of the series.
n_predict: int
Number of days to forecast
start_window: float
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical
Returns
-------
list[float]
Adjusted Data series
list[float]
List of historical fcast values
list[float]
List of predicted fcast values
Optional[float]
precision
Any
Fit MSTL model object.
"""

use_scalers = False
# statsforecast preprocessing
# when including more time series
# the preprocessing is similar
_, ticker_series = helpers.get_series(data, target_column, is_scaler=use_scalers)
freq = ticker_series.freq_str
ticker_series = ticker_series.pd_dataframe().reset_index()
ticker_series.columns = ["ds", "y"]
ticker_series.insert(0, "unique_id", target_column)
# check MSLT availability
try:
from statsforecast.models import MSTL # pylint: disable=import-outside-toplevel
except Exception as e:
error = str(e)
if "cannot import name" in error:
console.print(
"[red]Please update statsforecast to version 1.2.0 or higher.[/red]"
)
else:
console.print(f"[red]{error}[/red]")
return [], [], [], None, None

try:
# Model Init
model = MSTL(
season_length=int(seasonal_periods),
)
fcst = StatsForecast(df=ticker_series, models=[model], freq=freq, verbose=True)
except Exception as e: # noqa
error = str(e)
if "got an unexpected keyword argument" in error:
console.print(
"[red]Please update statsforecast to version 1.1.3 or higher.[/red]"
)
else:
console.print(f"[red]{error}[/red]")
return [], [], [], None, None

# Historical backtesting
last_training_point = int((len(ticker_series) - 1) * start_window)
historical_fcast = fcst.cross_validation(
h=int(forecast_horizon),
test_size=len(ticker_series) - last_training_point,
n_windows=None,
input_size=min(10 * forecast_horizon, len(ticker_series)),
)

# train new model on entire timeseries to provide best current forecast
# we have the historical fcast, now lets predict.
forecast = fcst.forecast(int(n_predict))
y_true = historical_fcast["y"].values
y_hat = historical_fcast["MSTL"].values
precision = helpers.mean_absolute_percentage_error(y_true, y_hat)
console.print(f"MSTL obtains MAPE: {precision:.2f}% \n")

# transform outputs to make them compatible with
# plots
use_scalers = False
_, ticker_series = helpers.get_series(
ticker_series.rename(columns={"y": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, forecast = helpers.get_series(
forecast.rename(columns={"MSTL": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)
_, historical_fcast = helpers.get_series(
historical_fcast.groupby("ds").head(1).rename(columns={"MSTL": target_column}),
target_column,
is_scaler=use_scalers,
time_col="ds",
)

return (
ticker_series,
historical_fcast,
forecast,
precision,
fcst,
)
117 changes: 117 additions & 0 deletions openbb_terminal/forecast/mstl_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Multiple Seasonalities and Trend using Loess (MSTL) View"""
__docformat__ = "numpy"

import logging
from typing import Union, Optional, List
from datetime import datetime

import pandas as pd
import matplotlib.pyplot as plt

from openbb_terminal.forecast import mstl_model
from openbb_terminal.decorators import log_start_end
from openbb_terminal.forecast import helpers

logger = logging.getLogger(__name__)
# pylint: disable=too-many-arguments


@log_start_end(log=logger)
def display_mstl_forecast(
data: Union[pd.DataFrame, pd.Series],
target_column: str = "close",
dataset_name: str = "",
seasonal_periods: int = 7,
n_predict: int = 5,
start_window: float = 0.85,
forecast_horizon: int = 5,
export: str = "",
residuals: bool = False,
forecast_only: bool = False,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
naive: bool = False,
export_pred_raw: bool = False,
external_axes: Optional[List[plt.axes]] = None,
):
"""Display MSTL Model
Parameters
----------
data : Union[pd.Series, np.array]
Data to forecast
dataset_name str
The name of the ticker to be predicted
target_column (str, optional):
Target column to forecast. Defaults to "close".
seasonal_periods: int
Number of seasonal periods in a year
If not set, inferred from frequency of the series.
n_predict: int
Number of days to forecast
start_window: float
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical
export: str
Format to export data
residuals: bool
Whether to show residuals for the model. Defaults to False.
forecast_only: bool
Whether to only show dates in the forecasting range. Defaults to False.
start_date: Optional[datetime]
The starting date to perform analysis, data before this is trimmed. Defaults to None.
end_date: Optional[datetime]
The ending date to perform analysis, data after this is trimmed. Defaults to None.
naive: bool
Whether to show the naive baseline. This just assumes the closing price will be the same
as the previous day's closing price. Defaults to False.
external_axes:Optional[List[plt.axes]]
External axes to plot on
"""
data = helpers.clean_data(data, start_date, end_date, target_column, None)
if not helpers.check_data(data, target_column, None):
return

(
ticker_series,
historical_fcast,
predicted_values,
precision,
_model,
) = mstl_model.get_mstl_data(
data=data,
target_column=target_column,
seasonal_periods=seasonal_periods,
n_predict=n_predict,
start_window=start_window,
forecast_horizon=forecast_horizon,
)

if ticker_series == []:
return

probabilistic = False
helpers.plot_forecast(
name="MSTL",
target_col=target_column,
historical_fcast=historical_fcast,
predicted_values=predicted_values,
ticker_series=ticker_series,
ticker_name=dataset_name,
data=data,
n_predict=n_predict,
forecast_horizon=forecast_horizon,
past_covariates=None,
precision=precision,
probabilistic=probabilistic,
export=export,
forecast_only=forecast_only,
naive=naive,
export_pred_raw=export_pred_raw,
external_axes=external_axes,
)
if residuals:
helpers.plot_residuals(
_model, None, ticker_series, forecast_horizon=forecast_horizon
)
1 change: 1 addition & 0 deletions openbb_terminal/miscellaneous/i18n/en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ en:
forecast/autoselect: Select best statistical model from AutoARIMA, AutoETS, AutoCES, MSTL, etc.
forecast/autoces: Automatic Complex Exponential Smoothing Model
forecast/autoets: Automatic ETS (Error, Trend, Seasonality) Model
forecast/mstl: Multiple Seasonalities and Trend using Loess (MSTL) Model
forecast/rwd: Random Walk with Drift Model
forecast/arima: Arima (Non-darts)
forecast/expo: Probabilistic Exponential Smoothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ forecast.tcn,openbb_terminal.forecast.tcn_model.get_tcn_data,openbb_terminal.for
forecast.trans,openbb_terminal.forecast.trans_model.get_trans_data,openbb_terminal.forecast.trans_view.display_trans_forecast
forecast.tft,openbb_terminal.forecast.tft_model.get_tft_data,openbb_terminal.forecast.tft_view.display_tft_forecast
forecast.nhits,openbb_terminal.forecast.nhits_model.get_nhits_data,openbb_terminal.forecast.nhits_view.display_nhits_forecast
forecast.mstl,openbb_terminal.forecast.mstl_model.get_mstl_data,openbb_terminal.forecast.mstl_view.display_mstl_forecast
forecast.autoarima,openbb_terminal.forecast.autoarima_model.get_autoarima_data,openbb_terminal.forecast.autoarima_view.display_autoarima_forecast
forecast.rwd,openbb_terminal.forecast.rwd_model.get_rwd_data,openbb_terminal.forecast.rwd_view.display_rwd_forecast
forecast.autoselect,openbb_terminal.forecast.autoselect_model.get_autoselect_data,openbb_terminal.forecast.autoselect_view.display_autoselect_forecast
11 changes: 11 additions & 0 deletions tests/openbb_terminal/forecast/test_mstl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from tests.openbb_terminal.forecast import conftest

try:
from openbb_terminal.forecast import mstl_model
except ImportError:
pytest.skip(allow_module_level=True)


def test_get_mstl_model(tsla_csv):
conftest.test_model(mstl_model.get_mstl_data, tsla_csv)
Loading

0 comments on commit 37c943f

Please sign in to comment.