dvc/repo/experiments/show.py
from collections import Counter, defaultdict
from collections.abc import Iterable, Iterator, Mapping
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union
from dvc.exceptions import InvalidArgumentError
from dvc.log import logger
from dvc.scm import Git
from dvc.ui import ui
from dvc.utils.flatten import flatten
from .collect import collect
if TYPE_CHECKING:
from dvc.compare import TabularData
from dvc.repo import Repo
from dvc.ui.table import CellT
from .serialize import ExpRange, ExpState
logger = logger.getChild(__name__)
def show(
repo: "Repo",
revs: Union[list[str], str, None] = None,
all_branches: bool = False,
all_tags: bool = False,
all_commits: bool = False,
num: int = 1,
hide_queued: bool = False,
hide_failed: bool = False,
sha_only: bool = False,
**kwargs,
) -> list["ExpState"]:
return collect(
repo,
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
num=num,
hide_queued=hide_queued,
hide_failed=hide_failed,
sha_only=sha_only,
**kwargs,
)
def tabulate(
baseline_states: Iterable["ExpState"],
fill_value: Optional[str] = "-",
error_value: str = "!",
**kwargs,
) -> tuple["TabularData", dict[str, Iterable[str]]]:
"""Return table data for experiments.
Returns:
Tuple of (table_data, data_headers)
"""
from funcy import lconcat
from funcy.seqs import flatten as flatten_list
from dvc.compare import TabularData
data_names = _collect_names(baseline_states)
metrics_names = data_names.metrics
params_names = data_names.params
deps_names = data_names.sorted_deps
headers = [
"Experiment",
"rev",
"typ",
"Created",
"parent",
"State",
"Executor",
]
names = metrics_names | params_names
counter = Counter(flatten_list([list(a.keys()) for a in names.values()]))
counter.update(headers)
metrics_headers = _normalize_headers(metrics_names, counter)
params_headers = _normalize_headers(params_names, counter)
all_headers = lconcat(headers, metrics_headers, params_headers, deps_names)
td = TabularData(all_headers, fill_value=fill_value)
td.extend(
_build_rows(
baseline_states,
all_headers=all_headers,
metrics_headers=metrics_headers,
params_headers=params_headers,
metrics_names=metrics_names,
params_names=params_names,
deps_names=deps_names,
fill_value=fill_value,
error_value=error_value,
**kwargs,
)
)
data_headers: dict[str, Iterable[str]] = {
"metrics": metrics_headers,
"params": params_headers,
"deps": deps_names,
}
return td, data_headers
def _build_rows(
baseline_states: Iterable["ExpState"],
*,
all_headers: Iterable[str],
fill_value: Optional[str],
sort_by: Optional[str] = None,
sort_order: Optional[Literal["asc", "desc"]] = None,
**kwargs,
) -> Iterator[tuple["CellT", ...]]:
for baseline in baseline_states:
row: dict[str, "CellT"] = dict.fromkeys(all_headers, fill_value)
row["Experiment"] = ""
if baseline.name:
row["rev"] = baseline.name
elif Git.is_sha(baseline.rev):
row["rev"] = baseline.rev[:7]
else:
row["rev"] = baseline.rev
row["typ"] = "baseline"
row["parent"] = ""
if baseline.data:
row["Created"] = format_time(
baseline.data.timestamp, fill_value=fill_value, **kwargs
)
row.update(_data_cells(baseline, fill_value=fill_value, **kwargs))
yield tuple(row.values())
if baseline.experiments:
if sort_by:
metrics_names: Mapping[str, Iterable[str]] = kwargs.get(
"metrics_names", {}
)
params_names: Mapping[str, Iterable[str]] = kwargs.get(
"params_names", {}
)
sort_path, sort_name, sort_type = _sort_column(
sort_by, metrics_names, params_names
)
reverse = sort_order == "desc"
experiments = _sort_exp(
baseline.experiments, sort_path, sort_name, sort_type, reverse
)
else:
experiments = baseline.experiments
for i, child in enumerate(experiments):
yield from _exp_range_rows(
child,
all_headers=all_headers,
fill_value=fill_value,
is_base=i == len(baseline.experiments) - 1,
**kwargs,
)
def _sort_column( # noqa: C901
sort_by: str,
metric_names: Mapping[str, Iterable[str]],
param_names: Mapping[str, Iterable[str]],
) -> tuple[str, str, str]:
sep = ":"
parts = sort_by.split(sep)
matches: set[tuple[str, str, str]] = set()
for split_num in range(len(parts)):
path = sep.join(parts[:split_num])
sort_name = sep.join(parts[split_num:])
if not path: # handles ':metric_name' case
sort_by = sort_name
if path in metric_names and sort_name in metric_names[path]:
matches.add((path, sort_name, "metrics"))
if path in param_names and sort_name in param_names[path]:
matches.add((path, sort_name, "params"))
if not matches:
for path in metric_names:
if sort_by in metric_names[path]:
matches.add((path, sort_by, "metrics"))
for path in param_names:
if sort_by in param_names[path]:
matches.add((path, sort_by, "params"))
if len(matches) == 1:
return matches.pop()
if len(matches) > 1:
raise InvalidArgumentError(
"Ambiguous sort column '{}' matched '{}'".format(
sort_by,
", ".join([f"{path}:{name}" for path, name, _ in matches]),
)
)
raise InvalidArgumentError(f"Unknown sort column '{sort_by}'")
def _sort_exp(
experiments: Iterable["ExpRange"],
sort_path: str,
sort_name: str,
typ: str,
reverse: bool,
) -> list["ExpRange"]:
from funcy import first
def _sort(exp_range: "ExpRange"):
exp = first(exp_range.revs)
if not exp:
return True
data = exp.data.dumpd().get(typ, {}).get(sort_path, {}).get("data", {})
val = flatten(data).get(sort_name)
return val is None, val
return sorted(experiments, key=_sort, reverse=reverse)
def _exp_range_rows(
exp_range: "ExpRange",
*,
all_headers: Iterable[str],
fill_value: Optional[str],
is_base: bool = False,
**kwargs,
) -> Iterator[tuple["CellT", ...]]:
from funcy import first
if len(exp_range.revs) > 1:
logger.debug("Returning tip commit for legacy checkpoint exp")
exp = first(exp_range.revs)
if exp:
row: dict[str, "CellT"] = dict.fromkeys(all_headers, fill_value)
row["Experiment"] = exp.name or ""
row["rev"] = exp.rev[:7] if Git.is_sha(exp.rev) else exp.rev
row["typ"] = "branch_base" if is_base else "branch_commit"
row["parent"] = ""
if exp_range.executor:
row["State"] = exp_range.executor.state.capitalize()
if exp_range.executor.name:
row["Executor"] = exp_range.executor.name.capitalize()
if exp.data:
row["Created"] = format_time(
exp.data.timestamp, fill_value=fill_value, **kwargs
)
row.update(_data_cells(exp, fill_value=fill_value, **kwargs))
yield tuple(row.values())
def _data_cells(
exp: "ExpState",
*,
metrics_headers: Iterable[str],
params_headers: Iterable[str],
metrics_names: Mapping[str, Iterable[str]],
params_names: Mapping[str, Iterable[str]],
deps_names: Iterable[str],
fill_value: Optional[str] = "-",
error_value: str = "!",
precision: Optional[int] = None,
**kwargs,
) -> Iterator[tuple[str, "CellT"]]:
def _d_cells(
d: Mapping[str, Any],
names: Mapping[str, Iterable[str]],
headers: Iterable[str],
) -> Iterator[tuple[str, "CellT"]]:
from dvc.compare import _format_field, with_value
for fname, data in d.items():
item = data.get("data", {})
item = flatten(item) if isinstance(item, dict) else {fname: item}
for name in names[fname]:
value = with_value(
item.get(name),
error_value if data.get("error") else fill_value,
)
# wrap field data in ui.rich_text, otherwise rich may
# interpret unescaped braces from list/dict types as rich
# markup tags
value = ui.rich_text(str(_format_field(value, precision)))
if name in headers:
yield name, value
else:
yield f"{fname}:{name}", value
if not exp.data:
return
yield from _d_cells(exp.data.metrics, metrics_names, metrics_headers)
yield from _d_cells(exp.data.params, params_names, params_headers)
for name in deps_names:
dep = exp.data.deps.get(name)
if dep:
yield name, dep.hash or fill_value
def format_time(
timestamp: Optional[datetime],
fill_value: Optional[str] = "-",
iso: bool = False,
**kwargs,
) -> Optional[str]:
if not timestamp:
return fill_value
if iso:
return timestamp.isoformat()
if timestamp.date() == date.today(): # noqa: DTZ011
fmt = "%I:%M %p"
else:
fmt = "%b %d, %Y"
return timestamp.strftime(fmt)
class _DataNames(NamedTuple):
# NOTE: we use nested dict instead of set for metrics/params names to
# preserve key ordering
metrics: dict[str, dict[str, Any]]
params: dict[str, dict[str, Any]]
deps: set[str]
@property
def sorted_deps(self):
return sorted(self.deps)
def update(self, other: "_DataNames"):
def _update_d(
d: dict[str, dict[str, Any]], other_d: Mapping[str, Mapping[str, Any]]
):
for k, v in other_d.items():
if k in d:
d[k].update(v)
else:
d[k] = dict(v)
_update_d(self.metrics, other.metrics)
_update_d(self.params, other.params)
self.deps.update(other.deps)
def _collect_names(exp_states: Iterable["ExpState"]) -> _DataNames:
result = _DataNames(defaultdict(dict), defaultdict(dict), set())
def _collect_d(result_d: dict[str, dict[str, Any]], data_d: dict[str, Any]):
for path, item in data_d.items():
item = item.get("data", {})
if isinstance(item, dict):
item = flatten(item)
result_d[path].update((key, None) for key in item)
for exp in exp_states:
if exp.data:
_collect_d(result.metrics, exp.data.metrics)
_collect_d(result.params, exp.data.params)
result.deps.update(exp.data.deps)
if exp.experiments:
for child in exp.experiments:
result.update(_collect_names(child.revs))
return result
def _normalize_headers(
names: Mapping[str, Mapping[str, Any]], count: Mapping[str, int]
) -> list[str]:
return [
name if count[name] == 1 else f"{path}:{name}"
for path in names
for name in names[path]
]