iterative/dvc

View on GitHub
dvc/repo/experiments/show.py

Summary

Maintainability
D
1 day
Test Coverage
import logging
from collections import OrderedDict, defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from dvc.repo.experiments.base import ExpRefInfo
from dvc.repo.metrics.show import _gather_metrics
from dvc.repo.params.show import _gather_params
from dvc.scm import iter_revs
from dvc.utils import error_handler, onerror_collect

if TYPE_CHECKING:
    from dvc.repo import Repo

logger = logging.getLogger(__name__)


@error_handler
def _collect_experiment_commit(
    repo,
    exp_rev,
    stash=False,
    sha_only=True,
    param_deps=False,
    running=None,
    onerror: Optional[Callable] = None,
    is_baseline: bool = False,
):
    from dvc.dependency import ParamsDependency, RepoDependency

    res: Dict[str, Optional[Any]] = defaultdict(dict)
    for rev in repo.brancher(revs=[exp_rev]):
        if rev == "workspace":
            if exp_rev != "workspace":
                continue
            res["timestamp"] = None
        else:
            commit = repo.scm.resolve_commit(rev)
            res["timestamp"] = datetime.fromtimestamp(commit.commit_time)

        params = _gather_params(
            repo, rev=rev, targets=None, deps=param_deps, onerror=onerror
        )
        if params:
            res["params"] = params

        res["deps"] = {
            dep.def_path: {
                "hash": dep.hash_info.value,
                "size": dep.meta.size,
                "nfiles": dep.meta.nfiles,
            }
            for dep in repo.index.deps
            if not isinstance(dep, (ParamsDependency, RepoDependency))
        }

        res["outs"] = {
            out.def_path: {
                "hash": out.hash_info.value,
                "size": out.meta.size,
                "nfiles": out.meta.nfiles,
            }
            for out in repo.index.outs
            if not (out.is_metric or out.is_plot)
        }

        res["queued"] = stash
        if running is not None and exp_rev in running:
            res["running"] = True
            res["executor"] = running[exp_rev].get("location")
        else:
            res["running"] = False
            res["executor"] = None
        if not stash:
            vals = _gather_metrics(
                repo, targets=None, rev=rev, recursive=False, onerror=onerror
            )
            res["metrics"] = vals

        if not sha_only and rev != "workspace":
            name: Optional[str] = None
            if is_baseline:
                for refspec in ["refs/tags", "refs/heads"]:
                    name = repo.scm.describe(rev, base=refspec)
                    if name:
                        break
            if not name:
                name = repo.experiments.get_exact_name(rev)
            if name:
                name = name.rsplit("/")[-1]
                res["name"] = name

    return res


def _collect_experiment_branch(
    res, repo, branch, baseline, onerror: Optional[Callable] = None, **kwargs
):
    from dvc.scm import resolve_rev

    exp_rev = resolve_rev(repo.scm, branch)
    prev = None
    revs = list(repo.scm.branch_revs(exp_rev, baseline))
    for rev in revs:
        collected_exp = _collect_experiment_commit(
            repo, rev, onerror=onerror, **kwargs
        )
        if len(revs) > 1:
            exp = {"checkpoint_tip": exp_rev}
            if prev:
                res[prev]["data"][  # type: ignore[unreachable]
                    "checkpoint_parent"
                ] = rev
            if rev in res:
                res[rev]["data"].update(exp)
                res.move_to_end(rev)
            else:
                exp.update(collected_exp["data"])
        else:
            exp = collected_exp["data"]
        if rev not in res:
            res[rev] = {"data": exp}
        prev = rev
    if len(revs) > 1:
        res[prev]["data"]["checkpoint_parent"] = baseline
    return res


def show(
    repo: "Repo",
    all_branches=False,
    all_tags=False,
    revs: Union[List[str], str, None] = None,
    all_commits=False,
    sha_only=False,
    num=1,
    param_deps=False,
    onerror: Optional[Callable] = None,
    fetch_running: bool = True,
):

    if onerror is None:
        onerror = onerror_collect

    res: Dict[str, Dict] = defaultdict(OrderedDict)

    if not any([revs, all_branches, all_tags, all_commits]):
        revs = ["HEAD"]
    if isinstance(revs, str):
        revs = [revs]

    found_revs: Dict[str, List[str]] = {"workspace": []}
    found_revs.update(
        iter_revs(repo.scm, revs, num, all_branches, all_tags, all_commits)
    )

    running = repo.experiments.get_running_exps(fetch_refs=fetch_running)

    for rev in found_revs:
        res[rev]["baseline"] = _collect_experiment_commit(
            repo,
            rev,
            sha_only=sha_only,
            param_deps=param_deps,
            running=running,
            onerror=onerror,
            is_baseline=True,
        )

        if rev == "workspace":
            continue

        ref_info = ExpRefInfo(baseline_sha=rev)
        commits = [
            (ref, repo.scm.resolve_commit(ref))
            for ref in repo.scm.iter_refs(base=str(ref_info))
        ]
        for exp_ref, _ in sorted(
            commits, key=lambda x: x[1].commit_time, reverse=True
        ):
            ref_info = ExpRefInfo.from_ref(exp_ref)
            assert ref_info.baseline_sha == rev
            _collect_experiment_branch(
                res[rev],
                repo,
                exp_ref,
                rev,
                sha_only=sha_only,
                param_deps=param_deps,
                running=running,
                onerror=onerror,
            )
        # collect queued (not yet reproduced) experiments
        for stash_rev, entry in repo.experiments.stash_revs.items():
            if entry.baseline_rev in found_revs:
                if stash_rev not in running or not running[stash_rev].get(
                    "last"
                ):
                    experiment = _collect_experiment_commit(
                        repo,
                        stash_rev,
                        sha_only=sha_only,
                        stash=stash_rev not in running,
                        param_deps=param_deps,
                        running=running,
                        onerror=onerror,
                    )
                    res[entry.baseline_rev][stash_rev] = experiment
    return res