iterative/dvc

View on GitHub
dvc/repo/open_repo.py

Summary

Maintainability
A
3 hrs
Test Coverage
import os
import tempfile
import threading
from typing import TYPE_CHECKING, Optional

from funcy import retry, wrap_with

from dvc.exceptions import NotDvcRepoError
from dvc.log import logger
from dvc.repo import Repo
from dvc.scm import CloneError, map_scm_exception
from dvc.utils import relpath

if TYPE_CHECKING:
    from dvc.scm import Git

logger = logger.getChild(__name__)


@map_scm_exception()
def _external_repo(url, rev: Optional[str] = None, **kwargs) -> "Repo":
    logger.debug("Creating external repo %s@%s", url, rev)
    path = _cached_clone(url, rev)
    # Local HEAD points to the tip of whatever branch we first cloned from
    # (which may not be the default branch), use origin/HEAD here to get
    # the tip of the default branch
    rev = rev or "refs/remotes/origin/HEAD"

    config = _get_remote_config(url) if os.path.isdir(url) else {}
    config.update({"cache": {"dir": _get_cache_dir(url)}})
    config.update(kwargs.pop("config", None) or {})

    main_root = "/"
    repo_kwargs = dict(
        root_dir=path,
        url=url,
        config=config,
        repo_factory=erepo_factory(url, main_root, {"cache": config["cache"]}),
        rev=rev,
        **kwargs,
    )

    return Repo(**repo_kwargs)


def open_repo(url, *args, **kwargs):
    if url is None:
        url = os.getcwd()

    if os.path.exists(url):
        url = os.path.abspath(url)
        try:
            config = _get_remote_config(url)
            config.update(kwargs.get("config") or {})
            kwargs["config"] = config
            return Repo(url, *args, **kwargs)
        except NotDvcRepoError:
            pass  # fallthrough to _external_repo

    return _external_repo(url, *args, **kwargs)


def erepo_factory(url, root_dir, cache_config):
    from dvc.fs import localfs

    def make_repo(path, fs=None, **_kwargs):
        _config = cache_config.copy()
        if os.path.isdir(url):
            fs = fs or localfs
            repo_path = os.path.join(url, *fs.relparts(path, root_dir))
            _config.update(_get_remote_config(repo_path))
        return Repo(path, fs=fs, config=_config, **_kwargs)

    return make_repo


CLONES: dict[str, tuple[str, bool]] = {}
CACHE_DIRS: dict[str, str] = {}


@wrap_with(threading.Lock())
def _get_cache_dir(url):
    try:
        cache_dir = CACHE_DIRS[url]
    except KeyError:
        cache_dir = CACHE_DIRS[url] = tempfile.mkdtemp("dvc-cache")
    return cache_dir


def clean_repos():
    # Outside code should not see cache while we are removing
    paths = [path for path, _ in CLONES.values()] + list(CACHE_DIRS.values())
    CLONES.clear()
    CACHE_DIRS.clear()

    for path in paths:
        _remove(path)


def _get_remote_config(url):
    try:
        repo = Repo(url)
    except NotDvcRepoError:
        return {}

    try:
        name = repo.config["core"].get("remote")
        if not name:
            # Fill the empty upstream entry with a new remote pointing to the
            # original repo's cache location.
            name = "auto-generated-upstream"
            return {
                "core": {"remote": name},
                "remote": {name: {"url": repo.cache.local_cache_dir}},
            }

        # Use original remote to make sure that we are using correct url,
        # credential paths, etc if they are relative to the config location.
        return {"remote": {name: repo.config["remote"][name]}}
    finally:
        repo.close()


def _cached_clone(url, rev):
    """Clone an external git repo to a temporary directory.

    Returns the path to a local temporary directory with the specified
    revision checked out.
    """
    from shutil import copytree

    # even if we have already cloned this repo, we may need to
    # fetch/fast-forward to get specified rev
    clone_path, shallow = _clone_default_branch(url, rev)

    if url in CLONES:
        return CLONES[url][0]

    # Copy to a new dir to keep the clone clean
    repo_path = tempfile.mkdtemp("dvc-erepo")
    logger.debug("erepo: making a copy of %s clone", url)
    copytree(clone_path, repo_path)

    CLONES[url] = (repo_path, shallow)
    return repo_path


@wrap_with(threading.Lock())
def _clone_default_branch(url, rev):  # noqa: PLR0912
    """Get or create a clean clone of the url.

    The cloned is reactualized with git pull unless rev is a known sha.
    """
    from dvc.scm import Git

    clone_path, shallow = CLONES.get(url) or (None, False)

    git = None
    try:
        if clone_path:
            git = Git(clone_path)
            # Do not pull for known shas, branches and tags might move
            if not Git.is_sha(rev) or not git.has_rev(rev):
                if shallow:
                    # If we are missing a rev in a shallow clone, fallback to
                    # a full (unshallowed) clone. Since fetching specific rev
                    # SHAs is only available in certain git versions, if we
                    # have need to reference multiple specific revs for a
                    # given repo URL it is easier/safer for us to work with
                    # full clones in this case.
                    logger.debug("erepo: unshallowing clone for '%s'", url)
                    _pull(git, unshallow=True)
                    shallow = False
                    CLONES[url] = (clone_path, shallow)
                else:
                    logger.debug("erepo: git pull '%s'", url)
                    _pull(git)
        else:
            from dvc.scm import clone

            logger.debug("erepo: git clone '%s' to a temporary dir", url)
            clone_path = tempfile.mkdtemp("dvc-clone")
            if rev and not Git.is_sha(rev):
                # If rev is a tag or branch name try shallow clone first

                try:
                    git = clone(url, clone_path, shallow_branch=rev)
                    shallow = os.path.exists(
                        os.path.join(clone_path, Git.GIT_DIR, "shallow")
                    )
                    if shallow:
                        logger.debug("erepo: using shallow clone for branch '%s'", rev)
                except CloneError:
                    git_dir = os.path.join(clone_path, ".git")
                    if os.path.exists(git_dir):
                        _remove(git_dir)
            if not git:
                git = clone(url, clone_path)
                shallow = False
            CLONES[url] = (clone_path, shallow)
    finally:
        if git:
            git.close()

    return clone_path, shallow


def _pull(git: "Git", unshallow: bool = False):
    from dvc.repo.experiments.utils import fetch_all_exps

    git.fetch(unshallow=unshallow)
    _merge_upstream(git)
    fetch_all_exps(git, "origin")


def _merge_upstream(git: "Git"):
    from scmrepo.exceptions import SCMError

    try:
        branch = git.active_branch()
        upstream = f"refs/remotes/origin/{branch}"
        if git.get_ref(upstream):
            git.merge(upstream)
    except SCMError:
        pass


def _remove(path):
    from dvc.utils.fs import remove

    if os.name == "nt":
        # git.exe may hang for a while not permitting to remove temp dir
        os_retry = retry(5, errors=OSError, timeout=0.1)
        try:
            os_retry(remove)(path)
        except PermissionError:
            logger.warning("Failed to remove '%s'", relpath(path), exc_info=True)
    else:
        remove(path)