iterative/dvc

View on GitHub
dvc/repo/experiments/collect.py

Summary

Maintainability
C
1 day
Test Coverage
import itertools
import os
from collections.abc import Collection, Iterable, Iterator
from dataclasses import fields
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Union

from funcy import first
from scmrepo.exceptions import SCMError as InnerSCMError

from dvc.log import logger
from dvc.scm import Git, SCMError, iter_revs

from .exceptions import InvalidExpRefError
from .refs import EXEC_BRANCH, ExpRefInfo
from .serialize import ExpRange, ExpState, SerializableError, SerializableExp
from .utils import describe

if TYPE_CHECKING:
    from dvc.repo import Repo

    from .cache import ExpCache

logger = logger.getChild(__name__)


def collect_rev(
    repo: "Repo",
    rev: str,
    param_deps: bool = False,
    force: bool = False,
    cache: Optional["ExpCache"] = None,
    **kwargs,
) -> ExpState:
    """Collect experiment state for the given revision.

    Exp will be loaded from cache when available unless rev is 'workspace' or
    force is set.
    """
    from dvc.fs import LocalFileSystem

    cache = cache or repo.experiments.cache
    assert cache
    # TODO: support filtering serialized exp when param_deps is set
    if rev != "workspace" and not (force or param_deps):
        cached_exp = cache.get(rev)
        if cached_exp:
            if isinstance(cached_exp, SerializableError):
                return ExpState(rev=rev, error=cached_exp)
            return ExpState(rev=rev, data=cached_exp)
    if rev == "workspace" and isinstance(repo.fs, LocalFileSystem):
        orig_cwd: Optional[str] = os.getcwd()
        os.chdir(repo.root_dir)
    else:
        orig_cwd = None
    try:
        data = _collect_rev(repo, rev, param_deps=param_deps, force=force, **kwargs)
        if not (rev == "workspace" or param_deps or data.contains_error):
            cache.put(data, force=True)
        return ExpState(rev=rev, data=data)
    except Exception as exc:  # noqa: BLE001
        logger.debug("", exc_info=True)
        error = SerializableError(str(exc), type(exc).__name__)
        return ExpState(rev=rev, error=error)
    finally:
        if orig_cwd:
            os.chdir(orig_cwd)


def _collect_rev(
    repo: "Repo",
    rev: str,
    param_deps: bool = False,
    **kwargs,
) -> SerializableExp:
    with repo.switch(rev) as rev:
        if rev == "workspace":
            timestamp: Optional[datetime] = None
        else:
            commit = repo.scm.resolve_commit(rev)
            timestamp = datetime.fromtimestamp(commit.commit_time)  # noqa: DTZ006

        return SerializableExp.from_repo(
            repo,
            rev=rev,
            param_deps=param_deps,
            timestamp=timestamp,
        )


def collect_branch(
    repo: "Repo",
    rev: str,
    end_rev: Optional[str] = None,
    **kwargs,
) -> Iterator["ExpState"]:
    """Iterate over exp states in a Git branch.

    Git branch will be traversed in reverse, starting from rev.

    Args:
        rev: Branch tip (head).
        end_rev: If specified, traversal will stop when end_rev is reached
            (exclusive, end_rev will not be collected).
    """
    try:
        for branch_rev in repo.scm.branch_revs(rev, end_rev):
            yield collect_rev(repo, branch_rev, **kwargs)
    except (SCMError, InnerSCMError):
        pass


def collect_exec_branch(
    repo: "Repo",
    baseline_rev: str,
    **kwargs,
) -> Iterator["ExpState"]:
    """Iterate over active experiment branch for the current executor."""
    last_rev = repo.scm.get_ref(EXEC_BRANCH) or repo.scm.get_rev()
    last_rev = repo.scm.get_rev()
    yield collect_rev(repo, "workspace", **kwargs)
    if last_rev != baseline_rev:
        yield from collect_branch(repo, last_rev, baseline_rev, **kwargs)


def collect_queued(
    repo: "Repo",
    baseline_revs: Collection[str],
    **kwargs,
) -> dict[str, list["ExpRange"]]:
    """Collect queued experiments derived from the specified revisions.

    Args:
        repo: Repo.
        baseline_revs: Resolved baseline Git SHAs.

    Returns:
        Dict mapping baseline revision to list of queued experiments.
    """
    if not baseline_revs:
        return {}
    queued_data = {}
    for rev, ranges in repo.experiments.celery_queue.collect_queued_data(
        baseline_revs, **kwargs
    ).items():
        for exp_range in ranges:
            for exp_state in exp_range.revs:
                if exp_state.data:
                    attrs = [f.name for f in fields(SerializableExp)]
                    exp_state.data = SerializableExp(
                        **{
                            attr: getattr(exp_state.data, attr)
                            for attr in attrs
                            if attr != "metrics"
                        }
                    )
        queued_data[rev] = ranges
    return queued_data


def collect_active(
    repo: "Repo",
    baseline_revs: Collection[str],
    **kwargs,
) -> dict[str, list["ExpRange"]]:
    """Collect active (running) experiments derived from the specified revisions.

    Args:
        repo: Repo.
        baseline_revs: Resolved baseline Git SHAs.

    Returns:
        Dict mapping baseline revision to list of active experiments.
    """
    if not baseline_revs:
        return {}
    result: dict[str, list["ExpRange"]] = {}
    exps = repo.experiments
    for queue in (exps.workspace_queue, exps.tempdir_queue, exps.celery_queue):
        for baseline, active_exps in queue.collect_active_data(
            baseline_revs, **kwargs
        ).items():
            if baseline in result:
                result[baseline].extend(active_exps)
            else:
                result[baseline] = list(active_exps)
    return result


