dvc/repo/experiments/pull.py
import logging
from typing import Iterable, List, Mapping, Optional, Set, Union
from funcy import group_by
from scmrepo.git.backend.base import SyncStatus
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit, iter_revs
from dvc.ui import ui
from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name
logger = logging.getLogger(__name__)
@locked
@scm_context
def pull(
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
all_commits=False,
rev: Optional[str] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
**kwargs,
) -> Iterable[str]:
exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm, git_remote))
else:
if exp_names:
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)
unresolved_exp_names = []
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
unresolved_exp_names.append(exp_name)
else:
exp_ref_set.add(exp_ref)
if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)
if rev:
rev_dict = iter_revs(repo.scm, [rev], num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)
pull_result = _pull(repo, git_remote, exp_ref_set, force)
if pull_result[SyncStatus.DIVERGED]:
diverged_refs = [ref.name for ref in pull_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the local experiment "
"re-run with '--force'."
)
if pull_cache:
pull_cache_ref = (
pull_result[SyncStatus.UP_TO_DATE]
+ pull_result[SyncStatus.SUCCESS]
)
_pull_cache(repo, pull_cache_ref, **kwargs)
return [ref.name for ref in pull_result[SyncStatus.SUCCESS]]
def _pull(
repo,
git_remote: str,
refs: Iterable["ExpRefInfo"],
force: bool,
) -> Mapping[SyncStatus, List["ExpRefInfo"]]:
refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs]
logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'")
with TqdmGit(desc="Fetching git refs") as pbar:
results: Mapping[str, SyncStatus] = repo.scm.fetch_refspecs(
git_remote,
refspec_list,
force=force,
progress=pbar.update_git,
)
def group_result(refspec):
return results[str(refspec)]
pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by(
group_result, refs
)
return pull_result
def _pull_cache(
repo,
refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]],
dvc_remote=None,
jobs=None,
run_cache=False,
odb=None,
):
if isinstance(refs, ExpRefInfo):
refs = [refs]
revs = list(exp_commits(repo.scm, refs))
logger.debug(f"dvc fetch experiment '{refs}'")
repo.fetch(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb
)