iterative/dvc

View on GitHub
dvc/repo/experiments/branch.py

Summary

Maintainability
A
1 hr
Test Coverage
from dvc.exceptions import InvalidArgumentError
from dvc.log import logger
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import RevError

from .exceptions import InvalidExpRevError
from .utils import exp_refs_by_rev

logger = logger.getChild(__name__)


@locked
@scm_context
def branch(repo, exp_rev, branch_name=None, **kwargs):
    from dvc.scm import resolve_rev

    try:
        rev = resolve_rev(repo.scm, exp_rev)
    except RevError:
        raise InvalidArgumentError(exp_rev)  # noqa: B904
    ref_info = None

    ref_infos = list(exp_refs_by_rev(repo.scm, rev))
    if len(ref_infos) == 1:
        ref_info = ref_infos[0]
    elif len(ref_infos) > 1:
        current_rev = repo.scm.get_rev()
        for info in ref_infos:
            if info.baseline_sha == current_rev:
                ref_info = info
                break
        if not ref_info:
            msg = [
                (
                    f"Ambiguous experiment name '{exp_rev}' can refer to "
                    "multiple experiments. To create a branch use a full "
                    "experiment ref:"
                ),
                "",
            ]
            msg.extend([str(info) for info in ref_infos])
            raise InvalidArgumentError("\n".join(msg))

    if not ref_info:
        raise InvalidExpRevError(exp_rev)

    branch_name = branch_name or f"{ref_info.name}-branch"

    branch_ref = f"refs/heads/{branch_name}"
    if repo.scm.get_ref(branch_ref):
        raise InvalidArgumentError(f"Git branch '{branch_name}' already exists.")

    target = repo.scm.get_ref(str(ref_info))
    repo.scm.set_ref(
        branch_ref,
        target,
        message=f"dvc: Created from experiment '{ref_info.name}'",
    )
    fmt = (
        "Git branch '%s' has been created from experiment '%s'.\n"
        "To switch to the new branch run:\n\n"
        "\tgit checkout %s"
    )
    logger.info(fmt, branch_name, ref_info.name, branch_name)