From 177e4ab18a4e04c30a02e4dab4f7e55f432a83ca Mon Sep 17 00:00:00 2001 From: tongnie <1854361@tongji.edu.cn> Date: Fri, 28 Jun 2024 21:00:07 +0800 Subject: [PATCH] Updated Imputeformer --- pypots/imputation/imputeformer/__init__.py | 20 ++ pypots/imputation/imputeformer/core.py | 126 ++++++++ pypots/imputation/imputeformer/data.py | 22 ++ pypots/imputation/imputeformer/model.py | 326 ++++++++++++++++++++ pypots/nn/modules/imputeformer/__init__.py | 27 ++ pypots/nn/modules/imputeformer/attention.py | 202 ++++++++++++ pypots/nn/modules/imputeformer/mlp.py | 59 ++++ 7 files changed, 782 insertions(+) create mode 100644 pypots/imputation/imputeformer/__init__.py create mode 100644 pypots/imputation/imputeformer/core.py create mode 100644 pypots/imputation/imputeformer/data.py create mode 100644 pypots/imputation/imputeformer/model.py create mode 100644 pypots/nn/modules/imputeformer/__init__.py create mode 100644 pypots/nn/modules/imputeformer/attention.py create mode 100644 pypots/nn/modules/imputeformer/mlp.py diff --git a/pypots/imputation/imputeformer/__init__.py b/pypots/imputation/imputeformer/__init__.py new file mode 100644 index 00000000..91eb89fa --- /dev/null +++ b/pypots/imputation/imputeformer/__init__.py @@ -0,0 +1,20 @@ +""" +The package of the partially-observed time-series imputation model Imputeformer. + +Refer to the papers +`Tong Nie, Guoyang Qin, Wei Ma, Yuewen Mei, Jian Sun. +"ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation" +KDD 2024. +`_ + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import Imputeformer + +__all__ = [ + "Imputeformer", +] diff --git a/pypots/imputation/imputeformer/core.py b/pypots/imputation/imputeformer/core.py new file mode 100644 index 00000000..764cd68a --- /dev/null +++ b/pypots/imputation/imputeformer/core.py @@ -0,0 +1,126 @@ +""" +The core wrapper assembles the submodules of Imputeformer imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...nn.modules.saits import SaitsLoss +from ...nn.modules.imputeformer import EmbeddedAttentionLayer, ProjectedAttentionLayer, MLP +from einops import repeat + + +class _Imputeformer(nn.Module): + """ + Spatiotempoarl Imputation Transformer induced by low-rank factorization, KDD'24. + Note: + This is a simplified implementation under the SAITS framework (ORT+MIT). + The timestamp encoding is also removed for ease of implementation. + """ + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_input_embed: int, + d_learnable_embed: int, + d_proj: int, + d_ffn: int, + num_temporal_heads: int, + dropout: float = 0., + input_dim: int = 1, + output_dim: int = 1, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + self.num_nodes = n_features + self.in_steps = n_steps + self.out_steps = n_steps + self.input_dim = input_dim + self.output_dim = output_dim + self.input_embedding_dim = d_input_embed + self.learnable_embedding_dim = d_learnable_embed + model_dim = d_input_embed + d_learnable_embed + self.model_dim = model_dim + + self.num_temporal_heads = num_temporal_heads + self.num_layers = n_layers + self.input_proj = nn.Linear(input_dim, self.input_embedding_dim) + self.dim_proj = d_proj + + self.learnable_embedding = nn.init.xavier_uniform_( + nn.Parameter(torch.empty(self.in_steps, self.num_nodes, self.learnable_embedding_dim))) + + self.readout = MLP(self.model_dim, self.model_dim, output_dim, n_layers=2) + + self.attn_layers_t = nn.ModuleList( + [ProjectedAttentionLayer(self.num_nodes, self.dim_proj, self.model_dim, num_temporal_heads, + self.model_dim, dropout) + for _ in range(self.num_layers)]) + + self.attn_layers_s = nn.ModuleList( + [EmbeddedAttentionLayer(self.model_dim, self.learnable_embedding_dim, d_ffn) + for _ in range(self.num_layers)]) + + # apply SAITS loss function to Transformer on the imputation task + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + + def forward(self, inputs: dict, training: bool = True) -> dict: + x, missing_mask = inputs["X"], inputs["missing_mask"] + + # x: (batch_size, in_steps, num_nodes) + # Note that Imputeformer is designed for Spatial-Temporal data that has the format [B, S, N, C], + # where N is the number of nodes and C is an additional feature dimension, + # We simply add an extra axis here for implementation. + x = x.unsqueeze(-1) # [b s n c] + missing_mask = missing_mask.unsqueeze(-1) # [b s n c] + batch_size = x.shape[0] + # Whiten missing values + x = x * missing_mask + x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) + + # Learnable node embedding + node_emb = self.learnable_embedding.expand(batch_size, *self.learnable_embedding.shape) + x = torch.cat([x, node_emb], dim=-1) # (batch_size, in_steps, num_nodes, model_dim) + + # Spatial and temporal processing with customized attention layers + x = x.permute(0, 2, 1, 3) # [b n s c] + for att_t, att_s in zip(self.attn_layers_t, self.attn_layers_s): + x = att_t(x) + x = att_s(x, self.learnable_embedding, dim=1) + + # Readout + x = x.permute(0, 2, 1, 3) # [b s n c] + reconstruction = self.readout(x) + reconstruction = reconstruction.squeeze(-1) # [b s n] + missing_mask = missing_mask.squeeze(-1) # [b s n] + + # Below is the SAITS processing pipeline: + # replace the observed part with values from X + imputed_data = missing_mask * inputs["X"] + (1 - missing_mask) * reconstruction + + # ensemble the results as a dictionary for return + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func( + reconstruction, X_ori, missing_mask, indicating_mask + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results + diff --git a/pypots/imputation/imputeformer/data.py b/pypots/imputation/imputeformer/data.py new file mode 100644 index 00000000..aa406c49 --- /dev/null +++ b/pypots/imputation/imputeformer/data.py @@ -0,0 +1,22 @@ +""" +Dataset class for the imputation model Imputeformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForImputeformer(DatasetForSAITS): + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/imputeformer/model.py b/pypots/imputation/imputeformer/model.py new file mode 100644 index 00000000..88ccf0ac --- /dev/null +++ b/pypots/imputation/imputeformer/model.py @@ -0,0 +1,326 @@ +""" +The package of the partially-observed time-series imputation model iTransformer. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _Imputeformer +from .data import DatasetForImputeformer +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class Imputeformer(BaseNNImputer): + """The PyTorch implementation of the Imputeformer model. + Imputeformer is originally proposed by Nie et al. in KDD'24: cite:`nie2024imputeformer`. + + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + + d_input_embed : + The dimension of the input embedding. + It is the input dimension of the input embedding layer. + + d_learnable_embed : + The dimension of the learnable node embedding. + It is the dimension of the learnable node embedding (spatial positional embedding) + used in spatial attention layers. + + d_proj : + The dimension of the learnable projector. + It is the dimension of the learnable projector + used in temporal attention layers. + + d_ffn : + The dimension of the layer in the Feed-Forward Networks (FFN). + + dropout : + The dropout rate for all fully-connected layers in the model. + + num_temporal_heads : + The number of attention heads in temporal attention layers. + + input_dim : + The dimension of the input feature dimension, default is 1. + + output_dim : + The dimension of the output feature dimension, default is 1. + + ORT_weight : + The weight for the ORT loss. + + MIT_weight : + The weight for the MIT loss. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_input_embed: int, + d_learnable_embed: int, + d_proj: int, + d_ffn: int, + num_temporal_heads: int, + dropout: float = 0., + input_dim: int = 1, + output_dim: int = 1, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_layers = n_layers + self.d_input_embed = d_input_embed + self.d_learnable_embed = d_learnable_embed + self.d_proj = d_proj + self.d_ffn = d_ffn + self.num_temporal_heads = num_temporal_heads + self.dropout = dropout + self.input_dim = input_dim + self.output_dim = output_dim + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _Imputeformer( + self.n_steps, + self.n_features, + self.n_layers, + self.d_input_embed, + self.d_learnable_embed, + self.d_proj, + self.d_ffn, + self.num_temporal_heads, + self.dropout, + self.input_dim, + self.output_dim, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForImputeformer( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForImputeformer( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/nn/modules/imputeformer/__init__.py b/pypots/nn/modules/imputeformer/__init__.py new file mode 100644 index 00000000..33db69c4 --- /dev/null +++ b/pypots/nn/modules/imputeformer/__init__.py @@ -0,0 +1,27 @@ +""" +The package including the modules of Imputeformer. + +Refer to the paper +`Tong Nie, Guoyang Qin, Wei Ma, Yuewen Mei, Jian Sun. +ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation. +KDD, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/WenjieDu/SAITS + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .attention import EmbeddedAttentionLayer, ProjectedAttentionLayer +from .mlp import MLP + +__all__ = [ + "EmbeddedAttentionLayer", + "ProjectedAttentionLayer", + "MLP" +] diff --git a/pypots/nn/modules/imputeformer/attention.py b/pypots/nn/modules/imputeformer/attention.py new file mode 100644 index 00000000..69877d9f --- /dev/null +++ b/pypots/nn/modules/imputeformer/attention.py @@ -0,0 +1,202 @@ +""" +The implementation of the customized spatia-temporal modules for Imputeformer :cite:`nie2024imputeformer` +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn +from einops import repeat + + +class AttentionLayer(nn.Module): + """Perform attention across the -2 dim (the -1 dim is `model_dim`). + + Make sure the tensor is permuted to correct shape before attention. + + E.g. + - Input shape (batch_size, in_steps, num_nodes, model_dim). + - Then the attention will be performed across the nodes. + + Also, it supports different src and tgt length. + + But must `src length == K length == V length`. + + """ + + def __init__(self, model_dim, num_heads=8, mask=False): + super().__init__() + + self.model_dim = model_dim + self.num_heads = num_heads + self.mask = mask + self.head_dim = model_dim // num_heads + + self.FC_Q = nn.Linear(model_dim, model_dim) + self.FC_K = nn.Linear(model_dim, model_dim) + self.FC_V = nn.Linear(model_dim, model_dim) + + self.out_proj = nn.Linear(model_dim, model_dim) + + def forward(self, query, key, value): + # Q (batch_size, ..., tgt_length, model_dim) + # K, V (batch_size, ..., src_length, model_dim) + batch_size = query.shape[0] + tgt_length = query.shape[-2] + src_length = key.shape[-2] + + query = self.FC_Q(query) + key = self.FC_K(key) + value = self.FC_V(value) + + # Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim) + query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0) + key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0) + value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0) + + key = key.transpose( + -1, -2 + ) # (num_heads * batch_size, ..., head_dim, src_length) + + attn_score = ( + query @ key + ) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length) + + if self.mask: + mask = torch.ones( + tgt_length, src_length, dtype=torch.bool, device=query.device + ).tril() # lower triangular part of the matrix + attn_score.masked_fill_(~mask, -torch.inf) # fill in-place + + attn_score = torch.softmax(attn_score, dim=-1) + out = attn_score @ value # (num_heads * batch_size, ..., tgt_length, head_dim) + out = torch.cat( + torch.split(out, batch_size, dim=0), dim=-1 + ) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim) + + out = self.out_proj(out) + + return out + + +class ProjectedAttentionLayer(nn.Module): + """ + Temporal projected attention layer. + A low-rank factorization is achieved in the temporal attention matrix. + """ + def __init__(self, + seq_len, + dim_proj, + d_model, + n_heads, + d_ff=None, + dropout=0.1): + super(ProjectedAttentionLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.out_attn = AttentionLayer(d_model, n_heads, mask=None) + self.in_attn = AttentionLayer(d_model, n_heads, mask=None) + self.projector = nn.Parameter(torch.randn(seq_len, dim_proj, d_model)) + # self.projector = nn.Parameter(torch.randn(dim_proj, d_model)) + + self.dropout = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.MLP = nn.Sequential(nn.Linear(d_model, d_ff), + nn.GELU(), + nn.Linear(d_ff, d_model)) + self.seq_len = seq_len + + def forward(self, x): + # x: [b s n d] + batch = x.shape[0] + projector = repeat(self.projector, 'seq_len dim_proj d_model -> repeat seq_len dim_proj d_model', repeat=batch) # [b, s, c, d] + # projector = repeat(self.projector, 'dim_proj d_model -> repeat seq_len dim_proj d_model', + # repeat=batch, seq_len=self.seq_len) # [b, s, c, d] + + message_out = self.out_attn(projector, x, x) # [b, s, c, d] <-> [b s n d] -> [b s c d] + message_in = self.in_attn(x, projector, message_out) # [b s n d] <-> [b, s, c, d] -> [b s n d] + message = x + self.dropout(message_in) + message = self.norm1(message) + message = message + self.dropout(self.MLP(message)) + message = self.norm2(message) + + return message + + +class EmbeddedAttention(nn.Module): + """ + Spatial embedded attention layer. + The node embedding serves as the query and key matrices for attentive aggregation on graphs. + """ + def __init__(self, model_dim, node_embedding_dim): + super().__init__() + + self.model_dim = model_dim + self.FC_Q_K = nn.Linear(node_embedding_dim, model_dim) + self.FC_V = nn.Linear(model_dim, model_dim) + self.out_proj = nn.Linear(model_dim, model_dim) + + def forward(self, value, emb): + # V (batch_size, ..., seq_length, model_dim) + # emb (..., length, model_dim) + batch_size = value.shape[0] + query = self.FC_Q_K(emb) + key = self.FC_Q_K(emb) + value = self.FC_V(value) + + # Q, K (..., length, model_dim) + # V (batch_size, ..., length, model_dim) + key = key.transpose(-1, -2) # (..., model_dim, src_length) + # attn_score = query @ key # (..., tgt_length, src_length) + # attn_score = torch.softmax(attn_score, dim=-1) + # attn_score = repeat(attn_score, 'n s1 s2 -> b n s1 s2', b=batch_size) + + # re-normalization + query = torch.softmax(query, dim=-1) + key = torch.softmax(key, dim=-1) + query = repeat(query, 'n s1 s2 -> b n s1 s2', b=batch_size) + key = repeat(key, 'n s2 s1 -> b n s2 s1', b=batch_size) + + # out = attn_score @ value # (batch_size, ..., tgt_length, model_dim) + out = key @ value # (batch_size, ..., tgt_length, model_dim) + out = query @ out # (batch_size, ..., tgt_length, model_dim) + + return out + + +class EmbeddedAttentionLayer(nn.Module): + def __init__(self, + model_dim, + node_embedding_dim, + feed_forward_dim=2048, + dropout=0): + super(EmbeddedAttentionLayer, self).__init__() + + self.attn = EmbeddedAttention(model_dim, node_embedding_dim) + self.feed_forward = nn.Sequential( + nn.Linear(model_dim, feed_forward_dim), + nn.ReLU(inplace=True), + nn.Linear(feed_forward_dim, model_dim), + ) + self.ln1 = nn.LayerNorm(model_dim) + self.ln2 = nn.LayerNorm(model_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, x, emb, dim=-2): + x = x.transpose(dim, -2) + # x: (batch_size, ..., length, model_dim) + # emb: (..., length, model_dim) + residual = x + out = self.attn(x, emb) # (batch_size, ..., length, model_dim) + out = self.dropout1(out) + out = self.ln1(residual + out) + + residual = out + out = self.feed_forward(out) # (batch_size, ..., length, model_dim) + out = self.dropout2(out) + out = self.ln2(residual + out) + + out = out.transpose(dim, -2) + return out diff --git a/pypots/nn/modules/imputeformer/mlp.py b/pypots/nn/modules/imputeformer/mlp.py new file mode 100644 index 00000000..6eb6159b --- /dev/null +++ b/pypots/nn/modules/imputeformer/mlp.py @@ -0,0 +1,59 @@ +""" +The implementation of the MLPs for Imputeformer :cite:`nie2024imputeformer` +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn +from einops import repeat + + +class Dense(nn.Module): + r""" + A simple fully-connected layer. + """ + def __init__(self, input_size, output_size, dropout=0., bias=True): + super(Dense, self).__init__() + self.layer = nn.Sequential( + nn.Linear(input_size, output_size, bias=bias), + nn.ReLU(), + nn.Dropout(dropout) if dropout > 0. else nn.Identity() + ) + + def forward(self, x): + return self.layer(x) + + +class MLP(nn.Module): + r""" + Simple Multi-layer Perceptron encoder with optional linear readout. + """ + def __init__(self, + input_size, + hidden_size, + output_size=None, + n_layers=1, + dropout=0.): + super(MLP, self).__init__() + + + layers = [ + Dense(input_size=input_size if i == 0 else hidden_size, + output_size=hidden_size, + dropout=dropout) for i in range(n_layers) + ] + self.mlp = nn.Sequential(*layers) + + if output_size is not None: + self.readout = nn.Linear(hidden_size, output_size) + else: + self.register_parameter('readout', None) + + def forward(self, x, u=None): + """""" + out = self.mlp(x) + if self.readout is not None: + return self.readout(out) + return out