Skip to content

Commit

Permalink
Merge pull request #522 from WenjieDu/(feat)add_fits
Browse files Browse the repository at this point in the history
Add FITS
  • Loading branch information
WenjieDu authored Sep 26, 2024
2 parents e2cd6cc + 145069a commit 3c01e29
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 0 deletions.
24 changes: 24 additions & 0 deletions pypots/imputation/fits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package including the modules of FITS.
Refer to the paper
`Zhijian Xu, Ailing Zeng, and Qiang Xu.
FITS: Modeling Time Series with 10k parameters.
In The Twelfth International Conference on Learning Representations, 2024.
<https://openreview.net/pdf?id=bWcnvZ3qMb>`_
Notes
-----
This implementation is inspired by the official one https://github.com/VEWOXIC/FITS
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import FITS

__all__ = [
"FITS",
]
86 changes: 86 additions & 0 deletions pypots/imputation/fits/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
The core wrapper assembles the submodules of FITS imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.modules.fits import BackboneFITS
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding


class _FITS(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
cut_freq: int,
individual: bool,
ORT_weight: float = 1,
MIT_weight: float = 1,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.n_steps = n_steps
self.apply_nonstationary_norm = apply_nonstationary_norm

self.saits_embedding = SaitsEmbedding(
n_features * 2,
n_features,
with_pos=False,
)
self.backbone = BackboneFITS(
n_steps,
n_features,
0, # n_pred_steps is not used in the imputation task
cut_freq,
individual,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(n_features, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)

# FITS encoder processing
enc_out = self.backbone(enc_out)
if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
enc_out = nonstationary_denorm(enc_out, means, stdev)

# project back the original data space
reconstruction = self.output_projection(enc_out)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/fits/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for FITS.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForFITS(DatasetForSAITS):
"""Actually FITS uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 3c01e29

Please sign in to comment.