ptp/workers/online_trainer.py
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__author__ = "Tomasz Kornuta, Vincent Marois"
import torch
import numpy as np
from ptp.workers.trainer import Trainer
import ptp.configuration.config_parsing as config_parsing
from ptp.configuration.configuration_error import ConfigurationError
from ptp.utils.termination_condition import TerminationCondition
class OnlineTrainer(Trainer):
"""
Implementation for the episode-based ``OnlineTrainer``.
..note ::
The ``OfflineTrainer`` is based on epochs. While an epoch can be defined for all finite-size datasets, \
it makes less sense for tasks which have a very large, almost infinite, dataset (like algorithmic \
tasks, which generate random data on-the-fly). \
This is why this OnlineTrainer was implemented. Despite the fact it has the notion of epoch, it is more \
flexible and operates on episodes (we call an iteration on a single batch an episode). \
"""
def __init__(self):
"""
Constructor. It on calls the ``Trainer`` constructor as the initialization phase is identical to the one from ``Trainer``.
"""
# Call base constructor to set up app state, registry and add default config.
super(OnlineTrainer, self).__init__("OnlineTrainer", OnlineTrainer)
def setup_experiment(self):
"""
Sets up experiment for episode trainer:
- Calls base class setup_experiment to parse the command line arguments,
- Sets up the terminal conditions (loss threshold, episodes & epochs (optional) limits).
"""
# Call base method to parse all command line arguments, load configuration, create tasks and model etc.
super(OnlineTrainer, self).setup_experiment()
# In this trainer Partial Validation is mandatory, hence interval must be > 0.
self.partial_validation_interval = self.config['validation']['partial_validation_interval']
if self.partial_validation_interval <= 0:
self.logger.error("Online Trainer relies on Partial Validation, thus 'partial_validation_interval' must be a positive number!")
exit(-4)
else:
self.logger.info("Partial Validation activated with interval equal to {} episodes\n".format(self.partial_validation_interval))
################# TERMINAL CONDITIONS #################
log_str = 'Terminal conditions:\n' + '='*80 + "\n"
# Terminal condition I: loss.
self.loss_stop_threshold = self.config_training['terminal_conditions']['loss_stop_threshold']
log_str += " I: Setting Loss Stop Threshold to {}\n".format(self.loss_stop_threshold)
# Terminal condition II: early stopping.
self.early_stop_validations = self.config_training['terminal_conditions']['early_stop_validations']
if self.early_stop_validations <= 0:
log_str += " II: Termination based on Early Stopping is disabled\n"
# Set to infinity.
self.early_stop_validations = np.Inf
else:
log_str += " II: Setting the Number of Validations in Early Stopping to: {}\n".format(self.early_stop_validations)
# Terminal condition III: max epochs (Optional for this trainer)
self.epoch_limit = self.config_training["terminal_conditions"]["epoch_limit"]
if self.epoch_limit <= 0:
log_str += " III: Termination based on Epoch Limit is disabled\n"
# Set to infinity.
self.epoch_limit = np.Inf
else:
log_str += " III: Setting the Epoch Limit to: {}\n".format(self.epoch_limit)
# Log the epoch size in terms of episodes.
self.epoch_size = self.training.get_epoch_size()
log_str += " Epoch size in terms of training episodes: {}\n".format(self.epoch_size)
# Terminal condition IV: max episodes. Mandatory.
self.episode_limit = self.config_training['terminal_conditions']['episode_limit']
if self.episode_limit <= 0:
self.logger.error("OnLine Trainer relies on episodes, thus 'episode_limit' must be a positive number!")
exit(-5)
else:
log_str += " IV: Setting the Episode Limit to: {}\n".format(self.episode_limit)
# Ok, finally print it.
log_str += '='*80
self.logger.info(log_str)
# Export and log configuration, optionally asking the user for confirmation.
config_parsing.display_parsing_results(self.logger, self.app_state.args, self.unparsed)
config_parsing.display_globals(self.logger, self.app_state.globalitems())
config_parsing.export_experiment_configuration_to_yml(self.logger, self.app_state.log_dir, "training_configuration.yml", self.config, self.app_state.args.confirm)
def run_experiment(self):
"""
Main function of the ``OnlineTrainer``, runs the experiment.
Iterates over the (cycled) DataLoader (one iteration = one episode).
.. note::
The test for terminal conditions (e.g. convergence) is done at the end of each episode. \
The terminal conditions are as follows:
- I. The loss is below the specified threshold (using the partial validation loss),
- II. Early stopping is set and the full validation loss did went down \
for the indicated number of validation steps,
- III. The maximum number of episodes has been met,
- IV. The maximum number of epochs has been met (OPTIONAL).
Additionally, experiment can be stopped by the user by pressing 'Stop experiment' \
during visualization.
The function does the following for each episode:
- Handles curriculum learning if set,
- Resets the gradients
- Forwards pass of the model,
- Logs statistics and exports to TensorBoard (if set),
- Computes gradients and update weights
- Activate visualization if set,
- Validate the model on a batch according to the validation frequency.
- Checks the above terminal conditions.
"""
# Initialize TensorBoard and statistics collection.
self.initialize_statistics_collection()
self.initialize_tensorboard()
try:
'''
Main training and validation loop.
'''
# Reset the counters.
self.app_state.episode = -1
self.app_state.epoch = -1
# Set initial status.
training_status = "Not Converged"
################################################################################################
# Beginning of external "epic loop".
################################################################################################
while(True):
self.app_state.epoch += 1
self.logger.info('Starting next epoch: {}\n{}'.format(self.app_state.epoch, '='*80))
# Inform the task managers that epoch has started.
self.training.initialize_epoch()
self.validation.initialize_epoch()
# Apply curriculum learning - change Task parameters.
self.curric_done = self.training.task.curriculum_learning_update_params(
0 if self.app_state.episode < 0 else self.app_state.episode,
self.app_state.epoch)
# Empty the statistics collector.
self.training_stat_col.empty()
############################################################################################
# Beginning of internal "episodic loop".
############################################################################################
for training_batch in self.training.dataloader:
# Next episode.
self.app_state.episode += 1
# reset all gradients
self.optimizer.zero_grad()
# Turn on training mode for the model.
self.pipeline.train()
# 1. Perform forward step.
self.pipeline.forward(training_batch)
# 2. Calculate statistics.
self.collect_all_statistics(self.training, self.pipeline, training_batch, self.training_stat_col)
# 3. Backward gradient flow.
self.pipeline.backward(training_batch)
# Check the presence of the 'gradient_clipping' parameter.
try:
# if present - clip gradients to a range (-gradient_clipping, gradient_clipping)
val = self.config_training['gradient_clipping']
torch.nn.utils.clip_grad_value_(self.pipeline.parameters(), val)
except KeyError:
# Else - do nothing.
pass
# 4. Perform optimization.
self.optimizer.step()
# 5. Log collected statistics.
# 5.1. Export to csv - at every step.
self.training_stat_col.export_to_csv()
# 5.2. Export data to TensorBoard - at logging frequency.
if (self.training_batch_writer is not None) and \
(self.app_state.episode % self.app_state.args.logging_interval == 0):
self.training_stat_col.export_to_tensorboard()
# Export histograms.
if self.app_state.args.tensorboard >= 1:
for name, param in self.pipeline.named_parameters():
try:
self.training_batch_writer.add_histogram(name,
param.data.cpu().numpy(), self.app_state.episode, bins='doane')
except Exception as e:
self.logger.error(" {} :: data :: {}".format(name, e))
# Export gradients.
if self.app_state.args.tensorboard >= 2:
for name, param in self.pipeline.named_parameters():
try:
self.training_batch_writer.add_histogram(name + '/grad',
param.grad.data.cpu().numpy(), self.app_state.episode, bins='doane')
except Exception as e:
self.logger.error(" {} :: grad :: {}".format(name, e))
# 5.3. Log to logger - at logging frequency.
if self.app_state.episode % self.app_state.args.logging_interval == 0:
self.logger.info(self.training_stat_col.export_to_string())
# 6. Validate and (optionally) save the model.
if (self.app_state.episode % self.partial_validation_interval) == 0:
# Clear the validation batch from all items aside of the ones originally returned by the task.
self.validation.batch.reinitialize(self.validation.task.output_data_definitions())
# Perform validation.
self.validate_on_batch(self.validation.batch)
# Get loss.
validation_batch_loss = self.pipeline.return_loss_on_batch(self.validation_stat_col)
# Save the pipeline using the latest validation statistics.
self.pipeline.save(self.checkpoint_dir, training_status, validation_batch_loss)
# Terminal conditions.
# I. The loss is < threshold (only when curriculum learning is finished if set).
# We check that condition only in validation step!
if self.curric_done or not self.must_finish_curriculum:
# Check the Partial Validation loss.
if (validation_batch_loss < self.loss_stop_threshold):
# Change the status.
training_status = "Converged (Partial Validation Loss went below " \
"Loss Stop threshold {})".format(self.loss_stop_threshold)
# Save the pipeline (update its statistics).
self.pipeline.save(self.checkpoint_dir, training_status, validation_batch_loss)
# And leave both loops.
raise TerminationCondition(training_status)
# II. Early stopping is set and loss hasn't improved by delta in n epochs.
if self.pipeline.validation_loss_down_counter >= self.early_stop_validations:
training_status = "Not converged: reached limit of validations without improvement (Early Stopping)"
raise TerminationCondition(training_status)
# III. The episodes number limit has been reached.
if self.app_state.episode+1 >= self.episode_limit:
# If we reach this condition, then it is possible that the model didn't converge correctly
# but it currently might get better since last validation.
training_status = "Not converged: Episode Limit reached"
raise TerminationCondition(training_status)
############################################################################################
# End of internal "episodic loop".
############################################################################################
# Epoch just ended!
self.logger.info('End of epoch: {}\n{}'.format(self.app_state.epoch, '='*80))
# Inform the task managers that the epoch has ended.
self.training.finalize_epoch()
self.validation.finalize_epoch()
# Aggregate training statistics for the epoch.
self.aggregate_all_statistics(self.training, self.pipeline, self.training_stat_col, self.training_stat_agg)
self.export_all_statistics( self.training_stat_agg, '[Full Training]')
# IV. Epoch limit has been reached.
if self.app_state.epoch+1 >= self.epoch_limit: # = np.Inf when inactive.
training_status = "Not converged: Epoch Limit reached"
# "Finish" the training.
raise TerminationCondition(training_status)
################################################################################################
# End of external "epic loop".
################################################################################################
except TerminationCondition as e:
# End of main training and validation loop. Perform final full validation.
# Eventually perform "last" validation on batch.
if self.validation_stat_col["episode"][-1] != self.app_state.episode:
# We still must validate and try to save the model as it may performed better during this episode.
# Clear the validation batch from all items aside of the ones originally returned by the task.
self.validation.batch.reinitialize(self.validation.task.output_data_definitions())
# Perform validation.
self.validate_on_batch(self.validation.batch)
# Get loss.
validation_batch_loss = self.pipeline.return_loss_on_batch(self.validation_stat_col)
# Try to save the model using the latest validation statistics.
self.pipeline.save(self.checkpoint_dir, training_status, validation_batch_loss)
self.logger.info('\n' + '='*80)
self.logger.info('Training finished because {}'.format(training_status))
# Validate over the entire validation set.
self.validate_on_set()
# Do not save the model, as we tried it already on "last" validation batch.
self.logger.info('Experiment finished!')
except SystemExit as e:
# the training did not end properly
self.logger.error('Experiment interrupted because {}'.format(e))
except ConfigurationError as e:
# the training did not end properly
self.logger.error('Experiment interrupted because {}'.format(e))
except KeyboardInterrupt:
# the training did not end properly
self.logger.error('Experiment interrupted!')
finally:
# Finalize statistics collection.
self.finalize_statistics_collection()
self.finalize_tensorboard()
self.logger.info("Experiment logged to: {}".format(self.app_state.log_dir))
def main():
"""
Entry point function for the ``OnlineTrainer``.
"""
# Create trainer.
trainer = OnlineTrainer()
# Parse args, load configuration and create all required objects.
trainer.setup_experiment()
# GO!
trainer.run_experiment()
if __name__ == '__main__':
main()