WenjieDu/PyPOTS

View on GitHub
pypots/imputation/timemixer/core.py

Summary

Maintainability
A
1 hr
Test Coverage
"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import (
    nonstationary_norm,
    nonstationary_denorm,
)
from ...nn.modules.timemixer import BackboneTimeMixer
from ...utils.metrics import calc_mse


class _TimeMixer(nn.Module):
    def __init__(
        self,
        n_layers,
        n_steps,
        n_features,
        d_model,
        d_ffn,
        dropout,
        top_k,
        channel_independence,
        decomp_method,
        moving_avg,
        downsampling_layers,
        downsampling_window,
        apply_nonstationary_norm: bool = False,
    ):
        super().__init__()

        self.apply_nonstationary_norm = apply_nonstationary_norm

        self.model = BackboneTimeMixer(
            task_name="imputation",
            n_steps=n_steps,
            n_features=n_features,
            n_pred_steps=None,
            n_pred_features=n_features,
            n_layers=n_layers,
            d_model=d_model,
            d_ffn=d_ffn,
            dropout=dropout,
            channel_independence=channel_independence,
            decomp_method=decomp_method,
            top_k=top_k,
            moving_avg=moving_avg,
            downsampling_layers=downsampling_layers,
            downsampling_window=downsampling_window,
            downsampling_method="avg",
            use_future_temporal_feature=False,
        )

    def forward(self, inputs: dict, training: bool = True) -> dict:
        X, missing_mask = inputs["X"], inputs["missing_mask"]

        if self.apply_nonstationary_norm:
            # Normalization from Non-stationary Transformer
            X, means, stdev = nonstationary_norm(X, missing_mask)

        # TimesMixer processing
        dec_out = self.model.imputation(X, None)

        if self.apply_nonstationary_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = nonstationary_denorm(dec_out, means, stdev)

        imputed_data = missing_mask * X + (1 - missing_mask) * dec_out
        results = {
            "imputed_data": imputed_data,
        }

        if training:
            # `loss` is always the item for backward propagating to update the model
            loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
            results["loss"] = loss

        return results