iterative/dvc

View on GitHub
dvc/stage/cache.py

Summary

Maintainability
B
4 hrs
Test Coverage
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

from funcy import first

from dvc import fs
from dvc.config import RemoteConfigError
from dvc.exceptions import CollectCacheError, DvcException
from dvc.log import logger
from dvc.utils import dict_sha256, relpath

if TYPE_CHECKING:
    from dvc_objects.db import ObjectDB

logger = logger.getChild(__name__)


class RunCacheNotFoundError(DvcException):
    def __init__(self, stage):
        super().__init__(f"No run-cache for {stage.addressing}")


class RunCacheNotSupported(DvcException):
    pass


def _get_cache_hash(cache, key=False):
    from dvc_data.hashfile.meta import Meta

    if key:
        cache["outs"] = [out["path"] for out in cache.get("outs", [])]
    return dict_sha256(cache, exclude=[Meta.PARAM_SIZE, Meta.PARAM_NFILES])


def _can_hash(stage):
    if stage.is_callback or stage.always_changed:
        return False

    if not all([stage.cmd, stage.deps, stage.outs]):
        return False

    for dep in stage.deps:
        if not (dep.protocol == "local" and dep.def_path and dep.get_hash()):
            return False

    for out in stage.outs:
        if (
            out.protocol != "local"
            or not out.def_path
            or out.persist
            or not out.is_in_repo
        ):
            return False

    return True


def _get_stage_hash(stage):
    from .serialize import to_single_stage_lockfile

    assert _can_hash(stage)
    return _get_cache_hash(to_single_stage_lockfile(stage), key=True)


