iterative/dvc

View on GitHub
dvc/dependency/__init__.py

Summary

Maintainability
A
2 hrs
Test Coverage
from collections import defaultdict
from typing import Any, Mapping

from dvc.output import ARTIFACT_SCHEMA, Output

from .base import Dependency
from .param import ParamsDependency
from .repo import RepoDependency

# NOTE: schema for dependencies is basically the same as for outputs, but
# without output-specific entries like 'cache' (whether or not output is
# cached, see -o and -O flags for `dvc run`) and 'metric' (whether or not
# output is a metrics file and how to parse it, see `-M` flag for `dvc run`).
SCHEMA: Mapping[str, Any] = {
    **ARTIFACT_SCHEMA,
    **RepoDependency.REPO_SCHEMA,
    **ParamsDependency.PARAM_SCHEMA,
}


def _get(stage, p, info):
    if info and info.get(RepoDependency.PARAM_REPO):
        repo = info.pop(RepoDependency.PARAM_REPO)
        return RepoDependency(repo, stage, p, info)

    if info and info.get(ParamsDependency.PARAM_PARAMS):
        params = info.pop(ParamsDependency.PARAM_PARAMS)
        return ParamsDependency(stage, p, params)

    return Dependency(stage, p, info)


def loadd_from(stage, d_list):
    ret = []
    for d in d_list:
        p = d.pop(Output.PARAM_PATH, None)
        ret.append(_get(stage, p, d))
    return ret


def loads_from(stage, s_list, erepo=None):
    assert isinstance(s_list, list)
    info = {RepoDependency.PARAM_REPO: erepo} if erepo else {}
    return [_get(stage, s, info.copy()) for s in s_list]


def _merge_params(s_list):
    d = defaultdict(list)
    default_file = ParamsDependency.DEFAULT_PARAMS_FILE

    # figure out completely tracked params file, and ignore specific keys
    wholly_tracked = set()
    for key in s_list:
        if not isinstance(key, dict):
            continue
        wholly_tracked.update(k for k, params in key.items() if not params)

    for key in s_list:
        if isinstance(key, str):
            if default_file not in wholly_tracked:
                d[default_file].append(key)
            continue

        if not isinstance(key, dict):
            msg = "Only list of str/dict is supported. Got: "
            msg += f"'{type(key).__name__}'."
            raise ValueError(msg)

        for k, params in key.items():
            if k in wholly_tracked:
                d[k] = []
                continue
            if not isinstance(params, list):
                msg = "Expected list of params for custom params file "
                msg += f"'{k}', got '{type(params).__name__}'."
                raise ValueError(msg)
            d[k].extend(params)
    return d


def loads_params(stage, s_list):
    d = _merge_params(s_list)
    return [
        ParamsDependency(stage, path, params) for path, params in d.items()
    ]