PiePline/piepline

View on GitHub
piepline/monitoring/hub.py

Summary

Maintainability
A
0 mins
Test Coverage
from piepline.monitoring.monitors import AbstractMonitor
from piepline import events_container
from piepline.train import Trainer
from piepline.train_config.metrics_processor import MetricsProcessor

__all__ = ['MonitorHub']


class MonitorHub:
    """
    Aggregator of monitors. This class collect monitors and provide unified interface to it's
    """

    def __init__(self, trainer: Trainer):
        self.monitors = []
        events_container.event(trainer, 'EPOCH_START').add_callback(lambda t: self.set_epoch_num(t.cur_epoch_id()))

    def subscribe2metrics_processor(self, metrics_processor: MetricsProcessor) -> 'MonitorHub':
        events_container.event(metrics_processor, "BEFORE_METRICS_RESET").add_callback(lambda mp: self.update_metrics(mp.get_metrics()))
        return self

    def set_epoch_num(self, epoch_num: int) -> None:
        """
        Set current epoch num

        :param epoch_num: num of current epoch
        """
        for m in self.monitors:
            m.set_epoch_num(epoch_num)

    def add_monitor(self, monitor: AbstractMonitor) -> 'MonitorHub':
        """
        Connect monitor to hub

        :param monitor: :class:`AbstractMonitor` object
        :return:
        """
        self.monitors.append(monitor)
        return self

    def update_metrics(self, metrics: {}) -> None:
        """
        Update metrics in all monitors

        :param metrics: metrics dict with keys 'metrics' and 'groups'
        """
        for m in self.monitors:
            if hasattr(m, 'update_metrics'):
                m.update_metrics(metrics)

    def update_losses(self, losses: {}) -> None:
        """
        Update monitor

        :param losses: losses values with keys 'train' and 'validation'
        """
        for m in self.monitors:
            if hasattr(m, 'update_losses'):
                m.update_losses(losses)

    def register_event(self, text: str) -> None:
        for m in self.monitors:
            m.register_event(text)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for m in self.monitors:
            m.__exit__(exc_type, exc_val, exc_tb)