def collect_failed(
    repo: "Repo",
    baseline_revs: Collection[str],
    **kwargs,
) -> dict[str, list["ExpRange"]]:
    """Collect failed experiments derived from the specified revisions.

    Args:
        repo: Repo.
        baseline_revs: Resolved baseline Git SHAs.

    Returns:
        Dict mapping baseline revision to list of active experiments.
    """
    if not baseline_revs:
        return {}
    return repo.experiments.celery_queue.collect_failed_data(baseline_revs, **kwargs)


def collect_successful(
    repo: "Repo",
    baseline_revs: Collection[str],
    **kwargs,
) -> dict[str, list["ExpRange"]]:
    """Collect successful experiments derived from the specified revisions.

    Args:
        repo: Repo.
        baseline_revs: Resolved baseline Git SHAs.

    Returns:
        Dict mapping baseline revision to successful experiments.
    """
    result: dict[str, list["ExpRange"]] = {}
    for baseline_rev in baseline_revs:
        result[baseline_rev] = list(_collect_baseline(repo, baseline_rev, **kwargs))
    return result


def _collect_baseline(
    repo: "Repo",
    baseline_rev: str,
    **kwargs,
) -> Iterator["ExpRange"]:
    """Iterate over experiments derived from a baseline revision.

    Args:
        repo: Repo.
        baseline_revs: Resolved baseline Git SHAs.

    Yields:
        Tuple of (timestamp, exp_range).
    """
    ref_info = ExpRefInfo(baseline_sha=baseline_rev)
    refs: Optional[Iterable[str]] = kwargs.get("refs")
    if refs:
        ref_it = (ref for ref in iter(refs) if ref.startswith(str(ref_info)))
    else:
        ref_it = repo.scm.iter_refs(base=str(ref_info))
    executors = repo.experiments.celery_queue.collect_success_executors([baseline_rev])
    for ref in ref_it:
        try:
            ref_info = ExpRefInfo.from_ref(ref)
            exp_rev = repo.scm.get_ref(ref)
            if not exp_rev:
                continue
        except (InvalidExpRefError, SCMError, InnerSCMError):
            continue
        exps = list(collect_branch(repo, exp_rev, baseline_rev, **kwargs))
        if exps:
            exps[0].name = ref_info.name
            yield ExpRange(
                exps,
                name=ref_info.name,
                executor=executors.get(str(ref_info)),
            )


def collect(
    repo: "Repo",
    revs: Union[list[str], str, None] = None,
    all_branches: bool = False,
    all_tags: bool = False,
    all_commits: bool = False,
    num: int = 1,
    hide_queued: bool = False,
    hide_failed: bool = False,
    sha_only: bool = False,
    **kwargs,
) -> list["ExpState"]:
    """Collect baseline revisions and derived experiments."""
    assert isinstance(repo.scm, Git)
    if repo.scm.no_commits:
        return []
    if not any([revs, all_branches, all_tags, all_commits]):
        revs = ["HEAD"]
    if isinstance(revs, str):
        revs = [revs]
    cached_refs = list(repo.scm.iter_refs())
    baseline_revs = list(
        iter_revs(
            repo.scm,
            revs=revs,
            num=num,
            all_branches=all_branches,
            all_tags=all_tags,
            all_commits=all_commits,
        )
    )
    if sha_only:
        baseline_names: dict[str, Optional[str]] = {}
    else:
        baseline_names = describe(
            repo.scm, baseline_revs, refs=cached_refs, logger=logger
        )

    workspace_data = collect_rev(repo, "workspace", **kwargs)
    result: list["ExpState"] = [workspace_data]
    queued = collect_queued(repo, baseline_revs, **kwargs) if not hide_queued else {}
    active = collect_active(repo, baseline_revs, **kwargs)
    failed = collect_failed(repo, baseline_revs, **kwargs) if not hide_failed else {}
    successful = collect_successful(repo, baseline_revs, **kwargs)

    for baseline_rev in baseline_revs:
        baseline_data = collect_rev(repo, baseline_rev)
        experiments = list(
            itertools.chain.from_iterable(
                _sorted_ranges(collected.get(baseline_rev, []))
                for collected in (active, successful, queued, failed)
            )
        )
        result.append(
            ExpState(
                rev=baseline_rev,
                name=baseline_names.get(baseline_rev),
                data=baseline_data.data,
                error=baseline_data.error,
                experiments=experiments if experiments else None,
            )
        )
    return result


def _sorted_ranges(exp_ranges: Iterable["ExpRange"]) -> list["ExpRange"]:
    """Return list of ExpRange sorted by (timestamp, rev)."""

    def _head_timestamp(exp_range: "ExpRange") -> tuple[datetime, str]:
        head_exp = first(exp_range.revs)
        if head_exp and head_exp.data and head_exp.data.timestamp:
            return head_exp.data.timestamp, head_exp.rev

        return datetime.fromtimestamp(0), ""  # noqa: DTZ006

    return sorted(exp_ranges, key=_head_timestamp, reverse=True)