pypots/nn/modules/mrnn/layers.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class MrnnFcnRegression(nn.Module):
"""M-RNN fully connection regression Layer"""
def __init__(self, feature_num):
super().__init__()
self.U = Parameter(torch.Tensor(feature_num, feature_num))
self.V1 = Parameter(torch.Tensor(feature_num, feature_num))
self.V2 = Parameter(torch.Tensor(feature_num, feature_num))
self.beta = Parameter(torch.Tensor(feature_num)) # bias beta
self.final_linear = nn.Linear(feature_num, feature_num)
m = torch.ones(feature_num, feature_num) - torch.eye(feature_num, feature_num)
self.register_buffer("m", m)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.U.size(0))
self.U.data.uniform_(-stdv, stdv)
self.V1.data.uniform_(-stdv, stdv)
self.V2.data.uniform_(-stdv, stdv)
self.beta.data.uniform_(-stdv, stdv)
def forward(self, x, missing_mask, target):
h_t = torch.sigmoid(
F.linear(x, self.U * self.m)
+ F.linear(target, self.V1 * self.m)
+ F.linear(missing_mask, self.V2)
+ self.beta
)
x_hat_t = torch.sigmoid(self.final_linear(h_t))
return x_hat_t