iterative/dvc

View on GitHub
dvc/scm.py

Summary

Maintainability
B
6 hrs
Test Coverage
"""Manages source control systems (e.g. Git)."""

import os
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Literal, Optional, Union, overload

from funcy import group_by
from scmrepo.base import Base  # noqa: F401
from scmrepo.git import Git
from scmrepo.noscm import NoSCM

from dvc.exceptions import DvcException
from dvc.progress import Tqdm

if TYPE_CHECKING:
    from scmrepo.progress import GitProgressEvent

    from dvc.fs import FileSystem


class SCMError(DvcException):
    """Base class for source control management errors."""


class CloneError(SCMError):
    pass


class RevError(SCMError):
    pass


class NoSCMError(SCMError):
    def __init__(self):
        msg = (
            "Only supported for Git repositories. If you're "
            "seeing this error in a Git repo, try updating the DVC "
            "configuration with `dvc config core.no_scm false`."
        )
        super().__init__(msg)


class InvalidRemoteSCMRepo(SCMError):
    pass


class GitAuthError(SCMError):
    def __init__(self, reason: str) -> None:
        doc = "See https://dvc.org/doc/user-guide/troubleshooting#git-auth"
        super().__init__(f"{reason}\n{doc}")


@contextmanager
def map_scm_exception(with_cause: bool = False) -> Iterator[None]:
    from scmrepo.exceptions import SCMError as InternalSCMError

    try:
        yield
    except InternalSCMError as exc:
        into = SCMError(str(exc))
        if with_cause:
            raise into from exc
        raise into  # noqa: B904


@overload
def SCM(
    root_dir: str,
    *,
    search_parent_directories: bool = ...,
    no_scm: Literal[False] = ...,
) -> "Git": ...


@overload
def SCM(
    root_dir: str,
    *,
    search_parent_directories: bool = ...,
    no_scm: Literal[True],
) -> "NoSCM": ...


@overload
def SCM(
    root_dir: str,
    *,
    search_parent_directories: bool = ...,
    no_scm: bool = ...,
) -> Union["Git", "NoSCM"]: ...


def SCM(root_dir, *, search_parent_directories=True, no_scm=False):
    """Returns SCM instance that corresponds to a repo at the specified
    path.

    Args:
        root_dir (str): path to a root directory of the repo.
        search_parent_directories (bool): whether to look for repo root in
        parent directories.
        no_scm (bool): return NoSCM if True.

    Returns:
        dvc.scm.base.Base: SCM instance.
    """
    with map_scm_exception():
        if no_scm:
            return NoSCM(root_dir, _raise_not_implemented_as=NoSCMError)
        return Git(root_dir, search_parent_directories=search_parent_directories)


class TqdmGit(Tqdm):
    BAR_FMT = (
        "{desc}|{bar}|{postfix[info]}{n_fmt}/{total_fmt} [{elapsed}, {rate_fmt:>11}]"
    )

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("unit", "obj")
        kwargs.setdefault("bar_format", self.BAR_FMT)
        super().__init__(*args, **kwargs)
        self._last_phase = None

    def update_git(self, event: "GitProgressEvent") -> None:
        phase, completed, total, message, *_ = event
        if phase:
            message = (phase + " | " + message) if message else phase
        if message:
            self.set_msg(message)
        force_refresh = (  # force-refresh progress bar when:
            (total and completed and completed >= total)  # the task completes
            or total != self.total  # the total changes
            or phase != self._last_phase  # or, the phase changes
        )
        if completed is not None:
            self.update_to(completed, total)
        if force_refresh:
            self.refresh()
        self._last_phase = phase


