piepline/train_config/stages.py
from abc import ABCMeta, abstractmethod
from typing import Dict
from torch.utils.data.dataloader import DataLoader
import numpy as np
from piepline.data_producer.data_producer import DataProducer
from piepline import events_container
from piepline.utils.events_system import Event
__all__ = ['AbstractStage', 'TrainStage', 'ValidationStage']
class AbstractStage(metaclass=ABCMeta):
"""
Stage of training process. For example there may be 2 stages: train and validation.
Every epochs in train loop is iteration by stages.
:param name: name of stage
"""
def __init__(self, name: str):
self._name = name
self._stage_end_event = events_container.add_event('STAGE_END', Event(self))
def name(self) -> str:
"""
Get name of stage
:return: name
"""
return self._name
@abstractmethod
def _run(self, data_processor: 'DataProcessor') -> None:
"""
Internal method with stage run implementation. This method was called in :meth:`run`
"""
def run(self, data_processor: 'DataProcessor') -> None:
"""
Run stage
Args:
data_processor (class:`DataProcessor`): data processor object
"""
self._run(data_processor)
self._stage_end_event()
self._after_epoch_end()
def _after_epoch_end(self):
pass
def get_losses(self) -> np.ndarray or None:
"""
Get losses from this stage
:return: array of losses or None if this stage doesn't need losses
"""
return None
def on_epoch_end(self) -> None:
"""
Callback for train epoch end
"""
pass
class StandardStage(AbstractStage):
"""
Standard stage for train process.
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
"""
def __init__(self, stage_name: str, is_train: bool, data_producer: DataProducer):
super().__init__(name=stage_name)
self.data_loader = None
self.data_producer = data_producer
self._losses = None
self._last_losses = None
self._is_train = is_train
self._last_result = None
self._epoch_end_event = events_container.add_event('EPOCH_END', Event(self))
self._epoch_start_event = events_container.add_event('EPOCH_START', Event(self))
self._batch_processed = events_container.add_event('BATCH_PROCESSED', Event(self))
def _run(self, data_processor: 'TrainDataProcessor') -> None:
"""
Run stage. This iterate by DataProducer and show progress in stdout
:param data_processor: :class:`DataProcessor` object
"""
if self.data_loader is None:
self.data_loader = self.data_producer.get_loader()
self._run_internal(self.data_loader, name=self.name(), data_processor=data_processor)
self._epoch_end_event()
def _run_internal(self, data_loader: DataLoader, name: str, data_processor: 'TrainDataProcessor'):
self._epoch_start_event()
self._losses = None
for batch in data_loader:
self._process_batch(batch, data_processor)
self._batch_processed()
def _after_epoch_end(self):
self._last_losses = self._losses
self._losses = None
self._last_result = None
def _process_batch(self, batch, data_processor: 'TrainDataProcessor'):
cur_loss, cur_predict, cur_target = data_processor.process_batch(batch, is_train=self._is_train)
self._last_result = {'output': cur_predict, 'target': cur_target}
cur_loss = cur_loss.detach().cpu().numpy()
if self._losses is None:
self._losses = cur_loss
else:
self._losses = np.append(self._losses, cur_loss)
def get_losses(self) -> np.ndarray:
"""
Get losses from this stage
:return: array of losses
"""
return self._losses if self._losses is not None else self._last_losses
def get_last_result(self) -> Dict['output', 'target']:
return self._last_result
class TrainStage(StandardStage):
"""
Standard training stage
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader with ``is_train=True`` flag.
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
:param name: name of stage. By default 'train'
"""
class _HardNegativesTrainStage(StandardStage):
def __init__(self, stage_name: str, data_producer: DataProducer, part: float):
super().__init__(stage_name, True, data_producer)
self._part = part
def exec(self, data_processor: 'TrainDataProcessor', losses: np.ndarray, indices: []) -> None:
num_losses = int(losses.size * self._part)
idxs = np.argpartition(losses, -num_losses)[-num_losses:]
self._run_internal(self.data_producer.get_loader([indices[i] for i in idxs]), self.name(), data_processor)
self._losses = None
def __init__(self, data_producer: DataProducer, name: str = 'train'):
super().__init__(name, True, data_producer)
self.hnm = None
self.hn_indices = []
self._dp_pass_indices_earlier = False
def enable_hard_negative_mining(self, part: float) -> 'TrainStage':
"""
Enable hard negative mining. Hard negatives was taken by losses values
:param part: part of data that repeat after train stage
:return: self object
"""
if not 0 < part < 1:
raise ValueError('Value of part for hard negative mining is out of range (0, 1)')
self.hnm = self._HardNegativesTrainStage(self.name() + '_hnm', self.data_producer, part)
self._dp_pass_indices_earlier = self.data_producer._is_passed_indices()
self.data_producer.pass_indices(True)
return self
def disable_hard_negative_mining(self) -> 'TrainStage':
"""
Enable hard negative mining.
:return: self object
"""
self.hnm = None
if not self._dp_pass_indices_earlier:
self.data_producer.pass_indices(False)
return self
def _run(self, data_processor: 'TrainDataProcessor') -> None:
"""
Run stage
:param data_processor: :class:`TrainDataProcessor` object
"""
super()._run(data_processor)
if self.hnm is not None:
self.hnm.exec(data_processor, self._losses, self.hn_indices)
self.hn_indices = []
def _process_batch(self, batch, data_processor: 'TrainDataProcessor') -> None:
"""
Internal method for process one bathc
:param batch: batch
:param data_processor: :class:`TrainDataProcessor` instance
"""
if self.hnm is not None:
self.hn_indices.append(batch['data_idx'])
super()._process_batch(batch, data_processor)
def on_epoch_end(self):
"""
Method, that calls after every epoch
"""
super().on_epoch_end()
if self.hnm is not None:
self.hnm.on_epoch_end()
class ValidationStage(StandardStage):
"""
Standard validation stage.
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader with ``is_tran=False`` flag.
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
:param name: name of stage. By default 'validation'
"""
def __init__(self, data_producer: DataProducer, name: str = 'validation'):
super().__init__(name, False, data_producer)