iterative/dvc

View on GitHub
dvc/fs/repo.py

Summary

Maintainability
D
1 day
Test Coverage
import logging
import os
import threading
from contextlib import suppress
from itertools import takewhile
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Type, Union

from fsspec.spec import AbstractFileSystem
from funcy import cached_property, wrap_prop, wrap_with

from ._callback import DEFAULT_CALLBACK
from .base import FileSystem
from .dvc import DvcFileSystem
from .path import Path

if TYPE_CHECKING:
    from dvc.repo import Repo

logger = logging.getLogger(__name__)

RepoFactory = Union[Callable[[str], "Repo"], Type["Repo"]]


def _wrap_walk(dvc_fs, *args, **kwargs):
    for root, dnames, fnames in dvc_fs.walk(*args, **kwargs):
        yield dvc_fs.path.join(dvc_fs.repo.root_dir, root), dnames, fnames


def _ls(fs, path):
    dnames = []
    fnames = []

    for entry in fs.ls(path, detail=True):
        name = fs.path.name(entry["name"])
        if entry["type"] == "directory":
            dnames.append(name)
        else:
            fnames.append(name)

    return dnames, fnames


def _merge_info(repo, fs_info, dvc_info):
    from dvc.utils import is_exec

    ret = {"repo": repo}

    if dvc_info:
        ret["dvc_info"] = dvc_info
        ret["type"] = dvc_info["type"]
        ret["size"] = dvc_info["size"]
        if not fs_info and "md5" in dvc_info:
            ret["md5"] = dvc_info["md5"]

    if fs_info:
        ret["type"] = fs_info["type"]
        ret["size"] = fs_info["size"]
        isexec = False
        if fs_info["type"] == "file":
            isexec = is_exec(fs_info["mode"])
        ret["isexec"] = isexec

    return ret


