iterative/dvc

View on GitHub
dvc/repo/experiments/queue/base.py

Summary

Maintainability
D
2 days
Test Coverage
import os
from abc import ABC, abstractmethod
from collections.abc import Collection, Generator, Iterable, Mapping
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union

from funcy import retry

from dvc.dependency import ParamsDependency
from dvc.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
from dvc.lock import LockError
from dvc.log import logger
from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.executor.base import BaseExecutor
from dvc.repo.experiments.executor.local import WorkspaceExecutor
from dvc.repo.experiments.refs import ExpRefInfo
from dvc.repo.experiments.stash import ExpStash, ExpStashEntry
from dvc.repo.experiments.utils import (
    EXEC_PID_DIR,
    EXEC_TMP_DIR,
    get_exp_rwlock,
    get_random_exp_name,
)
from dvc.utils.objects import cached_property
from dvc.utils.studio import config_to_env
from dvc_studio_client.post_live_metrics import get_studio_config

from .utils import get_remote_executor_refs

if TYPE_CHECKING:
    from dvc.repo import Repo
    from dvc.repo.experiments import Experiments
    from dvc.repo.experiments.executor.base import ExecutorResult
    from dvc.repo.experiments.serialize import ExpRange
    from dvc.scm import Git

logger = logger.getChild(__name__)


@dataclass(frozen=True)
class QueueEntry:
    dvc_root: str
    scm_root: str
    stash_ref: str
    stash_rev: str
    baseline_rev: str
    branch: Optional[str]
    name: Optional[str]
    head_rev: Optional[str] = None

    def __eq__(self, other: object):
        return (
            isinstance(other, QueueEntry)
            and self.dvc_root == other.dvc_root
            and self.scm_root == other.scm_root
            and self.stash_ref == other.stash_ref
            and self.stash_rev == other.stash_rev
        )

    def asdict(self) -> dict[str, Any]:
        return asdict(self)

    @classmethod
    def from_dict(cls, d: dict[str, Any]) -> "QueueEntry":
        return cls(**d)


class QueueGetResult(NamedTuple):
    entry: QueueEntry
    executor: BaseExecutor


class QueueDoneResult(NamedTuple):
    entry: QueueEntry
    result: Optional["ExecutorResult"]


class ExpRefAndQueueEntry(NamedTuple):
    exp_ref_info: Optional["ExpRefInfo"]
    queue_entry: Optional["QueueEntry"]


