Skip to content

Commit

Permalink
Fix WindowStatisticsTransform (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Mar 2, 2023
1 parent 663f5bb commit 4b9111d
Show file tree
Hide file tree
Showing 17 changed files with 105 additions and 20 deletions.
5 changes: 2 additions & 3 deletions etna/transforms/encoders/categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from enum import Enum
from typing import List
from typing import Optional
Expand Down Expand Up @@ -78,7 +77,7 @@ def __init__(self, in_column: str, out_column: Optional[str] = None, strategy: s
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
return [self._get_column_name()] if self.in_column_regressor else []

def _fit(self, df: pd.DataFrame) -> "LabelEncoderTransform":
Expand Down Expand Up @@ -159,7 +158,7 @@ def __init__(self, in_column: str, out_column: Optional[str] = None):
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
return self._get_out_column_names() if self.in_column_regressor else []

def _fit(self, df: pd.DataFrame) -> "OneHotEncoderTransform":
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/math/add_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
return [self._get_column_name()] if self.in_column_regressor and not self.inplace else []


Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/math/apply_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,5 @@ def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
return [self.change_column] if self.in_column_regressor and not self.inplace else []
9 changes: 4 additions & 5 deletions etna/transforms/math/differencing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Dict
from typing import List
from typing import Optional
Expand Down Expand Up @@ -81,10 +80,10 @@ def _get_column_name(self) -> str:

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
raise ValueError("Fit the transform to get the correct regressors info!")
if self.inplace:
return []
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
return [self._get_column_name()] if self.in_column_regressor else []

def fit(self, ts: TSDataset) -> "_SingleDifferencingTransform":
Expand Down Expand Up @@ -359,10 +358,10 @@ def _get_column_name(self) -> str:

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
raise ValueError("Fit the transform to get the correct regressors info!")
if self.inplace:
return []
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
return [self._get_column_name()] if self.in_column_regressor else []

def fit(self, ts: TSDataset) -> "DifferencingTransform":
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/math/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
return [self._get_column_name()] if self.in_column_regressor and not self.inplace else []


Expand Down
5 changes: 2 additions & 3 deletions etna/transforms/math/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,8 @@ def _inverse_reshape(self, df: pd.DataFrame, transformed: np.ndarray) -> np.ndar

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.inplace:
return []
if self.out_column_regressors is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
raise ValueError("Fit the transform to get the correct regressors info!")
if self.inplace:
return []
return self.out_column_regressors
12 changes: 11 additions & 1 deletion etna/transforms/math/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd

from etna.datasets import TSDataset
from etna.transforms.base import IrreversibleTransform


Expand Down Expand Up @@ -49,6 +50,13 @@ def __init__(
self.min_periods = min_periods
self.fillna = fillna
self.kwargs = kwargs
self.in_column_regressor: Optional[bool] = None

def fit(self, ts: TSDataset) -> "WindowStatisticsTransform":
"""Fit the transform."""
self.in_column_regressor = self.in_column in ts.regressors
super().fit(ts)
return self

def _fit(self, df: pd.DataFrame) -> "WindowStatisticsTransform":
"""Fits transform."""
Expand Down Expand Up @@ -100,7 +108,9 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame:

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return [self.out_column_name]
if self.in_column_regressor is None:
raise ValueError("Fit the transform to get the correct regressors info!")
return [self.out_column_name] if self.in_column_regressor else []


class MeanTransform(WindowStatisticsTransform):
Expand Down
4 changes: 2 additions & 2 deletions etna/transforms/missing_values/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def _get_out_column(self, out_column: Optional[str]) -> str:

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
raise ValueError("Fit the transform to get the correct regressors info!")
if self.inplace:
return []
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
return [self.out_column] if self.in_column_regressor else []

def fit(self, ts: TSDataset) -> "ResampleWithDistributionTransform":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,11 @@ def test_save_load_ohe(dtype):
for i in range(3):
transform = OneHotEncoderTransform(in_column=f"regressor_{i}", out_column="test")
assert_transformation_equals_loaded_original(transform=transform, ts=ts)


@pytest.mark.parametrize(
"transform", (OneHotEncoderTransform(in_column="regressor_0"), LabelEncoderTransform(in_column="regressor_0"))
)
def test_get_regressors_info_not_fitted(transform):
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ def test_right_number_features_with_integer_division(ts_with_exog_galeshapley):
top_k = len(ts_with_exog_galeshapley.segments)
transform = GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=top_k)

transform.fit(ts_with_exog_galeshapley.to_pandas())
df = transform.transform(ts_with_exog_galeshapley.to_pandas())
transform.fit(ts_with_exog_galeshapley)
ts = transform.transform(ts_with_exog_galeshapley)

remaining_columns = df.columns.get_level_values("feature").unique().tolist()
remaining_columns = ts.columns.get_level_values("feature").unique().tolist()
assert len(remaining_columns) == top_k + 1
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,9 @@ def test_fit_transform_with_nans(ts_diff_endings):
def test_save_load(inplace, example_tsds):
transform = AddConstTransform(in_column="target", value=10, inplace=inplace)
assert_transformation_equals_loaded_original(transform=transform, ts=example_tsds)


def test_get_regressors_info_not_fitted():
transform = AddConstTransform(in_column="target", value=1, out_column="out_column")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,9 @@ def test_save_load(inplace, ts_nans):
ts = ts_nans
transform = DifferencingTransform(in_column="target", inplace=inplace)
assert_transformation_equals_loaded_original(transform=transform, ts=ts)


def test_get_regressors_info_not_fitted():
transform = DifferencingTransform(in_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
6 changes: 6 additions & 0 deletions tests/test_transforms/test_math/test_lambda_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,9 @@ def test_save_load(inplace, ts_range_const):
inverse_transform_func=example_inverse_transform_func,
)
assert_transformation_equals_loaded_original(transform=transform, ts=ts_range_const)


def test_get_regressors_info_not_fitted():
transform = LambdaTransform(in_column="target", inplace=False, transform_func=lambda x: x)
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
6 changes: 6 additions & 0 deletions tests/test_transforms/test_math/test_log_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ def test_save_load(inplace, positive_ts_):
ts = positive_ts_
transform = LogTransform(in_column="target", inplace=inplace)
assert_transformation_equals_loaded_original(transform=transform, ts=ts)


def test_get_regressors_info_not_fitted():
transform = LogTransform(in_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,23 @@ def test_ordering(transform_constructor, in_column, mode, multicolumn_ts):
df_multi = transformed_df.loc[:, pd.IndexSlice[segments, column_multi]]
df_single = transformed_dfs_one_column[i].loc[:, pd.IndexSlice[segments, column_single]]
assert np.all(df_multi.values == df_single.values)


@pytest.mark.parametrize(
"transform_constructor",
[
BoxCoxTransform,
YeoJohnsonTransform,
StandardScalerTransform,
RobustScalerTransform,
MinMaxScalerTransform,
MaxAbsScalerTransform,
StandardScalerTransform,
RobustScalerTransform,
MinMaxScalerTransform,
],
)
def test_get_regressors_info_not_fitted(transform_constructor):
transform = transform_constructor(in_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()
20 changes: 20 additions & 0 deletions tests/test_transforms/test_math/test_statistics_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from etna.transforms.math import QuantileTransform
from etna.transforms.math import StdTransform
from etna.transforms.math import SumTransform
from etna.transforms.math import WindowStatisticsTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original


Expand Down Expand Up @@ -50,6 +51,25 @@ def ts_for_agg_with_nan() -> TSDataset:
return ts


class DummyWindowStatisticsTransform(WindowStatisticsTransform):
def _aggregate(self, series: np.ndarray):
return None


@pytest.mark.parametrize("in_column,expected_regressors", (("target", []), ("regressor_exog_weekend", ["out_column"])))
def test_get_regressors_info(example_reg_tsds, in_column, expected_regressors):
transform = DummyWindowStatisticsTransform(in_column=in_column, out_column="out_column", window=1)
transform.fit(ts=example_reg_tsds)
out_regressors = transform.get_regressors_info()
assert out_regressors == expected_regressors


def test_get_regressors_info_not_fitted():
transform = DummyWindowStatisticsTransform(in_column="target", out_column="out_column", window=1)
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()


@pytest.mark.parametrize(
"class_name,out_column",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,9 @@ def test_save_load(inplace, out_column, daily_exog_ts):
in_column="regressor_exog", inplace=inplace, distribution_column="target", out_column=out_column
)
assert_transformation_equals_loaded_original(transform=transform, ts=daily_exog_ts)


def test_get_regressors_info_not_fitted():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()

0 comments on commit 4b9111d

Please sign in to comment.