-
-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #371 from WenjieDu/(feat)add_itransformer
Implement iTransformer as an imputation model
- Loading branch information
Showing
8 changed files
with
617 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
The package of the partially-observed time-series imputation model iTransformer. | ||
Refer to the papers | ||
`Liu, Yong, Tengge Hu, Haoran Zhang, Haixu Wu, Shiyu Wang, Lintao Ma, and Mingsheng Long. | ||
"iTransformer: Inverted transformers are effective for time series forecasting." | ||
ICLR 2024. | ||
<https://openreview.net/pdf?id=JePfAI8fah>`_ | ||
Notes | ||
----- | ||
Partial implementation uses code from https://github.com/thuml/iTransformer | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
|
||
from .model import iTransformer | ||
|
||
__all__ = [ | ||
"iTransformer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ...nn.modules.saits import SaitsLoss | ||
from ...nn.modules.transformer import TransformerEncoder | ||
|
||
|
||
class _iTransformer(nn.Module): | ||
def __init__( | ||
self, | ||
n_steps: int, | ||
n_features: int, | ||
n_layers: int, | ||
d_model: int, | ||
d_ffn: int, | ||
n_heads: int, | ||
d_k: int, | ||
d_v: int, | ||
dropout: float, | ||
attn_dropout: float, | ||
ORT_weight: float = 1, | ||
MIT_weight: float = 1, | ||
): | ||
super().__init__() | ||
self.n_layers = n_layers | ||
self.n_features = n_features | ||
self.ORT_weight = ORT_weight | ||
self.MIT_weight = MIT_weight | ||
|
||
self.embedding = nn.Linear(n_steps, d_model) | ||
self.dropout = nn.Dropout(dropout) | ||
self.encoder = TransformerEncoder( | ||
n_layers, | ||
d_model, | ||
d_ffn, | ||
n_heads, | ||
d_k, | ||
d_v, | ||
dropout, | ||
attn_dropout, | ||
) | ||
self.output_projection = nn.Linear(d_model, n_steps) | ||
|
||
# apply SAITS loss function to Transformer on the imputation task | ||
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) | ||
|
||
def forward(self, inputs: dict, training: bool = True) -> dict: | ||
X, missing_mask = inputs["X"], inputs["missing_mask"] | ||
|
||
# apply the SAITS embedding strategy, concatenate X and missing mask for input | ||
input_X = torch.cat([X.permute(0, 2, 1), missing_mask.permute(0, 2, 1)], dim=1) | ||
|
||
# Transformer encoder processing | ||
input_X = self.embedding(input_X) | ||
input_X = self.dropout(input_X) | ||
enc_output, _ = self.encoder(input_X) | ||
# project the representation from the d_model-dimensional space to the original data space for output | ||
reconstruction = self.output_projection(enc_output) | ||
reconstruction = reconstruction.permute(0, 2, 1)[:, :, : self.n_features] | ||
|
||
# replace the observed part with values from X | ||
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction | ||
|
||
# ensemble the results as a dictionary for return | ||
results = { | ||
"imputed_data": imputed_data, | ||
} | ||
|
||
# if in training mode, return results with losses | ||
if training: | ||
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] | ||
loss, ORT_loss, MIT_loss = self.saits_loss_func( | ||
reconstruction, X_ori, missing_mask, indicating_mask | ||
) | ||
results["ORT_loss"] = ORT_loss | ||
results["MIT_loss"] = MIT_loss | ||
# `loss` is always the item for backward propagating to update the model | ||
results["loss"] = loss | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Dataset class for self-attention models trained with MIT (masked imputation task) task. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
from typing import Union | ||
|
||
from ..saits.data import DatasetForSAITS | ||
|
||
|
||
class DatasetForiTransformer(DatasetForSAITS): | ||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_X_ori: bool, | ||
return_y: bool, | ||
file_type: str = "hdf5", | ||
rate: float = 0.2, | ||
): | ||
super().__init__(data, return_X_ori, return_y, file_type, rate) |
Oops, something went wrong.