Skip to content

Commit

Permalink
Merge pull request #4178 from OpenBB-finance/feature/forecast-metrics
Browse files Browse the repository at this point in the history
[FEAT] Additional metrics for evaluating forecasting (RMSE+ MSE)
  • Loading branch information
jmaslek authored Feb 9, 2023
2 parents 6e484e8 + 9895cb2 commit 583bb1d
Show file tree
Hide file tree
Showing 26 changed files with 214 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ darts_logs/
# User data
custom_imports/*.csv
custom_imports/*/*.csv

# lightning logs
lightning_logs/
2 changes: 1 addition & 1 deletion build/conda/conda-3-9-env-full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ dependencies:
- python=3.9.6
- pip
- poetry=1.1.13
- lightgbm=3.3.3
- lightgbm=3.3.5
4 changes: 4 additions & 0 deletions openbb_terminal/forecast/brnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_brnn_data(
model_save_name: str = "brnn_model",
force_reset: bool = True,
save_checkpoints: bool = True,
metric: str = "mape",
) -> Tuple[
Optional[List[TimeSeries]],
Optional[List[TimeSeries]],
Expand Down Expand Up @@ -84,6 +85,8 @@ def get_brnn_data(
discarded). Defaults to True.
save_checkpoints: bool
Whether or not to automatically save the untrained model and checkpoints from training. Defaults to True.
metric: str
Metric to use for model selection. Defaults to "mape".
Returns
-------
Expand Down Expand Up @@ -174,4 +177,5 @@ def get_brnn_data(
train_split,
forecast_horizon,
n_predict,
metric,
)
7 changes: 7 additions & 0 deletions openbb_terminal/forecast/brnn_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def display_brnn_forecast(
end_date: Optional[datetime] = None,
naive: bool = False,
export_pred_raw: bool = False,
metric: str = "mape",
external_axes: Optional[List[plt.axes]] = None,
):
"""Display BRNN forecast
Expand Down Expand Up @@ -102,6 +103,10 @@ def display_brnn_forecast(
naive: bool
Whether to show the naive baseline. This just assumes the closing price will be the same
as the previous day's closing price. Defaults to False.
export_pred_raw: bool
Whether to export the raw predicted values. Defaults to False.
metric: str
The metric to use for the model. Defaults to "mape".
external_axes: Optional[List[plt.axes]]
External axes to plot on
"""
Expand Down Expand Up @@ -137,6 +142,7 @@ def display_brnn_forecast(
model_save_name=model_save_name,
force_reset=force_reset,
save_checkpoints=save_checkpoints,
metric=metric,
)
if ticker_series == []:
return
Expand All @@ -160,6 +166,7 @@ def display_brnn_forecast(
forecast_only=forecast_only,
naive=naive,
export_pred_raw=export_pred_raw,
metric=metric,
external_axes=external_axes,
)
if residuals:
Expand Down
20 changes: 17 additions & 3 deletions openbb_terminal/forecast/expo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pandas as pd
from darts import TimeSeries
from darts.metrics import mape
from darts.metrics import mape, mse, rmse, smape
from darts.models import ExponentialSmoothing
from darts.utils.utils import ModelMode, SeasonalityMode
from numpy import ndarray
Expand Down Expand Up @@ -40,6 +40,7 @@ def get_expo_data(
n_predict: int = 5,
start_window: float = 0.85,
forecast_horizon: int = 5,
metric: str = "mape",
) -> Tuple[
List[TimeSeries],
List[TimeSeries],
Expand Down Expand Up @@ -76,6 +77,8 @@ def get_expo_data(
Size of sliding window from start of timeseries and onwards
forecast_horizon: int
Number of days to forecast when backtesting and retraining historical
metric: str
Metric to use for backtesting. Defaults to MAPE.
Returns
-------
Expand Down Expand Up @@ -149,8 +152,19 @@ def get_expo_data(
# we have the historical fcast, now lets train on entire set and predict.
best_model.fit(ticker_series)
probabilistic_forecast = best_model.predict(int(n_predict), num_samples=500)
precision = mape(actual_series=ticker_series, pred_series=historical_fcast_es)
console.print(f"Exponential smoothing obtains MAPE: {precision:.2f}% \n")

if metric == "rmse":
precision = rmse(actual_series=ticker_series, pred_series=historical_fcast_es)
elif metric == "mse":
precision = mse(actual_series=ticker_series, pred_series=historical_fcast_es)
elif metric == "mape":
precision = mape(actual_series=ticker_series, pred_series=historical_fcast_es)
elif metric == "smape":
precision = smape(actual_series=ticker_series, pred_series=historical_fcast_es)

console.print(
f"Exponential smoothing obtains {metric.upper()}: {precision:.2f}% \n"
)

return (
ticker_series,
Expand Down
7 changes: 7 additions & 0 deletions openbb_terminal/forecast/expo_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def display_expo_forecast(
end_date: Optional[datetime] = None,
naive: bool = False,
export_pred_raw: bool = False,
metric: str = "mape",
external_axes: Optional[List[plt.axes]] = None,
):
"""Display Probabilistic Exponential Smoothing forecast
Expand Down Expand Up @@ -80,6 +81,10 @@ def display_expo_forecast(
naive: bool
Whether to show the naive baseline. This just assumes the closing price will be the same
as the previous day's closing price. Defaults to False.
export_pred_raw: bool
Whether to export the raw predicted values. Defaults to False.
metric: str
The metric to use when backtesting. Defaults to "mape".
external_axes: Optional[List[plt.axes]]
External axes to plot on
"""
Expand All @@ -103,6 +108,7 @@ def display_expo_forecast(
n_predict=n_predict,
start_window=start_window,
forecast_horizon=forecast_horizon,
metric=metric,
)

if ticker_series == []:
Expand All @@ -127,6 +133,7 @@ def display_expo_forecast(
forecast_only=forecast_only,
naive=naive,
export_pred_raw=export_pred_raw,
metric=metric,
external_axes=external_axes,
)
if residuals:
Expand Down
36 changes: 35 additions & 1 deletion openbb_terminal/forecast/forecast_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def add_standard_args(
naive: bool = False,
explainability_raw: bool = False,
export_pred_raw: bool = False,
metric: bool = False,
):
if hidden_size:
parser.add_argument(
Expand Down Expand Up @@ -531,7 +532,7 @@ def add_standard_args(
action="store",
dest="model_type",
default="LSTM",
help='Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU")',
help='Enter a string specifying the RNN module type ("RNN", "LSTM" or "GRU")',
)
if dropout is not None:
parser.add_argument(
Expand Down Expand Up @@ -631,6 +632,17 @@ def add_standard_args(
help="Export predictions to a csv file.",
)

if metric:
parser.add_argument(
"--metric",
type=str,
action="store",
dest="metric",
default="mape",
choices=["rmse", "mse", "mape", "smape"],
help="Calculate precision based on a specific metric (rmse, mse, mape)",
)

# if user does not put in --dataset
return parser

Expand Down Expand Up @@ -2114,6 +2126,7 @@ def call_expo(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2143,6 +2156,7 @@ def call_expo(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2175,6 +2189,7 @@ def call_theta(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2202,6 +2217,7 @@ def call_theta(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2261,6 +2277,7 @@ def call_rnn(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2297,6 +2314,7 @@ def call_rnn(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2373,6 +2391,7 @@ def call_nbeats(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2415,6 +2434,7 @@ def call_nbeats(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2482,6 +2502,7 @@ def call_tcn(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2525,6 +2546,7 @@ def call_tcn(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2561,6 +2583,7 @@ def call_regr(self, other_args: List[str]):
naive=True,
explainability_raw=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2595,6 +2618,7 @@ def call_regr(self, other_args: List[str]):
naive=ns_parser.naive,
explainability_raw=ns_parser.explainability_raw,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2630,6 +2654,7 @@ def call_linregr(self, other_args: List[str]):
naive=True,
explainability_raw=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2663,6 +2688,7 @@ def call_linregr(self, other_args: List[str]):
naive=ns_parser.naive,
explainability_raw=ns_parser.explainability_raw,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2714,6 +2740,7 @@ def call_brnn(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2755,6 +2782,7 @@ def call_brnn(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2842,6 +2870,7 @@ def call_trans(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -2885,6 +2914,7 @@ def call_trans(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -2958,6 +2988,7 @@ def call_tft(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -3001,6 +3032,7 @@ def call_tft(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down Expand Up @@ -3092,6 +3124,7 @@ def call_nhits(self, other_args: List[str]):
end=True,
naive=True,
export_pred_raw=True,
metric=True,
)
ns_parser = self.parse_known_args_and_warn(
parser,
Expand Down Expand Up @@ -3136,6 +3169,7 @@ def call_nhits(self, other_args: List[str]):
end_date=ns_parser.s_end_date,
naive=ns_parser.naive,
export_pred_raw=ns_parser.export_pred_raw,
metric=ns_parser.metric,
)

@log_start_end(log=logger)
Expand Down
Loading

0 comments on commit 583bb1d

Please sign in to comment.