pypots/nn/modules/autoformer/layers.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
from typing import Tuple, Optional
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
from ..transformer.attention import AttentionOperator, MultiHeadAttention
class AutoCorrelation(AttentionOperator):
"""
AutoCorrelation Mechanism with the following two phases:
(1) period-based dependencies discovery
(2) time delay aggregation
This block can replace the self-attention family mechanism seamlessly.
"""
def __init__(
self,
factor=1,
attention_dropout=0.1,
):
super().__init__()
self.factor = factor
self.dropout = nn.Dropout(attention_dropout)
def time_delay_agg_training(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
)
return delays_agg
def time_delay_agg_inference(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the inference phase.
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.to(values.device)
)
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
weights, delay = torch.topk(mean_value, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
)
return delays_agg
def time_delay_agg_full(self, values, corr):
"""
Standard version of Autocorrelation
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.to(values.device)
)
# find top k
top_k = int(self.factor * math.log(length))
weights, delay = torch.topk(corr, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
return delays_agg
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
# d_tensor could be d_q, d_k, d_v
B, L, H, E = q.shape
_, S, _, D = v.shape
if L > S:
zeros = torch.zeros_like(q[:, : (L - S), :]).float()
v = torch.cat([v, zeros], dim=1)
k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :L, :, :]
k = k[:, :L, :, :]
# period-based dependencies
q_fft = torch.fft.rfft(q.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(k.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1)
# time delay agg
if self.training:
V = self.time_delay_agg_training(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
else:
V = self.time_delay_agg_inference(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
attn = corr.permute(0, 3, 1, 2)
output = V.contiguous()
return output, attn
class SeasonalLayerNorm(nn.Module):
"""A special designed layer normalization for the seasonal part."""
def __init__(self, n_channels):
super().__init__()
self.layer_norm = nn.LayerNorm(n_channels)
def forward(self, x):
x_hat = self.layer_norm(x)
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
return x_hat - bias
class MovingAvgBlock(nn.Module):
"""
The moving average block to highlight the trend of time series.
"""
def __init__(self, kernel_size, stride):
super().__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class SeriesDecompositionBlock(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super().__init__()
self.moving_avg = MovingAvgBlock(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class AutoformerEncoderLayer(nn.Module):
"""Autoformer encoder layer with the progressive decomposition architecture."""
def __init__(
self,
attn_opt: AttentionOperator,
d_model: int,
n_heads: int,
d_ffn: int,
moving_avg: int = 25,
dropout: float = 0.1,
activation="relu",
):
super().__init__()
d_ffn = d_ffn or 4 * d_model
self.attention = MultiHeadAttention(
attn_opt,
d_model,
n_heads,
d_model // n_heads,
d_model // n_heads,
)
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ffn, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ffn, out_channels=d_model, kernel_size=1, bias=False)
self.series_decomp1 = SeriesDecompositionBlock(moving_avg)
self.series_decomp2 = SeriesDecompositionBlock(moving_avg)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
x = x + self.dropout(new_x)
x, _ = self.series_decomp1(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
res, _ = self.series_decomp2(x + y)
return res, attn
class AutoformerDecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
"""
def __init__(
self,
self_attn_opt,
cross_attn_opt,
d_model,
n_heads,
d_out,
d_ff=None,
moving_avg=25,
dropout=0.1,
activation="relu",
):
super().__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = MultiHeadAttention(
self_attn_opt,
d_model,
n_heads,
d_model // n_heads,
d_model // n_heads,
)
self.cross_attention = MultiHeadAttention(
cross_attn_opt,
d_model,
n_heads,
d_model // n_heads,
d_model // n_heads,
)
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
self.series_decomp1 = SeriesDecompositionBlock(moving_avg)
self.series_decomp2 = SeriesDecompositionBlock(moving_avg)
self.series_decomp3 = SeriesDecompositionBlock(moving_avg)
self.dropout = nn.Dropout(dropout)
self.projection = nn.Conv1d(
in_channels=d_model,
out_channels=d_out,
kernel_size=3,
stride=1,
padding=1,
padding_mode="circular",
bias=False,
)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
x, trend1 = self.series_decomp1(x)
x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0])
x, trend2 = self.series_decomp2(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.series_decomp3(x + y)
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
return x, residual_trend