kiwi/lib/utils.py
# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2020 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import json
import logging
import os.path
from pathlib import Path
from time import gmtime
from typing import Dict, List, Union
import hydra.experimental
import hydra.utils
import yaml
from hydra._internal.hydra import Hydra
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from kiwi.utils.io import BaseConfig
def configure_seed(seed: int):
"""Configure the random seed for all relevant packages.
These include: random, numpy, torch, torch.cuda and PYTHONHASHSEED.
Arguments:
seed: the random seed to be set.
"""
seed_everything(seed)
def configure_logging(
output_dir: Path = None, verbose: bool = False, quiet: bool = False
):
"""Configure the output logger.
Set up the log format, logging level, and output directory of logging.
Arguments:
output_dir: the directory where log output will be stored; defaults to None.
verbose: change logging level to debug.
quiet: change logging level to warning to suppress info logs.
"""
logging.Formatter.converter = gmtime
date_format = '%Y-%m-%dT%H:%M:%SZ'
log_format = '%(asctime)s %(levelname)-8s %(name)24.24s:%(lineno)3.3s: %(message)s'
if logging.getLogger().handlers:
log_formatter = logging.Formatter(log_format, datefmt=date_format)
for handler in logging.getLogger().handlers:
handler.setFormatter(log_formatter)
else:
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format)
log_level = logging.INFO
if verbose:
log_level = logging.DEBUG
if quiet:
log_level = logging.WARNING
logging.getLogger().setLevel(log_level)
if output_dir is not None:
fh = logging.FileHandler(str(output_dir / 'output.log'))
fh.setLevel(log_level)
logging.getLogger().addHandler(fh)
# Silence urllib3
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
def save_config_to_file(config: BaseConfig, file_name: Union[str, Path]):
"""Save a configuration object to file.
Arguments:
file_name: where to saved the configuration.
config: a pydantic configuration object.
"""
path = Path(file_name).with_suffix('.yaml')
yaml.dump(json.loads(config.json()), path.open('w'))
logging.debug(f'Saved current options to config file: {path}')
def setup_output_directory(
output_dir, run_uuid=None, experiment_id=None, create=True
) -> str:
"""Set up the output directory.
This means either creating one, or verifying that the provided directory exists.
Output directories are created using the run and experiment ids.
Arguments:
output_dir: the target output directory.
run_uuid : the hash of the current run.
experiment_id: the id of the current experiment.
create: whether to create the directory.
Return:
the path to the resolved output directory.
"""
if not output_dir:
if experiment_id is None or run_uuid is None:
raise ValueError('No output directory or run_uuid have been specified.')
output_path = Path('runs', str(experiment_id), str(run_uuid))
output_dir = str(output_path)
if create:
Path(output_dir).mkdir(parents=True, exist_ok=True)
elif not Path(output_dir).exists():
raise FileNotFoundError(
'Output directory does not exist: {}'.format(output_dir)
)
return output_dir
def file_to_configuration(config_file: Union[str, Path]) -> Dict:
"""Utility function to handle converting a configuration file to
a dictionary with the correct hydra composition.
Creates an argument dict and calls `arguments_to_configuration`
Arguments:
config_file: path to a configuration file.
Return:
Dictionary of the configuration imported from config file.
"""
_reset_hydra()
return arguments_to_configuration({'CONFIG_FILE': config_file})
def arguments_to_configuration(arguments: Dict) -> Dict:
"""Processes command line arguments into a dictionary.
Handles hydra file composition and parameter overwrites.
Arguments:
arguments: dictionary with all the cmd_line arguments passed to kiwi.
Returns:
Dictionary of the config imported from the config file.
"""
_reset_hydra()
config_file = Path(arguments['CONFIG_FILE'])
# Using Hydra
relative_dir = Path(
os.path.relpath(config_file.resolve().parent, start=Path(__file__).parent)
)
Hydra.create_main_hydra_file_or_module(
calling_file=__file__,
calling_module=None,
config_dir=str(relative_dir),
strict=False,
)
config = hydra.experimental.compose(
config_file=config_file.name, overrides=arguments.get('OVERWRITES', [])
)
# Back to a dictionary
config_dict = OmegaConf.to_container(config)
config_dict['verbose'] = arguments.get('--verbose', False)
config_dict['quiet'] = arguments.get('--quiet', False)
return config_dict
def _reset_hydra():
"""Utility function used to handle global hydra state"""
# TODO remove me when upgrading hydra version
hydra._internal.hydra.GlobalHydra().clear()
def sort_by_second_value(tuples: List[tuple]) -> List[tuple]:
"""Sort a list of tuples by the second value in descending order."""
if not tuples:
return []
else:
return list(tuple(zip(*sorted(tuples, key=lambda t: t[1], reverse=True)))[0])