Skip to content

Commit

Permalink
Merge pull request #371 from WenjieDu/(feat)add_itransformer
Browse files Browse the repository at this point in the history
Implement iTransformer as an imputation model
  • Loading branch information
WenjieDu authored Apr 29, 2024
2 parents ae0ba76 + 115b10b commit a939f2a
Show file tree
Hide file tree
Showing 8 changed files with 617 additions and 1 deletion.
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ pypots.imputation.transformer
:show-inheritance:
:inherited-members:

pypots.imputation.itransformer
------------------------------------

.. automodule:: pypots.imputation.itransformer
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.frets
------------------------------

Expand Down
10 changes: 9 additions & 1 deletion docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ @inproceedings{zhou2022film
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {12677--12690},
publisher = {Curran Associates, Inc.},
title = {FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting},
title = {{FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting}},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/524ef58c2bd075775861234266e5e020-Paper-Conference.pdf},
volume = {35},
year = {2022}
Expand All @@ -566,4 +566,12 @@ @inproceedings{yi2023frets
url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/f1d16af76939f476b5f040fd1398c0a3-Paper-Conference.pdf},
volume = {36},
year = {2023}
}

@inproceedings{liu2024itransformer,
title={{iTransformer: Inverted Transformers Are Effective for Time Series Forecasting}},
author={Yong Liu and Tengge Hu and Haoran Zhang and Haixu Wu and Shiyu Wang and Lintao Ma and Mingsheng Long},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=JePfAI8fah}
}
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .mrnn import MRNN
from .saits import SAITS
from .transformer import Transformer
from .itransformer import iTransformer
from .timesnet import TimesNet
from .etsformer import ETSformer
from .fedformer import FEDformer
Expand All @@ -33,6 +34,7 @@
# neural network imputation methods
"SAITS",
"Transformer",
"iTransformer",
"ETSformer",
"FEDformer",
"FiLM",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/itransformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model iTransformer.
Refer to the papers
`Liu, Yong, Tengge Hu, Haoran Zhang, Haixu Wu, Shiyu Wang, Lintao Ma, and Mingsheng Long.
"iTransformer: Inverted transformers are effective for time series forecasting."
ICLR 2024.
<https://openreview.net/pdf?id=JePfAI8fah>`_
Notes
-----
Partial implementation uses code from https://github.com/thuml/iTransformer
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import iTransformer

__all__ = [
"iTransformer",
]
87 changes: 87 additions & 0 deletions pypots/imputation/itransformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.saits import SaitsLoss
from ...nn.modules.transformer import TransformerEncoder


class _iTransformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
d_ffn: int,
n_heads: int,
d_k: int,
d_v: int,
dropout: float,
attn_dropout: float,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()
self.n_layers = n_layers
self.n_features = n_features
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.embedding = nn.Linear(n_steps, d_model)
self.dropout = nn.Dropout(dropout)
self.encoder = TransformerEncoder(
n_layers,
d_model,
d_ffn,
n_heads,
d_k,
d_v,
dropout,
attn_dropout,
)
self.output_projection = nn.Linear(d_model, n_steps)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# apply the SAITS embedding strategy, concatenate X and missing mask for input
input_X = torch.cat([X.permute(0, 2, 1), missing_mask.permute(0, 2, 1)], dim=1)

# Transformer encoder processing
input_X = self.embedding(input_X)
input_X = self.dropout(input_X)
enc_output, _ = self.encoder(input_X)
# project the representation from the d_model-dimensional space to the original data space for output
reconstruction = self.output_projection(enc_output)
reconstruction = reconstruction.permute(0, 2, 1)[:, :, : self.n_features]

# replace the observed part with values from X
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction

# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
22 changes: 22 additions & 0 deletions pypots/imputation/itransformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Dataset class for self-attention models trained with MIT (masked imputation task) task.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForiTransformer(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit a939f2a

Please sign in to comment.