diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py index 634e488a6d52..4fc793150f97 100644 --- a/openbb_terminal/forecast/forecast_controller.py +++ b/openbb_terminal/forecast/forecast_controller.py @@ -59,6 +59,7 @@ autoarima_view, autoces_view, autoets_view, + mstl_view, rwd_view, seasonalnaive_view, expo_model, @@ -119,6 +120,7 @@ class ForecastController(BaseController): "autoarima", "autoces", "autoets", + "mstl", "rwd", "seasonalnaive", "expo", @@ -262,6 +264,7 @@ def update_runtime_choices(self): "autoarima", "autoces", "autoets", + "mstl", "rwd", "seasonalnaive", "expo", @@ -346,6 +349,7 @@ def print_help(self): mt.add_cmd("autoarima", self.files) mt.add_cmd("autoces", self.files) mt.add_cmd("autoets", self.files) + mt.add_cmd("mstl", self.files) mt.add_cmd("rwd", self.files) mt.add_cmd("seasonalnaive", self.files) mt.add_cmd("expo", self.files) @@ -1926,6 +1930,61 @@ def call_autoets(self, other_args: List[str]): export_pred_raw=ns_parser.export_pred_raw, ) + # MSTL Model + @log_start_end(log=logger) + def call_mstl(self, other_args: List[str]): + """Process mstl command""" + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + add_help=False, + prog="mstl", + description=""" + Perform Multiple Seasonalities and Trend using Loess (MSTL) forecast: + https://nixtla.github.io/statsforecast/examples/multipleseasonalities.html + """, + ) + if other_args and "-" not in other_args[0][0]: + other_args.insert(0, "--target-dataset") + + ns_parser = self.parse_known_args_and_warn( + parser, + other_args, + export_allowed=EXPORT_ONLY_FIGURES_ALLOWED, + target_dataset=True, + target_column=True, + n_days=True, + seasonal="A", + periods=True, + window=True, + residuals=True, + forecast_only=True, + start=True, + end=True, + naive=True, + export_pred_raw=True, + ) + # TODO Convert this to multi series + if ns_parser: + if not helpers.check_parser_input(ns_parser, self.datasets): + return + + mstl_view.display_mstl_forecast( + data=self.datasets[ns_parser.target_dataset], + dataset_name=ns_parser.target_dataset, + n_predict=ns_parser.n_days, + target_column=ns_parser.target_column, + seasonal_periods=ns_parser.seasonal_periods, + start_window=ns_parser.start_window, + forecast_horizon=ns_parser.n_days, + export=ns_parser.export, + residuals=ns_parser.residuals, + forecast_only=ns_parser.forecast_only, + start_date=ns_parser.s_start_date, + end_date=ns_parser.s_end_date, + naive=ns_parser.naive, + export_pred_raw=ns_parser.export_pred_raw, + ) + # RWD Model @log_start_end(log=logger) def call_rwd(self, other_args: List[str]): diff --git a/openbb_terminal/forecast/mstl_model.py b/openbb_terminal/forecast/mstl_model.py new file mode 100644 index 000000000000..fb9c6af22221 --- /dev/null +++ b/openbb_terminal/forecast/mstl_model.py @@ -0,0 +1,155 @@ +# pylint: disable=too-many-arguments +"""Multiple Seasonalities and Trend using Loess (MSTL) Model""" +__docformat__ = "numpy" + +import logging +from typing import Any, Union, Optional, List, Tuple + +import warnings +import numpy as np +import pandas as pd +from statsforecast.core import StatsForecast + +from openbb_terminal.decorators import log_start_end +from openbb_terminal.rich_config import console +from openbb_terminal.forecast import helpers + + +warnings.simplefilter("ignore") + +logger = logging.getLogger(__name__) + +# pylint: disable=E1123 + + +@log_start_end(log=logger) +def get_mstl_data( + data: Union[pd.Series, pd.DataFrame], + target_column: str = "close", + seasonal_periods: int = 7, + n_predict: int = 5, + start_window: float = 0.85, + forecast_horizon: int = 5, +) -> Tuple[list[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[float], Any]: + + """Performs MSTL forecasting + This is a wrapper around StatsForecast MSTL; + we refer to this link for the original and more complete documentation of the parameters. + + + https://nixtla.github.io/statsforecast/models.html#mstl + + Parameters + ---------- + data : Union[pd.Series, np.ndarray] + Input data. + target_column (str, optional): + Target column to forecast. Defaults to "close". + seasonal_periods: int + Number of seasonal periods in a year (7 for daily data) + If not set, inferred from frequency of the series. + n_predict: int + Number of days to forecast + start_window: float + Size of sliding window from start of timeseries and onwards + forecast_horizon: int + Number of days to forecast when backtesting and retraining historical + + Returns + ------- + list[float] + Adjusted Data series + list[float] + List of historical fcast values + list[float] + List of predicted fcast values + Optional[float] + precision + Any + Fit MSTL model object. + """ + + use_scalers = False + # statsforecast preprocessing + # when including more time series + # the preprocessing is similar + _, ticker_series = helpers.get_series(data, target_column, is_scaler=use_scalers) + freq = ticker_series.freq_str + ticker_series = ticker_series.pd_dataframe().reset_index() + ticker_series.columns = ["ds", "y"] + ticker_series.insert(0, "unique_id", target_column) + # check MSLT availability + try: + from statsforecast.models import MSTL # pylint: disable=import-outside-toplevel + except Exception as e: + error = str(e) + if "cannot import name" in error: + console.print( + "[red]Please update statsforecast to version 1.2.0 or higher.[/red]" + ) + else: + console.print(f"[red]{error}[/red]") + return [], [], [], None, None + + try: + # Model Init + model = MSTL( + season_length=int(seasonal_periods), + ) + fcst = StatsForecast(df=ticker_series, models=[model], freq=freq, verbose=True) + except Exception as e: # noqa + error = str(e) + if "got an unexpected keyword argument" in error: + console.print( + "[red]Please update statsforecast to version 1.1.3 or higher.[/red]" + ) + else: + console.print(f"[red]{error}[/red]") + return [], [], [], None, None + + # Historical backtesting + last_training_point = int((len(ticker_series) - 1) * start_window) + historical_fcast = fcst.cross_validation( + h=int(forecast_horizon), + test_size=len(ticker_series) - last_training_point, + n_windows=None, + input_size=min(10 * forecast_horizon, len(ticker_series)), + ) + + # train new model on entire timeseries to provide best current forecast + # we have the historical fcast, now lets predict. + forecast = fcst.forecast(int(n_predict)) + y_true = historical_fcast["y"].values + y_hat = historical_fcast["MSTL"].values + precision = helpers.mean_absolute_percentage_error(y_true, y_hat) + console.print(f"MSTL obtains MAPE: {precision:.2f}% \n") + + # transform outputs to make them compatible with + # plots + use_scalers = False + _, ticker_series = helpers.get_series( + ticker_series.rename(columns={"y": target_column}), + target_column, + is_scaler=use_scalers, + time_col="ds", + ) + _, forecast = helpers.get_series( + forecast.rename(columns={"MSTL": target_column}), + target_column, + is_scaler=use_scalers, + time_col="ds", + ) + _, historical_fcast = helpers.get_series( + historical_fcast.groupby("ds").head(1).rename(columns={"MSTL": target_column}), + target_column, + is_scaler=use_scalers, + time_col="ds", + ) + + return ( + ticker_series, + historical_fcast, + forecast, + precision, + fcst, + ) diff --git a/openbb_terminal/forecast/mstl_view.py b/openbb_terminal/forecast/mstl_view.py new file mode 100644 index 000000000000..cbf1fd4b14c7 --- /dev/null +++ b/openbb_terminal/forecast/mstl_view.py @@ -0,0 +1,117 @@ +"""Multiple Seasonalities and Trend using Loess (MSTL) View""" +__docformat__ = "numpy" + +import logging +from typing import Union, Optional, List +from datetime import datetime + +import pandas as pd +import matplotlib.pyplot as plt + +from openbb_terminal.forecast import mstl_model +from openbb_terminal.decorators import log_start_end +from openbb_terminal.forecast import helpers + +logger = logging.getLogger(__name__) +# pylint: disable=too-many-arguments + + +@log_start_end(log=logger) +def display_mstl_forecast( + data: Union[pd.DataFrame, pd.Series], + target_column: str = "close", + dataset_name: str = "", + seasonal_periods: int = 7, + n_predict: int = 5, + start_window: float = 0.85, + forecast_horizon: int = 5, + export: str = "", + residuals: bool = False, + forecast_only: bool = False, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + naive: bool = False, + export_pred_raw: bool = False, + external_axes: Optional[List[plt.axes]] = None, +): + """Display MSTL Model + + Parameters + ---------- + data : Union[pd.Series, np.array] + Data to forecast + dataset_name str + The name of the ticker to be predicted + target_column (str, optional): + Target column to forecast. Defaults to "close". + seasonal_periods: int + Number of seasonal periods in a year + If not set, inferred from frequency of the series. + n_predict: int + Number of days to forecast + start_window: float + Size of sliding window from start of timeseries and onwards + forecast_horizon: int + Number of days to forecast when backtesting and retraining historical + export: str + Format to export data + residuals: bool + Whether to show residuals for the model. Defaults to False. + forecast_only: bool + Whether to only show dates in the forecasting range. Defaults to False. + start_date: Optional[datetime] + The starting date to perform analysis, data before this is trimmed. Defaults to None. + end_date: Optional[datetime] + The ending date to perform analysis, data after this is trimmed. Defaults to None. + naive: bool + Whether to show the naive baseline. This just assumes the closing price will be the same + as the previous day's closing price. Defaults to False. + external_axes:Optional[List[plt.axes]] + External axes to plot on + """ + data = helpers.clean_data(data, start_date, end_date, target_column, None) + if not helpers.check_data(data, target_column, None): + return + + ( + ticker_series, + historical_fcast, + predicted_values, + precision, + _model, + ) = mstl_model.get_mstl_data( + data=data, + target_column=target_column, + seasonal_periods=seasonal_periods, + n_predict=n_predict, + start_window=start_window, + forecast_horizon=forecast_horizon, + ) + + if ticker_series == []: + return + + probabilistic = False + helpers.plot_forecast( + name="MSTL", + target_col=target_column, + historical_fcast=historical_fcast, + predicted_values=predicted_values, + ticker_series=ticker_series, + ticker_name=dataset_name, + data=data, + n_predict=n_predict, + forecast_horizon=forecast_horizon, + past_covariates=None, + precision=precision, + probabilistic=probabilistic, + export=export, + forecast_only=forecast_only, + naive=naive, + export_pred_raw=export_pred_raw, + external_axes=external_axes, + ) + if residuals: + helpers.plot_residuals( + _model, None, ticker_series, forecast_horizon=forecast_horizon + ) diff --git a/openbb_terminal/miscellaneous/i18n/en.yml b/openbb_terminal/miscellaneous/i18n/en.yml index e74fbe0601a3..315d1a34792b 100644 --- a/openbb_terminal/miscellaneous/i18n/en.yml +++ b/openbb_terminal/miscellaneous/i18n/en.yml @@ -1046,6 +1046,7 @@ en: forecast/autoselect: Select best statistical model from AutoARIMA, AutoETS, AutoCES, MSTL, etc. forecast/autoces: Automatic Complex Exponential Smoothing Model forecast/autoets: Automatic ETS (Error, Trend, Seasonality) Model + forecast/mstl: Multiple Seasonalities and Trend using Loess (MSTL) Model forecast/rwd: Random Walk with Drift Model forecast/arima: Arima (Non-darts) forecast/expo: Probabilistic Exponential Smoothing diff --git a/openbb_terminal/miscellaneous/library/trail_map_forecasting.csv b/openbb_terminal/miscellaneous/library/trail_map_forecasting.csv index 2e98e07fc62a..f4fd12884abd 100644 --- a/openbb_terminal/miscellaneous/library/trail_map_forecasting.csv +++ b/openbb_terminal/miscellaneous/library/trail_map_forecasting.csv @@ -32,6 +32,7 @@ forecast.tcn,openbb_terminal.forecast.tcn_model.get_tcn_data,openbb_terminal.for forecast.trans,openbb_terminal.forecast.trans_model.get_trans_data,openbb_terminal.forecast.trans_view.display_trans_forecast forecast.tft,openbb_terminal.forecast.tft_model.get_tft_data,openbb_terminal.forecast.tft_view.display_tft_forecast forecast.nhits,openbb_terminal.forecast.nhits_model.get_nhits_data,openbb_terminal.forecast.nhits_view.display_nhits_forecast +forecast.mstl,openbb_terminal.forecast.mstl_model.get_mstl_data,openbb_terminal.forecast.mstl_view.display_mstl_forecast forecast.autoarima,openbb_terminal.forecast.autoarima_model.get_autoarima_data,openbb_terminal.forecast.autoarima_view.display_autoarima_forecast forecast.rwd,openbb_terminal.forecast.rwd_model.get_rwd_data,openbb_terminal.forecast.rwd_view.display_rwd_forecast forecast.autoselect,openbb_terminal.forecast.autoselect_model.get_autoselect_data,openbb_terminal.forecast.autoselect_view.display_autoselect_forecast \ No newline at end of file diff --git a/tests/openbb_terminal/forecast/test_mstl_model.py b/tests/openbb_terminal/forecast/test_mstl_model.py new file mode 100644 index 000000000000..c0ad009f121f --- /dev/null +++ b/tests/openbb_terminal/forecast/test_mstl_model.py @@ -0,0 +1,11 @@ +import pytest +from tests.openbb_terminal.forecast import conftest + +try: + from openbb_terminal.forecast import mstl_model +except ImportError: + pytest.skip(allow_module_level=True) + + +def test_get_mstl_model(tsla_csv): + conftest.test_model(mstl_model.get_mstl_data, tsla_csv) diff --git a/tests/openbb_terminal/forecast/test_mstl_view.py b/tests/openbb_terminal/forecast/test_mstl_view.py new file mode 100644 index 000000000000..c9e99cb1c456 --- /dev/null +++ b/tests/openbb_terminal/forecast/test_mstl_view.py @@ -0,0 +1,18 @@ +import pytest + +try: + from openbb_terminal.forecast import mstl_view +except ImportError: + pytest.skip(allow_module_level=True) + + +def test_display_mstl_forecast(tsla_csv): + with pytest.raises(AttributeError): + mstl_view.display_mstl_forecast( + tsla_csv, + target_column="close", + seasonal_periods=3, + n_predict=1, + start_window=0.5, + forecast_horizon=1, + ) diff --git a/website/content/terminal/forecast/mstl/_index.md b/website/content/terminal/forecast/mstl/_index.md new file mode 100644 index 000000000000..9038210eb853 --- /dev/null +++ b/website/content/terminal/forecast/mstl/_index.md @@ -0,0 +1,61 @@ +``` +usage: mstl [--naive] [-d {AAPL}] [-c TARGET_COLUMN] [-n N_DAYS] [-s {N,A,M}] [-p SEASONAL_PERIODS] [-w START_WINDOW] [--end S_END_DATE] [--start S_START_DATE] [--residuals] [--forecast-only] + [--export-pred-raw] [-h] [--export EXPORT] + +``` + +Perform Multiple Seasonalities and Trend using Loess (MSTL) forecast: https://nixtla.github.io/statsforecast/examples/multipleseasonalities.html + +``` +optional arguments: + --naive Show the naive baseline for a model. (default: False) + -d {AAPL}, --target-dataset {AAPL} + The name of the dataset you want to select (default: None) + -c TARGET_COLUMN, --target-column TARGET_COLUMN + The name of the specific column you want to use (default: close) + -n N_DAYS, --n-days N_DAYS + prediction days. (default: 5) + -s {N,A,M}, --seasonal {N,A,M} + Seasonality: N: None, A: Additive, M: Multiplicative. (default: A) + -p SEASONAL_PERIODS, --periods SEASONAL_PERIODS + Seasonal periods: 4: Quarterly, 7: Daily (default: 7) + -w START_WINDOW, --window START_WINDOW + Start point for rolling training and forecast window. 0.0-1.0 (default: 0.85) + --end S_END_DATE The end date (format YYYY-MM-DD) to select for testing (default: None) + --start S_START_DATE The start date (format YYYY-MM-DD) to select for testing (default: None) + --residuals Show the residuals for the model. (default: False) + --forecast-only Do not plot the hisotorical data without forecasts. (default: False) + --export-pred-raw Export predictions to a csv file. (default: False) + -h, --help show this help message (default: False) + --export EXPORT Export figure into png, jpg, pdf, svg (default: ) + +For more information and examples, use 'about mstl' to access the related guide. +``` + +Example: +``` + +2022 Nov 07, 18:16 (๐Ÿฆ‹) /forecast/ $ mstl AAPL + +Cross Validation Time Series 1: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 115/115 [00:01<00:00, 103.78it/s] +Forecast: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00, 19.19it/s] +MSTL obtains MAPE: 3.37% + + + Actual price: 138.38 +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Datetime โ”ƒ Prediction โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ 2022-11-07 โ”‚ 137.45 โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ 2022-11-08 โ”‚ 142.27 โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ 2022-11-09 โ”‚ 140.00 โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ 2022-11-10 โ”‚ 141.32 โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ 2022-11-11 โ”‚ 141.36 โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +image diff --git a/website/data/menu/main.yml b/website/data/menu/main.yml index d40981fed842..a0d1d56ed78f 100644 --- a/website/data/menu/main.yml +++ b/website/data/menu/main.yml @@ -779,6 +779,8 @@ main: ref: terminal/forecast/load - name: mom ref: terminal/forecast/mom + - name: mstl + ref: terminal/forecast/mstl - name: nbeats ref: terminal/forecast/nbeats - name: nhits