Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ETSformer as an imputation model #328

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- conda-forge::numpy
- conda-forge::scipy
- conda-forge::python
- conda-forge::einops
- conda-forge::pandas
- conda-forge::matplotlib
- conda-forge::tensorboard
Expand Down
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .saits import SAITS
from .transformer import Transformer
from .timesnet import TimesNet
from .etsformer import ETSformer
from .autoformer import Autoformer
from .dlinear import DLinear
from .patchtst import PatchTST
Expand All @@ -27,6 +28,7 @@
# neural network imputation methods
"SAITS",
"Transformer",
"ETSformer",
"TimesNet",
"PatchTST",
"DLinear",
Expand Down
17 changes: 17 additions & 0 deletions pypots/imputation/etsformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
The package of the partially-observed time-series imputation model ETSformer.

Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
ETSformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import ETSformer

__all__ = [
"ETSformer",
]
24 changes: 24 additions & 0 deletions pypots/imputation/etsformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for ETSformer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForETSformer(DatasetForSAITS):
"""Actually ETSformer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_labels, file_type, rate)
326 changes: 326 additions & 0 deletions pypots/imputation/etsformer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
"""
The implementation of ETSformer for the partially-observed time-series imputation task.

Refer to the paper "Woo, G., Liu, C., Sahoo, D., Kumar, A., & Hoi, S. (2023).
ETSformer: Exponential Smoothing Transformers for Time-series Forecasting. ICLR 2023.".

Notes
-----
Partial implementation uses code from https://github.com/salesforce/ETSformer

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union, Optional

import numpy as np
import torch
from torch.utils.data import DataLoader

from .data import DatasetForETSformer
from .modules.core import _ETSformer
from ..base import BaseNNImputer
from ...data.base import BaseDataset
from ...data.checking import check_X_ori_in_val_set
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class ETSformer(BaseNNImputer):
"""The PyTorch implementation of the ETSformer model.
ETSformer is originally proposed by Woo et al. in :cite:`woo2023etsformer`.

Parameters
----------
n_steps :
The number of time steps in the time-series data sample.

n_features :
The number of features in the time-series data sample.

n_e_layers :
The number of layers in the ETSformer encoder.

n_d_layers :
The number of layers in the ETSformer decoder.

n_heads :
The number of heads in each layer of ETSformer.

d_model :
The dimension of the model.

d_ffn :
The dimension of the feed-forward network.

top_k :
Top-K Fourier bases.

dropout :
The dropout rate for the model.

batch_size :
The batch size for training and evaluating the model.

epochs :
The number of epochs for training the model.

patience :
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.

num_workers :
The number of subprocesses to use for data loading.
`0` means data loading will be in the main process, i.e. there won't be subprocesses.

device :
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.

saving_path :
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
training into a tensorboard file). Will not save if not given.

model_saving_strategy :
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
No model will be saved when it is set as None.
The "best" strategy will only automatically save the best model after the training finished.
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
The "all" strategy will save every model after each epoch training.

References
----------
.. [1] `Woo, Gerald, Chenghao Liu, Doyen Sahoo, Akshat Kumar, and Steven Hoi.
"ETSformer: Exponential Smoothing Transformers for Time-series Forecasting ".
ICLR 2023.
<https://openreview.net/pdf?id=5m_3whfo483>`_

"""

def __init__(
self,
n_steps,
n_features,
n_e_layers,
n_d_layers,
n_heads,
d_model,
d_ffn,
top_k,
dropout: float = 0,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
)

self.n_steps = n_steps
self.n_features = n_features
# model hype-parameters
self.n_heads = n_heads
self.n_e_layers = n_e_layers
self.n_d_layers = n_d_layers
self.d_model = d_model
self.d_ffn = d_ffn
self.dropout = dropout
self.top_k = top_k

# set up the model
self.model = _ETSformer(
self.n_steps,
self.n_features,
self.n_e_layers,
self.n_d_layers,
self.n_heads,
self.d_model,
self.d_ffn,
self.dropout,
self.top_k,
)
self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list) -> dict:
(
indices,
X,
missing_mask,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(data)

inputs = {
"X": X,
"missing_mask": missing_mask,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}

return inputs

def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: list) -> dict:
indices, X, missing_mask = self._send_data_to_given_device(data)

inputs = {
"X": X,
"missing_mask": missing_mask,
}

return inputs

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
training_set = DatasetForETSformer(
train_set, return_X_ori=False, return_labels=False, file_type=file_type
)
training_loader = DataLoader(
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_loader = None
if val_set is not None:
if not check_X_ori_in_val_set(val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForETSformer(
val_set, return_X_ori=True, return_labels=False, file_type=file_type
)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.

Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type : str
The type of the given file if test_set is a path string.

Returns
-------
result_dict : dict,
The dictionary containing the clustering results and latent variables if necessary.

"""
# Step 1: wrap the input data with classes Dataset and DataLoader
self.model.eval() # set the model as eval status to freeze it.
test_set = BaseDataset(
test_set, return_X_ori=False, return_labels=False, file_type=file_type
)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
imputation_collector = []

# Step 2: process the data with the model
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.

Warnings
--------
The method impute is deprecated. Please use `predict()` instead.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
The type of the given file if X is a path string.

Returns
-------
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)

results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
Loading
Loading