PiePline/piepline

View on GitHub
piepline/train.py

Summary

Maintainability
A
25 mins
Test Coverage
"""
The main module for training process
"""
import math

import torch

from piepline import events_container
from piepline.train_config.train_config import BaseTrainConfig
from piepline.data_processor.data_processor import TrainDataProcessor
from piepline.utils.events_system import Event
from piepline.utils.messages_system import MessageReceiver

__all__ = ['Trainer']


class LearningRate:
    """
    Basic learning rate class
    """

    def __init__(self, value: float):
        self._value = value

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: current value
        """
        return self._value

    def set_value(self, value) -> None:
        """
        Set lr value

        :param value: lr value
        """
        self._value = value


class DecayingLR(LearningRate):
    """
    This class provide lr decaying by defined metric value (by :arg:`target_value_clbk`).
    If metric value doesn't update minimum after defined number of steps (:arg:`patience`) - lr was decaying
    by defined coefficient (:arg:`decay_coefficient`).

    :param start_value: start value
    :param decay_coefficient: coefficient of decaying
    :param patience: steps before decay
    :param target_value_clbk: callable, that return target value for lr decaying
    """

    def __init__(self, start_value: float, decay_coefficient: float, patience: int, target_value_clbk: callable):
        super().__init__(start_value)

        self._decay_coefficient = decay_coefficient
        self._patience = patience
        self._cur_step = 1
        self._target_value_clbk = target_value_clbk
        self._cur_min_target_val = None

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: learning rate value
        """
        metric_val = self._target_value_clbk()
        if metric_val is None:
            return self._value

        if self._cur_min_target_val is None:
            self._cur_min_target_val = metric_val

        if metric_val < self._cur_min_target_val:
            self._cur_step = 1
            self._cur_min_target_val = metric_val
            return self._value

        if self._cur_step > 0 and (self._cur_step % self._patience) == 0:
            self._value *= self._decay_coefficient
            self._cur_min_target_val = None
            self._cur_step = 1
            return self._value

        self._cur_step += 1
        return self._value

    def set_value(self, value):
        self._value = value
        self._cur_step = 0
        self._cur_min_target_val = None


class CosineAnnealingLR(LearningRate):
    """
    This class provide lr decaying by defined metric value (by :arg:`target_value_clbk`).
    If metric value doesn't update minimum after defined number of steps (:arg:`patience`) - lr was decaying
    by defined coefficient (:arg:`decay_coefficient`).

    :param T_max: Maximum number of iterations.
    :param eta_min: Minimum learning rate. Default: 0.
    """

    def __init__(self, start_value: float, T_max: int, eta_min: float = 0):
        super().__init__(start_value)
        self._start_value = start_value
        self._T_max = T_max
        self._eta_min = eta_min
        self._cur_step = 1

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: learning rate value
        """
        if self._cur_step > self._T_max:
            self._cur_step = 1
        self._value = self._eta_min + (self._start_value - self._eta_min) \
                      * (1 + math.cos(math.pi * self._cur_step / self._T_max)) / 2
        self._cur_step += 1

        return self._value

    def set_value(self, value):
        self._value = value
        self._cur_step = 0


class StepLR(LearningRate):
    """
    This class provide lr decaying by defined metric value (by :arg:`target_value_clbk`).
    If metric value doesn't update minimum after defined number of steps (:arg:`patience`) - lr was decaying
    by defined coefficient (:arg:`decay_coefficient`).

    :param step_size: Period of learning rate decay.
    :param eta_min: Multiplicative factor of learning rate decay. Default: 0.1.
    """

    def __init__(self, start_value: float, step_size: int, gamma: float = 0.1):
        super().__init__(start_value)
        self._start_value = start_value
        self._step_size = step_size
        self._gamma = gamma
        self._cur_step = 1

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: learning rate value
        """
        if self._cur_step % self._step_size == 0:
            self._value *= self._gamma
        self._cur_step += 1

        return self._value

    def set_value(self, value):
        self._value = value
        self._cur_step = 0


