iterative/dvc

View on GitHub
dvc/repo/add.py

Summary

Maintainability
A
1 hr
Test Coverage
import os
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, NamedTuple, Optional, Union

from dvc.exceptions import (
    CacheLinkError,
    DvcException,
    OutputDuplicationError,
    OutputNotFoundError,
    OverlappingOutputPathsError,
)
from dvc.repo.scm_context import scm_context
from dvc.ui import ui
from dvc.utils import glob_targets, resolve_output, resolve_paths

from . import locked

if TYPE_CHECKING:
    from dvc.repo import Repo
    from dvc.stage import Stage
    from dvc.types import StrOrBytesPath


class StageInfo(NamedTuple):
    stage: "Stage"
    output_exists: bool


def find_targets(
    targets: Union["StrOrBytesPath", Iterator["StrOrBytesPath"]], glob: bool = False
) -> list[str]:
    if isinstance(targets, (str, bytes, os.PathLike)):
        targets_list = [os.fsdecode(targets)]
    else:
        targets_list = [os.fsdecode(target) for target in targets]
    return glob_targets(targets_list, glob=glob)


PIPELINE_TRACKED_UPDATE_FMT = (
    "cannot update {out!r}: overlaps with an output of {stage} in '{path}'.\n"
    "Run the pipeline or use 'dvc commit' to force update it."
)


def get_or_create_stage(
    repo: "Repo",
    target: str,
    out: Optional[str] = None,
    to_remote: bool = False,
    force: bool = False,
) -> StageInfo:
    if out:
        target = resolve_output(target, out, force=force)
    path, wdir, out = resolve_paths(repo, target, always_local=to_remote and not out)

    try:
        (out_obj,) = repo.find_outs_by_path(target, strict=False)
        stage = out_obj.stage
        if not stage.is_data_source:
            msg = PIPELINE_TRACKED_UPDATE_FMT.format(
                out=out, stage=stage, path=stage.relpath
            )
            raise DvcException(msg)
        return StageInfo(stage, output_exists=True)
    except OutputNotFoundError:
        stage = repo.stage.create(
            single_stage=True,
            validate=False,
            fname=path,
            wdir=wdir,
            outs=[out],
            force=force,
        )
        return StageInfo(stage, output_exists=False)


OVERLAPPING_CHILD_FMT = (
    "Cannot add '{out}', because it is overlapping with other "
    "DVC tracked output: '{parent}'.\n"
    "To include '{out}' in '{parent}', run "
    "'dvc commit {parent_stage}'"
)

OVERLAPPING_PARENT_FMT = (
    "Cannot add '{parent}', because it is overlapping with other "
    "DVC tracked output: '{out}'.\n"
    "To include '{out}' in '{parent}', run "
    "'dvc remove {out_stage}' and then 'dvc add {parent}'"
)


@contextmanager
def translate_graph_error(stages: list["Stage"]) -> Iterator[None]:
    try:
        yield
    except OverlappingOutputPathsError as exc:
        if exc.parent in [o for s in stages for o in s.outs]:
            msg = OVERLAPPING_PARENT_FMT.format(
                out=exc.overlapping_out,
                parent=exc.parent,
                out_stage=exc.overlapping_out.stage.addressing,
            )
        else:
            msg = OVERLAPPING_CHILD_FMT.format(
                out=exc.overlapping_out,
                parent=exc.parent,
                parent_stage=exc.parent.stage.addressing,
            )
        raise OverlappingOutputPathsError(  # noqa: B904
            exc.parent, exc.overlapping_out, msg
        )
    except OutputDuplicationError as exc:
        raise OutputDuplicationError(  # noqa: B904
            exc.output, set(exc.stages) - set(stages)
        )


def progress_iter(stages: dict[str, StageInfo]) -> Iterator[tuple[str, StageInfo]]:
    total = len(stages)
    desc = "Adding..."
    with ui.progress(
        stages.items(), total=total, desc=desc, unit="file", leave=True
    ) as pbar:
        if total == 1:
            pbar.bar_format = desc
            pbar.refresh()

        for item, stage_info in pbar:
            if total > 1:
                pbar.set_msg(str(stage_info.stage.outs[0]))
                pbar.refresh()
            yield item, stage_info
            if total == 1:  # restore bar format for stats
                pbar.bar_format = pbar.BAR_FMT_DEFAULT


LINK_FAILURE_MESSAGE = (
    "\nSome targets could not be linked from cache to workspace.\n{}\n"
    "To re-link these targets, reconfigure cache types and then run:\n"
    "\n\tdvc checkout {}"
)


@contextmanager
def warn_link_failures() -> Iterator[list[str]]:
    link_failures: list[str] = []
    try:
        yield link_failures
    finally:
        if link_failures:
            msg = LINK_FAILURE_MESSAGE.format(
                CacheLinkError.SUPPORT_LINK,
                " ".join(link_failures),
            )
            ui.error_write(msg)


def _add_transfer(
    stage: "Stage",
    source: str,
    remote: Optional[str] = None,
    to_remote: bool = False,
    jobs: Optional[int] = None,
    force: bool = False,
) -> None:
    odb = None
    if to_remote:
        odb = stage.repo.cloud.get_remote_odb(remote, "add")
    stage.transfer(source, odb=odb, to_remote=to_remote, jobs=jobs, force=force)
    stage.dump()


def _add(stage: "Stage", source: Optional[str] = None, no_commit: bool = False) -> None:
    out = stage.outs[0]
    path = out.fs.abspath(source) if source else None
    try:
        stage.add_outs(path, no_commit=no_commit)
    except CacheLinkError:
        stage.dump()
        raise
    stage.dump()


@locked
@scm_context
def add(
    repo: "Repo",
    targets: Union["StrOrBytesPath", Iterator["StrOrBytesPath"]],
    no_commit: bool = False,
    glob: bool = False,
    out: Optional[str] = None,
    remote: Optional[str] = None,
    to_remote: bool = False,
    remote_jobs: Optional[int] = None,
    force: bool = False,
) -> list["Stage"]:
    add_targets = find_targets(targets, glob=glob)
    if not add_targets:
        return []

    stages_with_targets = {
        target: get_or_create_stage(
            repo,
            target,
            out=out,
            to_remote=to_remote,
            force=force,
        )
        for target in add_targets
    }

    stages = [stage for stage, _ in stages_with_targets.values()]
    msg = "Collecting stages from the workspace"
    with translate_graph_error(stages), ui.status(msg) as st:
        repo.check_graph(stages=stages, callback=lambda: st.update("Checking graph"))

    if to_remote or out:
        assert len(stages_with_targets) == 1, "multiple targets are unsupported"
        (source, (stage, _)) = next(iter(stages_with_targets.items()))
        _add_transfer(stage, source, remote, to_remote, jobs=remote_jobs, force=force)
        return [stage]

    with warn_link_failures() as link_failures:
        for source, (stage, output_exists) in progress_iter(stages_with_targets):
            try:
                _add(stage, source if output_exists else None, no_commit=no_commit)
            except CacheLinkError:
                link_failures.append(stage.relpath)
    return stages