From 5f02bded31cbb83701dfe9a639d1cdb8194bc00d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 11:23:08 +0800 Subject: [PATCH 1/3] feat: add ETSformer as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/etsformer/__init__.py | 17 + pypots/imputation/etsformer/data.py | 24 ++ pypots/imputation/etsformer/model.py | 324 ++++++++++++++++ .../imputation/etsformer/modules/__init__.py | 6 + pypots/imputation/etsformer/modules/core.py | 101 +++++ .../etsformer/modules/submodules.py | 354 ++++++++++++++++++ tests/imputation/etsformer.py | 130 +++++++ 8 files changed, 958 insertions(+) create mode 100644 pypots/imputation/etsformer/__init__.py create mode 100644 pypots/imputation/etsformer/data.py create mode 100644 pypots/imputation/etsformer/model.py create mode 100644 pypots/imputation/etsformer/modules/__init__.py create mode 100644 pypots/imputation/etsformer/modules/core.py create mode 100644 pypots/imputation/etsformer/modules/submodules.py create mode 100644 tests/imputation/etsformer.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 2d408d58..f1c4d381 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -13,6 +13,7 @@ from .saits import SAITS from .transformer import Transformer from .timesnet import TimesNet +from .etsformer import ETSformer from .autoformer import Autoformer from .dlinear import DLinear from .patchtst import PatchTST @@ -27,6 +28,7 @@ # neural network imputation methods "SAITS", "Transformer", + "ETSformer", "TimesNet", "PatchTST", "DLinear", diff --git a/pypots/imputation/etsformer/__init__.py b/pypots/imputation/etsformer/__init__.py new file mode 100644 index 00000000..1e5c8417 --- /dev/null +++ b/pypots/imputation/etsformer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model ETSformer. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +ETSformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import ETSformer + +__all__ = [ + "ETSformer", +] diff --git a/pypots/imputation/etsformer/data.py b/pypots/imputation/etsformer/data.py new file mode 100644 index 00000000..f03a4e61 --- /dev/null +++ b/pypots/imputation/etsformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for ETSformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForETSformer(DatasetForSAITS): + """Actually ETSformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py new file mode 100644 index 00000000..50281ba2 --- /dev/null +++ b/pypots/imputation/etsformer/model.py @@ -0,0 +1,324 @@ +""" +The implementation of ETSformer for the partially-observed time-series imputation task. + +Refer to the paper "Woo, G., Liu, C., Sahoo, D., Kumar, A., & Hoi, S. (2023). +ETSformer: Exponential Smoothing Transformers for Time-series Forecasting. ICLR 2023.". + +Notes +----- +Partial implementation uses code from https://github.com/salesforce/ETSformer + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForETSformer +from .modules.core import _ETSformer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class ETSformer(BaseNNImputer): + """The PyTorch implementation of the ETSformer model. + ETSformer is originally proposed by Woo et al. in :cite:`woo2023etsformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_e_layers : + The number of layers in the ETSformer encoder. + + n_d_layers : + The number of layers in the ETSformer decoder. + + n_heads : + The number of heads in each layer of ETSformer. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + + dropout : + The dropout rate for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Woo, Gerald, Chenghao Liu, Doyen Sahoo, Akshat Kumar, and Steven Hoi. + "ETSformer: Exponential Smoothing Transformers for Time-series Forecasting ". + ICLR 2023. + `_ + + """ + + def __init__( + self, + n_steps, + n_features, + n_e_layers, + n_d_layers, + n_heads, + d_model, + d_ffn, + top_k, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_e_layers = n_e_layers + self.n_d_layers = n_d_layers + self.d_model = d_model + self.d_ffn = d_ffn + self.dropout = dropout + self.top_k = top_k + + # set up the model + self.model = _ETSformer( + self.n_steps, + self.n_features, + self.n_e_layers, + self.n_d_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.dropout, + self.top_k, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForETSformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForETSformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/etsformer/modules/__init__.py b/pypots/imputation/etsformer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/etsformer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py new file mode 100644 index 00000000..13f692fc --- /dev/null +++ b/pypots/imputation/etsformer/modules/core.py @@ -0,0 +1,101 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from .submodules import ( + Transform, + ETSformerEncoderLayer, + ETSformerEncoder, + ETSformerDecoderLayer, + ETSformerDecoder, +) +from ...timesnet.modules.embedding import DataEmbedding +from ....utils.metrics import calc_mse + + +class _ETSformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_e_layers, + n_d_layers, + n_heads, + d_model, + d_ffn, + dropout, + top_k, + activation="sigmoid", + ): + super().__init__() + + self.n_steps = n_steps + + self.enc_embedding = DataEmbedding( + n_features, + d_model, + dropout=dropout, + ) + + # Encoder + self.encoder = ETSformerEncoder( + [ + ETSformerEncoderLayer( + d_model, + n_heads, + n_features, + n_steps, + n_steps, + top_k, + dim_feedforward=d_ffn, + dropout=dropout, + activation=activation, + ) + for _ in range(n_e_layers) + ] + ) + # Decoder + self.decoder = ETSformerDecoder( + [ + ETSformerDecoderLayer( + d_model, + n_heads, + n_features, + n_steps, + dropout=dropout, + ) + for _ in range(n_d_layers) + ], + ) + self.transform = Transform(sigma=0.2) + + # for the imputation task, the output dim is the same as input dim + self.projection = nn.Linear(d_model, n_features) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + res = self.enc_embedding(X) + + # ETSformer encoder processing + level, growths, seasons = self.encoder(res, X, attn_mask=None) + growth, season = self.decoder(growths, seasons) + output = level[:, -1:] + growth + season + + imputed_data = masks * X + (1 - masks) * output + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/etsformer/modules/submodules.py b/pypots/imputation/etsformer/modules/submodules.py new file mode 100644 index 00000000..d1a1c7bb --- /dev/null +++ b/pypots/imputation/etsformer/modules/submodules.py @@ -0,0 +1,354 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import math + +import torch +import torch.fft as fft +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from scipy.fftpack import next_fast_len + + +class Transform: + def __init__(self, sigma): + self.sigma = sigma + + @torch.no_grad() + def transform(self, x): + return self.jitter(self.shift(self.scale(x))) + + def jitter(self, x): + return x + (torch.randn(x.shape).to(x.device) * self.sigma) + + def scale(self, x): + return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1) + + def shift(self, x): + return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma) + + +def conv1d_fft(f, g, dim=-1): + N = f.size(dim) + M = g.size(dim) + + fast_len = next_fast_len(N + M - 1) + + F_f = fft.rfft(f, fast_len, dim=dim) + F_g = fft.rfft(g, fast_len, dim=dim) + + F_fg = F_f * F_g.conj() + out = fft.irfft(F_fg, fast_len, dim=dim) + out = out.roll((-1,), dims=(dim,)) + idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) + out = out.index_select(dim, idx) + + return out + + +class ExponentialSmoothing(nn.Module): + def __init__(self, dim, nhead, dropout=0.1, aux=False): + super().__init__() + self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1)) + self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim)) + self.dropout = nn.Dropout(dropout) + if aux: + self.aux_dropout = nn.Dropout(dropout) + + def forward(self, values, aux_values=None): + b, t, h, d = values.shape + + init_weight, weight = self.get_exponential_weight(t) + output = conv1d_fft(self.dropout(values), weight, dim=1) + output = init_weight * self.v0 + output + + if aux_values is not None: + aux_weight = weight / (1 - self.weight) * self.weight + aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight) + output = output + aux_output + + return output + + def get_exponential_weight(self, T): + # Generate array [0, 1, ..., T-1] + powers = torch.arange(T, dtype=torch.float, device=self.weight.device) + + # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0] + weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,))) + + # \alpha^t for all t = 1, 2, ..., T + init_weight = self.weight ** (powers + 1) + + return rearrange(init_weight, "h t -> 1 t h 1"), rearrange( + weight, "h t -> 1 t h 1" + ) + + @property + def weight(self): + return torch.sigmoid(self._smoothing_weight) + + +class Feedforward(nn.Module): + def __init__(self, d_model, dim_feedforward, dropout=0.1, activation="sigmoid"): + # Implementation of Feedforward model + super().__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) + self.dropout1 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) + self.dropout2 = nn.Dropout(dropout) + self.activation = getattr(F, activation) + + def forward(self, x): + x = self.linear2(self.dropout1(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class GrowthLayer(nn.Module): + def __init__(self, d_model, nhead, d_head=None, dropout=0.1): + super().__init__() + self.d_head = d_head or (d_model // nhead) + self.d_model = d_model + self.nhead = nhead + + self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) + self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) + self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout) + self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model) + + assert ( + self.d_head * self.nhead == self.d_model + ), "d_model must be divisible by nhead" + + def forward(self, inputs): + """ + :param inputs: shape: (batch, seq_len, dim) + :return: shape: (batch, seq_len, dim) + """ + b, t, d = inputs.shape + values = self.in_proj(inputs).view(b, t, self.nhead, -1) + values = torch.cat([repeat(self.z0, "h d -> b 1 h d", b=b), values], dim=1) + values = values[:, 1:] - values[:, :-1] + out = self.es(values) + out = torch.cat([repeat(self.es.v0, "1 1 h d -> b 1 h d", b=b), out], dim=1) + out = rearrange(out, "b t h d -> b t (h d)") + return self.out_proj(out) + + +class FourierLayer(nn.Module): + def __init__(self, d_model, pred_len, k=None, low_freq=1): + super().__init__() + self.d_model = d_model + self.pred_len = pred_len + self.k = k + self.low_freq = low_freq + + def forward(self, x): + """x: (b, t, d)""" + b, t, d = x.shape + x_freq = fft.rfft(x, dim=1) + + if t % 2 == 0: + x_freq = x_freq[:, self.low_freq : -1] + f = fft.rfftfreq(t)[self.low_freq : -1] + else: + x_freq = x_freq[:, self.low_freq :] + f = fft.rfftfreq(t)[self.low_freq :] + + x_freq, index_tuple = self.topk_freq(x_freq) + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)) + f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) + + return self.extrapolate(x_freq, f, t) + + def extrapolate(self, x_freq, f, t): + x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) + f = torch.cat([f, -f], dim=1) + t_val = rearrange( + torch.arange(t + self.pred_len, dtype=torch.float), "t -> () () t ()" + ).to(x_freq.device) + + amp = rearrange(x_freq.abs() / t, "b f d -> b f () d") + phase = rearrange(x_freq.angle(), "b f d -> b f () d") + + x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) + + return reduce(x_time, "b f t d -> b t d", "sum") + + def topk_freq(self, x_freq): + values, indices = torch.topk( + x_freq.abs(), self.k, dim=1, largest=True, sorted=True + ) + mesh_a, mesh_b = torch.meshgrid( + torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)) + ) + index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) + x_freq = x_freq[index_tuple] + + return x_freq, index_tuple + + +class LevelLayer(nn.Module): + def __init__(self, d_model, c_out, dropout=0.1): + super().__init__() + self.d_model = d_model + self.c_out = c_out + + self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True) + self.growth_pred = nn.Linear(self.d_model, self.c_out) + self.season_pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, level, growth, season): + b, t, _ = level.shape + growth = self.growth_pred(growth).view(b, t, self.c_out, 1) + season = self.season_pred(season).view(b, t, self.c_out, 1) + growth = growth.view(b, t, self.c_out, 1) + season = season.view(b, t, self.c_out, 1) + level = level.view(b, t, self.c_out, 1) + out = self.es(level - season, aux_values=growth) + out = rearrange(out, "b t h d -> b t (h d)") + return out + + +class ETSformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + n_heads, + c_out, + seq_len, + pred_len, + k, + dim_feedforward=None, + dropout=0.1, + activation="sigmoid", + layer_norm_eps=1e-5, + ): + super().__init__() + self.d_model = d_model + self.nhead = n_heads + self.c_out = c_out + self.seq_len = seq_len + self.pred_len = pred_len + dim_feedforward = dim_feedforward or 4 * d_model + self.dim_feedforward = dim_feedforward + + self.growth_layer = GrowthLayer(d_model, n_heads, dropout=dropout) + self.seasonal_layer = FourierLayer(d_model, pred_len, k=k) + self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) + + # Implementation of Feedforward model + self.ff = Feedforward( + d_model, dim_feedforward, dropout=dropout, activation=activation + ) + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, res, level, attn_mask=None): + season = self._season_block(res) + res = res - season[:, : -self.pred_len] + growth = self._growth_block(res) + res = self.norm1(res - growth[:, 1:]) + res = self.norm2(res + self.ff(res)) + + level = self.level_layer(level, growth[:, :-1], season[:, : -self.pred_len]) + return res, level, growth, season + + def _growth_block(self, x): + x = self.growth_layer(x) + return self.dropout1(x) + + def _season_block(self, x): + x = self.seasonal_layer(x) + return self.dropout2(x) + + +class ETSformerEncoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.layers = nn.ModuleList(layers) + + def forward(self, res, level, attn_mask=None): + growths = [] + seasons = [] + for layer in self.layers: + res, level, growth, season = layer(res, level, attn_mask=None) + growths.append(growth) + seasons.append(season) + + return level, growths, seasons + + +class DampingLayer(nn.Module): + def __init__(self, pred_len, nhead, dropout=0.1): + super().__init__() + self.pred_len = pred_len + self.nhead = nhead + self._damping_factor = nn.Parameter(torch.randn(1, nhead)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = repeat(x, "b 1 d -> b t d", t=self.pred_len) + b, t, d = x.shape + + powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 + powers = powers.view(self.pred_len, 1) + damping_factors = self.damping_factor**powers + damping_factors = damping_factors.cumsum(dim=0) + x = x.view(b, t, self.nhead, -1) + x = self.dropout(x) * damping_factors.unsqueeze(-1) + return x.view(b, t, d) + + @property + def damping_factor(self): + return torch.sigmoid(self._damping_factor) + + +class ETSformerDecoderLayer(nn.Module): + def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.c_out = c_out + self.pred_len = pred_len + + self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + + def forward(self, growth, season): + growth_horizon = self.growth_damping(growth[:, -1:]) + growth_horizon = self.dropout1(growth_horizon) + + seasonal_horizon = season[:, -self.pred_len :] + return growth_horizon, seasonal_horizon + + +class ETSformerDecoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.d_model = layers[0].d_model + self.c_out = layers[0].c_out + self.pred_len = layers[0].pred_len + self.nhead = layers[0].nhead + + self.layers = nn.ModuleList(layers) + self.pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, growths, seasons): + growth_repr = [] + season_repr = [] + + for idx, layer in enumerate(self.layers): + growth_horizon, season_horizon = layer(growths[idx], seasons[idx]) + growth_repr.append(growth_horizon) + season_repr.append(season_horizon) + growth_repr = sum(growth_repr) + season_repr = sum(season_repr) + return self.pred(growth_repr), self.pred(season_repr) diff --git a/tests/imputation/etsformer.py b/tests/imputation/etsformer.py new file mode 100644 index 00000000..c098b79f --- /dev/null +++ b/tests/imputation/etsformer.py @@ -0,0 +1,130 @@ +""" +Test cases for ETSformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import ETSformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestETSformer(unittest.TestCase): + logger.info("Running tests for an imputation model ETSformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "ETSformer") + model_save_name = "saved_etsformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a ETSformer model + etsformer = ETSformer( + DATA["n_steps"], + DATA["n_features"], + n_e_layers=2, + n_d_layers=2, + n_heads=2, + d_model=128, + d_ffn=256, + top_k=3, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_0_fit(self): + self.etsformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_1_impute(self): + imputation_results = self.etsformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"ETSformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_2_parameters(self): + assert hasattr(self.etsformer, "model") and self.etsformer.model is not None + + assert ( + hasattr(self.etsformer, "optimizer") + and self.etsformer.optimizer is not None + ) + + assert hasattr(self.etsformer, "best_loss") + self.assertNotEqual(self.etsformer.best_loss, float("inf")) + + assert ( + hasattr(self.etsformer, "best_model_dict") + and self.etsformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.etsformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.etsformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.etsformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_4_lazy_loading(self): + self.etsformer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.etsformer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading ETSformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 5e98d3b2f8cb6ebf7da8b017c51f689d85f9afac Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 13:56:56 +0800 Subject: [PATCH 2/3] refactor: remove unused modules in ETSformer; --- pypots/imputation/etsformer/model.py | 2 ++ pypots/imputation/etsformer/modules/core.py | 6 +----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index 50281ba2..2e50fb79 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -56,6 +56,8 @@ class ETSformer(BaseNNImputer): d_ffn : The dimension of the feed-forward network. + top_k : + Top-K Fourier bases. dropout : The dropout rate for the model. diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py index 13f692fc..81009dd2 100644 --- a/pypots/imputation/etsformer/modules/core.py +++ b/pypots/imputation/etsformer/modules/core.py @@ -8,7 +8,6 @@ import torch.nn as nn from .submodules import ( - Transform, ETSformerEncoderLayer, ETSformerEncoder, ETSformerDecoderLayer, @@ -72,10 +71,7 @@ def __init__( for _ in range(n_d_layers) ], ) - self.transform = Transform(sigma=0.2) - - # for the imputation task, the output dim is the same as input dim - self.projection = nn.Linear(d_model, n_features) + # self.transform = Transform(sigma=0.2) # for forecasting def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] From b26d4074c9d559637386bd4ac8188dd6bfd1ef61 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 13:59:48 +0800 Subject: [PATCH 3/3] depen: add einops as an dependency; --- environment-dev.yml | 1 + requirements.txt | 1 + setup.cfg | 3 ++- setup.py | 1 + tests/environment_for_conda_test.yml | 3 ++- 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index 2b3bea45..d3a47627 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -13,6 +13,7 @@ dependencies: - conda-forge::numpy - conda-forge::scipy - conda-forge::python + - conda-forge::einops - conda-forge::pandas - conda-forge::matplotlib - conda-forge::tensorboard diff --git a/requirements.txt b/requirements.txt index 13a82508..cfbab72f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ h5py numpy scipy +einops pandas matplotlib tensorboard diff --git a/setup.cfg b/setup.cfg index bbc3f761..a36169c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,8 @@ basic = numpy scikit-learn matplotlib - pandas<2.0.0 + einops + pandas torch>=1.10.0 tensorboard scipy diff --git a/setup.py b/setup.py index 4630e69c..92011eeb 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "h5py", "numpy", "scipy", + "einops", "pandas", "matplotlib", "tensorboard", diff --git a/tests/environment_for_conda_test.yml b/tests/environment_for_conda_test.yml index 2cb128d9..54f07925 100644 --- a/tests/environment_for_conda_test.yml +++ b/tests/environment_for_conda_test.yml @@ -13,7 +13,8 @@ dependencies: - conda-forge::scipy - conda-forge::numpy - conda-forge::scikit-learn - - conda-forge::pandas <2.0.0 + - conda-forge::einops + - conda-forge::pandas - conda-forge::h5py - conda-forge::tensorboard - conda-forge::pygrinder >=0.4