From 11de006b091a7b05a88642960b57c0cd1961674c Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Tue, 29 Nov 2022 12:18:49 +0300 Subject: [PATCH 1/5] SumTransform implementation --- etna/transforms/__init__.py | 1 + etna/transforms/math/__init__.py | 1 + etna/transforms/math/statistics.py | 53 ++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/etna/transforms/__init__.py b/etna/transforms/__init__.py index c9e261b4b..61712d65d 100644 --- a/etna/transforms/__init__.py +++ b/etna/transforms/__init__.py @@ -32,6 +32,7 @@ from etna.transforms.math import RobustScalerTransform from etna.transforms.math import StandardScalerTransform from etna.transforms.math import StdTransform +from etna.transforms.math import SumTransform from etna.transforms.math import YeoJohnsonTransform from etna.transforms.missing_values import ResampleWithDistributionTransform from etna.transforms.missing_values import TimeSeriesImputerTransform diff --git a/etna/transforms/math/__init__.py b/etna/transforms/math/__init__.py index ebdc08425..298ae96f8 100644 --- a/etna/transforms/math/__init__.py +++ b/etna/transforms/math/__init__.py @@ -19,3 +19,4 @@ from etna.transforms.math.statistics import QuantileTransform from etna.transforms.math.statistics import StdTransform from etna.transforms.math.statistics import WindowStatisticsTransform +from etna.transforms.math.statistics import SumTransform diff --git a/etna/transforms/math/statistics.py b/etna/transforms/math/statistics.py index c0371ea24..0066ed990 100644 --- a/etna/transforms/math/statistics.py +++ b/etna/transforms/math/statistics.py @@ -557,6 +557,58 @@ def _aggregate(self, series: np.ndarray) -> np.ndarray: return result +class SumTransform(WindowStatisticsTransform): + """SumTransform computes sum of values over given window.""" + def __init__( + self, + in_column: str, + window: int, + seasonality: int = 1, + min_periods: int = 1, + fillna: float = 0, + out_column: Optional[str] = None, + ): + """Init SumTransform. + + Parameters + ---------- + in_column: + name of processed column + window: + size of window to aggregate, if window == -1 compute rolling sum all over the given series + seasonality: + seasonality of lags to compute window's aggregation with + min_periods: + min number of targets in window to compute aggregation; + if there is less than ``min_periods`` number of targets return None + fillna: + value to fill results NaNs with + out_column: + result column name. If not given use ``self.__repr__()`` + """ + + self.in_column = in_column + self.window = window + self.seasonality = seasonality + self.min_periods = min_periods + self.fillna = fillna + self.out_column = out_column + + super().__init__( + in_column=in_column, + out_column=self.out_column if self.out_column is not None else self.__repr__(), + window=window, + seasonality=seasonality, + min_periods=min_periods, + fillna=fillna, + ) + + def _aggregate(self, series: np.ndarray) -> np.ndarray: + """Compute sum over the series.""" + series = bn.nansum(series, axis=2) + return series + + __all__ = [ "MedianTransform", "MaxTransform", @@ -567,4 +619,5 @@ def _aggregate(self, series: np.ndarray) -> np.ndarray: "WindowStatisticsTransform", "MADTransform", "MinMaxDifferenceTransform", + "SumTransform", ] From 10af571ab4e0cc1295497f87ded1fd02b461dd9b Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Tue, 29 Nov 2022 12:20:55 +0300 Subject: [PATCH 2/5] adding tests --- .../test_math/test_statistics_transform.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_transforms/test_math/test_statistics_transform.py b/tests/test_transforms/test_math/test_statistics_transform.py index dd73b4706..9967fed44 100644 --- a/tests/test_transforms/test_math/test_statistics_transform.py +++ b/tests/test_transforms/test_math/test_statistics_transform.py @@ -13,6 +13,7 @@ from etna.transforms.math import MinTransform from etna.transforms.math import QuantileTransform from etna.transforms.math import StdTransform +from etna.transforms.math import SumTransform @pytest.fixture @@ -62,6 +63,8 @@ def df_for_agg_with_nan() -> pd.DataFrame: (MADTransform, "test_mad"), (MinMaxDifferenceTransform, None), (MinMaxDifferenceTransform, "test_min_max_diff"), + (SumTransform, None), + (SumTransform, "test_sum"), ), ) def test_interface_simple(simple_df_for_agg: pd.DataFrame, class_name: Any, out_column: str): @@ -279,6 +282,56 @@ def test_min_max_diff_feature( assert (res["expected"] == res["segment_1"]["result"]).all() +@pytest.mark.parametrize( + "window,periods,fill_na,expected", + ((10, 1, 0, np.array([-1, 0, 3, 3, 7, 16, 24, 29, 35, 35])),), +) +def test_sum_feature_with_nan( + df_for_agg_with_nan: pd.DataFrame, + window: int, + periods: int, + fill_na: float, + expected: np.ndarray, +): + transform = SumTransform( + window=window, + min_periods=periods, + fillna=fill_na, + in_column="target", + out_column="result", + ) + res = transform.fit_transform(df_for_agg_with_nan) + np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"]) + + +@pytest.mark.parametrize( + "window,periods,fill_na,expected", + ( + (10, 1, 0, np.array([0, 1, 3, 6, 10, 15, 21, 28, 36, 45])), + (-1, 1, 0, np.array([0, 1, 3, 6, 10, 15, 21, 28, 36, 45])), + (3, 1, -17, np.array([0, 1, 3, 6, 9, 12, 15, 18, 21, 24])), + (1, 1, -17, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])), + ), +) +def test_sum_feature( + simple_df_for_agg: pd.DataFrame, + window: int, + periods: int, + fill_na: float, + expected: np.array, +): + transform = SumTransform( + window=window, + min_periods=periods, + fillna=fill_na, + in_column="target", + out_column="result", + ) + + res = transform.fit_transform(simple_df_for_agg) + np.testing.assert_array_almost_equal(expected, res["segment_1"]["result"]) + + @pytest.mark.parametrize( "transform", ( @@ -289,6 +342,7 @@ def test_min_max_diff_feature( StdTransform(in_column="target", window=5), MADTransform(in_column="target", window=5), MinMaxDifferenceTransform(in_column="target", window=5), + SumTransform(in_column="target", window=5), ), ) def test_fit_transform_with_nans(transform, ts_diff_endings): From d2c399afe9ecc96e23f25debe25e21925334ef99 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Tue, 29 Nov 2022 13:16:30 +0300 Subject: [PATCH 3/5] fix formatting --- etna/transforms/math/__init__.py | 2 +- etna/transforms/math/statistics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/etna/transforms/math/__init__.py b/etna/transforms/math/__init__.py index 298ae96f8..71877775d 100644 --- a/etna/transforms/math/__init__.py +++ b/etna/transforms/math/__init__.py @@ -18,5 +18,5 @@ from etna.transforms.math.statistics import MinTransform from etna.transforms.math.statistics import QuantileTransform from etna.transforms.math.statistics import StdTransform -from etna.transforms.math.statistics import WindowStatisticsTransform from etna.transforms.math.statistics import SumTransform +from etna.transforms.math.statistics import WindowStatisticsTransform diff --git a/etna/transforms/math/statistics.py b/etna/transforms/math/statistics.py index 0066ed990..7b27925a4 100644 --- a/etna/transforms/math/statistics.py +++ b/etna/transforms/math/statistics.py @@ -559,6 +559,7 @@ def _aggregate(self, series: np.ndarray) -> np.ndarray: class SumTransform(WindowStatisticsTransform): """SumTransform computes sum of values over given window.""" + def __init__( self, in_column: str, @@ -586,7 +587,6 @@ def __init__( out_column: result column name. If not given use ``self.__repr__()`` """ - self.in_column = in_column self.window = window self.seasonality = seasonality From 0485a0d29d69e16a0de9e8a2743eb5e8eacb8da2 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 30 Nov 2022 12:05:19 +0300 Subject: [PATCH 4/5] added testcase --- tests/test_transforms/test_math/test_statistics_transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_transforms/test_math/test_statistics_transform.py b/tests/test_transforms/test_math/test_statistics_transform.py index 9967fed44..5a90e4f9a 100644 --- a/tests/test_transforms/test_math/test_statistics_transform.py +++ b/tests/test_transforms/test_math/test_statistics_transform.py @@ -311,6 +311,7 @@ def test_sum_feature_with_nan( (-1, 1, 0, np.array([0, 1, 3, 6, 10, 15, 21, 28, 36, 45])), (3, 1, -17, np.array([0, 1, 3, 6, 9, 12, 15, 18, 21, 24])), (1, 1, -17, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])), + (3, 3, -17, np.array([-17, -17, 3, 6, 9, 12, 15, 18, 21, 24])), ), ) def test_sum_feature( From a32c6809f021eecfe66a424be8e9f304aa97a426 Mon Sep 17 00:00:00 2001 From: brsnw250 Date: Wed, 30 Nov 2022 12:05:42 +0300 Subject: [PATCH 5/5] updated changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32510d18b..0178b043f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - - Add python 3.10 support ([#1005](https://github.com/tinkoff-ai/etna/pull/1005)) -- +- Add `SumTranform`([#1021](https://github.com/tinkoff-ai/etna/pull/1021)) - Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988)) - Add `experimental` module with `TimeSeriesBinaryClassifier` and `PredictabilityAnalyzer` ([#985](https://github.com/tinkoff-ai/etna/pull/985)) - Inference track results: add `predict` method to pipelines, teach some models to work with context, change hierarchy of base models, update notebook examples ([#979](https://github.com/tinkoff-ai/etna/pull/979))