iterative/dvc

View on GitHub
dvc/dependency/param.py

Summary

Maintainability
A
2 hrs
Test Coverage
import os
import typing
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Optional

import dpath

from dvc.exceptions import DvcException
from dvc.log import logger
from dvc.utils.serialize import ParseError, load_path
from dvc_data.hashfile.hash_info import HashInfo

from .base import Dependency

if TYPE_CHECKING:
    from dvc.fs import FileSystem

logger = logger.getChild(__name__)


class MissingParamsError(DvcException):
    pass


class MissingParamsFile(DvcException):
    pass


class ParamsIsADirectoryError(DvcException):
    pass


class BadParamFileError(DvcException):
    pass


def read_param_file(
    fs: "FileSystem",
    path: str,
    key_paths: Optional[list[str]] = None,
    flatten: bool = False,
    **load_kwargs,
) -> Any:
    config = load_path(path, fs, **load_kwargs)
    if not key_paths:
        return config

    ret = {}
    if flatten:
        for key_path in key_paths:
            try:
                ret[key_path] = dpath.get(config, key_path, separator=".")
            except KeyError:
                continue
        return ret

    from copy import deepcopy

    from dpath import merge
    from funcy import distinct

    for key_path in distinct(key_paths):
        merge(
            ret,
            deepcopy(dpath.search(config, key_path, separator=".")),
            separator=".",
        )
    return ret


class ParamsDependency(Dependency):
    PARAM_PARAMS = "params"
    DEFAULT_PARAMS_FILE = "params.yaml"

    def __init__(self, stage, path, params=None, repo=None):
        self.params = list(params) if params else []
        hash_info = HashInfo()
        if isinstance(params, dict):
            hash_info = HashInfo(self.PARAM_PARAMS, params)  # type: ignore[arg-type]
        repo = repo or stage.repo
        path = path or os.path.join(repo.root_dir, self.DEFAULT_PARAMS_FILE)
        super().__init__(stage, path, repo=repo)
        self.hash_name = self.PARAM_PARAMS
        self.hash_info = hash_info

    def dumpd(self, **kwargs):
        ret = super().dumpd()
        if not self.hash_info:
            ret[self.PARAM_PARAMS] = self.params or {}
        return ret

    def fill_values(self, values=None):
        """Load params values dynamically."""
        if values is None:
            return

        info = {}
        if not self.params:
            info.update(values)
        for param in self.params:
            if param in values:
                info[param] = values[param]
        self.hash_info = HashInfo(self.PARAM_PARAMS, info)  # type: ignore[arg-type]

    def read_params(
        self, flatten: bool = True, **kwargs: typing.Any
    ) -> dict[str, typing.Any]:
        try:
            self.validate_filepath()
        except MissingParamsFile:
            return {}

        try:
            return read_param_file(
                self.repo.fs,
                self.fs_path,
                list(self.params) if self.params else None,
                flatten=flatten,
            )
        except ParseError as exc:
            raise BadParamFileError(f"Unable to read parameters from '{self}'") from exc

    def workspace_status(self):
        if not self.exists:
            return {str(self): "deleted"}
        if self.hash_info.value is None:
            return {str(self): "new"}

        from funcy import ldistinct

        status: dict[str, Any] = defaultdict(dict)
        info = self.hash_info.value if self.hash_info else {}
        assert isinstance(info, dict)
        actual = self.read_params()

        # NOTE: we want to preserve the order of params as specified in the
        # status. In case of tracking the whole file, the order is top-level
        # keys in the file and then the keys in the `info` from `dvc.lock`
        # (which are alphabetically sorted).
        params = self.params or ldistinct([*actual.keys(), *info.keys()])
        for param in params:
            if param not in actual:
                st = "deleted"
            elif param not in info:
                st = "new"
            elif actual[param] != info[param]:
                if (
                    isinstance(actual[param], tuple)
                    and list(actual[param]) == info[param]
                ):
                    continue
                st = "modified"
            else:
                continue

            status[str(self)][param] = st

        return status

    def status(self):
        return self.workspace_status()

    def validate_filepath(self):
        if not self.exists:
            raise MissingParamsFile(f"Parameters file '{self}' does not exist")
        if self.isdir():
            raise ParamsIsADirectoryError(
                f"'{self}' is a directory, expected a parameters file"
            )

    def get_hash(self):
        info = self.read_params()

        missing_params = set(self.params) - set(info.keys())
        if missing_params:
            raise MissingParamsError(
                "Parameters '{}' are missing from '{}'.".format(
                    ", ".join(missing_params), self
                )
            )

        return HashInfo(self.PARAM_PARAMS, info)  # type: ignore[arg-type]

    def save(self):
        if not self.exists:
            raise self.DoesNotExistError(self)

        if not self.isfile() and not self.isdir():
            raise self.IsNotFileOrDirError(self)

        self.ignore()
        self.hash_info = self.get_hash()