pypots/imputation/itransformer/core.py
"""
The core wrapper assembles the submodules of iTransformer imputation model
and takes over the forward progress of the algorithm.
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch
import torch.nn as nn
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
from ...nn.modules.transformer import TransformerEncoder
class _iTransformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
n_heads: int,
d_k: int,
d_v: int,
d_ffn: int,
dropout: float,
attn_dropout: float,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()
self.n_layers = n_layers
self.n_features = n_features
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight
self.saits_embedding = SaitsEmbedding(n_steps, d_model, with_pos=False, dropout=dropout)
self.encoder = TransformerEncoder(
n_layers,
d_model,
n_heads,
d_k,
d_v,
d_ffn,
dropout,
attn_dropout,
)
self.output_projection = nn.Linear(d_model, n_steps)
# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)
def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
# WDU: the original Informer paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
input_X = torch.cat([X.permute(0, 2, 1), missing_mask.permute(0, 2, 1)], dim=1)
input_X = self.saits_embedding(input_X)
# Transformer encoder processing
enc_output, _ = self.encoder(input_X)
# project the representation from the d_model-dimensional space to the original data space for output
reconstruction = self.output_projection(enc_output)
reconstruction = reconstruction.permute(0, 2, 1)[:, :, : self.n_features]
# replace the observed part with values from X
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}
# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss
return results