class BaseStashQueue(ABC):
    """Naive Git-stash based experiment queue.

    Maps queued experiments to (Git) stash reflog entries.
    """

    def __init__(self, repo: "Repo", ref: str, failed_ref: Optional[str] = None):
        """Construct a queue.

        Arguments:
            scm: Git SCM instance for this queue.
            ref: Git stash ref for this queue.
            failed_ref: Failed run Git stash ref for this queue.
        """
        self.repo = repo
        assert self.repo.tmp_dir
        self.ref = ref
        self.failed_ref = failed_ref

    @property
    def scm(self) -> "Git":
        from dvc.scm import Git

        assert isinstance(self.repo.scm, Git)
        return self.repo.scm

    @cached_property
    def stash(self) -> ExpStash:
        return ExpStash(self.scm, self.ref)

    @cached_property
    def failed_stash(self) -> Optional[ExpStash]:
        return ExpStash(self.scm, self.failed_ref) if self.failed_ref else None

    @cached_property
    def pid_dir(self) -> str:
        assert self.repo.tmp_dir is not None
        return os.path.join(self.repo.tmp_dir, EXEC_TMP_DIR, EXEC_PID_DIR)

    @cached_property
    def args_file(self) -> str:
        assert self.repo.tmp_dir is not None
        return os.path.join(self.repo.tmp_dir, BaseExecutor.PACKED_ARGS_FILE)

    @abstractmethod
    def put(self, *args, **kwargs) -> QueueEntry:
        """Stash an experiment and add it to the queue."""

    @abstractmethod
    def get(self) -> QueueGetResult:
        """Pop and return the first item in the queue."""

    def remove(
        self,
        revs: Collection[str],
        all_: bool = False,
        queued: bool = False,
        **kwargs,
    ) -> list[str]:
        """Remove the specified entries from the queue.

        Arguments:
            revs: Stash revisions or queued exp names to be removed.
            queued: Remove all queued tasks.
            all: Remove all tasks.

        Returns:
            Revisions (or names) which were removed.
        """

        if all_ or queued:
            return self.clear()

        name_to_remove: list[str] = []
        entry_to_remove: list[ExpStashEntry] = []
        queue_entries = self.match_queue_entry_by_name(revs, self.iter_queued())
        for name, entry in queue_entries.items():
            if entry:
                entry_to_remove.append(self.stash.stash_revs[entry.stash_rev])
                name_to_remove.append(name)

        self.stash.remove_revs(entry_to_remove)
        return name_to_remove

    def clear(self, **kwargs) -> list[str]:
        """Remove all entries from the queue."""
        stash_revs = self.stash.stash_revs
        name_to_remove = list(stash_revs)
        self.stash.remove_revs(list(stash_revs.values()))

        return name_to_remove

    def status(self) -> list[dict[str, Any]]:
        """Show the status of exp tasks in queue"""
        from datetime import datetime

        result: list[dict[str, Optional[str]]] = []

        def _get_timestamp(rev: str) -> datetime:
            commit = self.scm.resolve_commit(rev)
            return datetime.fromtimestamp(commit.commit_time)  # noqa: DTZ006

        def _format_entry(
            entry: QueueEntry,
            exp_result: Optional["ExecutorResult"] = None,
            status: str = "Unknown",
        ) -> dict[str, Any]:
            name = entry.name
            if not name and exp_result and exp_result.ref_info:
                name = exp_result.ref_info.name
            # NOTE: We fallback to Unknown status for experiments
            # generated in prior (incompatible) DVC versions
            return {
                "rev": entry.stash_rev,
                "name": name,
                "timestamp": _get_timestamp(entry.stash_rev),
                "status": status,
            }

        result.extend(
            _format_entry(queue_entry, status="Running")
            for queue_entry in self.iter_active()
        )
        result.extend(
            _format_entry(queue_entry, status="Queued")
            for queue_entry in self.iter_queued()
        )
        result.extend(
            _format_entry(queue_entry, status="Failed")
            for queue_entry, _ in self.iter_failed()
        )
        result.extend(
            _format_entry(queue_entry, exp_result=exp_result, status="Success")
            for queue_entry, exp_result in self.iter_success()
        )
        return result

    @abstractmethod
    def iter_queued(self) -> Generator[QueueEntry, None, None]:
        """Iterate over items in the queue."""

    @abstractmethod
    def iter_active(self) -> Generator[QueueEntry, None, None]:
        """Iterate over items which are being actively processed."""

    @abstractmethod
    def iter_done(self) -> Generator[QueueDoneResult, None, None]:
        """Iterate over items which been processed."""

    @abstractmethod
    def iter_success(self) -> Generator[QueueDoneResult, None, None]:
        """Iterate over items which been success."""

    @abstractmethod
    def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
        """Iterate over items which been failed."""

    @abstractmethod
    def reproduce(
        self, copy_paths: Optional[list[str]] = None, message: Optional[str] = None
    ) -> Mapping[str, Mapping[str, str]]:
        """Reproduce queued experiments sequentially."""

    @abstractmethod
    def get_result(self, entry: QueueEntry) -> Optional["ExecutorResult"]:
        """Return result of the specified item.

        This method blocks until the specified item has been collected.
        """

    @abstractmethod
    def kill(self, revs: str) -> None:
        """Kill the specified running entries in the queue.

        Arguments:
            revs: Stash revs or running exp name to be killed.
        """

    @abstractmethod
    def shutdown(self, kill: bool = False):
        """Shutdown the queue worker.

        Arguments:
            kill: If True, the any active experiments will be killed and the
                worker will shutdown immediately. If False, the worker will
                finish any active experiments before shutting down.
        """

    @abstractmethod
    def logs(self, rev: str, encoding: Optional[str] = None, follow: bool = False):
        """Print redirected output logs for an exp process.

        Args:
            rev: Stash rev or exp name.
            encoding: Text encoding for redirected output. Defaults to
                `locale.getpreferredencoding()`.
            follow: Attach to running exp process and follow additional
                output.
        """

    def _stash_exp(
        self,
        *args,
        params: Optional[dict[str, list[str]]] = None,
        baseline_rev: Optional[str] = None,
        branch: Optional[str] = None,
        name: Optional[str] = None,
        **kwargs,
    ) -> QueueEntry:
        """Stash changes from the workspace as an experiment.

        Args:
            params: Dict mapping paths to `Hydra Override`_ patterns,
                provided via `exp run --set-param`.
            baseline_rev: Optional baseline rev for this experiment, defaults
                to the current SCM rev.
            branch: Optional experiment branch name. If specified, the
                experiment will be added to `branch` instead of creating
                a new branch.
            name: Optional experiment name. If specified this will be used as
                the human-readable name in the experiment branch ref. Has no
                effect of branch is specified.

        .. _Hydra Override:
            https://hydra.cc/docs/next/advanced/override_grammar/basic/
        """
        with self.scm.stash_workspace(reinstate_index=True) as workspace:
            with self.scm.detach_head(client="dvc") as orig_head:
                stash_head = orig_head
                if baseline_rev is None:
                    baseline_rev = orig_head

                try:
                    if workspace:
                        self.stash.apply(workspace)

                    # update experiment params from command line
                    if params:
                        self._update_params(params)

                    # DVC commit data deps to preserve state across workspace
                    # & tempdir runs
                    self._stash_commit_deps(*args, **kwargs)

                    # save additional repro command line arguments
                    run_env = {DVC_EXP_BASELINE_REV: baseline_rev}
                    if not name:
                        name = get_random_exp_name(self.scm, baseline_rev)
                    run_env[DVC_EXP_NAME] = name
                    # Override DVC_ROOT env var to point to the parent DVC repo
                    # root (and not an executor tempdir root)
                    run_env[DVC_ROOT] = self.repo.root_dir

                    # save studio config to read later by dvc and dvclive
                    studio_config = get_studio_config(
                        dvc_studio_config=self.repo.config.get("studio")
                    )
                    run_env = config_to_env(studio_config) | run_env
                    self._pack_args(*args, run_env=run_env, **kwargs)
                    # save experiment as a stash commit
                    msg = self._stash_msg(
                        stash_head,
                        baseline_rev=baseline_rev,
                        branch=branch,
                        name=name,
                    )
                    stash_rev = self.stash.push(message=msg)
                    assert stash_rev
                    logger.debug(
                        (
                            "Stashed experiment '%s' with baseline '%s' "
                            "for future execution."
                        ),
                        stash_rev[:7],
                        baseline_rev[:7],
                    )
                finally:
                    # Revert any of our changes before prior unstashing
                    self.scm.reset(hard=True)

        return QueueEntry(
            self.repo.root_dir,
            self.scm.root_dir,
            self.ref,
            stash_rev,
            baseline_rev,
            branch,
            name,
            stash_head,
        )

    def _stash_commit_deps(self, *args, **kwargs):
        if len(args):
            targets = args[0]
        else:
            targets = kwargs.get("targets")
        if isinstance(targets, str):
            targets = [targets]
        elif not targets:
            targets = [None]
        for target in targets:
            self.repo.commit(
                target,
                with_deps=True,
                recursive=kwargs.get("recursive", False),
                force=True,
                allow_missing=True,
                data_only=True,
                relink=False,
            )

    @staticmethod
    def _stash_msg(
        rev: str,
        baseline_rev: str,
        branch: Optional[str] = None,
        name: Optional[str] = None,
    ) -> str:
        if not baseline_rev:
            baseline_rev = rev
        msg = ExpStash.format_message(rev, baseline_rev, name)
        if branch:
            return f"{msg}:{branch}"
        return msg

    def _pack_args(self, *args, **kwargs) -> None:
        import pickle

        if os.path.exists(self.args_file) and self.scm.is_tracked(self.args_file):
            logger.warning(
                (
                    "Temporary DVC file '.dvc/tmp/%s' exists and was "
                    "likely committed to Git by mistake. It should be removed "
                    "with:\n"
                    "\tgit rm .dvc/tmp/%s"
                ),
                BaseExecutor.PACKED_ARGS_FILE,
                BaseExecutor.PACKED_ARGS_FILE,
            )
            with open(self.args_file, "rb") as fobj:
                try:
                    data = pickle.load(fobj)  # noqa: S301
                except Exception:  # noqa: BLE001
                    data = {}
            extra = int(data.get("extra", 0)) + 1
        else:
            extra = None
        BaseExecutor.pack_repro_args(self.args_file, *args, extra=extra, **kwargs)
        self.scm.add(self.args_file, force=True)

    @staticmethod
    def _format_new_params_msg(new_params, config_path):
        """Format an error message for when new parameters are identified"""
        new_param_count = len(new_params)
        pluralise = "s are" if new_param_count > 1 else " is"
        param_list = ", ".join(new_params)
        return (
            f"{new_param_count} parameter{pluralise} missing "
            f"from '{config_path}': {param_list}"
        )

    def _update_params(self, params: dict[str, list[str]]):
        """Update param files with the provided `Hydra Override`_ patterns.

        Args:
            params: Dict mapping paths to `Hydra Override`_ patterns,
                provided via `exp run --set-param`.

        .. _Hydra Override:
            https://hydra.cc/docs/advanced/override_grammar/basic/
        """
        from dvc.utils.hydra import apply_overrides, compose_and_dump

        logger.debug("Using experiment params '%s'", params)

        hydra_config = self.repo.config.get("hydra", {})
        hydra_enabled = hydra_config.get("enabled", False)
        hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE
        for path, overrides in params.items():
            if hydra_enabled and path == hydra_output_file:
                if (config_module := hydra_config.get("config_module")) is None:
                    config_dir = os.path.join(
                        self.repo.root_dir, hydra_config.get("config_dir", "conf")
                    )
                else:
                    config_dir = None
                config_name = hydra_config.get("config_name", "config")
                plugins_path = os.path.join(
                    self.repo.root_dir, hydra_config.get("plugins_path", "")
                )
                compose_and_dump(
                    path,
                    config_dir,
                    config_module,
                    config_name,
                    plugins_path,
                    overrides,
                )
            else:
                apply_overrides(path, overrides)

        # Force params file changes to be staged in git
        # Otherwise in certain situations the changes to params file may be
        # ignored when we `git stash` them since mtime is used to determine
        # whether the file is dirty
        self.scm.add(list(params.keys()))

    @staticmethod
    @retry(180, errors=LockError, timeout=1)
    def get_stash_entry(exp: "Experiments", queue_entry: QueueEntry) -> "ExpStashEntry":
        stash = ExpStash(exp.scm, queue_entry.stash_ref)
        stash_rev = queue_entry.stash_rev
        with get_exp_rwlock(exp.repo, writes=[queue_entry.stash_ref]):
            stash_entry = stash.stash_revs.get(
                stash_rev,
                ExpStashEntry(None, stash_rev, stash_rev, None, None),
            )
            if stash_entry.stash_index is not None:
                stash.drop(stash_entry.stash_index)
        return stash_entry

    @classmethod
    def init_executor(
        cls,
        exp: "Experiments",
        queue_entry: QueueEntry,
        executor_cls: type[BaseExecutor] = WorkspaceExecutor,
        **kwargs,
    ) -> BaseExecutor:
        stash_entry = cls.get_stash_entry(exp, queue_entry)

        executor = executor_cls.from_stash_entry(exp.repo, stash_entry, **kwargs)

        stash_rev = queue_entry.stash_rev
        infofile = exp.celery_queue.get_infofile_path(stash_rev)
        executor.init_git(
            exp.repo,
            exp.repo.scm,
            stash_rev,
            stash_entry,
            infofile,
            branch=stash_entry.branch,
        )

        executor.init_cache(exp.repo, stash_rev)

        return executor

    def get_infofile_path(self, name: str) -> str:
        return os.path.join(
            self.pid_dir,
            name,
            f"{name}{BaseExecutor.INFOFILE_EXT}",
        )

    @staticmethod
    @retry(180, errors=LockError, timeout=1)
    def collect_git(
        exp: "Experiments",
        executor: BaseExecutor,
        exec_result: "ExecutorResult",
    ) -> dict[str, str]:
        results = {}

        def on_diverged(ref: str):
            ref_info = ExpRefInfo.from_ref(ref)
            raise ExperimentExistsError(ref_info.name)

        refs = get_remote_executor_refs(exp.scm, executor.git_url)

        with get_exp_rwlock(exp.repo, writes=refs):
            for ref in executor.fetch_exps(
                exp.scm,
                refs,
                force=exec_result.force,
                on_diverged=on_diverged,
            ):
                exp_rev = exp.scm.get_ref(ref)
                if exp_rev:
                    assert exec_result.exp_hash
                    logger.debug("Collected experiment '%s'.", exp_rev[:7])
                    results[exp_rev] = exec_result.exp_hash

        return results

    @classmethod
    def collect_executor(
        cls,
        exp: "Experiments",
        executor: BaseExecutor,
        exec_result: "ExecutorResult",
    ) -> dict[str, str]:
        results = cls.collect_git(exp, executor, exec_result)

        if exec_result.ref_info is not None:
            executor.collect_cache(exp.repo, exec_result.ref_info)

        return results

    def match_queue_entry_by_name(
        self,
        exp_names: Collection[str],
        *entries: Iterable[Union[QueueEntry, QueueDoneResult]],
    ) -> dict[str, Optional[QueueEntry]]:
        from funcy import concat

        entry_name_dict: dict[str, QueueEntry] = {}
        entry_rev_dict: dict[str, QueueEntry] = {}
        for entry in concat(*entries):
            if isinstance(entry, QueueDoneResult):
                queue_entry: QueueEntry = entry.entry
                if entry.result is not None and entry.result.ref_info is not None:
                    name: Optional[str] = entry.result.ref_info.name
                else:
                    name = queue_entry.name
            else:
                queue_entry = entry
                name = queue_entry.name
            if name:
                entry_name_dict[name] = queue_entry
            entry_rev_dict[queue_entry.stash_rev] = queue_entry

        result: dict[str, Optional[QueueEntry]] = {}
        for exp_name in exp_names:
            result[exp_name] = None
            if exp_name in entry_name_dict:
                result[exp_name] = entry_name_dict[exp_name]
                continue
            if self.scm.is_sha(exp_name):
                for rev in entry_rev_dict:
                    if rev.startswith(exp_name.lower()):
                        result[exp_name] = entry_rev_dict[rev]
                        break

        return result

    def stash_failed(self, entry: QueueEntry) -> None:
        """Add an entry to the failed exp stash.

        Arguments:
            entry: Failed queue entry to add. ``entry.stash_rev`` must be a
                valid Git stash commit.
        """
        if self.failed_stash is not None:
            assert entry.head_rev
            logger.debug("Stashing failed exp '%s'", entry.stash_rev[:7])
            msg = self.failed_stash.format_message(
                entry.head_rev,
                baseline_rev=entry.baseline_rev,
                name=entry.name,
                branch=entry.branch,
            )
            self.scm.set_ref(
                self.failed_stash.ref,
                entry.stash_rev,
                message=f"commit: {msg}",
            )

    @abstractmethod
    def collect_active_data(
        self,
        baseline_revs: Optional[Collection[str]],
        fetch_refs: bool = False,
        **kwargs,
    ) -> dict[str, list["ExpRange"]]:
        """Collect data for active (running) experiments.

        Args:
            baseline_revs: Optional resolved baseline Git SHAs. If set, only experiments
                derived from the specified revisions will be collected. Defaults to
                collecting all experiments.
            fetch_refs: Whether or not to fetch completed checkpoint commits from Git
                remote.

        Returns:
            Dict mapping baseline revision to list of active experiments.
        """

    @abstractmethod
    def collect_queued_data(
        self,
        baseline_revs: Optional[Collection[str]],
        **kwargs,
    ) -> dict[str, list["ExpRange"]]:
        """Collect data for queued experiments.

        Args:
            baseline_revs: Optional resolved baseline Git SHAs. If set, only experiments
                derived from the specified revisions will be collected. Defaults to
                collecting all experiments.

        Returns:
            Dict mapping baseline revision to list of queued experiments.
        """

    @abstractmethod
    def collect_failed_data(
        self,
        baseline_revs: Optional[Collection[str]],
        **kwargs,
    ) -> dict[str, list["ExpRange"]]:
        """Collect data for failed experiments.

        Args:
            baseline_revs: Optional resolved baseline Git SHAs. If set, only experiments
                derived from the specified revisions will be collected. Defaults to
                collecting all experiments.

        Returns:
            Dict mapping baseline revision to list of queued experiments.
        """

    def active_repo(self, name: str) -> "Repo":
        """Return a Repo for the specified active experiment if it exists."""
        from dvc.exceptions import DvcException
        from dvc.repo import Repo
        from dvc.repo.experiments.exceptions import (
            ExpNotStartedError,
            InvalidExpRevError,
        )
        from dvc.repo.experiments.executor.base import ExecutorInfo, TaskStatus

        for entry in self.iter_active():
            if entry.name != name:
                continue
            infofile = self.get_infofile_path(entry.stash_rev)
            executor_info = ExecutorInfo.load_json(infofile)
            if executor_info.status < TaskStatus.RUNNING:
                raise ExpNotStartedError(name)
            dvc_root = os.path.join(executor_info.root_dir, executor_info.dvc_dir)
            try:
                return Repo(dvc_root)
            except (FileNotFoundError, DvcException) as exc:
                raise InvalidExpRevError(name) from exc
        raise InvalidExpRevError(name)