iterative/dvc

View on GitHub
dvc/repo/metrics/show.py

Summary

Maintainability
A
3 hrs
Test Coverage
import logging
import os
from typing import List

from scmrepo.exceptions import SCMError

from dvc.fs.dvc import DVCFileSystem
from dvc.output import Output
from dvc.repo import locked
from dvc.repo.collect import StrPaths, collect
from dvc.scm import NoSCMError
from dvc.utils import (
    as_posix,
    error_handler,
    errored_revisions,
    onerror_collect,
)
from dvc.utils.collections import ensure_list
from dvc.utils.serialize import load_path

logger = logging.getLogger(__name__)


def _is_metric(out: Output) -> bool:
    return bool(out.metric)


def _to_fs_paths(metrics: List[Output]) -> StrPaths:
    result = []
    for out in metrics:
        if out.metric:
            result.append(out.repo.dvcfs.from_os_path(out.fs_path))
    return result


def _collect_top_level_metrics(repo):
    top_metrics = repo.index._metrics  # pylint: disable=protected-access
    for dvcfile, metrics in top_metrics.items():
        wdir = repo.fs.path.relpath(
            repo.fs.path.parent(dvcfile), repo.root_dir
        )
        for file in metrics:
            path = repo.fs.path.join(wdir, as_posix(file))
            yield repo.fs.path.normpath(path)


def _collect_metrics(repo, targets, recursive):
    metrics, fs_paths = collect(
        repo, targets=targets, output_filter=_is_metric, recursive=recursive
    )
    return _to_fs_paths(metrics) + list(fs_paths)


def _extract_metrics(metrics, path, rev):
    if isinstance(metrics, (int, float)):
        return metrics

    if not isinstance(metrics, dict):
        return None

    ret = {}
    for key, val in metrics.items():
        m = _extract_metrics(val, path, rev)
        if m not in (None, {}):
            ret[key] = m
        else:
            logger.debug(
                "Could not parse '%s' metric from '%s' at '%s' "
                "due to its unsupported type: '%s'",
                key,
                path,
                rev,
                type(val).__name__,
            )

    return ret


@error_handler
def _read_metric(path, fs, rev, **kwargs):
    val = load_path(path, fs)
    val = _extract_metrics(val, path, rev)
    return val or {}


def _read_metrics(repo, metrics, rev, onerror=None):
    fs = DVCFileSystem(repo=repo)

    relpath = ""
    if repo.root_dir != repo.fs.path.getcwd():
        relpath = repo.fs.path.relpath(repo.root_dir, repo.fs.path.getcwd())

    res = {}
    for metric in metrics:
        rel_metric_path = os.path.join(relpath, *fs.path.parts(metric))
        if not fs.isfile(metric):
            if fs.isfile(rel_metric_path):
                metric = rel_metric_path
            else:
                continue

        res[rel_metric_path] = _read_metric(metric, fs, rev, onerror=onerror)

    return res


def _gather_metrics(repo, targets, rev, recursive, onerror=None):
    metrics = _collect_metrics(repo, targets, recursive)
    metrics.extend(_collect_top_level_metrics(repo))
    return _read_metrics(repo, metrics, rev, onerror=onerror)


@locked
def show(
    repo,
    targets=None,
    all_branches=False,
    all_tags=False,
    recursive=False,
    revs=None,
    all_commits=False,
    onerror=None,
    hide_workspace=True,
):
    if onerror is None:
        onerror = onerror_collect

    targets = ensure_list(targets)
    targets = [repo.dvcfs.from_os_path(target) for target in targets]

    res = {}
    for rev in repo.brancher(
        revs=revs,
        all_branches=all_branches,
        all_tags=all_tags,
        all_commits=all_commits,
    ):
        res[rev] = error_handler(_gather_metrics)(
            repo, targets, rev, recursive, onerror=onerror
        )

    if hide_workspace:
        # Hide workspace metrics if they are the same as in the active branch
        try:
            active_branch = repo.scm.active_branch()
        except (SCMError, NoSCMError):
            # SCMError - detached head
            # NoSCMError - no repo case
            pass
        else:
            if res.get("workspace") == res.get(active_branch):
                res.pop("workspace", None)

    errored = errored_revisions(res)
    if errored:
        from dvc.ui import ui

        ui.error_write(
            "DVC failed to load some metrics for following revisions:"
            f" '{', '.join(errored)}'."
        )

    return res