pypots/nn/modules/brits/layers.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class FeatureRegression(nn.Module):
"""The module used to capture the correlation between features for imputation in BRITS.
Attributes
----------
W : tensor
The weights (parameters) of the module.
b : tensor
The bias of the module.
m (buffer) : tensor
The mask matrix, a squire matrix with diagonal entries all zeroes while left parts all ones.
It is applied to the weight matrix to mask out the estimation contributions from features themselves.
It is used to help enhance the imputation performance of the network.
Parameters
----------
input_size : the feature dimension of the input
"""
def __init__(self, input_size: int):
super().__init__()
self.W = Parameter(torch.Tensor(input_size, input_size))
self.b = Parameter(torch.Tensor(input_size))
m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
self.register_buffer("m", m)
self._reset_parameters()
def _reset_parameters(self) -> None:
std_dev = 1.0 / math.sqrt(self.W.size(0))
self.W.data.uniform_(-std_dev, std_dev)
if self.b is not None:
self.b.data.uniform_(-std_dev, std_dev)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward processing of the NN module.
Parameters
----------
x : tensor,
the input for processing
Returns
-------
output: tensor,
the processed result containing imputation from feature regression
"""
output = F.linear(x, self.W * Variable(self.m), self.b)
return output