IBM/pytorchpipe

View on GitHub
ptp/workers/trainer.py

Summary

Maintainability
F
5 days
Test Coverage
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Vincent Marois, Tomasz Kornuta"

from os import path,makedirs
import yaml
import torch
from time import sleep
from datetime import datetime

import ptp.configuration.config_parsing as config_parse
import ptp.utils.logger as logging

from ptp.workers.worker import Worker

from ptp.application.task_manager import TaskManager
from ptp.application.pipeline_manager import PipelineManager

from ptp.utils.statistics_collector import StatisticsCollector
from ptp.utils.statistics_aggregator import StatisticsAggregator


class Trainer(Worker):
    """
    Base class for the trainers.

    Iterates over epochs on the dataset.

    All other types of trainers (e.g. ``OnlineTrainer`` & ``OfflineTrainer``) should subclass it.

    """

    def __init__(self, name, class_type):
        """
        Base constructor for all trainers:

            - Adds default trainer command line arguments

        :param name: Name of the worker
        :type name: str

        :param class_type: Class type of the component.

        """ 
        # Call base constructor to set up app state, registry and add default arguments.
        super(Trainer, self).__init__(name, class_type)

        # Add arguments to the specific parser.
        # These arguments will be shared by all basic trainers.
        self.parser.add_argument(
            '--tensorboard',
            action='store',
            dest='tensorboard', choices=[0, 1, 2],
            type=int,
            help="If present, enable logging to TensorBoard. Available log levels:\n"
                "0: Log the collected statistics.\n"
                "1: Add the histograms of the model's biases & weights (Warning: Slow).\n"
                "2: Add the histograms of the model's biases & weights gradients "
                "(Warning: Even slower).")

        self.parser.add_argument(
            '--saveall',
            dest='save_intermediate',
            action='store_true',
            help='Setting to true results in saving intermediate models during training (DEFAULT: False)')

        self.parser.add_argument(
            '--training',
            dest='training_section_name',
            type=str,
            default="training",
            help='Name of the section defining the training procedure (DEFAULT: training)')

        self.parser.add_argument(
            '--validation',
            dest='validation_section_name',
            type=str,
            default="validation",
            help='Name of the section defining the validation procedure (DEFAULT: validation)')


    def setup_experiment(self):
        """
        Sets up experiment of all trainers:

        - Calls base class setup_experiment to parse the command line arguments,

        - Loads the config file(s)

        - Set up the log directory path

        - Add a ``FileHandler`` to the logger

        - Set random seeds

        - Creates the pipeline consisting of many components

        - Creates training task manager

        - Handles curriculum learning if indicated

        - Creates validation task manager

        - Set optimizer

        - Performs testing of compatibility of both training and validation tasks and created pipeline.

        """
        # Call base method to parse all command line arguments and add default sections.
        super(Trainer, self).setup_experiment()

        # "Pass" configuration parameters from the "default_training" section to training section indicated by the section_name.
        self.config.add_default_params({ self.app_state.args.training_section_name :  self.config['default_training'].to_dict()} )
        self.config.del_default_params('default_training')
        
        # "Pass" configuration parameters from the "default_validation" section to validation section indicated by the section_name.
        self.config.add_default_params({ self.app_state.args.validation_section_name:  self.config['default_validation'].to_dict()} )
        self.config.del_default_params('default_validation')


        # Check the presence of the CUDA-compatible devices.
        if self.app_state.args.use_gpu and (torch.cuda.device_count() == 0):
            self.logger.error("Cannot use GPU as there are no CUDA-compatible devices present in the system!")
            exit(-1)

        # Check if config file was selected.
        if self.app_state.args.config == '':
            print('Please pass configuration file(s) as --c parameter')
            exit(-2)

        # Split and make them absolute.
        root_configs = self.app_state.args.config.replace(" ", "").split(',')
        # If there are - expand them to absolute paths.
        abs_root_configs = [path.expanduser(config) for config in root_configs]
        
        # Get the list of configurations which need to be loaded.
        configs_to_load = config_parse.recurrent_config_parse(abs_root_configs, [], self.app_state.absolute_config_path)

        # Read the YAML files one by one - but in reverse order -> overwrite the first indicated config(s)
        config_parse.reverse_order_config_load(self.config, configs_to_load)

        # -> At this point, the Param Registry contains the configuration loaded (and overwritten) from several files.
        # Log the resulting training configuration.
        conf_str = 'Loaded (initial) configuration:\n'
        conf_str += '='*80 + '\n'
        conf_str += yaml.safe_dump(self.config.to_dict(), default_flow_style=False)
        conf_str += '='*80 + '\n'
        print(conf_str)

        # Get training section.
        try:
            tsn = self.app_state.args.training_section_name
            self.config_training = self.config[tsn]
            # We must additionally check if it is None - weird behvaiour when using default value.
            if self.config_training is None:
                raise KeyError()
        except KeyError:
            print("Error: Couldn't retrieve the training section '{}' from the loaded configuration".format(tsn))
            exit(-1)

        # Get training task type.
        try:
            training_task_type = self.config_training['task']['type']
        except KeyError:
            print("Error: Couldn't retrieve the task 'type' from the training section '{}' in the loaded configuration".format(tsn))
            exit(-1)

        # Get validation section.
        try:
            vsn = self.app_state.args.validation_section_name
            self.config_validation = self.config[vsn]
            if self.config_validation is None:
                raise KeyError()
        except KeyError:
            print("Error: Couldn't retrieve the validation section '{}' from the loaded configuration".format(vsn))
            exit(-1)

        # Get validation task type.
        try:
            _ = self.config_validation['task']['type']
        except KeyError:
            print("Error: Couldn't retrieve the task 'type' from the validation section '{}' in the loaded configuration".format(vsn))
            exit(-1)

        # Get pipeline section.
        try:
            psn = self.app_state.args.pipeline_section_name
            self.config_pipeline = self.config[psn]
            if self.config_pipeline is None:
                raise KeyError()
        except KeyError:
            print("Error: Couldn't retrieve the pipeline section '{}' from the loaded configuration".format(psn))
            exit(-1)

        # Get pipeline name.
        try:
            pipeline_name = self.config_pipeline['name']
        except KeyError:
            # Using name of the first configuration file from command line.
            basename = path.basename(root_configs[0])
            # Take config filename without extension.
            pipeline_name = path.splitext(basename)[0] 
            # Set pipeline name, so processor can use it afterwards.
            self.config_pipeline.add_config_params({'name': pipeline_name})

        # Prepare the output path for logging
        while True:  # Dirty fix: if log_dir already exists, wait for 1 second and try again
            try:
                time_str = '{0:%Y%m%d_%H%M%S}'.format(datetime.now())
                if self.app_state.args.exptag != '':
                    time_str = time_str + "_" + self.app_state.args.exptag
                self.app_state.log_dir = path.expanduser(self.app_state.args.expdir) + '/' + training_task_type + '/' + pipeline_name + '/' + time_str + '/'
                # Lowercase dir.
                self.app_state.log_dir = self.app_state.log_dir.lower()
                makedirs(self.app_state.log_dir, exist_ok=False)
            except FileExistsError:
                sleep(1)
            else:
                break

        # Set log dir.
        self.app_state.log_file = self.app_state.log_dir + 'trainer.log'
        # Initialize logger in app state.
        self.app_state.logger = logging.initialize_logger("AppState")
        # Add handlers for the logfile to worker logger.
        logging.add_file_handler_to_logger(self.logger)
        self.logger.info("Logger directory set to: {}".format(self.app_state.log_dir))

        # Set cpu/gpu types.
        self.app_state.set_types()

        # Models dir.
        self.checkpoint_dir = self.app_state.log_dir + 'checkpoints/'
        makedirs(self.checkpoint_dir, exist_ok=False)

        # Set random seeds in the training section.
        self.set_random_seeds('training', self.config_training)

        # Total number of detected errors.
        errors =0

        ################# TRAINING PROBLEM ################# 

        # Build training task manager.
        self.training = TaskManager('training', self.config_training) 
        errors += self.training.build()
        
        # parse the curriculum learning section in the loaded configuration.
        if 'curriculum_learning' in self.config_training:

            # Initialize curriculum learning - with values from loaded configuration.
            self.training.task.curriculum_learning_initialize(self.config_training['curriculum_learning'])

            # If the 'must_finish' key is not present in config then then it will be finished by default
            self.config_training['curriculum_learning'].add_default_params({'must_finish': True})

            self.must_finish_curriculum = self.config_training['curriculum_learning']['must_finish']
            self.logger.info("Curriculum Learning activated")

        else:
            # If not using curriculum learning then it does not have to be finished.
            self.must_finish_curriculum = False
            self.curric_done = True

        ################# VALIDATION PROBLEM ################# 
        
        # Build validation task manager.
        self.validation = TaskManager('validation', self.config_validation)
        errors += self.validation.build()

        ###################### PIPELINE ######################
        
        # Build the pipeline using the loaded configuration.
        self.pipeline = PipelineManager(pipeline_name, self.config_pipeline)
        errors += self.pipeline.build()

        # Check errors.
        if errors > 0:
            self.logger.error('Found {} errors, terminating execution'.format(errors))
            exit(-2)

        # Show pipeline.
        summary_str = self.pipeline.summarize_all_components_header()
        summary_str += self.training.task.summarize_io("training")
        summary_str += self.validation.task.summarize_io("validation")
        summary_str += self.pipeline.summarize_all_components()
        self.logger.info(summary_str)
        
        # Handshake definitions.
        self.logger.info("Handshaking training pipeline")
        defs_training = self.training.task.output_data_definitions()
        errors += self.pipeline.handshake(defs_training)

        self.logger.info("Handshaking validation pipeline")
        defs_valid = self.validation.task.output_data_definitions()
        errors += self.pipeline.handshake(defs_valid)

        # Check errors.
        if errors > 0:
            self.logger.error('Found {} errors, terminating execution'.format(errors))
            exit(-2)

        ################## MODEL LOAD/FREEZE #################

        # Load the pretrained models params from checkpoint.
        try: 
            # Check command line arguments, then check load option in config.
            if self.app_state.args.load_checkpoint != "":
                pipeline_name = self.app_state.args.load_checkpoint
                msg = "command line (--load)"
            elif "load" in self.config_pipeline:
                pipeline_name = self.config_pipeline['load']
                msg = "'pipeline' section of the configuration file"
            else:
                pipeline_name = ""
            # Try to load the model.
            if pipeline_name != "":
                if path.isfile(pipeline_name):
                    # Load parameters from checkpoint.
                    self.pipeline.load(pipeline_name)
                else:
                    raise Exception("Couldn't load the checkpoint {} indicated in the {}: file does not exist".format(pipeline_name, msg))
                # If we succeeded, we do not want to load the models from the file anymore!
            else:
                # Try to load the models parameters - one by one, if set so in the configuration file.
                self.pipeline.load_models()

        except KeyError:
            self.logger.error("File {} indicated in the {} seems not to be a valid model checkpoint".format(pipeline_name, msg))
            exit(-5)
        except Exception as e:
            self.logger.error(e)
            # Exit by following the logic: if user wanted to load the model but failed, then continuing the experiment makes no sense.
            exit(-6)

        # Finally, freeze the models (that the user wants to freeze).
        self.pipeline.freeze_models()

        # Log the model summaries.
        summary_str = self.pipeline.summarize_models_header()
        summary_str += self.pipeline.summarize_models()
        self.logger.info(summary_str)

        # Move the models in the pipeline to GPU.
        if self.app_state.args.use_gpu:
            self.pipeline.cuda()        

        ################# OPTIMIZER ################# 

        # Set the optimizer.
        optimizer_conf = dict(self.config_training['optimizer'])
        optimizer_type = optimizer_conf['type']
        del optimizer_conf['type']

        # Check if there are any models in the pipeline.
        if len(list(filter(lambda p: p.requires_grad, self.pipeline.parameters()))) == 0:
            self.logger.error('Cannot proceed with training, as there are no trainable models in the pipeline (or all models are frozen)')
            exit(-7)

        # Instantiate the optimizer and filter the model parameters based on if they require gradients.
        self.optimizer = getattr(torch.optim, optimizer_type)(
            filter(lambda p: p.requires_grad, self.pipeline.parameters()), **optimizer_conf)

        log_str = 'Optimizer:\n' + '='*80 + "\n"
        log_str += "  Type: " + optimizer_type + "\n"
        log_str += "  Params: {}".format(optimizer_conf)

        self.logger.info(log_str)

    def add_statistics(self, stat_col):
        """
        Calls base method and adds epoch statistics to ``StatisticsCollector``.

        :param stat_col: ``StatisticsCollector``.

        """
        # Add loss and episode.
        super(Trainer, self).add_statistics(stat_col)

        # Add default statistics with formatting.
        stat_col.add_statistics('epoch', '{:02d}')


    def add_aggregators(self, stat_agg):
        """
        Adds basic aggregators to to ``StatisticsAggregator`` and extends them with: epoch.

        :param stat_agg: ``StatisticsAggregator``.

        """
        # Add basic aggregators.
        super(Trainer, self).add_aggregators(stat_agg)

        # add 'aggregators' for the epoch.
        stat_agg.add_aggregator('epoch', '{:02d}')


    def initialize_statistics_collection(self):
        """
        - Initializes all ``StatisticsCollectors`` and ``StatisticsAggregators`` used by a given worker: \

            - For training statistics (adds the statistics of the model & task),
            - For validation statistics (adds the statistics of the model & task).

        - Creates the output files (csv).

        """
        # TRAINING.
        # Create statistics collector for training.
        self.training_stat_col = StatisticsCollector()
        self.add_statistics(self.training_stat_col)
        self.training.task.add_statistics(self.training_stat_col)
        self.pipeline.add_statistics(self.training_stat_col)
        # Create the csv file to store the training statistics.
        self.training_batch_stats_file = self.training_stat_col.initialize_csv_file(self.app_state.log_dir, 'training_statistics.csv')

        # Create statistics aggregator for training.
        self.training_stat_agg = StatisticsAggregator()
        self.add_aggregators(self.training_stat_agg)
        self.training.task.add_aggregators(self.training_stat_agg)
        self.pipeline.add_aggregators(self.training_stat_agg)
        # Create the csv file to store the training statistic aggregations.
        self.training_set_stats_file = self.training_stat_agg.initialize_csv_file(self.app_state.log_dir, 'training_set_agg_statistics.csv')

        # VALIDATION.
        # Create statistics collector for validation.
        self.validation_stat_col = StatisticsCollector()
        self.add_statistics(self.validation_stat_col)
        self.validation.task.add_statistics(self.validation_stat_col)
        self.pipeline.add_statistics(self.validation_stat_col)
        # Create the csv file to store the validation statistics.
        self.validation_batch_stats_file = self.validation_stat_col.initialize_csv_file(self.app_state.log_dir, 'validation_statistics.csv')

        # Create statistics aggregator for validation.
        self.validation_stat_agg = StatisticsAggregator()
        self.add_aggregators(self.validation_stat_agg)
        self.validation.task.add_aggregators(self.validation_stat_agg)
        self.pipeline.add_aggregators(self.validation_stat_agg)
        # Create the csv file to store the validation statistic aggregations.
        self.validation_set_stats_file = self.validation_stat_agg.initialize_csv_file(self.app_state.log_dir, 'validation_set_agg_statistics.csv')


    def finalize_statistics_collection(self):
        """
        Finalizes the statistics collection by closing the csv files.

        """
        # Close all files.
        self.training_batch_stats_file.close()
        self.training_set_stats_file.close()
        self.validation_batch_stats_file.close()
        self.validation_set_stats_file.close()


    def initialize_tensorboard(self):
        """
        Initializes the TensorBoard writers, and log directories.

        """
        # Create TensorBoard outputs - if TensorBoard is supposed to be used.
        if self.app_state.args.tensorboard is not None:
            from tensorboardX import SummaryWriter
            self.training_batch_writer = SummaryWriter(self.app_state.log_dir + '/training')
            self.training_stat_col.initialize_tensorboard(self.training_batch_writer)

            self.training_set_writer = SummaryWriter(self.app_state.log_dir + '/training_set_agg')
            self.training_stat_agg.initialize_tensorboard(self.training_set_writer)
            
            self.validation_batch_writer = SummaryWriter(self.app_state.log_dir + '/validation')
            self.validation_stat_col.initialize_tensorboard(self.validation_batch_writer)

            self.validation_set_writer = SummaryWriter(self.app_state.log_dir + '/validation_set_agg')
            self.validation_stat_agg.initialize_tensorboard(self.validation_set_writer)
        else:
            self.training_batch_writer = None
            self.training_set_writer = None
            self.validation_batch_writer = None
            self.validation_set_writer = None

    def finalize_tensorboard(self):
        """ 
        Finalizes the operation of TensorBoard writers by closing them.
        """
        # Close the TensorBoard writers.
        if self.training_batch_writer is not None:
            self.training_batch_writer.close()
        if self.training_set_writer is not None:
            self.training_set_writer.close()
        if self.validation_batch_writer is not None:
            self.validation_batch_writer.close()
        if self.validation_set_writer is not None:
            self.validation_set_writer.close()

    def validate_on_batch(self, valid_batch):
        """
        Performs a validation of the model using the provided batch.

        Additionally logs results (to files, TensorBoard) and handles visualization.

        :param valid_batch: data batch generated by the task and used as input to the model.
        :type valid_batch: ``DataStreams``

        :return: Validation loss.

        """
        # Turn on evaluation mode.
        self.pipeline.eval()
        # Empty the statistics collector.
        self.validation_stat_col.empty()

        # Compute the validation loss using the provided data batch.
        with torch.no_grad():
            # Forward pass.
            self.pipeline.forward(valid_batch)
            # Collect the statistics.
            self.collect_all_statistics(self.validation, self.pipeline, valid_batch, self.validation_stat_col)

        # Export collected statistics.
        self.export_all_statistics(self.validation_stat_col, '[Partial Validation]')

    def validate_on_set(self):
        """
        Performs a validation of the model on the whole validation set, using the validation ``DataLoader``.

        Iterates over the entire validation set (through the `DataLoader``), aggregates the collected statistics \
        and logs that to the console, csv and TensorBoard (if set).

        """
        # Get number of samples.
        num_samples = len(self.validation)
        
        self.logger.info('Validating over the entire validation set ({} samples in {} episodes)'.format(
            num_samples, len(self.validation.dataloader)))

        # Turn on evaluation mode.
        self.pipeline.eval()

        # Reset the statistics.
        self.validation_stat_col.empty()

        # Remember global episode number.
        old_episode = self.app_state.episode

        with torch.no_grad():
            for ep, valid_batch in enumerate(self.validation.dataloader):

                self.app_state.episode = ep
                # Forward pass.
                self.pipeline.forward(valid_batch)
                # Collect the statistics.
                self.collect_all_statistics(self.validation, self.pipeline, valid_batch,
                        self.validation_stat_col)

        # Revert to global episode number.
        self.app_state.episode = old_episode

        # Aggregate statistics for the whole set.
        self.aggregate_all_statistics(self.validation, self.pipeline,
            self.validation_stat_col, self.validation_stat_agg)

        # Export aggregated statistics.
        self.export_all_statistics(self.validation_stat_agg, '[Full Validation]')


if __name__ == '__main__':
    print("The trainer.py file contains only an abstract base class. Please try to use the \
online_trainer (mip-online-trainer) or  offline_trainer (mip-offline-trainer) instead.")