pypots/nn/modules/tcn/layers.py
""" """ # Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clause import torch.nn as nnfrom torch.nn.utils import weight_norm class Chomp1d(nn.Module): def __init__(self, chomp_size): super().__init__() self.chomp_size = chomp_size def forward(self, x): return x[:, :, : -self.chomp_size].contiguous() class TemporalBlock(nn.Module):Function `__init__` has 7 arguments (exceeds 4 allowed). Consider refactoring. def __init__( self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2, ): super().__init__() self.conv1 = weight_norm( nn.Conv1d( n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation, ) ) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() self.dropout1 = nn.Dropout(dropout) self.conv2 = weight_norm( nn.Conv1d( n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation, ) ) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() self.dropout2 = nn.Dropout(dropout) self.net = nn.Sequential( self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2, ) self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None self.relu = nn.ReLU() self.init_weights() def init_weights(self): self.conv1.weight.data.normal_(0, 0.01) self.conv2.weight.data.normal_(0, 0.01) if self.downsample is not None: self.downsample.weight.data.normal_(0, 0.01) def forward(self, x): out = self.net(x) res = x if self.downsample is None else self.downsample(x) return self.relu(out + res)