ptp/workers/offline_trainer.py
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__author__ = "Tomasz Kornuta, Vincent Marois"
import torch
import numpy as np
from ptp.workers.trainer import Trainer
import ptp.configuration.config_parsing as config_parsing
from ptp.configuration.configuration_error import ConfigurationError
from ptp.utils.termination_condition import TerminationCondition
class OfflineTrainer(Trainer):
"""
Implementation for the epoch-based ``OfflineTrainer``.
..note::
The default ``OfflineTrainer`` is based on epochs. \
An epoch is defined as passing through all samples of a finite-size dataset.\
The ``OfflineTrainer`` allows to loop over all samples from the training set many times i.e. in many epochs. \
When an epochs finishes, it performs a similar step for the validation set and collects the statistics.
"""
def __init__(self):
"""
Constructor. It on calls the ``Trainer`` constructor as the initialization phase is identical to the one from ``Trainer``.
"""
# Call base constructor to set up app state, registry and add default config.
super(OfflineTrainer, self).__init__("OfflineTrainer", OfflineTrainer)
def setup_experiment(self):
"""
Sets up experiment for episode trainer:
- Calls base class setup_experiment to parse the command line arguments,
- Sets up the terminal conditions (loss threshold, episodes & epochs (optional) limits).
"""
# Call base method to parse all command line arguments, load configuration, create tasks and model etc.
super(OfflineTrainer, self).setup_experiment()
# In this trainer Partial Validation is mandatory, hence interval must be > 0.
self.partial_validation_interval = self.config['validation']['partial_validation_interval']
if self.partial_validation_interval <= 0:
self.logger.info("Partial Validation deactivated")
else:
self.logger.info("Partial Validation activated with interval equal to {} episodes".format(self.partial_validation_interval))
################# TERMINAL CONDITIONS #################
log_str = 'Terminal conditions:\n' + '='*80 + "\n"
# Terminal condition I: loss.
self.loss_stop_threshold = self.config_training['terminal_conditions']['loss_stop_threshold']
log_str += " I: Setting Loss Stop Threshold to {}\n".format(self.loss_stop_threshold)
# Terminal condition II: early stopping.
self.early_stop_validations = self.config_training['terminal_conditions']['early_stop_validations']
if self.early_stop_validations <= 0:
log_str += " II: Termination based on Early Stopping is disabled\n"
# Set to infinity.
self.early_stop_validations = np.Inf
else:
log_str += " II: Setting the Number of Validations in Early Stopping to: {}\n".format(self.early_stop_validations)
# Terminal condition III: max epochs. Mandatory.
self.epoch_limit = self.config_training["terminal_conditions"]["epoch_limit"]
if self.epoch_limit <= 0:
self.logger.error("OffLine Trainer relies on epochs, thus Epoch Limit must be a positive number!")
exit(-5)
else:
log_str += " III: Setting the Epoch Limit to: {}\n".format(self.epoch_limit)
# Log the epoch size in terms of episodes.
self.epoch_size = self.training.get_epoch_size()
log_str += " Epoch size in terms of training episodes: {}\n".format(self.epoch_size)
# Terminal condition IV: max episodes. Optional.
self.episode_limit = self.config_training['terminal_conditions']['episode_limit']
if self.episode_limit < 0:
log_str += " IV: Termination based on Episode Limit is disabled\n"
# Set to infinity.
self.episode_limit = np.Inf
else:
log_str += " IV: Setting the Episode Limit to: {}\n".format(self.episode_limit)
# Ok, finally print it.
log_str += '='*80
self.logger.info(log_str)
# Export and log configuration, optionally asking the user for confirmation.
config_parsing.display_parsing_results(self.logger, self.app_state.args, self.unparsed)
config_parsing.display_globals(self.logger, self.app_state.globalitems())
config_parsing.export_experiment_configuration_to_yml(self.logger, self.app_state.log_dir, "training_configuration.yml", self.config, self.app_state.args.confirm)
def run_experiment(self):
"""
Main function of the ``OfflineTrainer``, runs the experiment.
Iterates over the (cycled) DataLoader (one iteration = one episode).
.. note::
The test for terminal conditions (e.g. convergence) is done at the end of each episode. \
The terminal conditions are as follows:
- I. The loss is below the specified threshold (using the partial validation loss),
- II. Early stopping is set and the full validation loss did went down \
for the indicated number of validation steps,
- III. The maximum number of episodes has been met (OPTIONAL),
- IV. The maximum number of epochs has been met.
Additionally, experiment can be stopped by the user by pressing 'Stop experiment' \
during visualization.
The function does the following for each episode:
- Handles curriculum learning if set,
- Resets the gradients
- Forwards pass of the model,
- Logs statistics and exports to TensorBoard (if set),
- Computes gradients and update weights
- Activate visualization if set,
- Validate the model on a batch according to the validation frequency.
- Checks the above terminal conditions.
"""
# Initialize TensorBoard and statistics collection.
self.initialize_statistics_collection()
self.initialize_tensorboard()
try:
'''
Main training and validation loop.
'''
# Reset the counters.
self.app_state.episode = -1
self.app_state.epoch = -1
# Set initial status.
training_status = "Not Converged"
################################################################################################
# Beginning of external "epic loop".
################################################################################################
while(True):
self.app_state.epoch += 1
self.logger.info('Starting next epoch: {}\n{}'.format(self.app_state.epoch, '='*80))
# Inform the task managers that epoch has started.
self.training.initialize_epoch()
self.validation.initialize_epoch()
# Apply curriculum learning - change Task parameters.
self.curric_done = self.training.task.curriculum_learning_update_params(
0 if self.app_state.episode < 0 else self.app_state.episode,
self.app_state.epoch)
# Empty the statistics collector.
self.training_stat_col.empty()
############################################################################################
# Beginning of internal "episodic loop".
############################################################################################
for training_batch in self.training.dataloader:
# Next episode.
self.app_state.episode += 1
# reset all gradients
self.optimizer.zero_grad()
# Turn on training mode for the model.
self.pipeline.train()
# 1. Perform forward step.
self.pipeline.forward(training_batch)
# 2. Calculate statistics.
self.collect_all_statistics(self.training, self.pipeline, training_batch, self.training_stat_col)
# 3. Backward gradient flow.
self.pipeline.backward(training_batch)
# Check the presence of the 'gradient_clipping' parameter.
try:
# if present - clip gradients to a range (-gradient_clipping, gradient_clipping)
val = self.config_training['gradient_clipping']
torch.nn.utils.clip_grad_value_(self.pipeline.parameters(), val)
except KeyError:
# Else - do nothing.
pass
# 4. Perform optimization.
self.optimizer.step()
# 5. Log collected statistics.
# 5.1. Export to csv - at every step.
self.training_stat_col.export_to_csv()
# 5.2. Export data to TensorBoard - at logging frequency.
if (self.training_batch_writer is not None) and \
(self.app_state.episode % self.app_state.args.logging_interval == 0):
self.training_stat_col.export_to_tensorboard()
# Export histograms.
if self.app_state.args.tensorboard >= 1:
for name, param in self.pipeline.named_parameters():
try:
self.training_batch_writer.add_histogram(name,
param.data.cpu().numpy(), self.app_state.episode, bins='doane')
except Exception as e:
self.logger.error(" {} :: data :: {}".format(name, e))
# Export gradients.
if self.app_state.args.tensorboard >= 2:
for name, param in self.pipeline.named_parameters():
try:
self.training_batch_writer.add_histogram(name + '/grad',
param.grad.data.cpu().numpy(), self.app_state.episode, bins='doane')
except Exception as e:
self.logger.error(" {} :: grad :: {}".format(name, e))
# 5.3. Log to logger - at logging frequency.
if self.app_state.episode % self.app_state.args.logging_interval == 0:
self.logger.info(self.training_stat_col.export_to_string())
# 6. Validate and (optionally) save the model.
if self.partial_validation_interval > 0 and (self.app_state.episode % self.partial_validation_interval) == 0:
# Clear the validation batch from all items aside of the ones originally returned by the task.
self.validation.batch.reinitialize(self.validation.task.output_data_definitions())
# Perform validation.
self.validate_on_batch(self.validation.batch)
# Do not save the model: OfflineTrainer uses the full set to determine whether to save or not.
# III. The episodes number limit has been reached.
if self.app_state.episode+1 >= self.episode_limit: # = np.Inf when inactive.
# If we reach this condition, then it is possible that the model didn't converge correctly
# but it currently might get better since last validation.
training_status = "Not converged: Episode Limit reached"
raise TerminationCondition(training_status)
############################################################################################
# End of internal "episodic loop".
############################################################################################
# Epoch just ended!
self.logger.info('End of epoch: {}\n{}'.format(self.app_state.epoch, '='*80))
# Aggregate training statistics for the epoch.
self.aggregate_all_statistics(self.training, self.pipeline, self.training_stat_col, self.training_stat_agg)
self.export_all_statistics( self.training_stat_agg, '[Full Training]')
# Inform the training task manager that the epoch has ended.
self.training.finalize_epoch()
# Validate over the entire validation set.
self.validate_on_set()
# Get loss.
validation_set_loss = self.pipeline.return_loss_on_set(self.validation_stat_agg)
# Save the pipeline using the latest validation statistics.
self.pipeline.save(self.checkpoint_dir, training_status, validation_set_loss)
# Inform the validation task manager that the epoch has ended.
self.validation.finalize_epoch()
# Terminal conditions.
# I. The loss is < threshold (only when curriculum learning is finished if set).
# We check that condition only in validation step!
if self.curric_done or not self.must_finish_curriculum:
# Check the Partial Validation loss.
if (validation_set_loss < self.loss_stop_threshold):
# Change the status.
training_status = "Converged (Full Validation Loss went below " \
"Loss Stop threshold of {})".format(self.loss_stop_threshold)
# Save the pipeline (update its statistics).
self.pipeline.save(self.checkpoint_dir, training_status, validation_set_loss)
# And leave both loops.
raise TerminationCondition(training_status)
# II. Early stopping is set and loss hasn't improved by delta in n epochs.
if self.pipeline.validation_loss_down_counter >= self.early_stop_validations:
training_status = "Not converged: reached limit of validations without improvement (Early Stopping)"
raise TerminationCondition(training_status)
# IV. Epoch limit has been reached.
if self.app_state.epoch+1 >= self.epoch_limit:
training_status = "Not converged: Epoch Limit reached"
# "Finish" the training.
raise TerminationCondition(training_status)
################################################################################################
# End of external "epic loop".
################################################################################################
except TerminationCondition as e:
# End of main training and validation loop. Perform final full validation.
self.logger.info('\n' + '='*80)
self.logger.info('Training finished because {}'.format(training_status))
# If episode limit was reached - perform last validation on the full set.
if training_status == "Not converged: Episode Limit reached":
# Validate over the entire validation set.
self.validate_on_set()
# Get loss.
validation_set_loss = self.pipeline.return_loss_on_set(self.validation_stat_agg)
# Save the pipeline using the latest validation statistics.
self.pipeline.save(self.checkpoint_dir, training_status, validation_set_loss)
self.logger.info('Experiment finished!')
except SystemExit as e:
# the training did not end properly
self.logger.error('Experiment interrupted because {}'.format(e))
except ConfigurationError as e:
# the training did not end properly
self.logger.error('Experiment interrupted because {}'.format(e))
except KeyboardInterrupt:
# the training did not end properly
self.logger.error('Experiment interrupted!')
finally:
# Finalize statistics collection.
self.finalize_statistics_collection()
self.finalize_tensorboard()
self.logger.info("Experiment logged to: {}".format(self.app_state.log_dir))
def main():
"""
Entry point function for the ``OfflineTrainer``.
"""
# Create trainer.
trainer = OfflineTrainer()
# Parse args, load configuration and create all required objects.
trainer.setup_experiment()
# GO!
trainer.run_experiment()
if __name__ == '__main__':
main()