iterative/dvc

View on GitHub
dvc/repo/experiments/__init__.py

Summary

Maintainability
C
1 day
Test Coverage
import os
import re
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional

from funcy import chain, first

from dvc.log import logger
from dvc.ui import ui
from dvc.utils import relpath
from dvc.utils.objects import cached_property

from .cache import ExpCache
from .exceptions import (
    BaselineMismatchError,
    ExperimentExistsError,
    InvalidExpRefError,
    MultipleBranchError,
)
from .refs import (
    APPLY_STASH,
    CELERY_FAILED_STASH,
    CELERY_STASH,
    EXEC_APPLY,
    EXEC_NAMESPACE,
    EXPS_NAMESPACE,
    WORKSPACE_STASH,
    ExpRefInfo,
)
from .stash import ApplyStash
from .utils import check_ref_format, exp_refs_by_rev, unlocked_repo

if TYPE_CHECKING:
    from .queue.base import BaseStashQueue, QueueEntry
    from .queue.celery import LocalCeleryQueue
    from .queue.tempdir import TempDirQueue
    from .queue.workspace import WorkspaceQueue
    from .stash import ExpStashEntry

logger = logger.getChild(__name__)


class Experiments:
    """Class that manages experiments in a DVC repo.

    Args:
        repo (dvc.repo.Repo): repo instance that these experiments belong to.
    """

    BRANCH_RE = re.compile(r"^(?P<baseline_rev>[a-f0-9]{7})-(?P<exp_sha>[a-f0-9]+)")

    def __init__(self, repo):
        from dvc.scm import NoSCMError

        if repo.config["core"].get("no_scm", False):
            raise NoSCMError

        self.repo = repo

    @property
    def scm(self):
        from dvc.scm import SCMError

        if self.repo.scm.no_commits:
            raise SCMError("Empty Git repo. Add a commit to use experiments.")

        return self.repo.scm

    @cached_property
    def dvc_dir(self) -> str:
        return relpath(self.repo.dvc_dir, self.repo.scm.root_dir)

    @cached_property
    def args_file(self) -> str:
        from .executor.base import BaseExecutor

        return os.path.join(self.repo.tmp_dir, BaseExecutor.PACKED_ARGS_FILE)

    @cached_property
    def workspace_queue(self) -> "WorkspaceQueue":
        from .queue.workspace import WorkspaceQueue

        return WorkspaceQueue(self.repo, WORKSPACE_STASH)

    @cached_property
    def tempdir_queue(self) -> "TempDirQueue":
        from .queue.tempdir import TempDirQueue

        # NOTE: tempdir and workspace stash is shared since both
        # implementations immediately push -> pop (queue length is only 0 or 1)
        return TempDirQueue(self.repo, WORKSPACE_STASH)

    @cached_property
    def celery_queue(self) -> "LocalCeleryQueue":
        from .queue.celery import LocalCeleryQueue

        return LocalCeleryQueue(self.repo, CELERY_STASH, CELERY_FAILED_STASH)

    @cached_property
    def apply_stash(self) -> ApplyStash:
        return ApplyStash(self.scm, APPLY_STASH)

    @cached_property
    def cache(self) -> ExpCache:
        return ExpCache(self.repo)

    @property
    def stash_revs(self) -> dict[str, "ExpStashEntry"]:
        revs = {}
        for queue in (self.workspace_queue, self.celery_queue):
            revs.update(queue.stash.stash_revs)
        return revs

    def reproduce_one(
        self,
        tmp_dir: bool = False,
        copy_paths: Optional[list[str]] = None,
        message: Optional[str] = None,
        **kwargs,
    ):
        """Reproduce and checkout a single (standalone) experiment."""
        exp_queue: "BaseStashQueue" = (
            self.tempdir_queue if tmp_dir else self.workspace_queue
        )
        self.queue_one(exp_queue, **kwargs)
        results = self._reproduce_queue(
            exp_queue, copy_paths=copy_paths, message=message
        )
        exp_rev = first(results)
        if exp_rev is not None:
            self._log_reproduced(results, tmp_dir=tmp_dir)
        return results

    def queue_one(self, queue: "BaseStashQueue", **kwargs) -> "QueueEntry":
        """Queue a single experiment."""
        return self.new(queue, **kwargs)

    def reproduce_celery(
        self, entries: Optional[Iterable["QueueEntry"]] = None, **kwargs
    ) -> dict[str, str]:
        results: dict[str, str] = {}
        if entries is None:
            entries = list(
                chain(self.celery_queue.iter_active(), self.celery_queue.iter_queued())
            )

        logger.debug("reproduce all these entries '%s'", entries)

        if not entries:
            return results

        self.celery_queue.start_workers(count=kwargs.get("jobs", 1))
        failed = []
        try:
            ui.write(
                "Following logs for all queued experiments. Use Ctrl+C to "
                "stop following logs (experiment execution will continue).\n"
            )
            for entry in entries:
                # wait for task execution to start
                self.celery_queue.wait_for_start(entry, sleep_interval=1)
                self.celery_queue.follow(entry)
                # wait for task collection to complete
                try:
                    result = self.celery_queue.get_result(entry)
                except FileNotFoundError:
                    result = None
                if result is None or result.exp_hash is None:
                    name = entry.name or entry.stash_rev[:7]
                    failed.append(name)
                elif result.ref_info:
                    exp_rev = self.scm.get_ref(str(result.ref_info))
                    results[exp_rev] = result.exp_hash
        except KeyboardInterrupt:
            ui.write(
                "Experiment(s) are still executing in the background. To "
                "abort execution use 'dvc queue kill' or 'dvc queue stop'."
            )
        if failed:
            names = ", ".join(name for name in failed)
            ui.error(f"Failed to reproduce experiment(s) '{names}'")
        if results:
            self._log_reproduced((rev for rev in results), True)
        return results

    def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False):
        names = []
        rev_names = self.get_exact_name(revs)
        for rev in revs:
            name = rev_names[rev]
            names.append(name if name else rev[:7])
        ui.write("\nRan experiment(s): {}".format(", ".join(names)))
        if tmp_dir:
            ui.write(
                "To apply the results of an experiment to your workspace "
                "run:\n\n"
                "\tdvc exp apply <exp>"
            )
        else:
            ui.write("Experiment results have been applied to your workspace.")

    def new(self, queue: "BaseStashQueue", *args, **kwargs) -> "QueueEntry":
        """Create and enqueue a new experiment.

        Experiment will be derived from the current workspace.
        """

        name = kwargs.get("name", None)
        baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev()

        if name:
            exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name)
            check_ref_format(self.scm, exp_ref)
            force = kwargs.get("force", False)
            if self.scm.get_ref(str(exp_ref)) and not force:
                raise ExperimentExistsError(exp_ref.name)

        return queue.put(*args, **kwargs)

    def _get_last_applied(self) -> Optional[str]:
        try:
            last_applied = self.scm.get_ref(EXEC_APPLY)
            if last_applied:
                self.check_baseline(last_applied)
            return last_applied
        except BaselineMismatchError:
            # If HEAD has moved since the last applied experiment,
            # the applied experiment is no longer relevant
            self.scm.remove_ref(EXEC_APPLY)
        return None

    @unlocked_repo
    def _reproduce_queue(
        self,
        queue: "BaseStashQueue",
        copy_paths: Optional[list[str]] = None,
        message: Optional[str] = None,
        **kwargs,
    ) -> dict[str, str]:
        """Reproduce queued experiments.

        Arguments:
            queue: Experiment queue.

        Returns:
            dict mapping successfully reproduced experiment revs to their
            results.
        """
        exec_results = queue.reproduce(copy_paths=copy_paths, message=message)

        results: dict[str, str] = {}
        for _, exp_result in exec_results.items():
            results.update(exp_result)
        return results

    def check_baseline(self, exp_rev):
        baseline_sha = self.repo.scm.get_rev()
        if exp_rev == baseline_sha:
            return exp_rev

        exp_baseline = self._get_baseline(exp_rev)
        if exp_baseline is None:
            # if we can't tell from branch name, fall back to parent commit
            exp_commit = self.scm.resolve_commit(exp_rev)
            if exp_commit:
                exp_baseline = first(exp_commit.parents)
        if exp_baseline == baseline_sha:
            return exp_baseline
        raise BaselineMismatchError(exp_baseline, baseline_sha)

    def get_baseline(self, rev):
        """Return the baseline rev for an experiment rev."""
        return self._get_baseline(rev)

    def _get_baseline(self, rev):
        from dvc.scm import resolve_rev

        rev = resolve_rev(self.scm, rev)

        if rev in self.stash_revs:
            entry = self.stash_revs.get(rev)
            if entry:
                return entry.baseline_rev
            return None

        ref_info = first(exp_refs_by_rev(self.scm, rev))
        if ref_info:
            return ref_info.baseline_sha
        return None

    def get_branch_by_rev(
        self, rev: str, allow_multiple: bool = False
    ) -> Optional[str]:
        """Returns full refname for the experiment branch containing rev."""
        ref_infos = list(exp_refs_by_rev(self.scm, rev))
        if not ref_infos:
            return None
        if len(ref_infos) > 1 and not allow_multiple:
            for ref_info in ref_infos:
                if self.scm.get_ref(str(ref_info)) == rev:
                    return str(ref_info)
            raise MultipleBranchError(rev, ref_infos)
        return str(ref_infos[0])

    def get_exact_name(self, revs: Iterable[str]) -> dict[str, Optional[str]]:
        """Returns preferred name for the specified revision.

        Prefers tags, branches (heads), experiments in that order.
        """
        result: dict[str, Optional[str]] = {}
        exclude = f"{EXEC_NAMESPACE}/*"
        ref_dict = self.scm.describe(revs, base=EXPS_NAMESPACE, exclude=exclude)
        for rev in revs:
            name: Optional[str] = None
            ref = ref_dict[rev]
            if ref:
                try:
                    name = ExpRefInfo.from_ref(ref).name
                except InvalidExpRefError:
                    pass
            if not name:
                if rev in self.stash_revs:
                    name = self.stash_revs[rev].name
                else:
                    failed_stash = self.celery_queue.failed_stash
                    if failed_stash and rev in failed_stash.stash_revs:
                        name = failed_stash.stash_revs[rev].name
            result[rev] = name
        return result

    def apply(self, *args, **kwargs):
        from dvc.repo.experiments.apply import apply

        return apply(self.repo, *args, **kwargs)

    def branch(self, *args, **kwargs):
        from dvc.repo.experiments.branch import branch

        return branch(self.repo, *args, **kwargs)

    def diff(self, *args, **kwargs):
        from dvc.repo.experiments.diff import diff

        return diff(self.repo, *args, **kwargs)

    def show(self, *args, **kwargs):
        from dvc.repo.experiments.show import show

        return show(self.repo, *args, **kwargs)

    def run(self, *args, **kwargs):
        from dvc.repo.experiments.run import run

        return run(self.repo, *args, **kwargs)

    def save(self, *args, **kwargs):
        from dvc.repo.experiments.save import save

        return save(self.repo, *args, **kwargs)

    def push(self, *args, **kwargs):
        from dvc.repo.experiments.push import push

        return push(self.repo, *args, **kwargs)

    def pull(self, *args, **kwargs):
        from dvc.repo.experiments.pull import pull

        return pull(self.repo, *args, **kwargs)

    def ls(self, *args, **kwargs):
        from dvc.repo.experiments.ls import ls

        return ls(self.repo, *args, **kwargs)

    def remove(self, *args, **kwargs):
        from dvc.repo.experiments.remove import remove

        return remove(self.repo, *args, **kwargs)

    def rename(self, *args, **kwargs):
        from dvc.repo.experiments.rename import rename

        return rename(self.repo, *args, **kwargs)

    def clean(self, *args, **kwargs):
        from dvc.repo.experiments.clean import clean

        return clean(self.repo, *args, **kwargs)