class _RepoFileSystem(AbstractFileSystem):  # pylint:disable=abstract-method
    """DVC + git-tracked files fs.

    Args:
        repo: DVC or git repo.
        subrepos: traverse to subrepos (by default, it ignores subrepos)
        repo_factory: A function to initialize subrepo with, default is Repo.
        kwargs: Additional keyword arguments passed to the `DvcFileSystem()`.
    """

    PARAM_REPO_URL = "repo_url"
    PARAM_REPO_ROOT = "repo_root"
    PARAM_REV = "rev"
    PARAM_CACHE_DIR = "cache_dir"
    PARAM_CACHE_TYPES = "cache_types"
    PARAM_SUBREPOS = "subrepos"

    def __init__(
        self,
        repo: Optional["Repo"] = None,
        subrepos=False,
        repo_factory: RepoFactory = None,
        **kwargs,
    ):
        super().__init__()

        from pygtrie import Trie

        if repo is None:
            repo, repo_factory = self._repo_from_fs_config(
                subrepos=subrepos, **kwargs
            )

        if not repo_factory:
            from dvc.repo import Repo

            self.repo_factory: RepoFactory = Repo
        else:
            self.repo_factory = repo_factory

        self.path = Path(self.sep)
        self.repo = repo
        self.hash_jobs = repo.fs.hash_jobs
        self._traverse_subrepos = subrepos

        self._subrepos_trie = Trie()
        """Keeps track of each and every path with the corresponding repo."""

        key = self._get_key(self.repo.root_dir)
        self._subrepos_trie[key] = repo

        self._dvcfss = {}
        """Keep a dvcfs instance of each repo."""

        if hasattr(repo, "dvc_dir"):
            self._dvcfss[key] = DvcFileSystem(repo=repo)

    def _get_key(self, path):
        parts = self.repo.fs.path.relparts(path, self.repo.root_dir)
        if parts == (".",):
            parts = ()
        return parts

    @property
    def repo_url(self):
        if self.repo is None:
            return None
        return self.repo.url

    @property
    def config(self):
        return {
            self.PARAM_REPO_URL: self.repo_url,
            self.PARAM_REPO_ROOT: self.repo.root_dir,
            self.PARAM_REV: getattr(self.repo.fs, "rev", None),
            self.PARAM_CACHE_DIR: os.path.abspath(
                self.repo.odb.local.cache_dir
            ),
            self.PARAM_CACHE_TYPES: self.repo.odb.local.cache_types,
            self.PARAM_SUBREPOS: self._traverse_subrepos,
        }

    @classmethod
    def _repo_from_fs_config(
        cls, **config
    ) -> Tuple["Repo", Optional["RepoFactory"]]:
        from dvc.external_repo import erepo_factory, external_repo
        from dvc.repo import Repo

        url = config.get(cls.PARAM_REPO_URL)
        root = config.get(cls.PARAM_REPO_ROOT)
        assert url or root

        def _open(*args, **kwargs):
            # NOTE: if original repo was an erepo (and has a URL),
            # we cannot use Repo.open() since it will skip erepo
            # cache/remote setup for local URLs
            if url is None:
                return Repo.open(*args, **kwargs)
            return external_repo(*args, **kwargs)

        cache_dir = config.get(cls.PARAM_CACHE_DIR)
        cache_config = (
            {}
            if not cache_dir
            else {
                "cache": {
                    "dir": cache_dir,
                    "type": config.get(cls.PARAM_CACHE_TYPES),
                }
            }
        )
        repo_kwargs: dict = {
            "rev": config.get(cls.PARAM_REV),
            "subrepos": config.get(cls.PARAM_SUBREPOS, False),
            "uninitialized": True,
        }
        factory: Optional["RepoFactory"] = None
        if url is None:
            repo_kwargs["config"] = cache_config
        else:
            repo_kwargs["cache_dir"] = cache_dir
            factory = erepo_factory(url, cache_config)

        with _open(
            url if url else root,
            **repo_kwargs,
        ) as repo:
            return repo, factory

    def _get_repo(self, path: str) -> "Repo":
        """Returns repo that the path falls in, using prefix.

        If the path is already tracked/collected, it just returns the repo.

        Otherwise, it collects the repos that might be in the path's parents
        and then returns the appropriate one.
        """
        if not self.repo.fs.path.isin_or_eq(path, self.repo.root_dir):
            # outside of repo
            return self.repo

        key = self._get_key(path)
        repo = self._subrepos_trie.get(key)
        if repo:
            return repo

        prefix_key, repo = self._subrepos_trie.longest_prefix(key)
        prefix = self.repo.fs.path.join(
            self.repo.root_dir,
            *prefix_key,  # pylint: disable=not-an-iterable
        )

        parents = (parent for parent in self.repo.fs.path.parents(path))
        dirs = [path] + list(takewhile(lambda p: p != prefix, parents))
        dirs.reverse()
        self._update(dirs, starting_repo=repo)
        return self._subrepos_trie.get(key) or self.repo

    @wrap_with(threading.Lock())
    def _update(self, dirs, starting_repo):
        """Checks for subrepo in directories and updates them."""
        repo = starting_repo
        for d in dirs:
            key = self._get_key(d)
            if self._is_dvc_repo(d):
                repo = self.repo_factory(
                    d,
                    fs=self.repo.fs,
                    repo_factory=self.repo_factory,
                )
                self._dvcfss[key] = DvcFileSystem(repo=repo)
            self._subrepos_trie[key] = repo

    def _is_dvc_repo(self, dir_path):
        """Check if the directory is a dvc repo."""
        if not self._traverse_subrepos:
            return False

        from dvc.repo import Repo

        repo_path = os.path.join(dir_path, Repo.DVC_DIR)
        return self.repo.fs.isdir(repo_path)

    def _get_fs_pair(
        self, path
    ) -> Tuple[
        Optional[FileSystem],
        Optional[str],
        Optional[DvcFileSystem],
        Optional[str],
    ]:
        """
        Returns a pair of fss based on repo the path falls in, using prefix.
        """
        from dvc.utils import as_posix

        if os.path.isabs(path):
            if self.repo.fs.path.isin_or_eq(path, self.repo.root_dir):
                path = self.repo.fs.path.relpath(path, self.repo.root_dir)
            else:
                return None, None, self.repo.dvcfs, path

        path = as_posix(path)

        parts = self.path.parts(path)
        if parts and parts[0] == os.curdir:
            parts = parts[1:]

        fs_path = self.repo.fs.path.join(self.repo.root_dir, *parts)
        repo = self._get_repo(fs_path)
        fs = repo.fs

        repo_parts = fs.path.relparts(repo.root_dir, self.repo.root_dir)
        if repo_parts[0] == os.curdir:
            repo_parts = repo_parts[1:]

        dvc_parts = parts[len(repo_parts) :]
        if dvc_parts and dvc_parts[0] == os.curdir:
            dvc_parts = dvc_parts[1:]

        key = self._get_key(repo.root_dir)
        dvc_fs = self._dvcfss.get(key)
        if dvc_fs:
            dvc_path = dvc_fs.path.join(*dvc_parts) if dvc_parts else ""
        else:
            dvc_path = None

        return fs, fs_path, dvc_fs, dvc_path

    def open(
        self, path, mode="r", encoding="utf-8", **kwargs
    ):  # pylint: disable=arguments-renamed, arguments-differ
        if "b" in mode:
            encoding = None

        fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path)
        try:
            return fs.open(fs_path, mode=mode, encoding=encoding)
        except FileNotFoundError:
            if not dvc_fs:
                raise

        return dvc_fs.open(dvc_path, mode=mode, encoding=encoding, **kwargs)

    def isdvc(self, path, **kwargs):
        _, _, dvc_fs, dvc_path = self._get_fs_pair(path)
        return dvc_fs is not None and dvc_fs.isdvc(dvc_path, **kwargs)

    def ls(self, path, detail=True, **kwargs):
        fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path)

        repo = dvc_fs.repo if dvc_fs else self.repo
        dvcignore = repo.dvcignore
        ignore_subrepos = kwargs.get("ignore_subrepos", True)

        names = set()
        if dvc_fs:
            with suppress(FileNotFoundError):
                for entry in dvc_fs.ls(dvc_path, detail=False):
                    names.add(dvc_fs.path.name(entry))

        if fs:
            try:
                for entry in dvcignore.ls(
                    fs, fs_path, detail=False, ignore_subrepos=ignore_subrepos
                ):
                    names.add(fs.path.name(entry))
            except (FileNotFoundError, NotADirectoryError):
                pass

        dvcfiles = kwargs.get("dvcfiles", False)

        def _func(fname):
            from dvc.dvcfile import is_valid_filename
            from dvc.ignore import DvcIgnore

            if dvcfiles:
                return True

            return not (
                is_valid_filename(fname) or fname == DvcIgnore.DVCIGNORE_FILE
            )

        names = filter(_func, names)

        infos = []
        paths = []
        for name in names:
            entry_path = self.path.join(path, name)
            try:
                info = self.info(entry_path, ignore_subrepos=ignore_subrepos)
            except FileNotFoundError:
                continue
            infos.append(info)
            paths.append(entry_path)

        if not detail:
            return paths

        return infos

    def get_file(self, rpath, lpath, callback=DEFAULT_CALLBACK, **kwargs):
        fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(rpath)

        if fs:
            try:
                fs.get_file(  # pylint: disable=protected-access
                    fs_path, lpath, callback=callback, **kwargs
                )
                return
            except FileNotFoundError:
                if not dvc_fs:
                    raise

        dvc_fs.get_file(  # pylint: disable=protected-access
            dvc_path, lpath, callback=callback, **kwargs
        )

    def info(self, path, **kwargs):
        fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path)

        repo = dvc_fs.repo if dvc_fs else self.repo
        dvcignore = repo.dvcignore
        ignore_subrepos = kwargs.get("ignore_subrepos", True)

        dvc_info = None
        if dvc_fs:
            try:
                dvc_info = dvc_fs.info(dvc_path)
            except FileNotFoundError:
                pass

        fs_info = None
        if fs:
            try:
                fs_info = fs.info(fs_path)
                if dvcignore.is_ignored(
                    fs, fs_path, ignore_subrepos=ignore_subrepos
                ):
                    fs_info = None
            except (FileNotFoundError, NotADirectoryError):
                if not dvc_info:
                    raise

        # NOTE: if some parent in fs_path turns out to be a file, it means
        # that the whole repofs branch doesn't exist.
        if fs and not fs_info and dvc_info:
            for parent in fs.path.parents(fs_path):
                try:
                    if fs.info(parent)["type"] != "directory":
                        dvc_info = None
                        break
                except FileNotFoundError:
                    continue

        if not dvc_info and not fs_info:
            raise FileNotFoundError

        info = _merge_info(dvc_fs.repo, fs_info, dvc_info)
        info["name"] = path
        return info

    def checksum(self, path):
        fs, fs_path, dvc_fs, dvc_path = self._get_fs_pair(path)

        try:
            return fs.checksum(fs_path)
        except FileNotFoundError:
            return dvc_fs.checksum(dvc_path)


class RepoFileSystem(FileSystem):
    scheme = "local"
    PARAM_CHECKSUM = "md5"

    def _prepare_credentials(self, **config):
        return config

    @wrap_prop(threading.Lock())
    @cached_property
    def fs(self):
        return _RepoFileSystem(**self.fs_args)

    def isdvc(self, path, **kwargs):
        return self.fs.isdvc(path, **kwargs)

    @property
    def repo(self):
        return self.fs.repo

    @property
    def repo_url(self):
        return self.fs.repo_url

    @property
    def config(self):
        return self.fs.config