-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] StatsForecast MSTL forecasting model (#3338)
* feat: add mstl model * feat: update mstl usage example * SDK : Checkout files from main * SDK : update forecasting map * feat: reformat forecast controller black * fix: add pylint exp to preserve informative warning * fix: merge formatting Co-authored-by: Jeroen Bouma <jer.bouma@gmail.com> Co-authored-by: Chavithra PARANA <chavithra@gmail.com> Co-authored-by: martinb-bb <105685594+martinb-bb@users.noreply.github.com>
- Loading branch information
1 parent
d2b49b3
commit 37c943f
Showing
9 changed files
with
425 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# pylint: disable=too-many-arguments | ||
"""Multiple Seasonalities and Trend using Loess (MSTL) Model""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
from typing import Any, Union, Optional, List, Tuple | ||
|
||
import warnings | ||
import numpy as np | ||
import pandas as pd | ||
from statsforecast.core import StatsForecast | ||
|
||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.rich_config import console | ||
from openbb_terminal.forecast import helpers | ||
|
||
|
||
warnings.simplefilter("ignore") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# pylint: disable=E1123 | ||
|
||
|
||
@log_start_end(log=logger) | ||
def get_mstl_data( | ||
data: Union[pd.Series, pd.DataFrame], | ||
target_column: str = "close", | ||
seasonal_periods: int = 7, | ||
n_predict: int = 5, | ||
start_window: float = 0.85, | ||
forecast_horizon: int = 5, | ||
) -> Tuple[list[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[float], Any]: | ||
|
||
"""Performs MSTL forecasting | ||
This is a wrapper around StatsForecast MSTL; | ||
we refer to this link for the original and more complete documentation of the parameters. | ||
https://nixtla.github.io/statsforecast/models.html#mstl | ||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.ndarray] | ||
Input data. | ||
target_column (str, optional): | ||
Target column to forecast. Defaults to "close". | ||
seasonal_periods: int | ||
Number of seasonal periods in a year (7 for daily data) | ||
If not set, inferred from frequency of the series. | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
Returns | ||
------- | ||
list[float] | ||
Adjusted Data series | ||
list[float] | ||
List of historical fcast values | ||
list[float] | ||
List of predicted fcast values | ||
Optional[float] | ||
precision | ||
Any | ||
Fit MSTL model object. | ||
""" | ||
|
||
use_scalers = False | ||
# statsforecast preprocessing | ||
# when including more time series | ||
# the preprocessing is similar | ||
_, ticker_series = helpers.get_series(data, target_column, is_scaler=use_scalers) | ||
freq = ticker_series.freq_str | ||
ticker_series = ticker_series.pd_dataframe().reset_index() | ||
ticker_series.columns = ["ds", "y"] | ||
ticker_series.insert(0, "unique_id", target_column) | ||
# check MSLT availability | ||
try: | ||
from statsforecast.models import MSTL # pylint: disable=import-outside-toplevel | ||
except Exception as e: | ||
error = str(e) | ||
if "cannot import name" in error: | ||
console.print( | ||
"[red]Please update statsforecast to version 1.2.0 or higher.[/red]" | ||
) | ||
else: | ||
console.print(f"[red]{error}[/red]") | ||
return [], [], [], None, None | ||
|
||
try: | ||
# Model Init | ||
model = MSTL( | ||
season_length=int(seasonal_periods), | ||
) | ||
fcst = StatsForecast(df=ticker_series, models=[model], freq=freq, verbose=True) | ||
except Exception as e: # noqa | ||
error = str(e) | ||
if "got an unexpected keyword argument" in error: | ||
console.print( | ||
"[red]Please update statsforecast to version 1.1.3 or higher.[/red]" | ||
) | ||
else: | ||
console.print(f"[red]{error}[/red]") | ||
return [], [], [], None, None | ||
|
||
# Historical backtesting | ||
last_training_point = int((len(ticker_series) - 1) * start_window) | ||
historical_fcast = fcst.cross_validation( | ||
h=int(forecast_horizon), | ||
test_size=len(ticker_series) - last_training_point, | ||
n_windows=None, | ||
input_size=min(10 * forecast_horizon, len(ticker_series)), | ||
) | ||
|
||
# train new model on entire timeseries to provide best current forecast | ||
# we have the historical fcast, now lets predict. | ||
forecast = fcst.forecast(int(n_predict)) | ||
y_true = historical_fcast["y"].values | ||
y_hat = historical_fcast["MSTL"].values | ||
precision = helpers.mean_absolute_percentage_error(y_true, y_hat) | ||
console.print(f"MSTL obtains MAPE: {precision:.2f}% \n") | ||
|
||
# transform outputs to make them compatible with | ||
# plots | ||
use_scalers = False | ||
_, ticker_series = helpers.get_series( | ||
ticker_series.rename(columns={"y": target_column}), | ||
target_column, | ||
is_scaler=use_scalers, | ||
time_col="ds", | ||
) | ||
_, forecast = helpers.get_series( | ||
forecast.rename(columns={"MSTL": target_column}), | ||
target_column, | ||
is_scaler=use_scalers, | ||
time_col="ds", | ||
) | ||
_, historical_fcast = helpers.get_series( | ||
historical_fcast.groupby("ds").head(1).rename(columns={"MSTL": target_column}), | ||
target_column, | ||
is_scaler=use_scalers, | ||
time_col="ds", | ||
) | ||
|
||
return ( | ||
ticker_series, | ||
historical_fcast, | ||
forecast, | ||
precision, | ||
fcst, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Multiple Seasonalities and Trend using Loess (MSTL) View""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
from typing import Union, Optional, List | ||
from datetime import datetime | ||
|
||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
|
||
from openbb_terminal.forecast import mstl_model | ||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.forecast import helpers | ||
|
||
logger = logging.getLogger(__name__) | ||
# pylint: disable=too-many-arguments | ||
|
||
|
||
@log_start_end(log=logger) | ||
def display_mstl_forecast( | ||
data: Union[pd.DataFrame, pd.Series], | ||
target_column: str = "close", | ||
dataset_name: str = "", | ||
seasonal_periods: int = 7, | ||
n_predict: int = 5, | ||
start_window: float = 0.85, | ||
forecast_horizon: int = 5, | ||
export: str = "", | ||
residuals: bool = False, | ||
forecast_only: bool = False, | ||
start_date: Optional[datetime] = None, | ||
end_date: Optional[datetime] = None, | ||
naive: bool = False, | ||
export_pred_raw: bool = False, | ||
external_axes: Optional[List[plt.axes]] = None, | ||
): | ||
"""Display MSTL Model | ||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.array] | ||
Data to forecast | ||
dataset_name str | ||
The name of the ticker to be predicted | ||
target_column (str, optional): | ||
Target column to forecast. Defaults to "close". | ||
seasonal_periods: int | ||
Number of seasonal periods in a year | ||
If not set, inferred from frequency of the series. | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
export: str | ||
Format to export data | ||
residuals: bool | ||
Whether to show residuals for the model. Defaults to False. | ||
forecast_only: bool | ||
Whether to only show dates in the forecasting range. Defaults to False. | ||
start_date: Optional[datetime] | ||
The starting date to perform analysis, data before this is trimmed. Defaults to None. | ||
end_date: Optional[datetime] | ||
The ending date to perform analysis, data after this is trimmed. Defaults to None. | ||
naive: bool | ||
Whether to show the naive baseline. This just assumes the closing price will be the same | ||
as the previous day's closing price. Defaults to False. | ||
external_axes:Optional[List[plt.axes]] | ||
External axes to plot on | ||
""" | ||
data = helpers.clean_data(data, start_date, end_date, target_column, None) | ||
if not helpers.check_data(data, target_column, None): | ||
return | ||
|
||
( | ||
ticker_series, | ||
historical_fcast, | ||
predicted_values, | ||
precision, | ||
_model, | ||
) = mstl_model.get_mstl_data( | ||
data=data, | ||
target_column=target_column, | ||
seasonal_periods=seasonal_periods, | ||
n_predict=n_predict, | ||
start_window=start_window, | ||
forecast_horizon=forecast_horizon, | ||
) | ||
|
||
if ticker_series == []: | ||
return | ||
|
||
probabilistic = False | ||
helpers.plot_forecast( | ||
name="MSTL", | ||
target_col=target_column, | ||
historical_fcast=historical_fcast, | ||
predicted_values=predicted_values, | ||
ticker_series=ticker_series, | ||
ticker_name=dataset_name, | ||
data=data, | ||
n_predict=n_predict, | ||
forecast_horizon=forecast_horizon, | ||
past_covariates=None, | ||
precision=precision, | ||
probabilistic=probabilistic, | ||
export=export, | ||
forecast_only=forecast_only, | ||
naive=naive, | ||
export_pred_raw=export_pred_raw, | ||
external_axes=external_axes, | ||
) | ||
if residuals: | ||
helpers.plot_residuals( | ||
_model, None, ticker_series, forecast_horizon=forecast_horizon | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import pytest | ||
from tests.openbb_terminal.forecast import conftest | ||
|
||
try: | ||
from openbb_terminal.forecast import mstl_model | ||
except ImportError: | ||
pytest.skip(allow_module_level=True) | ||
|
||
|
||
def test_get_mstl_model(tsla_csv): | ||
conftest.test_model(mstl_model.get_mstl_data, tsla_csv) |
Oops, something went wrong.