iterative/dvc

View on GitHub
dvc/repo/experiments/executor/local.py

Summary

Maintainability
A
2 hrs
Test Coverage
import os
from contextlib import ExitStack
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional, Union

from configobj import ConfigObj
from funcy import retry
from shortuuid import uuid

from dvc.lock import LockError
from dvc.log import logger
from dvc.repo.experiments.refs import (
    EXEC_BASELINE,
    EXEC_BRANCH,
    EXEC_HEAD,
    EXEC_MERGE,
    EXEC_NAMESPACE,
    TEMP_NAMESPACE,
)
from dvc.repo.experiments.utils import EXEC_TMP_DIR, get_exp_rwlock
from dvc.scm import SCM, Git
from dvc.utils.fs import remove
from dvc.utils.objects import cached_property

from .base import BaseExecutor, TaskStatus

if TYPE_CHECKING:
    from dvc.repo import Repo
    from dvc.repo.experiments.refs import ExpRefInfo
    from dvc.repo.experiments.stash import ExpStashEntry
    from dvc.scm import NoSCM

logger = logger.getChild(__name__)


class BaseLocalExecutor(BaseExecutor):
    """Base local machine executor."""

    @property
    def git_url(self) -> str:
        root_dir = os.path.abspath(self.root_dir)
        if os.name == "nt":
            root_dir = root_dir.replace(os.sep, "/")
        return f"file://{root_dir}"

    @cached_property
    def scm(self) -> Union["Git", "NoSCM"]:
        return SCM(self.root_dir)

    def cleanup(self, infofile: Optional[str] = None):
        self.scm.close()
        del self.scm
        super().cleanup(infofile)

    def collect_cache(
        self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True
    ):
        """Collect DVC cache."""


class TempDirExecutor(BaseLocalExecutor):
    """Temp directory experiment executor."""

    # Temp dir executors should warn if untracked files exist (to help with
    # debugging user code), and suppress other DVC hints (like `git add`
    # suggestions) that are not applicable outside of workspace runs
    WARN_UNTRACKED = True
    DEFAULT_LOCATION = "tempdir"

    @retry(180, errors=LockError, timeout=1)
    def init_git(
        self,
        repo: "Repo",
        scm: "Git",
        stash_rev: str,
        entry: "ExpStashEntry",
        infofile: Optional[str],
        branch: Optional[str] = None,
    ):
        from dulwich.repo import Repo as DulwichRepo

        from dvc.repo.experiments.utils import push_refspec

        DulwichRepo.init(os.fspath(self.root_dir))

        self.status = TaskStatus.PREPARING
        if infofile:
            self.info.dump_json(infofile)

        temp_head = f"{TEMP_NAMESPACE}/head-{uuid()}"
        temp_merge = f"{TEMP_NAMESPACE}/merge-{uuid()}"
        temp_baseline = f"{TEMP_NAMESPACE}/baseline-{uuid()}"

        temp_ref_dict = {
            temp_head: entry.head_rev,
            temp_merge: stash_rev,
            temp_baseline: entry.baseline_rev,
        }
        with (
            get_exp_rwlock(repo, writes=[temp_head, temp_merge, temp_baseline]),
            self.set_temp_refs(scm, temp_ref_dict),
        ):
            # Executor will be initialized with an empty git repo that
            # we populate by pushing:
            #   EXEC_HEAD - the base commit for this experiment
            #   EXEC_MERGE - the unmerged changes (from our stash)
            #       to be reproduced
            #   EXEC_BASELINE - the baseline commit for this experiment
            refspec = [
                (temp_head, EXEC_HEAD),
                (temp_merge, EXEC_MERGE),
                (temp_baseline, EXEC_BASELINE),
            ]

            if branch:
                refspec.append((branch, branch))
                with get_exp_rwlock(repo, reads=[branch]):
                    push_refspec(scm, self.git_url, refspec)
                self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
            else:
                push_refspec(scm, self.git_url, refspec)
                if self.scm.get_ref(EXEC_BRANCH):
                    self.scm.remove_ref(EXEC_BRANCH)

        # checkout EXEC_HEAD and apply EXEC_MERGE on top of it without
        # committing
        assert isinstance(self.scm, Git)
        head = EXEC_BRANCH if branch else EXEC_HEAD
        self.scm.checkout(head, detach=True)
        merge_rev = self.scm.get_ref(EXEC_MERGE)

        self.scm.stash.apply(merge_rev)
        self._update_config(repo.config.read("local"))
        local_git_config = os.path.join(repo.scm.root_dir, ".git", "config")
        self._update_git_config(ConfigObj(local_git_config, list_values=False))

    def _update_config(self, update):
        local_config = os.path.join(self.root_dir, self.dvc_dir, "config.local")
        logger.debug("Writing experiments local config '%s'", local_config)
        if os.path.exists(local_config):
            conf_obj = ConfigObj(local_config)
            conf_obj.merge(update)
        else:
            conf_obj = ConfigObj(update)
        if conf_obj:
            with open(local_config, "wb") as fobj:
                conf_obj.write(fobj)

    def _update_git_config(self, update):
        local_config = os.path.join(self.scm.root_dir, ".git", "config")
        logger.debug("Writing experiments local Git config '%s'", local_config)
        if os.path.exists(local_config):
            conf_obj = ConfigObj(local_config, list_values=False)
            conf_obj.merge(update)
        else:
            conf_obj = ConfigObj(update, list_values=False)
        if conf_obj:
            with open(local_config, "wb") as fobj:
                conf_obj.write(fobj)

    def init_cache(
        self,
        repo: "Repo",
        rev: str,  # noqa: ARG002
        run_cache: bool = True,  # noqa: ARG002
    ):
        """Initialize DVC cache."""
        self._update_config({"cache": {"dir": repo.cache.local_cache_dir}})

    def cleanup(self, infofile: Optional[str] = None):
        super().cleanup(infofile)
        logger.debug("Removing tmpdir '%s'", self.root_dir)
        remove(self.root_dir)

    @classmethod
    def from_stash_entry(
        cls,
        repo: "Repo",
        entry: "ExpStashEntry",
        wdir: Optional[str] = None,
        **kwargs,
    ):
        assert repo.tmp_dir
        parent_dir: str = wdir or os.path.join(repo.tmp_dir, EXEC_TMP_DIR)
        os.makedirs(parent_dir, exist_ok=True)
        tmp_dir = mkdtemp(dir=parent_dir)
        try:
            executor = cls._from_stash_entry(repo, entry, tmp_dir, **kwargs)
            logger.debug("Init temp dir executor in '%s'", tmp_dir)
            return executor
        except Exception:
            remove(tmp_dir)
            raise