class Trainer(MessageReceiver):
    """
    Class, that run drive process.

    Trainer get list of training stages and every epoch loop over it.

    Training process looks like:

    .. highlight:: python
    .. code-block:: python

        for epoch in epochs_num:
            for stage in training_stages:
                stage.run()
                monitor_hub.update_metrics(stage.metrics_processor().get_metrics())
            save_state()
            on_epoch_end_callback()

    :param train_config: :class:`TrainConfig` object
    :param device: device for training process
    """

    class TrainerException(Exception):
        def __init__(self, msg):
            super().__init__()
            self._msg = msg

        def __str__(self):
            return self._msg

    def __init__(self, train_config: BaseTrainConfig, device: torch.device = None):
        MessageReceiver.__init__(self)

        self.__epoch_num, self._cur_epoch_id = 100, 0

        self._train_config = train_config
        self._data_processor = TrainDataProcessor(self._train_config, device)
        self._lr = LearningRate(self._data_processor.get_lr())

        self._epoch_end_event = events_container.add_event('EPOCH_END', Event(self))
        self._epoch_start_event = events_container.add_event('EPOCH_START', Event(self))
        self._train_done_event = events_container.add_event('TRAIN_DONE', Event(self))

        self._add_message('NEED_STOP')

    def set_epoch_num(self, epoch_number: int) -> 'Trainer':
        """
        Define number of epoch for training. One epoch - one iteration over all train stages

        :param epoch_number: number of training epoch
        :return: self object
        """
        self.__epoch_num = epoch_number
        return self

    def set_lr_scheduler(self, lr_scheduler):
        self._lr = lr_scheduler

    def enable_lr_decaying(self, coeff: float, patience: int, target_val_clbk: callable) -> 'Trainer':
        """
        Enable rearing rate decaying. Learning rate decay when `target_val_clbk` returns doesn't update
        minimum for `patience` steps

        :param coeff: lr decay coefficient
        :param patience: number of steps
        :param target_val_clbk: callback which returns the value that is used for lr decaying
        :return: self object
        """
        self._lr = DecayingLR(self._data_processor.get_lr(), coeff, patience, target_val_clbk)
        return self

    def cur_epoch_id(self) -> int:
        """
        Get current epoch index
        """
        return self._cur_epoch_id

    def set_cur_epoch(self, idx: int) -> 'Trainer':
        self._cur_epoch_id = idx
        return self

    def train(self) -> None:
        """
        Run training process
        """
        if len(self._train_config.stages()) < 1:
            raise self.TrainerException("There's no sages for training")

        start_epoch_idx = self._cur_epoch_id

        self._connect_stages_to_events()

        for epoch_idx in range(start_epoch_idx, self.__epoch_num + start_epoch_idx):
            if True in self.message('NEED_STOP').read():
                break

            self._cur_epoch_id = epoch_idx
            self._epoch_start_event()

            for stage in self._train_config.stages():
                stage.run(self._data_processor)

            self._data_processor.update_lr(self._lr.value())
            self._epoch_end_event()

        self._train_done_event()

    def _update_losses(self) -> None:
        """
        Update loses procedure
        """
        losses = {}
        for stage in self._train_config.stages():
            if stage.get_losses() is not None:
                losses[stage.name()] = stage.get_losses()
        self.monitor_hub.update_losses(losses)

    def data_processor(self) -> TrainDataProcessor:
        """
        Get data processor object

        :return: data processor
        """
        return self._data_processor

    def train_config(self) -> BaseTrainConfig:
        """
        Get train config

        :return: TrainConfig object
        """
        return self._train_config

    def _connect_stages_to_events(self):
        for stage in self._train_config.stages():
            self._epoch_end_event.add_callback(lambda x: stage.on_epoch_end())