pypots/nn/modules/transformer/embedding.py
"""
Embedding methods for Transformer models are put here.
This implementation is inspired by the official one https://github.com/zhouhaoyi/Informer2020/blob/main/models/embed.py
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
import torch
import torch.fft
import torch.nn as nn
class PositionalEncoding(nn.Module):
"""The original positional-encoding module for Transformer.
Parameters
----------
d_hid:
The dimension of the hidden layer.
n_positions:
The max number of positions.
"""
def __init__(self, d_hid: int, n_positions: int = 1000):
super().__init__()
pe = torch.zeros(n_positions, d_hid, requires_grad=False).float()
position = torch.arange(0, n_positions).float().unsqueeze(1)
div_term = (torch.arange(0, d_hid, 2).float() * -(torch.log(torch.tensor(10000)) / d_hid)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pos_table", pe)
def forward(self, x: torch.Tensor, return_only_pos: bool = False) -> torch.Tensor:
"""Forward processing of the positional encoding module.
Parameters
----------
x:
Input tensor.
return_only_pos:
Whether to return only the positional encoding.
Returns
-------
If return_only_pos is True:
pos_enc:
The positional encoding.
else:
x_with_pos:
Output tensor, the input tensor with the positional encoding added.
"""
pos_enc = self.pos_table[:, : x.size(1)].clone().detach()
if return_only_pos:
return pos_enc
x_with_pos = x + pos_enc
return x_with_pos
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super().__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu")
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super().__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type="fixed", freq="h"):
super().__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
if freq == "t":
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, freq="h"):
super().__init__()
freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(
self,
c_in,
d_model,
embed_type="fixed",
freq="h",
dropout=0.1,
with_pos=True,
n_max_steps=1000,
):
super().__init__()
self.with_pos = with_pos
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
if with_pos:
self.position_embedding = PositionalEncoding(d_hid=d_model, n_positions=n_max_steps)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, freq=freq)
)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_timestamp=None):
if x_timestamp is None:
x = self.value_embedding(x)
if self.with_pos:
x += self.position_embedding(x, return_only_pos=True)
else:
x = self.value_embedding(x) + self.temporal_embedding(x_timestamp)
if self.with_pos:
x += self.position_embedding(x, return_only_pos=True)
return self.dropout(x)