diff --git a/darts/tests/utils/test_statistics.py b/darts/tests/utils/test_statistics.py index 6297399334..fa9a79ec63 100644 --- a/darts/tests/utils/test_statistics.py +++ b/darts/tests/utils/test_statistics.py @@ -5,12 +5,20 @@ from darts.tests.base_test_class import DartsBaseTestClass from darts.utils.statistics import ( check_seasonality, + extract_trend_and_seasonality, granger_causality_tests, + remove_seasonality, + remove_trend, stationarity_test_adf, stationarity_test_kpss, stationarity_tests, ) -from darts.utils.timeseries_generation import constant_timeseries, gaussian_timeseries +from darts.utils.timeseries_generation import ( + constant_timeseries, + gaussian_timeseries, + linear_timeseries, +) +from darts.utils.utils import ModelMode, SeasonalityMode class TimeSeriesTestCase(DartsBaseTestClass): @@ -93,3 +101,95 @@ def test_stationarity_tests(self): self.assertTrue(stationarity_test_kpss(series_3)[1] > 0.05) self.assertTrue(stationarity_test_adf(series_3)[1] < 0.05) self.assertTrue(stationarity_tests) + + +class SeasonalDecomposeTestCase(DartsBaseTestClass): + pd_series = pd.Series(range(50), index=pd.date_range("20130101", "20130219")) + pd_series = pd_series.map(lambda x: np.sin(x * np.pi / 3 + np.pi / 2)) + season = TimeSeries.from_series(pd_series) + trend = linear_timeseries( + start_value=1, end_value=10, start=season.start_time(), end=season.end_time() + ) + ts = trend + season + + def test_extract(self): + # test default (naive) method + calc_trend, _ = extract_trend_and_seasonality(self.ts, freq=6) + diff = self.trend - calc_trend + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # test default (naive) method additive + calc_trend, _ = extract_trend_and_seasonality( + self.ts, freq=6, model=ModelMode.ADDITIVE + ) + diff = self.trend - calc_trend + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # test STL method + calc_trend, _ = extract_trend_and_seasonality( + self.ts, freq=6, method="STL", model=ModelMode.ADDITIVE + ) + diff = self.trend - calc_trend + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # check if error is raised + with self.assertRaises(ValueError): + calc_trend, _ = extract_trend_and_seasonality( + self.ts, freq=6, method="STL", model=ModelMode.MULTIPLICATIVE + ) + + def test_remove_seasonality(self): + # test default (naive) method + calc_trend = remove_seasonality(self.ts, freq=6) + diff = self.trend - calc_trend + self.assertTrue(np.mean(diff.values() ** 2).item() < 0.5) + + # test default (naive) method additive + calc_trend = remove_seasonality(self.ts, freq=6, model=SeasonalityMode.ADDITIVE) + diff = self.trend - calc_trend + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # test STL method + calc_trend = remove_seasonality( + self.ts, + freq=6, + method="STL", + model=SeasonalityMode.ADDITIVE, + low_pass=9, + ) + diff = self.trend - calc_trend + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # check if error is raised + with self.assertRaises(ValueError): + calc_trend = remove_seasonality( + self.ts, freq=6, method="STL", model=SeasonalityMode.MULTIPLICATIVE + ) + + def test_remove_trend(self): + # test naive method + calc_season = remove_trend(self.ts, freq=6) + diff = self.season - calc_season + self.assertTrue(np.mean(diff.values() ** 2).item() < 1.5) + + # test naive method additive + calc_season = remove_trend(self.ts, freq=6, model=ModelMode.ADDITIVE) + diff = self.season - calc_season + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # test STL method + calc_season = remove_trend( + self.ts, + freq=6, + method="STL", + model=ModelMode.ADDITIVE, + low_pass=9, + ) + diff = self.season - calc_season + self.assertTrue(np.isclose(np.mean(diff.values() ** 2), 0.0)) + + # check if error is raised + with self.assertRaises(ValueError): + calc_season = remove_trend( + self.ts, freq=6, method="STL", model=ModelMode.MULTIPLICATIVE + ) diff --git a/darts/utils/statistics.py b/darts/utils/statistics.py index 193a2de7ef..0db4014443 100644 --- a/darts/utils/statistics.py +++ b/darts/utils/statistics.py @@ -10,7 +10,7 @@ import numpy as np from scipy.signal import argrelmax from scipy.stats import norm -from statsmodels.tsa.seasonal import seasonal_decompose +from statsmodels.tsa.seasonal import STL, seasonal_decompose from statsmodels.tsa.stattools import acf, adfuller, grangercausalitytests, kpss, pacf from darts import TimeSeries @@ -125,9 +125,11 @@ def extract_trend_and_seasonality( ts: TimeSeries, freq: int = None, model: Union[SeasonalityMode, ModelMode] = ModelMode.MULTIPLICATIVE, + method: str = "naive", + **kwargs, ) -> Tuple[TimeSeries, TimeSeries]: """ - Extracts trend and seasonality from a TimeSeries instance using `statsmodels.seasonal_decompose`. + Extracts trend and seasonality from a TimeSeries instance using `statsmodels.tsa`. Parameters ---------- @@ -140,11 +142,21 @@ def extract_trend_and_seasonality( Must be ``from darts import ModelMode, SeasonalityMode`` Enum member. Either ``MULTIPLICATIVE`` or ``ADDITIVE``. Defaults ``ModelMode.MULTIPLICATIVE``. - + method + The method to be used to decompose the series. + - "naive" : Seasonal decomposition using moving averages [1]_. + - "STL" : Season-Trend decomposition using LOESS [2]_. Only compatible with ``ADDITIVE`` model type. + kwargs + Other keyword arguments are passed down to the decomposition method. Returns ------- Tuple[TimeSeries, TimeSeries] A tuple of (trend, seasonal) time series. + + References + ------- + .. [1] https://www.statsmodels.org/devel/generated/statsmodels.tsa.seasonal.seasonal_decompose.html + .. [2] https://www.statsmodels.org/devel/generated/statsmodels.tsa.seasonal.STL.html """ ts._assert_univariate() @@ -158,9 +170,27 @@ def extract_trend_and_seasonality( "The model must be either MULTIPLICATIVE or ADDITIVE.", ) - decomp = seasonal_decompose( - ts.pd_series(), period=freq, model=model.value, extrapolate_trend="freq" - ) + if method == "naive": + + decomp = seasonal_decompose( + ts.pd_series(), period=freq, model=model.value, extrapolate_trend="freq" + ) + + elif method == "STL": + raise_if_not( + model in [SeasonalityMode.ADDITIVE, ModelMode.ADDITIVE], + f"Only ADDITIVE model is compatible with the STL method. Current model is {model}.", + logger, + ) + + decomp = STL( + endog=ts.pd_series(), + period=freq, + **kwargs, + ).fit() + + else: + raise_log(ValueError(f"Unknown value for method: {method}"), logger) season = TimeSeries.from_times_and_values(ts.time_index, decomp.seasonal) trend = TimeSeries.from_times_and_values(ts.time_index, decomp.trend) @@ -185,7 +215,6 @@ def remove_from_series( The type of model considered. Must be `from darts import ModelMode, SeasonalityMode` Enums member. Either MULTIPLICATIVE or ADDITIVE. - Returns ------- TimeSeries @@ -218,6 +247,8 @@ def remove_seasonality( ts: TimeSeries, freq: int = None, model: SeasonalityMode = SeasonalityMode.MULTIPLICATIVE, + method: str = "naive", + **kwargs, ) -> TimeSeries: """ Adjusts the TimeSeries `ts` for a seasonality of order `frequency` using the `model` decomposition. @@ -233,25 +264,43 @@ def remove_seasonality( Must be a `from darts import SeasonalityMode` Enum member. Either SeasonalityMode.MULTIPLICATIVE or SeasonalityMode.ADDITIVE. Defaults SeasonalityMode.MULTIPLICATIVE. - Returns + method + The method to be used to decompose the series. + - "naive" : Seasonal decomposition using moving averages [1]_. + - "STL" : Season-Trend decomposition using LOESS [2]_. Only compatible with ``ADDITIVE`` model type. + Defaults to "naive" + kwargs + Other keyword arguments are passed down to the decomposition method. + Returns ------- TimeSeries A new TimeSeries instance that corresponds to the seasonality-adjusted 'ts'. + References + ------- + .. [1] https://www.statsmodels.org/devel/generated/statsmodels.tsa.seasonal.seasonal_decompose.html + .. [2] https://www.statsmodels.org/devel/generated/statsmodels.tsa.seasonal.STL.html """ - ts._assert_univariate() raise_if_not( model is not SeasonalityMode.NONE, "The model must be either MULTIPLICATIVE or ADDITIVE.", ) + raise_if( + model not in [SeasonalityMode.ADDITIVE, ModelMode.ADDITIVE] and method == "STL", + f"Only ADDITIVE seasonality is compatible with the STL method. Current model is {model}.", + logger, + ) - _, seasonality = extract_trend_and_seasonality(ts, freq, model) + _, seasonality = extract_trend_and_seasonality(ts, freq, model, method, **kwargs) new_ts = remove_from_series(ts, seasonality, model) return new_ts def remove_trend( - ts: TimeSeries, model: ModelMode = ModelMode.MULTIPLICATIVE + ts: TimeSeries, + model: ModelMode = ModelMode.MULTIPLICATIVE, + method: str = "naive", + **kwargs, ) -> TimeSeries: """ Adjusts the TimeSeries `ts` for a trend using the `model` decomposition. @@ -262,9 +311,16 @@ def remove_trend( The TimeSeries to adjust. model The type of decomposition to use. - Must be `from darts import ModelMode` Enum member. + Must be a `from darts import ModelMode` Enum member. Either ModelMode.MULTIPLICATIVE or ModelMode.ADDITIVE. - Defaults to modelMode.MULTIPLICATIVE. + Defaults ModelMode.MULTIPLICATIVE. + method + The method to be used to decompose the series. + - "naive" : Seasonal decomposition using moving averages [1]_. + - "STL" : Season-Trend decomposition using LOESS [2]_. Only compatible with ``ADDITIVE`` model type. + Defaults to "naive" + kwargs + Other keyword arguments are passed down to the decomposition method. Returns ------- TimeSeries @@ -273,7 +329,12 @@ def remove_trend( ts._assert_univariate() - trend, _ = extract_trend_and_seasonality(ts, model=model) + raise_if( + model not in [SeasonalityMode.ADDITIVE, ModelMode.ADDITIVE] and method == "STL", + f"Only ADDITIVE seasonality is compatible with the STL method. Current model is {model}.", + logger, + ) + trend, _ = extract_trend_and_seasonality(ts, model=model, method=method, **kwargs) new_ts = remove_from_series(ts, trend, model) return new_ts