pypots/nn/modules/timemixer/backbone.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import PastDecomposableMixing
from ..autoformer import SeriesDecompositionBlock
from ..revin import RevIN
from ..transformer.embedding import DataEmbedding
class BackboneTimeMixer(nn.Module):
def __init__(
self,
task_name,
n_steps,
n_features,
n_pred_steps,
n_pred_features,
n_layers,
d_model,
d_ffn,
dropout,
channel_independence,
decomp_method,
top_k,
moving_avg,
downsampling_layers,
downsampling_window,
downsampling_method,
use_future_temporal_feature,
embed="fixed",
freq="h",
n_classes=None,
):
super().__init__()
self.task_name = task_name
self.n_steps = n_steps
self.n_features = n_features
self.n_pred_steps = n_pred_steps
self.n_pred_features = n_pred_features
self.n_layers = n_layers
self.channel_independence = channel_independence
self.downsampling_window = downsampling_window
self.downsampling_layers = downsampling_layers
self.downsampling_method = downsampling_method
self.use_future_temporal_feature = use_future_temporal_feature
self.pdm_blocks = nn.ModuleList(
[
PastDecomposableMixing(
n_steps,
n_pred_steps,
d_model,
d_ffn,
dropout,
channel_independence,
decomp_method,
top_k,
moving_avg,
downsampling_layers,
downsampling_window,
)
for _ in range(n_layers)
]
)
self.preprocess = SeriesDecompositionBlock(moving_avg)
if self.channel_independence == 1:
self.enc_embedding = DataEmbedding(1, d_model, embed, freq, dropout, with_pos=False)
else:
self.enc_embedding = DataEmbedding(n_features, d_model, embed, freq, dropout, with_pos=False)
self.normalize_layers = torch.nn.ModuleList([RevIN(n_features) for _ in range(downsampling_layers + 1)])
if task_name == "long_term_forecast" or task_name == "short_term_forecast":
self.predict_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_pred_steps,
)
for i in range(downsampling_layers + 1)
]
)
if self.channel_independence == 1:
self.projection_layer = nn.Linear(d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)
self.out_res_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_steps // (downsampling_window**i),
)
for i in range(downsampling_layers + 1)
]
)
self.regression_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
n_steps // (downsampling_window**i),
n_pred_steps,
)
for i in range(downsampling_layers + 1)
]
)
if task_name == "imputation" or task_name == "anomaly_detection":
if self.channel_independence == 1:
self.projection_layer = nn.Linear(d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)
if task_name == "classification":
self.act = F.gelu
self.dropout = nn.Dropout(dropout)
self.projection = nn.Linear(d_model * n_steps, n_classes)
def out_projection(self, dec_out, i, out_res):
dec_out = self.projection_layer(dec_out)
out_res = out_res.permute(0, 2, 1)
out_res = self.out_res_layers[i](out_res)
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
dec_out = dec_out + out_res
return dec_out
def pre_enc(self, x_list):
if self.channel_independence == 1:
return x_list, None
else:
out1_list = []
out2_list = []
for x in x_list:
x_1, x_2 = self.preprocess(x)
out1_list.append(x_1)
out2_list.append(x_2)
return out1_list, out2_list
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
if self.downsampling_method == "max":
down_pool = torch.nn.MaxPool1d(self.downsampling_window, return_indices=False)
elif self.downsampling_method == "avg":
down_pool = torch.nn.AvgPool1d(self.downsampling_window)
elif self.downsampling_method == "conv":
padding = 1 if torch.__version__ >= "1.5.0" else 2
down_pool = nn.Conv1d(
in_channels=self.enc_in,
out_channels=self.enc_in,
kernel_size=3,
padding=padding,
stride=self.downsampling_window,
padding_mode="circular",
bias=False,
)
else:
return x_enc, x_mark_enc
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc
x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)
for i in range(self.downsampling_layers):
x_enc_sampling = down_pool(x_enc_ori)
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling
if x_mark_enc_mark_ori is not None:
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, :: self.downsampling_window, :])
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, :: self.downsampling_window, :]
x_enc = x_enc_sampling_list
if x_mark_enc_mark_ori is not None:
x_mark_enc = x_mark_sampling_list
else:
x_mark_enc = x_mark_enc
return x_enc, x_mark_enc
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
if self.use_future_temporal_feature:
if self.channel_independence == 1:
B, T, N = x_enc.size()
x_mark_dec = x_mark_dec.repeat(N, 1, 1)
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
else:
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_mark = x_mark.repeat(N, 1, 1)
x_list.append(x)
x_mark_list.append(x_mark)
else:
for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
x_list = self.pre_enc(x_list)
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
enc_out_list.append(enc_out)
else:
for i, x in zip(range(len(x_list[0])), x_list[0]):
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# Past Decomposable Mixing as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
# Future Multipredictor Mixing as decoder for future
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
dec_out = self.normalize_layers[0](dec_out, "denorm")
return dec_out
def future_multi_mixing(self, B, enc_out_list, x_list):
dec_out_list = []
if self.channel_independence == 1:
x_list = x_list[0]
for i, enc_out in zip(range(len(x_list)), enc_out_list):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
if self.use_future_temporal_feature:
dec_out = dec_out + self.x_mark_dec
dec_out = self.projection_layer(dec_out)
else:
dec_out = self.projection_layer(dec_out)
dec_out = dec_out.reshape(B, self.c_out, self.n_pred_steps).permute(0, 2, 1).contiguous()
dec_out_list.append(dec_out)
else:
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
dec_out = self.out_projection(dec_out, i, out_res)
dec_out_list.append(dec_out)
return dec_out_list
def classification(self, x_enc, x_mark_enc):
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = x_enc
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
enc_out = enc_out_list[0]
# Output
# the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
# zero-out padding embeddings
output = output * x_mark_enc.unsqueeze(-1)
# (batch_size, n_stepsgth * d_model)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output
def anomaly_detection(self, x_enc):
B, T, N = x_enc.size()
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = []
for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
x = self.normalize_layers[i](x, "norm")
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.c_out, -1).permute(0, 2, 1).contiguous()
dec_out = self.normalize_layers[0](dec_out, "denorm")
return dec_out
def imputation(self, x_enc, x_mark_enc):
B, T, N = x_enc.size()
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
x_mark = x_mark.repeat(N, 1, 1)
x_mark_list.append(x_mark)
else:
for i, x in zip(
range(len(x_enc)),
x_enc,
):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.n_layers):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.n_pred_features, -1).permute(0, 2, 1).contiguous()
return dec_out