class WorkspaceExecutor(BaseLocalExecutor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._detach_stack = ExitStack()

    @classmethod
    def from_stash_entry(cls, repo: "Repo", entry: "ExpStashEntry", **kwargs):
        root_dir = repo.scm.root_dir
        executor: "WorkspaceExecutor" = cls._from_stash_entry(
            repo, entry, root_dir, **kwargs
        )
        logger.debug("Init workspace executor in '%s'", root_dir)
        return executor

    @retry(180, errors=LockError, timeout=1)
    def init_git(
        self,
        repo: "Repo",
        scm: "Git",
        stash_rev: str,
        entry: "ExpStashEntry",
        infofile: Optional[str],
        branch: Optional[str] = None,
    ):
        self.status = TaskStatus.PREPARING
        if infofile:
            self.info.dump_json(infofile)

        assert isinstance(self.scm, Git)

        with get_exp_rwlock(repo, writes=[EXEC_NAMESPACE]):
            scm.set_ref(EXEC_HEAD, entry.head_rev)
            scm.set_ref(EXEC_MERGE, stash_rev)
            scm.set_ref(EXEC_BASELINE, entry.baseline_rev)
            self._detach_stack.enter_context(
                self.scm.detach_head(
                    self.scm.get_ref(EXEC_HEAD),
                    force=True,
                    client="dvc",
                )
            )
            merge_rev = self.scm.get_ref(EXEC_MERGE)
            self.scm.stash.apply(merge_rev)
            if branch:
                self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
            elif scm.get_ref(EXEC_BRANCH):
                self.scm.remove_ref(EXEC_BRANCH)

    def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True):
        pass

    def cleanup(self, infofile: Optional[str] = None):
        super().cleanup(infofile)
        if infofile:
            remove(os.path.dirname(infofile))
        with self._detach_stack:
            self.scm.remove_ref(EXEC_BASELINE)
            self.scm.remove_ref(EXEC_MERGE)
            if self.scm.get_ref(EXEC_BRANCH):
                self.scm.remove_ref(EXEC_BRANCH)