WenjieDu/PyPOTS

View on GitHub
pypots/nn/modules/timesnet/backbone.py

Summary

Maintainability
A
50 mins
Test Coverage
"""

"""
import torch
import torch.nn as nn

from .layers import TimesBlock


# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


class BackboneTimesNet(nn.Module):
    def __init__(
        self,
        n_layers,
        n_steps,
        n_pred_steps,
        top_k,
        d_model,
        d_ffn,
        n_kernels,
    ):
        super().__init__()

        self.seq_len = n_steps
        self.n_layers = n_layers

        self.n_pred_steps = n_pred_steps
        self.model = nn.ModuleList(
            [
                TimesBlock(n_steps, n_pred_steps, top_k, d_model, d_ffn, n_kernels)
                for _ in range(n_layers)
            ]
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, X) -> torch.Tensor:

        for i in range(self.n_layers):
            enc_out = self.layer_norm(self.model[i](X))

        return enc_out