Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement iTransformer as an imputation model #371

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ pypots.imputation.transformer
:show-inheritance:
:inherited-members:

pypots.imputation.itransformer
------------------------------------

.. automodule:: pypots.imputation.itransformer
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.frets
------------------------------

Expand Down
10 changes: 9 additions & 1 deletion docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ @inproceedings{zhou2022film
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {12677--12690},
publisher = {Curran Associates, Inc.},
title = {FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting},
title = {{FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting}},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/524ef58c2bd075775861234266e5e020-Paper-Conference.pdf},
volume = {35},
year = {2022}
Expand All @@ -566,4 +566,12 @@ @inproceedings{yi2023frets
url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/f1d16af76939f476b5f040fd1398c0a3-Paper-Conference.pdf},
volume = {36},
year = {2023}
}

@inproceedings{liu2024itransformer,
title={{iTransformer: Inverted Transformers Are Effective for Time Series Forecasting}},
author={Yong Liu and Tengge Hu and Haoran Zhang and Haixu Wu and Shiyu Wang and Lintao Ma and Mingsheng Long},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=JePfAI8fah}
}
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .mrnn import MRNN
from .saits import SAITS
from .transformer import Transformer
from .itransformer import iTransformer
from .timesnet import TimesNet
from .etsformer import ETSformer
from .fedformer import FEDformer
Expand All @@ -33,6 +34,7 @@
# neural network imputation methods
"SAITS",
"Transformer",
"iTransformer",
"ETSformer",
"FEDformer",
"FiLM",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/itransformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model iTransformer.
Refer to the papers
`Liu, Yong, Tengge Hu, Haoran Zhang, Haixu Wu, Shiyu Wang, Lintao Ma, and Mingsheng Long.
"iTransformer: Inverted transformers are effective for time series forecasting."
ICLR 2024.
<https://openreview.net/pdf?id=JePfAI8fah>`_
Notes
-----
Partial implementation uses code from https://github.com/thuml/iTransformer
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import iTransformer

__all__ = [
"iTransformer",
]
87 changes: 87 additions & 0 deletions pypots/imputation/itransformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.saits import SaitsLoss
from ...nn.modules.transformer import TransformerEncoder


class _iTransformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
d_ffn: int,
n_heads: int,
d_k: int,
d_v: int,
dropout: float,
attn_dropout: float,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()
self.n_layers = n_layers
self.n_features = n_features
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.embedding = nn.Linear(n_steps, d_model)
self.dropout = nn.Dropout(dropout)
self.encoder = TransformerEncoder(
n_layers,
d_model,
d_ffn,
n_heads,
d_k,
d_v,
dropout,
attn_dropout,
)
self.output_projection = nn.Linear(d_model, n_steps)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# apply the SAITS embedding strategy, concatenate X and missing mask for input
input_X = torch.cat([X.permute(0, 2, 1), missing_mask.permute(0, 2, 1)], dim=1)

# Transformer encoder processing
input_X = self.embedding(input_X)
input_X = self.dropout(input_X)
enc_output, _ = self.encoder(input_X)
# project the representation from the d_model-dimensional space to the original data space for output
reconstruction = self.output_projection(enc_output)
reconstruction = reconstruction.permute(0, 2, 1)[:, :, : self.n_features]

# replace the observed part with values from X
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction

# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
22 changes: 22 additions & 0 deletions pypots/imputation/itransformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Dataset class for self-attention models trained with MIT (masked imputation task) task.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForiTransformer(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading
Loading