pypots/nn/modules/timemixer/backbone.py
File `backbone.py` has 304 lines of code (exceeds 250 allowed). Consider refactoring.""" """ # Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clause import torchimport torch.nn as nnimport torch.nn.functional as F from .layers import PastDecomposableMixingfrom ..autoformer import SeriesDecompositionBlockfrom ..revin import RevINfrom ..transformer.embedding import DataEmbedding class BackboneTimeMixer(nn.Module):Function `__init__` has 21 arguments (exceeds 4 allowed). Consider refactoring.
Function `__init__` has a Cognitive Complexity of 15 (exceeds 5 allowed). Consider refactoring. def __init__( self, task_name: str, n_steps: int, n_features: int, n_pred_steps: int, n_pred_features: int, n_layers: int, d_model: int, d_ffn: int, dropout: float, top_k: int, channel_independence: bool, decomp_method: str, moving_avg: int, downsampling_layers: int, downsampling_window: int, downsampling_method: str, use_future_temporal_feature: bool, use_norm: bool = False, embed="fixed", freq="h", n_classes=None, ): super().__init__() self.task_name = task_name self.n_steps = n_steps self.n_features = n_features self.n_pred_steps = n_pred_steps self.n_pred_features = n_pred_features self.n_layers = n_layers self.channel_independence = channel_independence self.downsampling_window = downsampling_window self.downsampling_layers = downsampling_layers self.downsampling_method = downsampling_method self.use_norm = use_norm self.use_future_temporal_feature = use_future_temporal_feature assert downsampling_method in ["max", "avg", "conv"], "downsampling_method must be in ['max', 'avg', 'conv']" self.pdm_blocks = nn.ModuleList( [ PastDecomposableMixing( n_steps, n_pred_steps, d_model, d_ffn, dropout, channel_independence, decomp_method, top_k, moving_avg, downsampling_layers, downsampling_window, ) for _ in range(n_layers) ] ) self.preprocess = SeriesDecompositionBlock(moving_avg) if self.channel_independence: self.enc_embedding = DataEmbedding(1, d_model, embed, freq, dropout, with_pos=False) else: self.enc_embedding = DataEmbedding(n_features, d_model, embed, freq, dropout, with_pos=False) if self.use_norm: self.normalize_layers = torch.nn.ModuleList([RevIN(n_features) for _ in range(downsampling_layers + 1)]) if task_name == "long_term_forecast" or task_name == "short_term_forecast": self.predict_layers = torch.nn.ModuleList( [ torch.nn.Linear( n_steps // (downsampling_window**i), n_pred_steps, ) for i in range(downsampling_layers + 1) ] ) if self.channel_independence: self.projection_layer = nn.Linear(d_model, 1, bias=True) else: self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True) self.out_res_layers = torch.nn.ModuleList( [ torch.nn.Linear( n_steps // (downsampling_window**i), n_steps // (downsampling_window**i), ) for i in range(downsampling_layers + 1) ] ) self.regression_layers = torch.nn.ModuleList( [ torch.nn.Linear( n_steps // (downsampling_window**i), n_pred_steps, ) for i in range(downsampling_layers + 1) ] ) elif task_name == "imputation" or task_name == "anomaly_detection": if self.channel_independence: self.projection_layer = nn.Linear(d_model, 1, bias=True) else: self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True) elif task_name == "classification": self.act = F.gelu self.dropout = nn.Dropout(dropout) self.projection = nn.Linear(d_model * n_steps, n_classes) else: raise NotImplementedError("Task not supported") def out_projection(self, dec_out, i, out_res): dec_out = self.projection_layer(dec_out) out_res = out_res.permute(0, 2, 1) out_res = self.out_res_layers[i](out_res) out_res = self.regression_layers[i](out_res).permute(0, 2, 1) dec_out = dec_out + out_res return dec_out def pre_enc(self, x_list): if self.channel_independence: return x_list, None else: out1_list = [] out2_list = [] for x in x_list: x_1, x_2 = self.preprocess(x) out1_list.append(x_1) out2_list.append(x_2) return out1_list, out2_list Function `__multi_scale_process_inputs` has a Cognitive Complexity of 11 (exceeds 5 allowed). Consider refactoring. def __multi_scale_process_inputs(self, x_enc, x_mark_enc): if self.downsampling_method == "max": down_pool = torch.nn.MaxPool1d(self.downsampling_window, return_indices=False) elif self.downsampling_method == "avg": down_pool = torch.nn.AvgPool1d(self.downsampling_window) elif self.downsampling_method == "conv": padding = 1 if torch.__version__ >= "1.5.0" else 2 down_pool = nn.Conv1d( in_channels=self.enc_in, out_channels=self.enc_in, kernel_size=3, padding=padding, stride=self.downsampling_window, padding_mode="circular", bias=False, ) else: return x_enc, x_mark_enc # B,T,C -> B,C,T x_enc = x_enc.permute(0, 2, 1) x_enc_ori = x_enc x_mark_enc_mark_ori = x_mark_enc x_enc_sampling_list = [] x_mark_sampling_list = [] x_enc_sampling_list.append(x_enc.permute(0, 2, 1)) x_mark_sampling_list.append(x_mark_enc) for i in range(self.downsampling_layers): x_enc_sampling = down_pool(x_enc_ori) x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1)) x_enc_ori = x_enc_sampling if x_mark_enc_mark_ori is not None: x_mark_sampling_list.append(x_mark_enc_mark_ori[:, :: self.downsampling_window, :]) x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, :: self.downsampling_window, :] x_enc = x_enc_sampling_list if x_mark_enc_mark_ori is not None: x_mark_enc = x_mark_sampling_list else: x_mark_enc = x_mark_enc return x_enc, x_mark_enc Function `forecast` has a Cognitive Complexity of 30 (exceeds 5 allowed). Consider refactoring. def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None): if self.use_future_temporal_feature: if self.channel_independence: B, T, N = x_enc.size() x_mark_dec = x_mark_dec.repeat(N, 1, 1) self.x_mark_dec = self.enc_embedding(None, x_mark_dec) else: self.x_mark_dec = self.enc_embedding(None, x_mark_dec) x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) x_list = [] x_mark_list = [] if x_mark_enc is not None: for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): B, T, N = x.size() x = self.normalize_layers[i](x, x_mark, mode="norm") if self.use_norm else x if self.channel_independence: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_mark = x_mark.repeat(N, 1, 1) x_list.append(x) x_mark_list.append(x_mark) else: for i, x in zip( range(len(x_enc)), x_enc, ): B, T, N = x.size() x = self.normalize_layers[i](x, mode="norm") if self.use_norm else x if self.channel_independence: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_list.append(x) # embedding enc_out_list = [] x_list = self.pre_enc(x_list) if x_mark_enc is not None: for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list): enc_out = self.enc_embedding(x, x_mark) # [B,T,C] enc_out_list.append(enc_out) else: for i, x in zip(range(len(x_list[0])), x_list[0]): enc_out = self.enc_embedding(x, None) # [B,T,C] enc_out_list.append(enc_out) # Past Decomposable Mixing as encoder for past for i in range(self.n_layers): enc_out_list = self.pdm_blocks[i](enc_out_list) # Future Multipredictor Mixing as decoder for future dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list) dec_out = torch.stack(dec_out_list, dim=-1).sum(-1) dec_out = self.normalize_layers[0](dec_out, mode="denorm") if self.use_norm else dec_out return dec_out Function `future_multi_mixing` has a Cognitive Complexity of 10 (exceeds 5 allowed). Consider refactoring. def future_multi_mixing(self, B, enc_out_list, x_list): dec_out_list = [] if self.channel_independence: x_list = x_list[0] for i, enc_out in zip(range(len(x_list)), enc_out_list): dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension if self.use_future_temporal_feature: dec_out = dec_out + self.x_mark_dec dec_out = self.projection_layer(dec_out) else: dec_out = self.projection_layer(dec_out) dec_out = dec_out.reshape(B, self.c_out, self.n_pred_steps).permute(0, 2, 1).contiguous() dec_out_list.append(dec_out) else: for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]): dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension dec_out = self.out_projection(dec_out, i, out_res) dec_out_list.append(dec_out) return dec_out_list def classification(self, x_enc, x_mark_enc): x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) x_list = x_enc # embedding enc_out_list = [] for x in x_list: enc_out = self.enc_embedding(x, None) # [B,T,C] enc_out_list.append(enc_out) # MultiScale-CrissCrossAttention as encoder for past for i in range(self.n_layer): enc_out_list = self.pdm_blocks[i](enc_out_list) enc_out = enc_out_list[0] # Output # the output transformer encoder/decoder embeddings don't include non-linearity output = self.act(enc_out) output = self.dropout(output) # zero-out padding embeddings output = output * x_mark_enc.unsqueeze(-1) # (batch_size, n_stepsgth * d_model) output = output.reshape(output.shape[0], -1) output = self.projection(output) # (batch_size, num_classes) return output Function `anomaly_detection` has a Cognitive Complexity of 8 (exceeds 5 allowed). Consider refactoring. def anomaly_detection(self, x_enc): B, T, N = x_enc.size() x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) x_list = [] for i, x in zip( range(len(x_enc)), x_enc, ): B, T, N = x.size() x = self.normalize_layers[i](x, "norm") if self.use_norm else x if self.channel_independence: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_list.append(x) # embedding enc_out_list = [] for x in x_list: enc_out = self.enc_embedding(x, None) # [B,T,C] enc_out_list.append(enc_out) # MultiScale-CrissCrossAttention as encoder for past for i in range(self.n_layers): enc_out_list = self.pdm_blocks[i](enc_out_list) dec_out = self.projection_layer(enc_out_list[0]) dec_out = dec_out.reshape(B, self.c_out, -1).permute(0, 2, 1).contiguous() dec_out = self.normalize_layers[0](dec_out, "denorm") if self.use_norm else dec_out return dec_out Function `imputation` has a Cognitive Complexity of 14 (exceeds 5 allowed). Consider refactoring. def imputation(self, x_enc, x_mark_enc): B, T, N = x_enc.size() x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) x_list = [] x_mark_list = [] if x_mark_enc is not None: for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): B, T, N = x.size() if self.channel_independence: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_list.append(x) x_mark = x_mark.repeat(N, 1, 1) x_mark_list.append(x_mark) else: for i, x in zip(range(len(x_enc)), x_enc): B, T, N = x.size() if self.channel_independence: x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) x_list.append(x) # embedding enc_out_list = [] for x in x_list: enc_out = self.enc_embedding(x, None) # [B,T,C] enc_out_list.append(enc_out) # MultiScale-CrissCrossAttention as encoder for past for i in range(self.n_layers): enc_out_list = self.pdm_blocks[i](enc_out_list) dec_out = self.projection_layer(enc_out_list[0]) dec_out = dec_out.reshape(B, self.n_pred_features, -1).permute(0, 2, 1).contiguous() return dec_out