Skip to content

Commit

Permalink
Add OneSegmentTransform (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Sep 1, 2022
1 parent 4872c59 commit cd7ff57
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 25 deletions.
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from etna.transforms.base import IrreversibleTransform
from etna.transforms.base import NewPerSegmentWrapper
from etna.transforms.base import NewTransform
from etna.transforms.base import OneSegmentTransform
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import ReversiblePerSegmentWrapper
from etna.transforms.base import ReversibleTransform
Expand Down
98 changes: 84 additions & 14 deletions etna/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,33 +342,103 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
return df


class OneSegmentTransform(ABC, BaseMixin):
"""Base class to create one segment transforms to apply to data."""

@abstractmethod
def fit(self, df: pd.DataFrame):
"""Fit the transform.
Should be implemented by user.
Parameters
----------
df:
Dataframe in etna long format.
"""
pass

@abstractmethod
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Transform dataframe.
Should be implemented by user
Parameters
----------
df:
Dataframe in etna long format.
Returns
-------
:
Transformed Dataframe in etna long format.
"""
pass

def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Fit and transform Dataframe.
May be reimplemented. But it is not recommended.
Parameters
----------
df:
Dataframe in etna long format to transform.
Returns
-------
:
Transformed Dataframe.
"""
return self.fit(df=df).transform(df=df)

@abstractmethod
def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Inverse transform Dataframe.
Should be reimplemented in the subclasses where necessary.
Parameters
----------
df:
Dataframe in etna long format to be inverse transformed.
Returns
-------
:
Dataframe after applying inverse transformation.
"""
pass


class NewPerSegmentWrapper(NewTransform):
"""Class to apply transform in per segment manner."""

def __init__(self, transform: NewTransform):
def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]):
self._base_transform = transform
self.segment_transforms: Optional[Dict[str, NewTransform]] = None
super().__init__(required_features=transform.required_features)
self.segment_transforms: Optional[Dict[str, OneSegmentTransform]] = None
super().__init__(required_features=required_features)

def _fit(self, df: pd.DataFrame):
"""Fit transform on each segment."""
self.segment_transforms = {}
segments = df.columns.get_level_values("segment").unique()
for segment in segments:
self.segment_transforms[segment] = deepcopy(self._base_transform)
self.segment_transforms[segment]._fit(df[segment])
self.segment_transforms[segment].fit(df[segment])

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Apply transform to each segment separately."""
if self.segment_transforms is None:
raise ValueError("Transform is not fitted!")

results = []
for key, value in self.segment_transforms.items():
seg_df = value._transform(df[key])
for segment, transform in self.segment_transforms.items():
seg_df = transform.transform(df[segment])

_idx = seg_df.columns.to_frame()
_idx.insert(0, "segment", key)
_idx.insert(0, "segment", segment)
seg_df.columns = pd.MultiIndex.from_frame(_idx)

results.append(seg_df)
Expand All @@ -382,27 +452,27 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
class IrreversiblePerSegmentWrapper(NewPerSegmentWrapper, IrreversibleTransform):
"""Class to apply irreversible transform in per segment manner."""

def __init__(self, transform: IrreversibleTransform):
super().__init__(transform=transform)
def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]):
super().__init__(transform=transform, required_features=required_features)


class ReversiblePerSegmentWrapper(NewPerSegmentWrapper, ReversibleTransform):
"""Class to apply reversible transform in per segment manner."""

def __init__(self, transform: ReversibleTransform):
super().__init__(transform=transform)
def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]):
super().__init__(transform=transform, required_features=required_features)

def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Apply inverse_transform to each segment."""
if self.segment_transforms is None:
raise ValueError("Transform is not fitted!")

results = []
for key, value in self.segment_transforms.items():
seg_df = value._inverse_transform(df[key]) # type: ignore
for segment, transform in self.segment_transforms.items():
seg_df = transform.inverse_transform(df[segment])

_idx = seg_df.columns.to_frame()
_idx.insert(0, "segment", key)
_idx.insert(0, "segment", segment)
seg_df.columns = pd.MultiIndex.from_frame(_idx)

results.append(seg_df)
Expand Down
35 changes: 28 additions & 7 deletions etna/transforms/missing_values/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import pandas as pd

from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.datasets import TSDataset
from etna.transforms.base import IrreversiblePerSegmentWrapper
from etna.transforms.base import OneSegmentTransform


class _OneSegmentResampleWithDistributionTransform(Transform):
class _OneSegmentResampleWithDistributionTransform(OneSegmentTransform):
"""_OneSegmentResampleWithDistributionTransform resamples the given column using the distribution of the other column."""

def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_column: Optional[str]):
def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_column: str):
"""
Init _OneSegmentResampleWithDistributionTransform.
Expand All @@ -34,7 +35,7 @@ def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_
self.distribution_column = distribution_column
self.inplace = inplace
self.out_column = out_column
self.distribution: pd.DataFrame = None
self.distribution: Optional[pd.DataFrame] = None