def clone(url: str, to_path: str, **kwargs):
    from scmrepo.exceptions import CloneError as InternalCloneError

    from dvc.repo.experiments.utils import fetch_all_exps

    with TqdmGit(desc=f"Cloning {os.path.basename(url)}") as pbar:
        try:
            git = Git.clone(url, to_path, progress=pbar.update_git, **kwargs)
            if "shallow_branch" not in kwargs:
                fetch_all_exps(git, url, progress=pbar.update_git)
            return git
        except InternalCloneError as exc:
            raise CloneError("SCM error") from exc


def resolve_rev(scm: Union["Git", "NoSCM"], rev: str) -> str:
    from scmrepo.exceptions import RevError as InternalRevError

    from dvc.repo.experiments.utils import fix_exp_head

    try:
        return scm.resolve_rev(fix_exp_head(scm, rev))
    except InternalRevError as exc:
        assert isinstance(scm, Git)
        # `scm` will only resolve git branch and tag names,
        # if rev is not a sha it may be an abbreviated experiment name
        if not (rev == "HEAD" or rev.startswith("refs/")):
            from dvc.repo.experiments.utils import AmbiguousExpRefInfo, resolve_name

            try:
                ref_infos = resolve_name(scm, rev).get(rev)
            except AmbiguousExpRefInfo:
                raise RevError(f"ambiguous Git revision '{rev}'")  # noqa: B904
            if ref_infos:
                return scm.get_ref(str(ref_infos))

        raise RevError(str(exc))  # noqa: B904


def _get_n_commits(scm: "Git", revs: list[str], num: int) -> list[str]:
    results = []
    for rev in revs:
        if num == 0:
            continue
        results.append(rev)
        n = 1
        while True:
            if num == n:
                break
            try:
                head = f"{rev}~{n}"
                results.append(resolve_rev(scm, head))
            except RevError:
                break
            n += 1
    return results


def iter_revs(
    scm: "Git",
    revs: Optional[list[str]] = None,
    num: int = 1,
    all_branches: bool = False,
    all_tags: bool = False,
    all_commits: bool = False,
    all_experiments: bool = False,
    commit_date: Optional[str] = None,
) -> Mapping[str, list[str]]:
    from scmrepo.exceptions import SCMError as _SCMError

    from dvc.repo.experiments.utils import exp_commits

    if not any(
        [
            revs,
            all_branches,
            all_tags,
            all_commits,
            all_experiments,
            commit_date,
        ]
    ):
        return {}

    revs = revs or []
    results: list[str] = _get_n_commits(scm, revs, num)

    if all_commits:
        results.extend(scm.list_all_commits())
    else:
        if all_branches:
            results.extend(scm.list_branches())

        if all_tags:
            results.extend(scm.list_tags())

        if commit_date:
            from datetime import datetime

            commit_datestamp = (
                datetime.strptime(commit_date, "%Y-%m-%d").timestamp()  # noqa: DTZ007
            )

            def _time_filter(rev):
                try:
                    return scm.resolve_commit(rev).commit_time >= commit_datestamp
                except _SCMError:
                    return True

            results.extend(filter(_time_filter, scm.list_all_commits()))

    if all_experiments:
        results.extend(exp_commits(scm))

    rev_resolver = partial(resolve_rev, scm)
    return group_by(rev_resolver, results)


def lfs_prefetch(fs: "FileSystem", paths: list[str]):
    from scmrepo.git.lfs import fetch as _lfs_fetch

    from dvc.fs.dvc import DVCFileSystem
    from dvc.fs.git import GitFileSystem

    if isinstance(fs, DVCFileSystem) and isinstance(fs.repo.fs, GitFileSystem):
        git_fs = fs.repo.fs
        scm = fs.repo.scm
        assert isinstance(scm, Git)
    else:
        return

    try:
        if "filter=lfs" not in git_fs.open(".gitattributes").read():
            return
    except OSError:
        return
    with TqdmGit(desc="Checking for Git-LFS objects") as pbar:
        _lfs_fetch(
            scm,
            [git_fs.rev],
            include=[(path if path.startswith("/") else f"/{path}") for path in paths],
            progress=pbar.update_git,
        )