Unbabel/OpenKiwi

View on GitHub
kiwi/loggers.py

Summary

Maintainability
A
45 mins
Test Coverage
B
81%
#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging
import re
import threading
from typing import Any, Dict, Optional

import torch
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.utilities import rank_zero_only

from kiwi.utils.io import generate_slug

logger = logging.getLogger(__name__)

_INVALID_PARAM_AND_METRIC_CHARACTERS = re.compile(r'[^/\w.\- ]')


def normalize_metric_key(key):
    """Normalize key name for MLflow.

    mlflow.exceptions.MlflowException: Invalid metric name: 'WMT19_F1_MULT+PEARSON'.
    Names may only contain alphanumerics, underscores (_), dashes (-), periods (.),
    spaces ( ), and slashes (/).

    This is raised by matching against ``r"^[/\\w.\\- ]*$"``.
    """
    return _INVALID_PARAM_AND_METRIC_CHARACTERS.sub('-', key)


def validate_metric_value(value):
    if torch.is_tensor(value):
        return value.mean().item()
    else:
        try:
            return float(value)
        except TypeError:
            return None


class MLFlowTrackingLogger(MLFlowLogger):
    def __init__(
        self,
        experiment_name: str = 'default',
        run_id: Optional[str] = None,
        tracking_uri: Optional[str] = None,
        tags: Optional[Dict[str, Any]] = None,
        save_dir: Optional[str] = None,
        always_log_artifacts: bool = False,
    ):
        super().__init__(
            experiment_name=experiment_name,
            tracking_uri=tracking_uri,
            tags=tags,
            save_dir=save_dir,
        )
        self.always_log_artifacts = always_log_artifacts

        if run_id:
            run = self._mlflow_client.get_run(run_id)
            self._run_id = run.info.run_id
            self._experiment_id = run.info.experiment_id
        else:
            # Force creation of experiment and run
            _ = self.run_id

    @property
    def tracking_uri(self):
        from mlflow import get_tracking_uri

        return get_tracking_uri()

    @rank_zero_only
    def log_param(self, key, value):
        self.experiment.log_param(self.run_id, key, value)

    @rank_zero_only
    def log_hyperparams(self, params: Dict[str, Any]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params, delimiter='.')
        for k, v in params.items():
            self.experiment.log_param(self.run_id, k, v)

    @rank_zero_only
    def log_metrics(
        self, metrics: Dict[str, float], step: Optional[int] = None, prefix=''
    ) -> None:
        # FIXME: do this in the background (non-blocking way)
        normalized_metrics = {}
        for key, value in metrics.items():
            if prefix:
                key = '{}_{}'.format(prefix, key)
            key = normalize_metric_key(key)
            value = validate_metric_value(value)
            normalized_metrics[key] = value
        super().log_metrics(metrics=normalized_metrics, step=step)

    @rank_zero_only
    def log_artifact(self, local_path, artifact_path=None):
        t = threading.Thread(
            target=self.experiment.log_artifact,
            args=(self.run_id, local_path),
            kwargs={'artifact_path': artifact_path},
            daemon=True,
        )
        t.start()
        return t

    @rank_zero_only
    def log_artifacts(self, local_dir, artifact_path=None):
        def send(e, run_id, dpath, path):
            self.experiment.log_artifacts(run_id, dpath, artifact_path=path)
            e.set()

        event = threading.Event()
        t = threading.Thread(
            target=send,
            args=(event, self.run_id, local_dir, artifact_path),
            daemon=True,
        )
        t.start()
        return event

    def get_artifact_uri(self):
        run = self.experiment.get_run(self.run_id)
        return run.info.artifact_uri

    @rank_zero_only
    def log_model(self, local_file, name=None):
        from mlflow.exceptions import MlflowException

        if not name:
            name = generate_slug(self._experiment_name)

        try:
            self._mlflow_client.create_model_version(
                name=name, source=local_file, run_id=self.run_id
            )
        except MlflowException as e:
            logger.warning(e.message)
            logger.info('Logging model as artifact instead')
            thread = self.log_artifact(local_file, artifact_path='model')
            thread.join()

    @rank_zero_only
    def log_tag(self, name: str, value: str):
        self._mlflow_client.set_tag(self.run_id, name, value)

    @rank_zero_only
    def log_tags(self, tags: Dict[str, str]):
        for name, value in tags.items():
            self.log_tag(name, value)