Skip to content

Commit

Permalink
TSDataset 2.0 (#956)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Sep 26, 2022
1 parent 74d0167 commit c2edede
Show file tree
Hide file tree
Showing 38 changed files with 175 additions and 418 deletions.
1 change: 0 additions & 1 deletion etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def get_residuals(forecast_df: pd.DataFrame, ts: "TSDataset") -> "TSDataset":
new_ts = TSDataset(df=true_df, freq=ts.freq)
new_ts.known_future = ts.known_future
new_ts._regressors = ts.regressors
new_ts.transforms = ts.transforms
new_ts.df_exog = ts.df_exog
return new_ts

Expand Down
105 changes: 21 additions & 84 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -137,29 +136,19 @@ def __init__(
self.df_exog.index = pd.to_datetime(self.df_exog.index)
self.df = self._merge_exog(self.df)

self.transforms: Optional[Sequence["Transform"]] = None

def transform(self, transforms: Sequence["Transform"]):
"""Apply given transform to the data."""
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
for transform in transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
self._update_regressors(transform=transform, columns_before=columns_before, columns_after=columns_after)
transform.transform(self)

def fit_transform(self, transforms: Sequence["Transform"]):
"""Fit and apply given transforms to the data."""
self._check_endings(warning=True)
self.transforms = transforms
for transform in self.transforms:
for transform in transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
columns_before = set(self.columns.get_level_values("feature"))
self.df = transform.fit_transform(self.df)
columns_after = set(self.columns.get_level_values("feature"))
self._update_regressors(transform=transform, columns_before=columns_before, columns_after=columns_after)
transform.fit_transform(self)

@staticmethod
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -170,60 +159,6 @@ def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
df_copy.columns = pd.MultiIndex.from_frame(columns_frame)
return df_copy

def _update_regressors(self, transform: "Transform", columns_before: Set[str], columns_after: Set[str]):
from etna.transforms import OneHotEncoderTransform
from etna.transforms.base import FutureMixin

# intersect list of regressors with columns after the transform
self._regressors = list(set(self._regressors).intersection(columns_after))

unseen_columns = list(columns_after - columns_before)

if len(unseen_columns) == 0:
return

new_regressors = []

if isinstance(transform, FutureMixin):
# Every column from FutureMixin is regressor
out_columns = list(columns_after - columns_before)
new_regressors = out_columns
elif isinstance(transform, OneHotEncoderTransform):
# Only the columns created with OneHotEncoderTransform from regressor are regressors
in_column = transform.in_column
out_columns = list(columns_after - columns_before)
if in_column in self.regressors:
new_regressors = out_columns
elif hasattr(transform, "in_column"):
# Only the columns created with the other transforms from regressors are regressors
in_columns = transform.in_column if isinstance(transform.in_column, list) else [transform.in_column] # type: ignore
if hasattr(transform, "out_columns") and transform.out_columns is not None: # type: ignore
# User defined out_columns in sklearn
# TODO: remove this case after fixing the out_column attribute in SklearnTransform
out_columns = transform.out_columns # type: ignore
regressors_in_column_ids = [i for i, in_column in enumerate(in_columns) if in_column in self.regressors]
new_regressors = [out_columns[i] for i in regressors_in_column_ids]
elif hasattr(transform, "out_column") and transform.out_column is not None: # type: ignore
# User defined out_columns
out_columns = transform.out_column if isinstance(transform.out_column, list) else [transform.out_column] # type: ignore
regressors_in_column_ids = [i for i, in_column in enumerate(in_columns) if in_column in self.regressors]
new_regressors = [out_columns[i] for i in regressors_in_column_ids]
else:
# Default out_columns
out_columns = list(columns_after - columns_before)
regressors_in_column = [in_column for in_column in in_columns if in_column in self.regressors]
new_regressors = [
out_column
for out_column in out_columns
if np.any([regressor in out_column for regressor in regressors_in_column])
]

else:
raise ValueError("Transform is not FutureMixin and does not have in_column attribute!")

new_regressors = [regressor for regressor in new_regressors if regressor not in self.regressors]
self._regressors.extend(new_regressors)

def __repr__(self):
return self.df.__repr__()

Expand All @@ -243,13 +178,17 @@ def __getitem__(self, item):
df = df.loc[first_valid_idx:]
return df

def make_future(self, future_steps: int, tail_steps: int = 0) -> "TSDataset":
def make_future(
self, future_steps: int, transforms: Sequence["Transform"] = (), tail_steps: int = 0
) -> "TSDataset":
"""Return new TSDataset with future steps.
Parameters
----------
future_steps:
number of timestamp in the future to build features for.
transforms:
sequence of transforms to be applied.
tail_steps:
number of timestamp for context to build features for.
Expand Down Expand Up @@ -307,20 +246,21 @@ def make_future(self, future_steps: int, tail_steps: int = 0) -> "TSDataset":
f"NaN-s will be used for missing values"
)

if self.transforms is not None:
for transform in self.transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
df = transform.transform(df)
# Here only df is required, other metadata is not necessary to build the dataset
ts = TSDataset(df=df, freq=self.freq)
for transform in transforms:
tslogger.log(f"Transform {repr(transform)} is applied to dataset")
transform.transform(ts)
df = ts.to_pandas()

future_dataset = df.tail(future_steps + tail_steps).copy(deep=True)

future_dataset = future_dataset.sort_index(axis=1, level=(0, 1))
future_ts = TSDataset(df=future_dataset, freq=self.freq)

# can't put known_future into constructor, _check_known_future fails with df_exog=None
future_ts.known_future = self.known_future
future_ts._regressors = self.regressors
future_ts.transforms = self.transforms
future_ts.known_future = deepcopy(self.known_future)
future_ts._regressors = deepcopy(self.regressors)
future_ts.df_exog = self.df_exog
return future_ts

Expand All @@ -344,7 +284,6 @@ def tsdataset_idx_slice(self, start_idx: Optional[int] = None, end_idx: Optional
# can't put known_future into constructor, _check_known_future fails with df_exog=None
tsdataset_slice.known_future = deepcopy(self.known_future)
tsdataset_slice._regressors = deepcopy(self.regressors)
tsdataset_slice.transforms = deepcopy(self.transforms)
tsdataset_slice.df_exog = self.df_exog
return tsdataset_slice

Expand Down Expand Up @@ -425,16 +364,14 @@ def _check_endings(self, warning=False):
else:
raise ValueError("All segments should end at the same timestamp")

def inverse_transform(self):
def inverse_transform(self, transforms: Sequence["Transform"]):
"""Apply inverse transform method of transforms to the data.
Applied in reversed order.
"""
# TODO: return regressors after inverse_transform
if self.transforms is not None:
for transform in reversed(self.transforms):
tslogger.log(f"Inverse transform {repr(transform)} is applied to dataset")
self.df = transform.inverse_transform(self.df)
for transform in reversed(transforms):
tslogger.log(f"Inverse transform {repr(transform)} is applied to dataset")
transform.inverse_transform(self)

@property
def segments(self) -> List[str]:
Expand Down
5 changes: 0 additions & 5 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def forecast(self, ts: TSDataset) -> TSDataset:

df = TSDataset.to_dataset(df)
ts.df = df
ts.inverse_transform()
return ts


Expand Down Expand Up @@ -373,7 +372,6 @@ def forecast(

df = TSDataset.to_dataset(df)
ts.df = df
ts.inverse_transform()
return ts


Expand Down Expand Up @@ -429,7 +427,6 @@ def forecast(self, ts: TSDataset) -> TSDataset:
x = ts.to_pandas(flatten=True).drop(["segment"], axis=1)
y = self._base_model.predict(x).reshape(-1, horizon).T
ts.loc[:, pd.IndexSlice[:, "target"]] = y
ts.inverse_transform()
return ts

def get_model(self) -> Any:
Expand Down Expand Up @@ -805,8 +802,6 @@ def forecast(self, ts: "TSDataset", horizon: int) -> "TSDataset":
for (segment, feature_nm), value in predictions.items():
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:horizon, :]

future_ts.inverse_transform()

return future_ts

def get_model(self) -> "DeepBaseNet":
Expand Down
6 changes: 3 additions & 3 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule:
@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
# TODO: TSDataset does not have "transform" attribute anymore
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform): # type: ignore
return ts.transforms[-1] # type: ignore
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
Expand Down Expand Up @@ -236,5 +237,4 @@ def forecast(
df = df.sort_index(axis=1)
ts.df = df

ts.inverse_transform()
return ts
6 changes: 3 additions & 3 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule:
@staticmethod
def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform:
"""Get PytorchForecastingTransform from ts.transforms or raise exception if not found."""
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform):
return ts.transforms[-1]
# TODO: TSDataset does not have "transform" attribute anymore
if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform): # type: ignore
return ts.transforms[-1] # type: ignore
else:
raise ValueError(
"Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms"
Expand Down Expand Up @@ -268,5 +269,4 @@ def forecast(
df = df.sort_index(axis=1)
ts.df = df

ts.inverse_transform()
return ts
8 changes: 3 additions & 5 deletions etna/pipeline/autoregressive_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def fit(self, ts: TSDataset) -> "AutoRegressivePipeline":
self.ts = ts
ts.fit_transform(self.transforms)
self.model.fit(ts)
self.ts.inverse_transform()
self.ts.inverse_transform(self.transforms)
return self

def _create_predictions_template(self) -> pd.DataFrame:
Expand Down Expand Up @@ -121,8 +121,6 @@ def _forecast(self) -> TSDataset:
df_exog=self.ts.df_exog,
known_future=self.ts.known_future,
)
# manually set transforms in current_ts, otherwise make_future won't know about them
current_ts.transforms = self.transforms
with warnings.catch_warnings():
warnings.filterwarnings(
message="TSDataset freq can't be inferred",
Expand All @@ -132,7 +130,7 @@ def _forecast(self) -> TSDataset:
message="You probably set wrong freq.",
action="ignore",
)
current_ts_forecast = current_ts.make_future(current_step)
current_ts_forecast = current_ts.make_future(current_step, transforms=self.transforms)
current_ts_future = self.model.forecast(current_ts_forecast)
prediction_df = prediction_df.combine_first(current_ts_future.to_pandas()[prediction_df.columns])

Expand All @@ -141,7 +139,7 @@ def _forecast(self) -> TSDataset:
df=prediction_df, freq=self.ts.freq, df_exog=self.ts.df_exog, known_future=self.ts.known_future
)
prediction_ts.transform(self.transforms)
prediction_ts.inverse_transform()
prediction_ts.inverse_transform(self.transforms)
# cut only last timestamps from result dataset
prediction_ts.df = prediction_ts.df.tail(self.horizon)
prediction_ts.raw_df = prediction_ts.raw_df.tail(self.horizon)
Expand Down
11 changes: 7 additions & 4 deletions etna/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fit(self, ts: TSDataset) -> "Pipeline":
self.ts = ts
self.ts.fit_transform(self.transforms)
self.model.fit(self.ts)
self.ts.inverse_transform()
self.ts.inverse_transform(self.transforms)
return self

def _forecast(self) -> TSDataset:
Expand All @@ -55,10 +55,12 @@ def _forecast(self) -> TSDataset:
raise ValueError("Something went wrong, ts is None!")

if isinstance(self.model, DeepBaseModel):
future = self.ts.make_future(future_steps=self.model.decoder_length, tail_steps=self.model.encoder_length)
future = self.ts.make_future(
future_steps=self.model.decoder_length, transforms=self.transforms, tail_steps=self.model.encoder_length
)
predictions = self.model.forecast(ts=future, horizon=self.horizon)
else:
future = self.ts.make_future(self.horizon)
future = self.ts.make_future(self.horizon, transforms=self.transforms)
predictions = self.model.forecast(ts=future)
return predictions

Expand Down Expand Up @@ -90,10 +92,11 @@ def forecast(
self._validate_backtest_n_folds(n_folds=n_folds)

if prediction_interval and isinstance(self.model, PredictIntervalAbstractModel):
future = self.ts.make_future(self.horizon)
future = self.ts.make_future(self.horizon, transforms=self.transforms)
predictions = self.model.forecast(ts=future, prediction_interval=prediction_interval, quantiles=quantiles)
else:
predictions = super().forecast(
prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds
)
predictions.inverse_transform(self.transforms)
return predictions
2 changes: 0 additions & 2 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from etna.transforms.base import IrreversiblePerSegmentWrapper
from etna.transforms.base import IrreversibleTransform
from etna.transforms.base import NewPerSegmentWrapper
from etna.transforms.base import NewTransform
from etna.transforms.base import OneSegmentTransform
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import ReversiblePerSegmentWrapper
Expand Down
Loading

0 comments on commit c2edede

Please sign in to comment.