class StageCache:
    def __init__(self, repo):
        self.repo = repo
        self.cache_dir = os.path.join(self.repo.cache.legacy.path, "runs")

    def _get_cache_dir(self, key):
        return os.path.join(self.cache_dir, key[:2], key)

    def _get_cache_path(self, key, value):
        return os.path.join(self._get_cache_dir(key), value)

    def _load_cache(self, key, value):
        from voluptuous import Invalid

        from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA
        from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml

        path = self._get_cache_path(key, value)

        try:
            return COMPILED_LOCK_FILE_STAGE_SCHEMA(load_yaml(path))
        except FileNotFoundError:
            return None
        except (YAMLFileCorruptedError, Invalid):
            logger.warning("corrupted cache file '%s'.", relpath(path))
            os.unlink(path)
            return None

    def _load(self, stage):
        key = _get_stage_hash(stage)
        if not key:
            return None

        cache_dir = self._get_cache_dir(key)
        if not os.path.exists(cache_dir):
            return None

        newest_entry = first(
            sorted(
                os.listdir(cache_dir),
                key=lambda f: os.path.getmtime(os.path.join(cache_dir, f)),
                reverse=True,
            )
        )
        cache = self._load_cache(key, newest_entry)
        if cache:
            return cache

        return None

    def _create_stage(self, cache, wdir=None):
        from . import PipelineStage, create_stage
        from .loader import StageLoader

        stage = create_stage(
            PipelineStage,
            repo=self.repo,
            path="dvc.yaml",
            cmd=cache["cmd"],
            wdir=wdir,
            outs=[out["path"] for out in cache["outs"]],
        )
        StageLoader.fill_from_lock(stage, cache)
        return stage

    @contextmanager
    def _cache_type_copy(self):
        cache_types = self.repo.cache.local.cache_types
        legacy_cache_types = self.repo.cache.legacy.cache_types
        self.repo.cache.local.cache_types = ["copy"]
        self.repo.cache.legacy.cache_types = ["copy"]
        try:
            yield
        finally:
            self.repo.cache.local.cache_types = cache_types
            self.repo.cache.legacy.cache_types = legacy_cache_types

    def _uncached_outs(self, stage, cache):
        # NOTE: using temporary stage to avoid accidentally modifying original
        # stage and to workaround `commit/checkout` not working for uncached
        # outputs.
        cached_stage = self._create_stage(cache, wdir=stage.wdir)

        outs_no_cache = [out.def_path for out in stage.outs if not out.use_cache]

        # NOTE: using copy link to make it look like a git-tracked file
        with self._cache_type_copy():
            for out in cached_stage.outs:
                if out.def_path in outs_no_cache and out.is_in_repo:
                    yield out

    def save(self, stage):
        from .serialize import to_single_stage_lockfile

        if not _can_hash(stage):
            return

        cache_key = _get_stage_hash(stage)
        cache = to_single_stage_lockfile(stage)
        cache_value = _get_cache_hash(cache)

        existing_cache = self._load_cache(cache_key, cache_value)
        cache = existing_cache or cache

        for out in self._uncached_outs(stage, cache):
            out.commit()

        if existing_cache:
            return

        from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA
        from dvc.utils.serialize import dump_yaml

        # sanity check
        COMPILED_LOCK_FILE_STAGE_SCHEMA(cache)

        path = self._get_cache_path(cache_key, cache_value)
        local_fs = self.repo.cache.legacy.fs
        parent = local_fs.parent(path)
        self.repo.cache.legacy.makedirs(parent)
        tmp = local_fs.join(parent, fs.utils.tmp_fname())
        assert os.path.exists(parent)
        assert os.path.isdir(parent)
        dump_yaml(tmp, cache)
        self.repo.cache.legacy.move(tmp, path)

    def restore(self, stage, run_cache=True, pull=False, dry=False):  # noqa: C901
        from .serialize import to_single_stage_lockfile

        if not _can_hash(stage):
            raise RunCacheNotFoundError(stage)

        if (
            not stage.changed_stage()
            and stage.deps_cached()
            and all(bool(out.hash_info) for out in stage.outs)
        ):
            cache = to_single_stage_lockfile(stage)
        else:
            if not run_cache:  # backward compatibility
                raise RunCacheNotFoundError(stage)
            if not dry:
                stage.save_deps()
            cache = self._load(stage)
            if not cache:
                raise RunCacheNotFoundError(stage)

        cached_stage = self._create_stage(cache, wdir=stage.wdir)

        if pull and not dry:
            try:
                for objs in cached_stage.get_used_objs().values():
                    self.repo.cloud.pull(objs)
            except CollectCacheError as exc:
                raise RunCacheNotFoundError(stage) from exc

        if not cached_stage.outs_cached():
            raise RunCacheNotFoundError(stage)

        logger.info(
            "Stage '%s' is cached - skipping run, checking out outputs",
            stage.addressing,
        )
        if not dry:
            cached_stage.checkout()

    def transfer(self, from_odb, to_odb, force=True):
        from dvc.fs import HTTPFileSystem, LocalFileSystem
        from dvc.fs.callbacks import TqdmCallback

        from_fs = from_odb.fs
        to_fs = to_odb.fs
        func = fs.generic.log_exceptions(fs.generic.copy)
        runs = from_fs.join(from_odb.path, "runs")

        http_odb = next(
            (odb for odb in (from_odb, to_odb) if isinstance(odb.fs, HTTPFileSystem)),
            None,
        )
        if http_odb:
            path = http_odb.path
            message = f"run-cache is not supported for http filesystem: {path}"
            raise RunCacheNotSupported(message)

        ret: list[tuple[str, str]] = []
        if not from_fs.exists(runs):
            return ret

        for src in from_fs.find(runs):
            rel = from_fs.relpath(src, from_odb.path)
            if not isinstance(to_fs, LocalFileSystem):
                rel = from_fs.as_posix(rel)

            dst = to_fs.join(to_odb.path, rel)
            key = to_fs.parent(dst)

            # check if any build cache already exists for this key
            # TODO: check if MaxKeys=1 or something like that applies
            # or otherwise this will take a lot of time!
            if not force and to_fs.exists(key) and first(to_fs.find(key)):
                continue

            src_name = from_fs.name(src)
            parent_name = from_fs.name(from_fs.parent(src))
            with TqdmCallback(desc=src_name, bytes=True) as cb:
                func(from_fs, src, to_fs, dst, callback=cb)
            ret.append((parent_name, src_name))
        return ret

    def push(self, remote: Optional[str], odb: Optional["ObjectDB"] = None):
        try:
            dest_odb = odb or self.repo.cloud.get_remote_odb(
                remote, "push --run-cache", hash_name="md5-dos2unix"
            )
        except RemoteConfigError as e:
            raise RunCacheNotSupported(e) from e
        return self.transfer(self.repo.cache.legacy, dest_odb)

    def pull(self, remote: Optional[str], odb: Optional["ObjectDB"] = None):
        try:
            odb = odb or self.repo.cloud.get_remote_odb(
                remote, "fetch --run-cache", hash_name="md5-dos2unix"
            )
        except RemoteConfigError as e:
            raise RunCacheNotSupported(e) from e
        return self.transfer(odb, self.repo.cache.legacy)

    def get_used_objs(self, used_run_cache, *args, **kwargs):
        """Return used cache for the specified run-cached stages."""
        from collections import defaultdict

        used_objs = defaultdict(set)
        for key, value in used_run_cache:
            entry = self._load_cache(key, value)
            if not entry:
                continue
            stage = self._create_stage(entry)
            for odb, objs in stage.get_used_objs(*args, **kwargs).items():
                used_objs[odb].update(objs)
        return used_objs