pypots/nn/modules/stemgnn/backbone.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import StockBlockLayer
class BackboneStemGNN(nn.Module):
def __init__(
self,
units,
stack_cnt,
time_step,
multi_layer,
horizon=1,
dropout_rate=0.5,
leaky_rate=0.2,
):
super().__init__()
self.unit = units
self.stack_cnt = stack_cnt
self.unit = units
self.alpha = leaky_rate
self.time_step = time_step
self.horizon = horizon
self.weight_key = nn.Parameter(torch.zeros(size=(self.unit, 1)))
nn.init.xavier_uniform_(self.weight_key.data, gain=1.414)
self.weight_query = nn.Parameter(torch.zeros(size=(self.unit, 1)))
nn.init.xavier_uniform_(self.weight_query.data, gain=1.414)
self.GRU = nn.GRU(self.time_step, self.unit)
self.multi_layer = multi_layer
self.stock_block = nn.ModuleList()
self.stock_block.extend(
[StockBlockLayer(self.time_step, self.unit, self.multi_layer, stack_cnt=i) for i in range(self.stack_cnt)]
)
self.fc = nn.Sequential(
nn.Linear(int(self.time_step), int(self.time_step)),
nn.LeakyReLU(),
nn.Linear(int(self.time_step), self.horizon),
)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.dropout = nn.Dropout(p=dropout_rate)
@staticmethod
def get_laplacian(graph, normalize):
"""
return the laplacian of the graph.
:param graph: the graph structure without self loop, [N, N].
:param normalize: whether to used the normalized laplacian.
:return: graph laplacian.
"""
if normalize:
D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)
else:
D = torch.diag(torch.sum(graph, dim=-1))
L = D - graph
return L
@staticmethod
def cheb_polynomial(laplacian):
"""
Compute the Chebyshev Polynomial, according to the graph laplacian.
:param laplacian: the graph laplacian, [N, N].
:return: the multi order Chebyshev laplacian, [K, N, N].
"""
N = laplacian.size(0) # [N, N]
laplacian = laplacian.unsqueeze(0)
first_laplacian = torch.zeros([1, N, N], device=laplacian.device, dtype=torch.float)
second_laplacian = laplacian
third_laplacian = (2 * torch.matmul(laplacian, second_laplacian)) - first_laplacian
forth_laplacian = 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian
multi_order_laplacian = torch.cat([first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0)
return multi_order_laplacian
def latent_correlation_layer(self, x):
input, _ = self.GRU(x.permute(2, 0, 1).contiguous())
input = input.permute(1, 0, 2).contiguous()
attention = self.self_graph_attention(input)
attention = torch.mean(attention, dim=0)
degree = torch.sum(attention, dim=1)
# laplacian is sym or not
attention = 0.5 * (attention + attention.T)
degree_l = torch.diag(degree)
diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7))
laplacian = torch.matmul(diagonal_degree_hat, torch.matmul(degree_l - attention, diagonal_degree_hat))
mul_L = self.cheb_polynomial(laplacian)
return mul_L, attention
def self_graph_attention(self, input):
input = input.permute(0, 2, 1).contiguous()
bat, N, fea = input.size()
key = torch.matmul(input, self.weight_key)
query = torch.matmul(input, self.weight_query)
data = key.repeat(1, 1, N).view(bat, N * N, 1) + query.repeat(1, N, 1)
data = data.squeeze(2)
data = data.view(bat, N, -1)
data = self.leakyrelu(data)
attention = F.softmax(data, dim=2)
attention = self.dropout(attention)
return attention
@staticmethod
def graph_fft(input, eigenvectors):
return torch.matmul(eigenvectors, input)
def forward(self, x):
mul_L, attention = self.latent_correlation_layer(x)
X = x.unsqueeze(1).permute(0, 1, 3, 2).contiguous()
result = None
for stack_i in range(self.stack_cnt):
forecast, X = self.stock_block[stack_i](X, mul_L)
if stack_i == 0:
result = forecast
else:
result += forecast # residual connection
forecast_result = self.fc(result)
if forecast_result.size()[-1] == 1:
return forecast_result.unsqueeze(1).squeeze(-1), attention
else:
return forecast_result.permute(0, 2, 1).contiguous(), attention