def _get_folds(self, df: pd.DataFrame) -> List[int]:
"""
Expand Down Expand Up @@ -101,8 +102,12 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.drop(["fold", "distribution"], axis=1)
return df

def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Inverse transform Dataframe."""
return df


class ResampleWithDistributionTransform(PerSegmentWrapper):
class ResampleWithDistributionTransform(IrreversiblePerSegmentWrapper):
"""ResampleWithDistributionTransform resamples the given column using the distribution of the other column.
Warning
Expand Down Expand Up @@ -136,13 +141,15 @@ def __init__(
self.distribution_column = distribution_column
self.inplace = inplace
self.out_column = self._get_out_column(out_column)
self.in_column_regressor: Optional[bool] = None
super().__init__(
transform=_OneSegmentResampleWithDistributionTransform(
in_column=in_column,
distribution_column=distribution_column,
inplace=inplace,
out_column=self.out_column,
)
),
required_features=[in_column, distribution_column],
)

def _get_out_column(self, out_column: Optional[str]) -> str:
Expand All @@ -154,3 +161,17 @@ def _get_out_column(self, out_column: Optional[str]) -> str:
if out_column:
return out_column
return self.__repr__()

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.inplace:
return []
if self.in_column_regressor is None:
warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.")
return [self.out_column] if self.in_column_regressor else []

def fit(self, ts: TSDataset) -> "ResampleWithDistributionTransform":
"""Fit the transform."""
self.in_column_regressor = self.in_column in ts.regressors
super().fit(ts)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_fail_on_incompatible_freq(incompatible_freq_ts):
in_column="exog", inplace=True, distribution_column="target", out_column=None
)
with pytest.raises(ValueError, match="Can not infer in_column frequency!"):
_ = resampler.fit(incompatible_freq_ts.df)
_ = resampler.fit(incompatible_freq_ts)


@pytest.mark.parametrize(
Expand All @@ -27,7 +27,7 @@ def test_fit(ts, request):
resampler = ResampleWithDistributionTransform(
in_column="regressor_exog", inplace=True, distribution_column="target", out_column=None
)
resampler.fit(ts.df)
resampler.fit(ts)
segments = ts.df.columns.get_level_values("segment").unique()
for segment in segments:
assert (resampler.segment_transforms[segment].distribution == expected_distribution[segment]).all().all()
Expand All @@ -48,10 +48,11 @@ def test_transform(daily_exog_ts, inplace, out_column, expected_resampled_ts, re
resampler = ResampleWithDistributionTransform(
in_column="regressor_exog", inplace=inplace, distribution_column="target", out_column=out_column
)
resampled_df = resampler.fit_transform(daily_exog_ts.df)
resampled_df = resampler.fit_transform(daily_exog_ts).to_pandas()
assert resampled_df.equals(expected_resampled_df)


@pytest.mark.xfail(reason="TSDataset 2.0")
@pytest.mark.parametrize(
"inplace,out_column,expected_resampled_ts",
(
Expand All @@ -77,4 +78,29 @@ def test_fit_transform_with_nans(daily_exog_ts_diff_endings):
resampler = ResampleWithDistributionTransform(
in_column="regressor_exog", inplace=True, distribution_column="target"
)
daily_exog_ts_diff_endings.fit_transform([resampler])
_ = resampler.fit_transform(daily_exog_ts_diff_endings)


@pytest.mark.filterwarnings("ignore: Regressors info might be incorrect.")
@pytest.mark.parametrize(
"inplace, in_column_regressor, out_column, expected_regressors",
[
(True, False, None, []),
(False, False, "output_regressor", []),
(False, True, "output_regressor", ["output_regressor"]),
],
)
def test_get_regressors_info(
daily_exog_ts, inplace, in_column_regressor, out_column, expected_regressors, in_column="regressor_exog"
):
daily_exog_ts = daily_exog_ts["ts"]
if in_column_regressor:
daily_exog_ts._regressors.append(in_column)
else:
daily_exog_ts._regressors.remove(in_column)
resampler = ResampleWithDistributionTransform(
in_column=in_column, inplace=inplace, distribution_column="target", out_column=out_column
)
resampler.fit(daily_exog_ts)
regressors_info = resampler.get_regressors_info()
assert sorted(regressors_info) == sorted(expected_regressors)

0 comments on commit cd7ff57

Please sign in to comment.