View on GitHub


1 day
Test Coverage
import itertools
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from copy import copy
from functools import partial
from typing import TYPE_CHECKING, Optional

from dvc.objects.errors import ObjectDBPermissionError, ObjectFormatError
from dvc.objects.file import HashFile

    from typing import Tuple

    from dvc.fs.base import AnyFSPath, FileSystem
    from dvc.hash_info import HashInfo

logger = logging.getLogger(__name__)

def noop(*args, **kwargs):

def wrap_iter(iterable, callback):
    for index, item in enumerate(iterable, start=1):
        yield item

class ObjectDB:

    DEFAULT_CACHE_TYPES = ["copy"]
    CACHE_MODE: Optional[int] = None

    def __init__(self, fs: "FileSystem", path: str, **config):
        from dvc.state import StateNoop

        self.fs = fs
        self.fs_path = path
        self.state = config.get("state", StateNoop())
        self.verify = config.get("verify", self.DEFAULT_VERIFY)
        self.cache_types = config.get("type") or copy(self.DEFAULT_CACHE_TYPES)
        self.slow_link_warning = config.get("slow_link_warning", True)
        self.tmp_dir = config.get("tmp_dir")
        self.read_only = config.get("read_only", False)

    def config(self):
        return {
            "state": self.state,
            "verify": self.verify,
            "type": self.cache_types,
            "slow_link_warning": self.slow_link_warning,
            "tmp_dir": self.tmp_dir,
            "read_only": self.read_only,

    def __eq__(self, other):
        return (
            self.fs == other.fs
            and self.fs_path == other.fs_path
            and self.read_only == other.read_only

    def __hash__(self):
        return hash((self.fs.scheme, self.fs_path))

    def exists(self, hash_info: "HashInfo"):
        return self.fs.exists(self.hash_to_path(hash_info.value))

    def move(self, from_info, to_info):
        self.fs.move(from_info, to_info)

    def makedirs(self, fs_path):

    def get(self, hash_info: "HashInfo"):
        """get raw object"""
        return HashFile(

    def _add_file(
        from_fs: "FileSystem",
        from_info: "AnyFSPath",
        to_info: "AnyFSPath",
        _hash_info: "HashInfo",
        hardlink: bool = False,
        from dvc import fs

        return fs.utils.transfer(

    def add(
        fs_path: "AnyFSPath",
        fs: "FileSystem",
        hash_info: "HashInfo",
        hardlink: bool = False,
        verify: Optional[bool] = None,
        if self.read_only:
            raise ObjectDBPermissionError("Cannot add to read-only ODB")

        if verify is None:
            verify = self.verify
            self.check(hash_info, check_hash=verify)
        except (ObjectFormatError, FileNotFoundError):

        cache_fs_path = self.hash_to_path(hash_info.value)
            fs, fs_path, cache_fs_path, hash_info, hardlink=hardlink

            if verify:
                self.check(hash_info, check_hash=True)
  , self.fs, hash_info)
        except (ObjectFormatError, FileNotFoundError):

    def hash_to_path(self, hash_):
        return self.fs.path.join(self.fs_path, hash_[0:2], hash_[2:])

    def protect(self, fs_path):  # pylint: disable=unused-argument

    def is_protected(self, fs_path):  # pylint: disable=unused-argument
        return False

    def unprotect(self, fs_path):  # pylint: disable=unused-argument

    def set_exec(self, fs_path):  # pylint: disable=unused-argument

    def check(
        hash_info: "HashInfo",
        check_hash: bool = True,
        """Compare the given hash with the (corresponding) actual one if
        check_hash is specified, or just verify the existence of the cache
        files on the filesystem.

        - Use `State` as a cache for computed hashes
            + The entries are invalidated by taking into account the following:
                * mtime
                * inode
                * size
                * hash

        - Remove the file from cache if it doesn't match the actual hash

        obj = self.get(hash_info)
        if self.is_protected(obj.fs_path):
            logger.trace(  # type: ignore[attr-defined]
                "Assuming '%s' is unchanged since it is read-only",

            obj.check(self, check_hash=check_hash)
        except ObjectFormatError:
            logger.warning("corrupted cache file '%s'.", obj.fs_path)
            with suppress(FileNotFoundError):

        if check_hash:
            # making cache file read-only so we don't need to check it
            # next time

    def _list_paths(self, prefix: str = None):
        prefix = prefix or ""
        parts: "Tuple[str, ...]" = (self.fs_path,)
        if prefix:
            parts = *parts, prefix[:2]
        if len(prefix) > 2:
            parts = *parts, prefix[2:]
        yield from self.fs.find(self.fs.path.join(*parts), prefix=bool(prefix))

    def _path_to_hash(self, path):
        parts =[-2:]

        if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2):
            raise ValueError(f"Bad cache file path '{path}'")

        return "".join(parts)

    def _list_hashes(self, prefix=None):
        """Iterate over hashes in this fs.

        If `prefix` is specified, only hashes which begin with `prefix`
        will be returned.
        for path in self._list_paths(prefix):
                yield self._path_to_hash(path)
            except ValueError:
                    "'%s' doesn't look like a cache file, skipping", path

    def _hashes_with_limit(self, limit, prefix=None):
        count = 0
        for hash_ in self._list_hashes(prefix):
            yield hash_
            count += 1
            if count > limit:
                    "`_list_hashes()` returned max '{}' hashes, "
                    "skipping remaining results".format(limit)

    def _max_estimation_size(self, hashes):
        # Max remote size allowed for us to use traverse method
        return max(
            / self.fs.TRAVERSE_WEIGHT_MULTIPLIER
            * self.fs.LIST_OBJECT_PAGE_SIZE,

    def _estimate_remote_size(self, hashes=None, progress=noop):
        """Estimate fs size based on number of entries beginning with
        "00..." prefix.

        Takes a progress callback that returns current_estimated_size.
        prefix = "0" * self.fs.TRAVERSE_PREFIX_LEN
        total_prefixes = pow(16, self.fs.TRAVERSE_PREFIX_LEN)
        if hashes:
            max_hashes = self._max_estimation_size(hashes)
            max_hashes = None

        def iter_with_pbar(hashes):
            total = 0
            for hash_ in hashes:
                total += total_prefixes
                yield hash_

        if max_hashes:
            hashes = self._hashes_with_limit(
                max_hashes / total_prefixes, prefix
            hashes = self._list_hashes(prefix)

        remote_hashes = set(iter_with_pbar(hashes))
        if remote_hashes:
            remote_size = total_prefixes * len(remote_hashes)
            remote_size = total_prefixes
        logger.debug(f"Estimated remote size: {remote_size} files")
        return remote_size, remote_hashes

    def _list_hashes_traverse(self, remote_size, remote_hashes, jobs=None):
        """Iterate over all hashes found in this fs.
        Hashes are fetched in parallel according to prefix, except in
        cases where the remote size is very small.

        All hashes from the remote (including any from the size
        estimation step passed via the `remote_hashes` argument) will be

        NOTE: For large remotes the list of hashes will be very
        big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list)
        and we don't really need all of it at the same time, so it makes
        sense to use a generator to gradually iterate over it, without
        keeping all of it in memory.
        from funcy import collecting

        num_pages = remote_size / self.fs.LIST_OBJECT_PAGE_SIZE
        if num_pages < 256 /
            # Fetching prefixes in parallel requires at least 255 more
            # requests, for small enough remotes it will be faster to fetch
            # entire cache without splitting it into prefixes.
            # NOTE: this ends up re-fetching hashes that were already
            # fetched during remote size estimation
            traverse_prefixes = [None]
            yield from remote_hashes
            traverse_prefixes = [f"{i:02x}" for i in range(1, 256)]
            if self.fs.TRAVERSE_PREFIX_LEN > 2:
                traverse_prefixes += [
                    "{0:0{1}x}".format(i, self.fs.TRAVERSE_PREFIX_LEN)
                    for i in range(1, pow(16, self.fs.TRAVERSE_PREFIX_LEN - 2))

        list_hashes = collecting(self._list_hashes)
        with ThreadPoolExecutor(max_workers=jobs or as executor:
            in_remote =, traverse_prefixes)
            yield from itertools.chain.from_iterable(in_remote)

    def all(self, jobs=None):
        """Iterate over all hashes in this fs.

        Hashes will be fetched in parallel threads according to prefix
        (except for small remotes) and a progress bar will be displayed.
        if not self.fs.CAN_TRAVERSE:
            return self._list_hashes()

        remote_size, remote_hashes = self._estimate_remote_size()
        return self._list_hashes_traverse(
            remote_size, remote_hashes, jobs=jobs

    def _remove_unpacked_dir(self, hash_):

    def list_hashes_exists(self, hashes, jobs=None):
        """Return list of the specified hashes which exist in this fs.
        Hashes will be queried individually.
        logger.debug(f"Querying {len(hashes)} hashes via object_exists")
        with ThreadPoolExecutor(max_workers=jobs or as executor:
            fs_paths = map(self.hash_to_path, hashes)
            in_remote =, fs_paths)
            yield from itertools.compress(hashes, in_remote)

    def hashes_exist(self, hashes, jobs=None, progress=noop):
        """Check if the given hashes are stored in the remote.

        There are two ways of performing this check:

        - Traverse method: Get a list of all the files in the remote
            (traversing the cache directory) and compare it with
            the given hashes. Cache entries will be retrieved in parallel
            threads according to prefix (i.e. entries starting with, "00...",
            "01...", and so on) and a progress bar will be displayed.

        - Exists method: For each given hash, run the `exists`
            method and filter the hashes that aren't on the remote.
            This is done in parallel threads.
            It also shows a progress bar when performing the check.

        The reason for such an odd logic is that most of the remotes
        take much shorter time to just retrieve everything they have under
        a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can
        check if particular file exists much quicker, use their own
        implementation of hashes_exist (see ssh, local).

        Which method to use will be automatically determined after estimating
        the size of the remote cache, and comparing the estimated size with
        len(hashes). To estimate the size of the remote cache, we fetch
        a small subset of cache entries (i.e. entries starting with "00...").
        Based on the number of entries in that subset, the size of the full
        cache can be estimated, since the cache is evenly distributed according
        to hash.

        Takes a callback that returns value in the format of:
        (phase, total, current). The phase can be {"estimating, "querying"}.

            A list with hashes that were found in the remote
        # Remotes which do not use traverse prefix should override
        # hashes_exist() (see ssh, local)
        assert self.fs.TRAVERSE_PREFIX_LEN >= 2

        # During the tests, for ensuring that the traverse behavior
        # is working we turn on this option. It will ensure the
        # _list_hashes_traverse() is called.
        always_traverse = getattr(self.fs, "_ALWAYS_TRAVERSE", False)

        hashes = set(hashes)
        if (
            len(hashes) == 1 or not self.fs.CAN_TRAVERSE
        ) and not always_traverse:
            remote_hashes = self.list_hashes_exists(hashes, jobs)
            callback = partial(progress, "querying", len(hashes))
            return list(wrap_iter(remote_hashes, callback))

        # Max remote size allowed for us to use traverse method

        estimator_cb = partial(progress, "estimating", None)
        remote_size, remote_hashes = self._estimate_remote_size(
            hashes, progress=estimator_cb

        traverse_pages = remote_size / self.fs.LIST_OBJECT_PAGE_SIZE
        # For sufficiently large remotes, traverse must be weighted to account
        # for performance overhead from large lists/sets.
        # From testing with S3, for remotes with 1M+ files, object_exists is
        # faster until len(hashes) is at least 10k~100k
        if remote_size > self.fs.TRAVERSE_THRESHOLD_SIZE:
            traverse_weight = (
                traverse_pages * self.fs.TRAVERSE_WEIGHT_MULTIPLIER
            traverse_weight = traverse_pages
        if len(hashes) < traverse_weight and not always_traverse:
                "Large remote ('{}' hashes < '{}' traverse weight), "
                "using object_exists for remaining hashes".format(
                    len(hashes), traverse_weight
            remaining_hashes = hashes - remote_hashes
            ret = list(hashes & remote_hashes)
            callback = partial(progress, "querying", len(remaining_hashes))
                    self.list_hashes_exists(remaining_hashes, jobs), callback
            return ret

        logger.debug(f"Querying '{len(hashes)}' hashes via traverse")
        remote_hashes = self._list_hashes_traverse(
            remote_size, remote_hashes, jobs=jobs
        callback = partial(progress, "querying", remote_size)
        return list(hashes & set(wrap_iter(remote_hashes, callback)))