src/graphnet/models/easy_model.py
File `easy_model.py` has 424 lines of code (exceeds 250 allowed). Consider refactoring."""Suggested Model subclass that enables simple user syntax.""" from collections import OrderedDictfrom typing import Any, Dict, List, Optional, Union, Type import numpy as npimport torchfrom pytorch_lightning import Callback, Trainerfrom pytorch_lightning.callbacks import EarlyStopping, ModelCheckpointfrom torch import Tensorfrom torch.nn import ModuleListfrom torch.optim import Adamfrom torch.utils.data import DataLoader, SequentialSamplerfrom torch_geometric.data import Dataimport pandas as pdfrom pytorch_lightning.loggers import Logger as LightningLogger from graphnet.training.callbacks import ProgressBarfrom graphnet.models.model import Modelfrom graphnet.models.task import StandardLearnedTask class EasySyntax(Model): """A suggested Model class that comes with simple user syntax. This class delivers simple user syntax for training and prediction, while imposing minimal constraints on structure. """ Function `__init__` has 7 arguments (exceeds 4 allowed). Consider refactoring. def __init__( self, *, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, ) -> None: """Construct `StandardModel`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) # Check(s) if not isinstance(tasks, (list, tuple)): tasks = [tasks] # Member variable(s) self._tasks = ModuleList(tasks) self._optimizer_class = optimizer_class self._optimizer_kwargs = optimizer_kwargs or dict() self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() self.validate_tasks() def compute_loss( self, preds: Tensor, data: List[Data], verbose: bool = False ) -> Tensor: """Compute and sum losses across tasks.""" raise NotImplementedError def forward( self, data: Union[Data, List[Data]] ) -> List[Union[Tensor, Data]]: """Forward pass, chaining model components.""" raise NotImplementedError def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. Applies the forward pass and the following loss calculation, shared between the training and validation step. """ raise NotImplementedError def validate_tasks(self) -> None: """Verify that self._tasks contain compatible elements.""" raise NotImplementedError @staticmethodFunction `_construct_trainer` has 7 arguments (exceeds 4 allowed). Consider refactoring. def _construct_trainer( max_epochs: int = 10, gpus: Optional[Union[List[int], int]] = None, callbacks: Optional[List[Callback]] = None, logger: Optional[LightningLogger] = None, log_every_n_steps: int = 1, gradient_clip_val: Optional[float] = None, distribution_strategy: Optional[str] = "ddp", **trainer_kwargs: Any, ) -> Trainer: if gpus: accelerator = "gpu" devices = gpus else: accelerator = "cpu" devices = 1 trainer = Trainer( accelerator=accelerator, devices=devices, max_epochs=max_epochs, callbacks=callbacks, log_every_n_steps=log_every_n_steps, logger=logger, gradient_clip_val=gradient_clip_val, strategy=distribution_strategy, **trainer_kwargs, ) return trainer Function `fit` has 13 arguments (exceeds 4 allowed). Consider refactoring.
Function `fit` has a Cognitive Complexity of 10 (exceeds 5 allowed). Consider refactoring. def fit( self, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None, *, max_epochs: int = 10, early_stopping_patience: int = 5, gpus: Optional[Union[List[int], int]] = None, callbacks: Optional[List[Callback]] = None, ckpt_path: Optional[str] = None, logger: Optional[LightningLogger] = None, log_every_n_steps: int = 1, gradient_clip_val: Optional[float] = None, distribution_strategy: Optional[str] = "ddp", **trainer_kwargs: Any, ) -> None: """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" # Checks if callbacks is None: # We create the bare-minimum callbacks for you. callbacks = self._create_default_callbacks( val_dataloader=val_dataloader, early_stopping_patience=early_stopping_patience, ) self.debug("No Callbacks specified. Default callbacks added.") else: # You are on your own! self.debug("Initializing training with user-provided callbacks.") pass self._print_callbacks(callbacks) has_early_stopping = self._contains_callback(callbacks, EarlyStopping) has_model_checkpoint = self._contains_callback( callbacks, ModelCheckpoint ) if (has_early_stopping) & (has_model_checkpoint is False): self.warning( "No ModelCheckpoint found in callbacks. Best-fit model will" " not automatically be loaded after training!" "" ) self.train(mode=True) trainer = self._construct_trainer( max_epochs=max_epochs, gpus=gpus, callbacks=callbacks, logger=logger, log_every_n_steps=log_every_n_steps, gradient_clip_val=gradient_clip_val, distribution_strategy=distribution_strategy, **trainer_kwargs, ) try: trainer.fit( self, train_dataloader, val_dataloader, ckpt_path=ckpt_path ) except KeyboardInterrupt: self.warning("[ctrl+c] Exiting gracefully.") pass # Load weights from best-fit model after training if possible if has_early_stopping & has_model_checkpoint: for callback in callbacks: if isinstance(callback, ModelCheckpoint): checkpoint_callback = callback self.load_state_dict( torch.load(checkpoint_callback.best_model_path)["state_dict"] ) self.info("Best-fit weights from EarlyStopping loaded.") def _print_callbacks(self, callbacks: List[Callback]) -> None: callback_names = [] for cbck in callbacks: callback_names.append(cbck.__class__.__name__) self.info( f"Training initiated with callbacks: {', '.join(callback_names)}" ) def _contains_callback( self, callbacks: List[Callback], callback: Callback ) -> bool: """Check if `callback` is in `callbacks`.""" for cbck in callbacks: if isinstance(cbck, callback): return True return False @property def target_labels(self) -> List[str]: """Return target label.""" return [label for task in self._tasks for label in task._target_labels] @property def prediction_labels(self) -> List[str]: """Return prediction labels.""" return [ label for task in self._tasks for label in task._prediction_labels ] def configure_optimizers(self) -> Dict[str, Any]: """Configure the model's optimizer(s).""" optimizer = self._optimizer_class( self.parameters(), **self._optimizer_kwargs ) config = { "optimizer": optimizer, } if self._scheduler_class is not None: scheduler = self._scheduler_class( optimizer, **self._scheduler_kwargs ) config.update( { "lr_scheduler": { "scheduler": scheduler, **self._scheduler_config, }, } ) return config def training_step( self, train_batch: Union[Data, List[Data]], batch_idx: int ) -> Tensor: """Perform training step.""" if isinstance(train_batch, Data): train_batch = [train_batch] loss = self.shared_step(train_batch, batch_idx) self.log( "train_loss", loss, batch_size=self._get_batch_size(train_batch), prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log("lr", current_lr, prog_bar=True, on_step=True) return loss def validation_step( self, val_batch: Union[Data, List[Data]], batch_idx: int ) -> Tensor: """Perform validation step.""" if isinstance(val_batch, Data): val_batch = [val_batch] loss = self.shared_step(val_batch, batch_idx) self.log( "val_loss", loss, batch_size=self._get_batch_size(val_batch), prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) return loss def inference(self) -> None: """Activate inference mode.""" for task in self._tasks: task.inference() def train(self, mode: bool = True) -> "Model": """Deactivate inference mode.""" super().train(mode) if mode: for task in self._tasks: task.train_eval() return self def predict( self, dataloader: DataLoader, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", **trainer_kwargs: Any, ) -> List[Tensor]: """Return predictions for `dataloader`.""" self.inference() self.train(mode=False) callbacks = self._create_default_callbacks( val_dataloader=None, ) inference_trainer = self._construct_trainer( gpus=gpus, distribution_strategy=distribution_strategy, callbacks=callbacks, **trainer_kwargs, ) predictions_list = inference_trainer.predict(self, dataloader) assert len(predictions_list), "Got no predictions" nb_outputs = len(predictions_list[0]) predictions: List[Tensor] = [ torch.cat([preds[ix] for preds in predictions_list], dim=0) for ix in range(nb_outputs) ] return predictions Function `predict_as_dataframe` has 7 arguments (exceeds 4 allowed). Consider refactoring. def predict_as_dataframe( self, dataloader: DataLoader, prediction_columns: Optional[List[str]] = None, *, additional_attributes: Optional[List[str]] = None, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", **trainer_kwargs: Any, ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. Include `additional_attributes` as additional columns in the output DataFrame. """ if prediction_columns is None: prediction_columns = self.prediction_labels if additional_attributes is None: additional_attributes = [] assert isinstance(additional_attributes, list) if ( not isinstance(dataloader.sampler, SequentialSampler) and additional_attributes ): print(dataloader.sampler) raise UserWarning( "DataLoader has a `sampler` that is not `SequentialSampler`, " "indicating that shuffling is enabled. Using " "`predict_as_dataframe` with `additional_attributes` assumes " "that the sequence of batches in `dataloader` are " "deterministic. Either call this method a `dataloader` which " "doesn't resample batches; or do not request " "`additional_attributes`." ) self.info(f"Column names for predictions are: \n {prediction_columns}") predictions_torch = self.predict( dataloader=dataloader, gpus=gpus, distribution_strategy=distribution_strategy, **trainer_kwargs, ) predictions = ( torch.cat(predictions_torch, dim=1).detach().cpu().numpy() ) assert len(prediction_columns) == predictions.shape[1], ( f"Number of provided column names ({len(prediction_columns)}) and " f"number of output columns ({predictions.shape[1]}) don't match." ) # Check if predictions are on event- or pulse-level pulse_level_predictions = len(predictions) > len(dataloader.dataset) # Get additional attributes attributes: Dict[str, List[np.ndarray]] = OrderedDict( [(attr, []) for attr in additional_attributes] ) for batch in dataloader: for attr in attributes: attribute = batch[attr] if isinstance(attribute, torch.Tensor): attribute = attribute.detach().cpu().numpy() # Check if node level predictions # If true, additional attributes are repeated # to make dimensions fit if pulse_level_predictions: if len(attribute) < np.sum( batch.n_pulses.detach().cpu().numpy() ): attribute = np.repeat( attribute, batch.n_pulses.detach().cpu().numpy() ) attributes[attr].extend(attribute) # Confirm that attributes match length of predictions skip_attributes = [] for attr in attributes.keys(): try: assert len(attributes[attr]) == len(predictions) except AssertionError: self.warning_once( "Could not automatically adjust length" f" of additional attribute '{attr}' to match length of" f" predictions.This error can be caused by heavy" " disagreement between number of examples in the" " dataset vs. actual events in the dataloader, e.g. " " heavy filtering of events in `collate_fn` passed to" " `dataloader`. This can also be caused by requesting" " pulse-level attributes for `Task`s that produce" " event-level predictions. Attribute skipped." ) skip_attributes.append(attr) # Remove bad attributes for attr in skip_attributes: attributes.pop(attr) additional_attributes.remove(attr) data = np.concatenate( [predictions] + [ np.asarray(values)[:, np.newaxis] for values in attributes.values() ], axis=1, ) results = pd.DataFrame( data, columns=prediction_columns + additional_attributes ) return results def _create_default_callbacks( self, val_dataloader: DataLoader, early_stopping_patience: Optional[int] = None, ) -> List: """Create default callbacks. Used in cases where no callbacks are specified by the user in .fit """ callbacks = [ProgressBar()] if val_dataloader is not None: assert early_stopping_patience is not None # Add Early Stopping callbacks.append( EarlyStopping( monitor="val_loss", patience=early_stopping_patience, ) ) # Add Model Check Point callbacks.append( ModelCheckpoint( save_top_k=1, monitor="val_loss", mode="min", filename=f"{self.backbone.__class__.__name__}" + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", ) ) self.info( "EarlyStopping has been added" f" with a patience of {early_stopping_patience}." ) return callbacks def _add_early_stopping( self, val_dataloader: DataLoader, callbacks: List ) -> List: if val_dataloader is None: return callbacks has_early_stopping = False assert isinstance(callbacks, list) for callback in callbacks: if isinstance(callback, EarlyStopping): has_early_stopping = True if not has_early_stopping: callbacks.append( EarlyStopping( monitor="val_loss", patience=5, ) ) self.warning_once( "Got validation dataloader but no EarlyStopping callback. An " "EarlyStopping callback has been added automatically with " "patience=5 and monitor = 'val_loss'." ) return callbacks