Skip to content

Commit

Permalink
MLPModel (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
DBcreator committed Aug 25, 2022
1 parent 74096ea commit be73043
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
-
- Function to transform etna objects to dict([#818](https://github.com/tinkoff-ai/etna/issues/818))
-
- `MLPModel`([#860](https://github.com/tinkoff-ai/etna/pull/860))
- `DeadlineMovingAverageModel` ([#827](https://github.com/tinkoff-ai/etna/pull/827))
- `DirectEnsemble` ([#824](https://github.com/tinkoff-ai/etna/pull/824))
- CICD: untaged docker image cleaner ([#856](https://github.com/tinkoff-ai/etna/pull/856))
Expand Down
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

if SETTINGS.torch_required:
from etna.models.nn.deepar import DeepARModel
from etna.models.nn.mlp import MLPModel
from etna.models.nn.rnn import RNNModel
from etna.models.nn.tft import TFTModel
220 changes: 220 additions & 0 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional

import pandas as pd
from typing_extensions import TypedDict

from etna import SETTINGS

if SETTINGS.torch_required:
import torch
import torch.nn as nn

import numpy as np

from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet


class MLPBatch(TypedDict):
"""Batch specification for MLP."""

decoder_real: "torch.Tensor"
decoder_target: "torch.Tensor"
segment: "torch.Tensor"


class MLPNet(DeepBaseNet):
"""MLP model."""

def __init__(
self,
input_size: int,
hidden_size: List[int],
lr: float,
loss: "torch.nn.Module",
optimizer_params: Optional[dict],
) -> None:
"""Init MLP model.
Parameters
----------
input_size:
size of the input feature space: target plus extra features
hidden_size:
list of sizes of the hidden states
lr:
learning rate
loss:
loss function
optimizer_params:
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`)
"""
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.lr = lr
self.loss = nn.MSELoss() if loss is None else loss
self.optimizer_params = {} if optimizer_params is None else optimizer_params
layers = [nn.Linear(in_features=input_size, out_features=hidden_size[0]), nn.ReLU()]
for i in range(1, len(hidden_size)):
layers.append(nn.Linear(in_features=hidden_size[i - 1], out_features=hidden_size[i]))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features=hidden_size[-1], out_features=1))
self.mlp = nn.Sequential(*layers)

def forward(self, batch: MLPBatch): # type: ignore
"""Forward pass.
Parameters
----------
batch:
batch of data
Returns
-------
:
forecast
"""
decoder_real = batch["decoder_real"].float()
return self.mlp(decoder_real)

def step(self, batch: MLPBatch, *args, **kwargs): # type: ignore
"""Step for loss computation for training or validation.
Parameters
----------
batch:
batch of data
Returns
-------
:
loss, true_target, prediction_target
"""
decoder_real = batch["decoder_real"].float()
decoder_target = batch["decoder_target"].float()

output = self.mlp(decoder_real)
loss = self.loss(output, decoder_target)
return loss, decoder_target, output

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterable[dict]:
"""Make samples from segment DataFrame."""

def _make(df: pd.DataFrame, start_idx: int, decoder_length: int) -> Optional[dict]:
sample: Dict[str, Any] = {"decoder_real": list(), "decoder_target": list(), "segment": None}
total_length = len(df["target"])
total_sample_length = decoder_length

if total_sample_length + start_idx > total_length:
return None

sample["decoder_real"] = (
df.select_dtypes(include=[np.number])
.pipe(lambda x: x[[i for i in x.columns if i != "target"]])
.values[start_idx : start_idx + decoder_length]
)

target = df["target"].values[start_idx : start_idx + decoder_length].reshape(-1, 1)
sample["decoder_target"] = target
sample["segment"] = df["segment"].values[0]
return sample

start_idx = 0
while True:
batch = _make(
df=df,
start_idx=start_idx,
decoder_length=decoder_length,
)
if batch is None:
break
yield batch
start_idx += decoder_length
if start_idx < len(df):
resid_length = len(df) - decoder_length
batch = _make(df=df, start_idx=resid_length, decoder_length=decoder_length)
if batch is not None:
yield batch

def configure_optimizers(self):
"""Optimizer configuration."""
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, **self.optimizer_params)
return optimizer


class MLPModel(DeepBaseModel):
"""MLPModel."""

def __init__(
self,
input_size: int,
decoder_length: int,
hidden_size: List,
encoder_length: int = 0,
lr: float = 1e-3,
loss: Optional["torch.nn.Module"] = None,
train_batch_size: int = 16,
test_batch_size: int = 16,
optimizer_params: Optional[dict] = None,
trainer_params: Optional[dict] = None,
train_dataloader_params: Optional[dict] = None,
test_dataloader_params: Optional[dict] = None,
val_dataloader_params: Optional[dict] = None,
split_params: Optional[dict] = None,
):
super().__init__(
net=MLPNet(
input_size=input_size,
hidden_size=hidden_size,
lr=lr,
loss=loss, # type: ignore
optimizer_params=optimizer_params,
),
encoder_length=encoder_length,
decoder_length=decoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
train_dataloader_params=train_dataloader_params,
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
"""Init MLP model.
Parameters
----------
input_size:
size of the input feature space: target plus extra features
decoder_length:
decoder length
hidden_size:
List of sizes of the hidden states
encoder_length:
encoder length
lr:
learning rate
loss:
loss function, MSELoss by default
train_batch_size:
batch size for training
test_batch_size:
batch size for testing
optimizer_params:
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`)
trainer_params:
Pytorch ligthning trainer parameters (api reference :py:class:`pytorch_lightning.trainer.trainer.Trainer`)
train_dataloader_params:
parameters for train dataloader like sampler for example (api reference :py:class:`torch.utils.data.DataLoader`)
test_dataloader_params:
parameters for test dataloader
val_dataloader_params:
parameters for validation dataloader
split_params:
dictionary with parameters for :py:func:`torch.utils.data.random_split` for train-test splitting
* **train_size**: (*float*) value from 0 to 1 - fraction of samples to use for training
* **generator**: (*Optional[torch.Generator]*) - generator for reproducibile train-test splitting
* **torch_dataset_size**: (*Optional[int]*) - number of samples in dataset, in case of dataset not implementing ``__len__``
"""
96 changes: 96 additions & 0 deletions tests/test_models/nn/test_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from unittest.mock import MagicMock

import numpy as np
import pytest
import torch
from torch import nn

from etna.datasets.tsdataset import TSDataset
from etna.metrics import MAE
from etna.models.nn import MLPModel
from etna.models.nn.mlp import MLPNet
from etna.transforms import FourierTransform
from etna.transforms import LagTransform
from etna.transforms import StandardScalerTransform


@pytest.mark.parametrize("horizon", [8, 13])
def test_mlp_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_function_with_horizon, horizon):

ts_train, ts_test = ts_dataset_weekly_function_with_horizon(horizon)
lag = LagTransform(in_column="target", lags=list(range(horizon, horizon + 4)))
fourier = FourierTransform(period=7, order=3)
std = StandardScalerTransform(in_column="target")
ts_train.fit_transform([std, lag, fourier])

decoder_length = 14
model = MLPModel(
input_size=10,
hidden_size=[10, 10, 10, 10, 10],
lr=1e-1,
decoder_length=decoder_length,
trainer_params=dict(max_epochs=100),
)
future = ts_train.make_future(decoder_length)
model.fit(ts_train)
future = model.forecast(future, horizon=horizon)

mae = MAE("macro")
assert mae(ts_test, future) < 0.05


def test_mlp_make_samples(simple_df_relevance):
mlp_module = MagicMock()
df, df_exog = simple_df_relevance

ts = TSDataset(df=df, df_exog=df_exog, freq="D")
df = ts.to_flatten(ts.df)
encoder_length = 0
decoder_length = 5
ts_samples = list(
MLPNet.make_samples(
mlp_module, df=df[df.segment == "1"], encoder_length=encoder_length, decoder_length=decoder_length
)
)
first_sample = ts_samples[0]
second_sample = ts_samples[1]
last_sample = ts_samples[-1]
expected = {
"decoder_real": np.array([[58.0, 0], [59.0, 0], [60.0, 0], [61.0, 0], [62.0, 0]]),
"decoder_target": np.array([[27.0], [28.0], [29.0], [30.0], [31.0]]),
"segment": "1",
}

assert first_sample["segment"] == "1"
assert first_sample["decoder_real"].shape == (decoder_length, 2)
assert first_sample["decoder_target"].shape == (decoder_length, 1)
assert len(ts_samples) == 7
assert np.all(last_sample["decoder_target"] == expected["decoder_target"])
assert np.all(last_sample["decoder_real"] == expected["decoder_real"])
assert last_sample["segment"] == expected["segment"]
np.testing.assert_equal(df[["target"]].iloc[:decoder_length], first_sample["decoder_target"])
np.testing.assert_equal(df[["target"]].iloc[decoder_length : 2 * decoder_length], second_sample["decoder_target"])


def test_mlp_step():

batch = {
"decoder_real": torch.Tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]),
"decoder_target": torch.Tensor([[1], [2], [3]]),
"segment": "A",
}
model = MLPNet(input_size=3, hidden_size=[1], lr=1e-2, loss=nn.MSELoss(), optimizer_params=None)
loss, decoder_target, output = model.step(batch)
assert type(loss) == torch.Tensor
assert type(decoder_target) == torch.Tensor
assert torch.all(decoder_target == batch["decoder_target"])
assert type(output) == torch.Tensor
assert output.shape == torch.Size([3, 1])


def test_mlp_layers():
model = MLPNet(input_size=3, hidden_size=[10], lr=1e-2, loss=None, optimizer_params=None)
model_ = nn.Sequential(
nn.Linear(in_features=3, out_features=10), nn.ReLU(), nn.Linear(in_features=10, out_features=1)
)
assert repr(model_) == repr(model.mlp)

1 comment on commit be73043

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.