pypots/nn/modules/fedformer/layers.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import math
from functools import partial
from typing import List, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.special import eval_legendre
from sympy import Poly, legendre, Symbol, chebyshevt
from torch import Tensor
from torch import nn
from ..autoformer.layers import MovingAvgBlock
from ..transformer.attention import AttentionOperator
def legendreDer(k, x):
def _legendre(k, x):
return (2 * k + 1) * eval_legendre(k, x)
out = 0
for i in np.arange(k - 1, -1, -2):
out += _legendre(i, x)
return out
def phi_(phi_c, x, lb=0, ub=1):
mask = np.logical_or(x < lb, x > ub) * 1.0
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
def get_phi_psi(k, base):
x = Symbol("x")
phi_coeff = np.zeros((k, k))
phi_2x_coeff = np.zeros((k, k))
if base == "legendre":
for ki in range(k):
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, : ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, : ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
a = phi_2x_coeff[ki, : ki + 1]
b = phi_coeff[i, : i + 1]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
a = phi_2x_coeff[ki, : ki + 1]
b = psi1_coeff[j, :]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
a = psi1_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
a = psi2_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
elif base == "chebyshev":
for ki in range(k):
if ki == 0:
phi_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi)
phi_2x_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
else:
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, : ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, : ki + 1] = np.flip(
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)
)
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
x = Symbol("x")
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
psi1 = [[] for _ in range(k)]
psi2 = [[] for _ in range(k)]
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
return phi, psi1, psi2
def get_filter(base, k):
def psi(psi1, psi2, i, inp):
mask = (inp <= 0.5) * 1.0
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
if base not in ["legendre", "chebyshev"]:
raise Exception("Base not supported")
x = Symbol("x")
H0 = np.zeros((k, k))
H1 = np.zeros((k, k))
G0 = np.zeros((k, k))
G1 = np.zeros((k, k))
PHI0 = np.zeros((k, k))
PHI1 = np.zeros((k, k))
phi, psi1, psi2 = get_phi_psi(k, base)
if base == "legendre":
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0 = np.eye(k)
PHI1 = np.eye(k)
elif base == "chebyshev":
x = Symbol("x")
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
PHI0[np.abs(PHI0) < 1e-8] = 0
PHI1[np.abs(PHI1) < 1e-8] = 0
H0[np.abs(H0) < 1e-8] = 0
H1[np.abs(H1) < 1e-8] = 0
G0[np.abs(G0) < 1e-8] = 0
G1[np.abs(G1) < 1e-8] = 0
return H0, H1, G0, G1, PHI0, PHI1
class sparseKernelFT1d(nn.Module):
def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs):
super().__init__()
self.modes1 = alpha
self.scale = 1 / (c * k * c * k)
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights1.requires_grad = True
self.weights2.requires_grad = True
self.k = k
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(
torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real),
)
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, x):
B, N, c, k = x.shape # (B, N, c, k)
x = x.view(B, N, -1)
x = x.permute(0, 2, 1)
x_fft = torch.fft.rfft(x)
# Multiply relevant Fourier modes
mode = min(self.modes1, N // 2 + 1)
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
out_ft[:, :, :mode] = self.compl_mul1d(
"bix,iox->box",
x_fft[:, :, :mode],
torch.complex(self.weights1, self.weights2)[:, :, :mode],
)
x = torch.fft.irfft(out_ft, n=N)
x = x.permute(0, 2, 1).view(B, N, c, k)
return x
class MWT_CZ1d(nn.Module):
def __init__(self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs):
super().__init__()
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.A = sparseKernelFT1d(k, alpha, c)
self.B = sparseKernelFT1d(k, alpha, c)
self.C = sparseKernelFT1d(k, alpha, c)
self.T0 = nn.Linear(k, k)
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
def forward(self, x):
B, N, c, k = x.shape # (B, N, k)
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_x = x[:, 0 : nl - N, :, :]
x = torch.cat([x, extra_x], 1)
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
for i in range(ns - self.L):
d, x = self.wavelet_transform(x)
Ud += [self.A(d) + self.B(x)]
Us += [self.C(d)]
x = self.T0(x) # coarsest scale transform
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
x = x + Us[i]
x = torch.cat((x, Ud[i]), -1)
x = self.evenOdd(x)
x = x[:, :N, :, :]
return x
def wavelet_transform(self, x):
xa = torch.cat(
[
x[:, ::2, :, :],
x[:, 1::2, :, :],
],
-1,
)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
class MultiWaveletTransform(AttentionOperator):
"""
1D multiwavelet block.
"""
def __init__(
self,
ich=1,
k=8,
alpha=16,
c=128,
nCZ=1,
L=0,
base="legendre",
attention_dropout=0.1,
):
super().__init__()
# print("base", base)
self.k = k
self.c = c
self.L = L
self.nCZ = nCZ
self.Lk0 = nn.Linear(ich, c * k)
self.Lk1 = nn.Linear(c * k, ich)
self.ich = ich
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
# d_tensor could be d_q, d_k, d_v
B, L, H, E = q.shape
_, S, _, D = v.shape
if L > S:
zeros = torch.zeros_like(q[:, : (L - S), :]).float()
v = torch.cat([v, zeros], dim=1)
# k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :L, :, :]
# k = k[:, :L, :, :]
v = v.reshape(B, L, -1)
V = self.Lk0(v).view(B, L, self.c, -1)
for i in range(self.nCZ):
V = self.MWT_CZ[i](V)
if i < self.nCZ - 1:
V = F.relu(V)
V = self.Lk1(V.view(B, L, -1))
V = V.view(B, L, -1, D)
return V.contiguous(), None
class FourierCrossAttentionW(nn.Module):
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes=16,
activation="tanh",
mode_select_method="random",
):
super().__init__()
# print("corss fourier correlation used!")
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes
self.activation = activation
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(
torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real),
)
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
B, L, E, H = q.shape
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
xk = k.permute(0, 3, 2, 1)
xv = v.permute(0, 3, 2, 1)
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_k_v):
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
if self.activation == "tanh":
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception("{} actiation function is not implemented".format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = xqkv_ft
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
# size = [B, L, H, E]
return (out, None)
class MultiWaveletCross(AttentionOperator):
"""
1D Multiwavelet Cross Attention layer.
"""
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes,
c=64,
k=8,
ich=512,
L=0,
base="legendre",
mode_select_method="random",
initializer=None,
activation="tanh",
**kwargs,
):
super().__init__()
self.c = c
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.attn1 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn2 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn3 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.attn4 = FourierCrossAttentionW(
in_channels=in_channels,
out_channels=out_channels,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
modes=modes,
activation=activation,
mode_select_method=mode_select_method,
)
self.T0 = nn.Linear(k, k)
self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
self.Lk = nn.Linear(ich, c * k)
self.Lq = nn.Linear(ich, c * k)
self.Lv = nn.Linear(ich, c * k)
self.out = nn.Linear(c * k, ich)
self.modes1 = modes
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
# d_tensor could be d_q, d_k, d_v
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
v = v.view(v.shape[0], v.shape[1], -1)
q = self.Lq(q)
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
k = self.Lk(k)
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
v = self.Lv(v)
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
if N > S:
zeros = torch.zeros_like(q[:, : (N - S), :]).float()
v = torch.cat([v, zeros], dim=1)
k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :N, :, :]
k = k[:, :N, :, :]
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_q = q[:, 0 : nl - N, :, :]
extra_k = k[:, 0 : nl - N, :, :]
extra_v = v[:, 0 : nl - N, :, :]
q = torch.cat([q, extra_q], 1)
k = torch.cat([k, extra_k], 1)
v = torch.cat([v, extra_v], 1)
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
Us_q = torch.jit.annotate(List[Tensor], [])
Us_k = torch.jit.annotate(List[Tensor], [])
Us_v = torch.jit.annotate(List[Tensor], [])
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns - self.L):
d, q = self.wavelet_transform(q)
Ud_q += [tuple([d, q])]
Us_q += [d]
for i in range(ns - self.L):
d, k = self.wavelet_transform(k)
Ud_k += [tuple([d, k])]
Us_k += [d]
for i in range(ns - self.L):
d, v = self.wavelet_transform(v)
Ud_v += [tuple([d, v])]
Us_v += [d]
for i in range(ns - self.L):
dk, sk = Ud_k[i], Us_k[i]
dq, sq = Ud_q[i], Us_q[i]
dv, sv = Ud_v[i], Us_v[i]
Ud += [self.attn1(dq[0], dk[0], dv[0], attn_mask)[0] + self.attn2(dq[1], dk[1], dv[1], attn_mask)[0]]
Us += [self.attn3(sq, sk, sv, attn_mask)[0]]
v = self.attn4(q, k, v, attn_mask)[0]
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
v = v + Us[i]
v = torch.cat((v, Ud[i]), -1)
v = self.evenOdd(v)
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
return v.contiguous(), None
def wavelet_transform(self, x):
xa = torch.cat(
[
x[:, ::2, :, :],
x[:, 1::2, :, :],
],
-1,
)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k, device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
"""
get modes on frequency domain:
'random' means sampling randomly;
'else' means sampling the lowest modes;
"""
modes = min(modes, seq_len // 2)
if mode_select_method == "random":
index = list(range(0, seq_len // 2))
np.random.shuffle(index)
index = index[:modes]
else:
index = list(range(0, modes))
index.sort()
return index
# ########## fourier layer #############
class FourierBlock(AttentionOperator):
def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random"):
super().__init__()
# print("fourier enhanced block used!")
"""
1D Fourier block. It performs representation learning on frequency domain,
it does FFT, linear transform, and Inverse FFT.
"""
# get modes on frequency domain
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
# print("modes={}, index={}".format(modes, self.index))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale
* torch.rand(
8,
in_channels // 8,
out_channels // 8,
len(self.index),
dtype=torch.cfloat,
)
)
# Complex multiplication
def compl_mul1d(self, input, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bhi,hio->bho", input, weights)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
# d_tensor could be d_q, d_k, d_v
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
# Compute Fourier coefficients
x_ft = torch.fft.rfft(x, dim=-1)
# Perform Fourier neural operations
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi])
# Return to time domain
x = torch.fft.irfft(out_ft, n=x.size(-1))
return x, None
# ########## Fourier Cross Former ####################
class FourierCrossAttention(AttentionOperator):
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes=64,
mode_select_method="random",
activation="tanh",
policy=0,
num_heads=8,
):
super().__init__()
# print("fourier enhanced cross attention used!")
"""
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
"""
self.activation = activation
self.in_channels = in_channels
self.out_channels = out_channels
# get modes for queries and keys (& values) on frequency domain
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
# print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
# print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale
* torch.rand(
num_heads,
in_channels // num_heads,
out_channels // num_heads,
len(self.index_q),
dtype=torch.float,
)
)
self.weights2 = nn.Parameter(
self.scale
* torch.rand(
num_heads,
in_channels // num_heads,
out_channels // num_heads,
len(self.index_q),
dtype=torch.float,
)
)
# Complex multiplication
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(
torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real),
)
else:
return torch.einsum(order, x.real, weights.real)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
# d_tensor could be d_q, d_k, d_v
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
xk = k.permute(0, 2, 3, 1)
# xv = v.permute(0, 2, 3, 1)
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
if j >= xq_ft.shape[3]:
continue
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
if j >= xk_ft.shape[3]:
continue
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
# perform attention mechanism on frequency domain
xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
if self.activation == "tanh":
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception("{} actiation function is not implemented".format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
continue
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
# Return to time domain
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
return out, None
class SeriesDecompositionMultiBlock(nn.Module):
"""
Series decomposition block from FEDfromer,
i.e. series_decomp_multi from https://github.com/MAZiqing/FEDformer
"""
def __init__(self, kernel_size):
super().__init__()
self.moving_avg = [MovingAvgBlock(kernel, stride=1) for kernel in kernel_size]
self.layer = torch.nn.Linear(1, len(kernel_size))
def forward(self, x):
moving_mean = []
for func in self.moving_avg:
moving_avg = func(x)
moving_mean.append(moving_avg.unsqueeze(-1))
moving_mean = torch.cat(moving_mean, dim=-1)
moving_mean = torch.sum(moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1)
res = x - moving_mean
return res, moving_mean