pypots/nn/modules/grud/layers.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class TemporalDecay(nn.Module):
"""The module used to generate the temporal decay factor gamma in the GRU-D model.
Please refer to the original paper :cite:`che2018GRUD` for more details.
Attributes
----------
W: tensor,
The weights (parameters) of the module.
b: tensor,
The bias of the module.
Parameters
----------
input_size : int,
the feature dimension of the input
output_size : int,
the feature dimension of the output
diag : bool,
whether to product the weight with an identity matrix before forward processing
References
----------
.. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
"Recurrent neural networks for multivariate time series with missing values."
Scientific reports 8, no. 1 (2018): 6085.
<https://www.nature.com/articles/s41598-018-24271-9.pdf>`_
"""
def __init__(self, input_size: int, output_size: int, diag: bool = False):
super().__init__()
self.diag = diag
self.W = Parameter(torch.Tensor(output_size, input_size))
self.b = Parameter(torch.Tensor(output_size))
if self.diag:
assert input_size == output_size
m = torch.eye(input_size, input_size)
self.register_buffer("m", m)
self._reset_parameters()
def _reset_parameters(self) -> None:
std_dev = 1.0 / math.sqrt(self.W.size(0))
self.W.data.uniform_(-std_dev, std_dev)
if self.b is not None:
self.b.data.uniform_(-std_dev, std_dev)
def forward(self, delta: torch.Tensor) -> torch.Tensor:
"""Forward processing of this NN module.
Parameters
----------
delta : tensor, shape [n_samples, n_steps, n_features]
The time gaps.
Returns
-------
gamma : tensor, of the same shape with parameter `delta`, values in (0,1]
The temporal decay factor.
"""
if self.diag:
gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b))
else:
gamma = F.relu(F.linear(delta, self.W, self.b))
gamma = torch.exp(-gamma)
return gamma