iterative/dvc

View on GitHub
dvc/repo/experiments/run.py

Summary

Maintainability
B
5 hrs
Test Coverage
from collections.abc import Iterable
from typing import Optional

from dvc.dependency.param import ParamsDependency
from dvc.exceptions import InvalidArgumentError
from dvc.log import logger
from dvc.repo import locked
from dvc.ui import ui
from dvc.utils.cli_parse import to_path_overrides

logger = logger.getChild(__name__)


@locked
def run(  # noqa: C901, PLR0912
    repo,
    targets: Optional[Iterable[str]] = None,
    params: Optional[Iterable[str]] = None,
    run_all: bool = False,
    jobs: int = 1,
    tmp_dir: bool = False,
    queue: bool = False,
    copy_paths: Optional[Iterable[str]] = None,
    message: Optional[str] = None,
    **kwargs,
) -> dict[str, str]:
    """Reproduce the specified targets as an experiment.

    Accepts the same additional kwargs as Repo.reproduce.

    Returns a dict mapping new experiment SHAs to the results
    of `repro` for that experiment.
    """
    if kwargs.get("dry"):
        tmp_dir = True

    if run_all:
        return repo.experiments.reproduce_celery(jobs=jobs)

    hydra_sweep = None
    if params:
        from dvc.utils.hydra import to_hydra_overrides

        path_overrides = to_path_overrides(params)

        if tmp_dir or queue:
            untracked = repo.scm.untracked_files()
            for path in path_overrides:
                if path in untracked:
                    logger.debug(
                        "'%s' is currently untracked but will be modified by DVC. "
                        "Adding it to git.",
                        path,
                    )
                    repo.scm.add([path])

        hydra_sweep = any(
            x.is_sweep_override()
            for param_file in path_overrides
            for x in to_hydra_overrides(path_overrides[param_file])
        )

        if hydra_sweep and not queue:
            raise InvalidArgumentError(
                "Sweep overrides can't be used without `--queue`"
            )
    else:
        path_overrides = {}

    hydra_enabled = repo.config.get("hydra", {}).get("enabled", False)
    hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE
    if hydra_enabled and hydra_output_file not in path_overrides:
        # Force `_update_params` even if `--set-param` was not used
        path_overrides[hydra_output_file] = []

    if not queue:
        return repo.experiments.reproduce_one(
            targets=targets,
            params=path_overrides,
            tmp_dir=tmp_dir,
            copy_paths=copy_paths,
            message=message,
            **kwargs,
        )

    if hydra_sweep:
        from dvc.utils.hydra import get_hydra_sweeps

        sweeps = get_hydra_sweeps(path_overrides)
        name_prefix = kwargs.get("name")
    else:
        sweeps = [path_overrides]

    for idx, sweep_overrides in enumerate(sweeps):
        if hydra_sweep and name_prefix is not None:
            kwargs["name"] = f"{name_prefix}-{idx+1}"
        queue_entry = repo.experiments.queue_one(
            repo.experiments.celery_queue,
            targets=targets,
            params=sweep_overrides,
            copy_paths=copy_paths,
            message=message,
            **kwargs,
        )
        if sweep_overrides:
            ui.write(f"Queueing with overrides '{sweep_overrides}'.")
        name = queue_entry.name or queue_entry.stash_rev[:7]
        ui.write(f"Queued experiment '{name}' for future execution.")

    return {}