WenjieDu/PyPOTS

View on GitHub
pypots/nn/modules/saits/loss.py

Summary

Maintainability
A
0 mins
Test Coverage
"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from typing import Callable

import torch.nn as nn

from ....utils.metrics import calc_mae


class SaitsLoss(nn.Module):
    def __init__(
        self,
        ORT_weight,
        MIT_weight,
        loss_calc_func: Callable = calc_mae,
    ):
        super().__init__()
        self.ORT_weight = ORT_weight
        self.MIT_weight = MIT_weight
        self.loss_calc_func = loss_calc_func

    def forward(self, reconstruction, X_ori, missing_mask, indicating_mask):
        # calculate loss for the observed reconstruction task (ORT)
        ORT_loss = self.ORT_weight * self.loss_calc_func(
            reconstruction, X_ori, missing_mask
        )
        # calculate loss for the masked imputation task (MIT)
        MIT_loss = self.MIT_weight * self.loss_calc_func(
            reconstruction, X_ori, indicating_mask
        )
        # calculate the loss to back propagate for model updating
        loss = ORT_loss + MIT_loss
        return loss, ORT_loss, MIT_loss