dvc/external_repo.py
import logging
import os
import tempfile
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict
from funcy import retry, wrap_with
from dvc.exceptions import (
FileMissingError,
NoOutputInExternalRepoError,
NoRemoteInExternalRepoError,
NotDvcRepoError,
OutputNotFoundError,
PathMissingError,
)
from dvc.repo import Repo
from dvc.scm import CloneError, map_scm_exception
from dvc.utils import relpath
if TYPE_CHECKING:
from scmrepo import Git
logger = logging.getLogger(__name__)
@contextmanager
@map_scm_exception()
def external_repo(
url, rev=None, for_write=False, cache_dir=None, cache_types=None, **kwargs
):
from scmrepo.git import Git
from dvc.config import NoRemoteError
from dvc.fs import GitFileSystem
logger.debug("Creating external repo %s@%s", url, rev)
path = _cached_clone(url, rev, for_write=for_write)
# Local HEAD points to the tip of whatever branch we first cloned from
# (which may not be the default branch), use origin/HEAD here to get
# the tip of the default branch
rev = rev or "refs/remotes/origin/HEAD"
cache_config = {
"cache": {"dir": cache_dir or _get_cache_dir(url), "type": cache_types}
}
config = _get_remote_config(url) if os.path.isdir(url) else {}
config.update(cache_config)
if for_write:
scm = None
root_dir = path
fs = None
else:
scm = Git(path)
fs = GitFileSystem(scm=scm, rev=rev)
root_dir = "/"
repo_kwargs = dict(
root_dir=root_dir,
url=url,
fs=fs,
config=config,
repo_factory=erepo_factory(url, root_dir, cache_config),
scm=scm,
**kwargs,
)
if "subrepos" not in repo_kwargs:
repo_kwargs["subrepos"] = True
if "uninitialized" not in repo_kwargs:
repo_kwargs["uninitialized"] = True
repo = Repo(**repo_kwargs)
try:
yield repo
except NoRemoteError as exc:
raise NoRemoteInExternalRepoError(url) from exc
except OutputNotFoundError as exc:
if exc.repo is repo:
raise NoOutputInExternalRepoError(
exc.output, repo.root_dir, url
) from exc
raise
except FileMissingError as exc:
raise PathMissingError(exc.path, url) from exc
finally:
repo.close()
if for_write:
_remove(path)
def erepo_factory(url, root_dir, cache_config):
from dvc.fs import localfs
def make_repo(path, fs=None, **_kwargs):
_config = cache_config.copy()
if os.path.isdir(url):
fs = fs or localfs
repo_path = os.path.join(url, *fs.path.relparts(path, root_dir))
_config.update(_get_remote_config(repo_path))
return Repo(path, fs=fs, config=_config, **_kwargs)
return make_repo
CLONES: Dict[str, str] = {}
CACHE_DIRS: Dict[str, str] = {}
@wrap_with(threading.Lock())
def _get_cache_dir(url):
try:
cache_dir = CACHE_DIRS[url]
except KeyError:
cache_dir = CACHE_DIRS[url] = tempfile.mkdtemp("dvc-cache")
return cache_dir
def clean_repos():
# Outside code should not see cache while we are removing
paths = [path for path, _ in CLONES.values()] + list(CACHE_DIRS.values())
CLONES.clear()
CACHE_DIRS.clear()
for path in paths:
_remove(path)
def _get_remote_config(url):
try:
repo = Repo(url)
except NotDvcRepoError:
return {}
try:
name = repo.config["core"].get("remote")
if not name:
# Fill the empty upstream entry with a new remote pointing to the
# original repo's cache location.
name = "auto-generated-upstream"
return {
"core": {"remote": name},
"remote": {name: {"url": repo.odb.local.cache_dir}},
}
# Use original remote to make sure that we are using correct url,
# credential paths, etc if they are relative to the config location.
return {"remote": {name: repo.config["remote"][name]}}
finally:
repo.close()
def _cached_clone(url, rev, for_write=False):
"""Clone an external git repo to a temporary directory.
Returns the path to a local temporary directory with the specified
revision checked out. If for_write is set prevents reusing this dir via
cache.
"""
from shutil import copytree
# even if we have already cloned this repo, we may need to
# fetch/fast-forward to get specified rev
clone_path, shallow = _clone_default_branch(url, rev, for_write=for_write)
if not for_write and (url) in CLONES:
return CLONES[url][0]
# Copy to a new dir to keep the clone clean
repo_path = tempfile.mkdtemp("dvc-erepo")
logger.debug("erepo: making a copy of %s clone", url)
copytree(clone_path, repo_path)
# Check out the specified revision
if for_write:
_git_checkout(repo_path, rev)
else:
CLONES[url] = (repo_path, shallow)
return repo_path
@wrap_with(threading.Lock())
def _clone_default_branch(url, rev, for_write=False):
"""Get or create a clean clone of the url.
The cloned is reactualized with git pull unless rev is a known sha.
"""
from scmrepo.git import Git
clone_path, shallow = CLONES.get(url, (None, False))
git = None
try:
if clone_path:
git = Git(clone_path)
# Do not pull for known shas, branches and tags might move
if not Git.is_sha(rev) or not git.has_rev(rev):
if shallow:
# If we are missing a rev in a shallow clone, fallback to
# a full (unshallowed) clone. Since fetching specific rev
# SHAs is only available in certain git versions, if we
# have need to reference multiple specific revs for a
# given repo URL it is easier/safer for us to work with
# full clones in this case.
logger.debug("erepo: unshallowing clone for '%s'", url)
_pull(git, unshallow=True)
shallow = False
CLONES[url] = (clone_path, shallow)
else:
logger.debug("erepo: git pull '%s'", url)
_pull(git)
else:
from dvc.scm import clone
logger.debug("erepo: git clone '%s' to a temporary dir", url)
clone_path = tempfile.mkdtemp("dvc-clone")
if not for_write and rev and not Git.is_sha(rev):
# If rev is a tag or branch name try shallow clone first
try:
git = clone(url, clone_path, shallow_branch=rev)
shallow = os.path.exists(
os.path.join(clone_path, Git.GIT_DIR, "shallow")
)
if shallow:
logger.debug(
"erepo: using shallow clone for branch '%s'", rev
)
except CloneError:
git_dir = os.path.join(clone_path, ".git")
if os.path.exists(git_dir):
_remove(git_dir)
if not git:
git = clone(url, clone_path)
shallow = False
CLONES[url] = (clone_path, shallow)
finally:
if git:
git.close()
return clone_path, shallow
def _pull(git: "Git", unshallow: bool = False):
from dvc.repo.experiments.utils import fetch_all_exps
git.fetch(unshallow=unshallow)
_merge_upstream(git)
fetch_all_exps(git, "origin")
def _merge_upstream(git: "Git"):
from scmrepo.exceptions import SCMError
try:
branch = git.active_branch()
upstream = f"refs/remotes/origin/{branch}"
if git.get_ref(upstream):
git.merge(upstream)
except SCMError:
pass
def _git_checkout(repo_path, rev):
from scmrepo.git import Git
logger.debug("erepo: git checkout %s@%s", repo_path, rev)
git = Git(repo_path)
try:
git.checkout(rev)
finally:
git.close()
def _remove(path):
from dvc.utils.fs import remove
if os.name == "nt":
# git.exe may hang for a while not permitting to remove temp dir
os_retry = retry(5, errors=OSError, timeout=0.1)
try:
os_retry(remove)(path)
except PermissionError:
logger.warning(
"Failed to remove '%s'", relpath(path), exc_info=True
)
else:
remove(path)