pypots/nn/modules/tcn/layers.py
"""
"""
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch.nn as nn
from torch.nn.utils import weight_norm
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super().__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, : -self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
def __init__(
self,
n_inputs,
n_outputs,
kernel_size,
stride,
dilation,
padding,
dropout=0.2,
):
super().__init__()
self.conv1 = weight_norm(
nn.Conv1d(
n_inputs,
n_outputs,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
)
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = weight_norm(
nn.Conv1d(
n_outputs,
n_outputs,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
)
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(
self.conv1,
self.chomp1,
self.relu1,
self.dropout1,
self.conv2,
self.chomp2,
self.relu2,
self.dropout2,
)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)