pypots/nn/modules/grud/backbone.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
from typing import Tuple
import torch
import torch.nn as nn
from .layers import TemporalDecay
class BackboneGRUD(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
# create models
self.rnn_cell = nn.GRUCell(self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size)
self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False)
self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True)
def forward(self, X, missing_mask, deltas, empirical_mean, X_filledLOCF) -> Tuple[torch.Tensor, ...]:
"""Forward processing of GRU-D.
Parameters
----------
X:
missing_mask:
deltas:
empirical_mean:
X_filledLOCF:
Returns
-------
classification_pred:
logits:
"""
hidden_state = torch.zeros((X.size()[0], self.rnn_hidden_size), device=X.device)
representation_collector = []
for t in range(self.n_steps):
# for data, [batch, time, features]
x = X[:, t, :] # values
m = missing_mask[:, t, :] # mask
d = deltas[:, t, :] # delta, time gap
x_filledLOCF = X_filledLOCF[:, t, :]
gamma_h = self.temp_decay_h(d)
gamma_x = self.temp_decay_x(d)
hidden_state = hidden_state * gamma_h
representation_collector.append(hidden_state)
x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean
x_replaced = m * x + (1 - m) * x_h
data_input = torch.cat([x_replaced, hidden_state, m], dim=1)
hidden_state = self.rnn_cell(data_input, hidden_state)
representation_collector = torch.stack(representation_collector, dim=1)
return representation_collector, hidden_state