iterative/dvc

View on GitHub
dvc/render/converter/vega.py

Summary

Maintainability
B
6 hrs
Test Coverage
import os
from collections.abc import Iterable
from typing import Any, Optional, Union

from funcy import first, last

from dvc.exceptions import DvcException
from dvc.render import FIELD, FILENAME, INDEX, REVISION

from . import Converter


class FieldNotFoundError(DvcException):
    def __init__(self, expected_field, found_fields):
        found_str = ", ".join(found_fields)
        super().__init__(
            f"Could not find provided field ('{expected_field}') "
            f"in data fields ('{found_str}')."
        )


def _lists(blob: Union[dict, list]) -> Iterable[list]:
    if isinstance(blob, list):
        yield blob
    else:
        for _, value in blob.items():
            if isinstance(value, dict):
                yield from _lists(value)
            elif isinstance(value, list):
                yield value


def _file_field(*args):
    for axis_def in args:
        if axis_def is not None:
            for file, val in axis_def.items():
                if isinstance(val, str):
                    yield file, val
                elif isinstance(val, list):
                    for field in val:
                        yield file, field


def _find(filename: str, field: str, data_series: list[tuple[str, str, Any]]):
    for data_file, data_field, data in data_series:
        if data_file == filename and data_field == field:
            return data_file, data_field, data
    return None


def _verify_field(file2datapoints: dict[str, list], filename: str, field: str):
    if filename in file2datapoints:
        datapoint = first(file2datapoints[filename])
        if field not in datapoint:
            raise FieldNotFoundError(field, datapoint.keys())


def _get_xs(properties: dict, file2datapoints: dict[str, list[dict]]):
    x = properties.get("x")
    if x is not None and isinstance(x, dict):
        for filename, field in _file_field(x):
            _verify_field(file2datapoints, filename, field)
            yield filename, field


def _get_ys(properties, file2datapoints: dict[str, list[dict]]):
    y = properties.get("y", None)
    if y is not None:
        for filename, field in _file_field(y):
            _verify_field(file2datapoints, filename, field)
            yield filename, field


def _is_datapoints(lst: list[dict]):
    """
    check if dict keys match, datapoints with different keys mgiht lead
    to unexpected behavior
    """

    return all(isinstance(item, dict) for item in lst) and set(first(lst).keys()) == {
        key for keys in lst for key in keys
    }


def get_datapoints(file_content: dict):
    result: list[dict[str, Any]] = []
    for lst in _lists(file_content):
        if _is_datapoints(lst):
            for index, datapoint in enumerate(lst):
                if len(result) <= index:
                    result.append({})
                result[index].update(datapoint)
    return result


class VegaConverter(Converter):
    """
    Class that takes care of converting unspecified data blob
    (Dict or List[Dict]) into datapoints (List[Dict]).
    If some properties that are required by Template class are missing
    ('x', 'y') it will attempt to fill in the blanks.
    """

    def __init__(
        self,
        plot_id: str,
        data: Optional[dict] = None,
        properties: Optional[dict] = None,
    ):
        super().__init__(plot_id, data, properties)
        self.plot_id = plot_id

    def _infer_y_from_data(self):
        if self.plot_id in self.data:
            for lst in _lists(self.data[self.plot_id]):
                if all(isinstance(item, dict) for item in lst):
                    datapoint = first(lst)
                    field = last(datapoint.keys())
                    return {self.plot_id: field}
        return None

    def _infer_x_y(self):
        x = self.properties.get("x", None)
        y = self.properties.get("y", None)

        inferred_properties: dict = {}

        # Infer x.
        if isinstance(x, str):
            inferred_properties["x"] = {}
            # If multiple y files, duplicate x for each file.
            if isinstance(y, dict):
                for file, fields in y.items():
                    # Duplicate x for each y.
                    if isinstance(fields, list):
                        inferred_properties["x"][file] = [x] * len(fields)
                    else:
                        inferred_properties["x"][file] = x
            # Otherwise use plot ID as file.
            else:
                inferred_properties["x"][self.plot_id] = x

        # Infer y.
        if y is None:
            inferred_properties["y"] = self._infer_y_from_data()
        # If y files not provided, use plot ID as file.
        elif not isinstance(y, dict):
            inferred_properties["y"] = {self.plot_id: y}

        return inferred_properties

    def _find_datapoints(self):
        result = {}
        for file, content in self.data.items():
            result[file] = get_datapoints(content)

        return result

    @staticmethod
    def infer_y_label(properties):
        y_label = properties.get("y_label", None)
        if y_label is not None:
            return y_label
        y = properties.get("y", None)
        if isinstance(y, str):
            return y
        if isinstance(y, list):
            return "y"
        if not isinstance(y, dict):
            return

        fields = {field for _, field in _file_field(y)}
        if len(fields) == 1:
            return first(fields)
        return "y"

    @staticmethod
    def infer_x_label(properties):
        x_label = properties.get("x_label", None)
        if x_label is not None:
            return x_label

        x = properties.get("x", None)
        if not isinstance(x, dict):
            return INDEX

        fields = {field for _, field in _file_field(x)}
        if len(fields) == 1:
            return first(fields)
        return "x"

    def flat_datapoints(self, revision):  # noqa: C901, PLR0912
        file2datapoints, properties = self.convert()

        props_update: dict[str, Union[str, list[dict[str, str]]]] = {}

        xs = list(_get_xs(properties, file2datapoints))

        # assign "step" if no x provided
        if not xs:
            x_file, x_field = None, INDEX
        else:
            x_file, x_field = xs[0]

        num_xs = len(xs)
        multiple_x_fields = num_xs > 1 and len({x[1] for x in xs}) > 1
        props_update["x"] = "dvc_inferred_x_value" if multiple_x_fields else x_field

        ys = list(_get_ys(properties, file2datapoints))

        num_ys = len(ys)
        if num_xs > 1 and num_xs != num_ys:
            raise DvcException(
                "Cannot have different number of x and y data sources. Found "
                f"{num_xs} x and {num_ys} y data sources."
            )

        all_datapoints = []
        if ys:
            _all_y_files, _all_y_fields = list(zip(*ys))
            all_y_fields = set(_all_y_fields)
            all_y_files = set(_all_y_files)
        else:
            all_y_files = set()
            all_y_fields = set()

        # override to unified y field name if there are different y fields
        if len(all_y_fields) > 1:
            props_update["y"] = "dvc_inferred_y_value"
        else:
            props_update["y"] = first(all_y_fields)

        # get common prefix to drop from file names
        if len(all_y_files) > 1:
            common_prefix_len = len(os.path.commonpath(list(all_y_files)))
        else:
            common_prefix_len = 0

        props_update["anchors_y_definitions"] = [
            {FILENAME: _get_short_y_file(y_file, common_prefix_len), FIELD: y_field}
            for y_file, y_field in ys
        ]

        for i, (y_file, y_field) in enumerate(ys):
            if num_xs > 1:
                x_file, x_field = xs[i]
            datapoints = [{**d} for d in file2datapoints.get(y_file, [])]

            if props_update.get("y") == "dvc_inferred_y_value":
                _update_from_field(
                    datapoints,
                    field="dvc_inferred_y_value",
                    source_field=y_field,
                )

            if x_field == INDEX and x_file is None:
                _update_from_index(datapoints, INDEX)
            else:
                x_datapoints = file2datapoints.get(x_file, [])
                try:
                    _update_from_field(
                        datapoints,
                        field="dvc_inferred_x_value" if multiple_x_fields else x_field,
                        source_datapoints=x_datapoints,
                        source_field=x_field,
                    )
                except IndexError:
                    raise DvcException(  # noqa: B904
                        f"Cannot join '{x_field}' from '{x_file}' and "
                        f"'{y_field}' from '{y_file}'. "
                        "They have to have same length."
                    )

            _update_all(
                datapoints,
                update_dict={
                    REVISION: revision,
                    FILENAME: _get_short_y_file(y_file, common_prefix_len),
                    FIELD: y_field,
                },
            )

            all_datapoints.extend(datapoints)

        if not all_datapoints:
            return [], {}

        properties = properties | props_update

        return all_datapoints, properties

    def convert(self):
        """
        Convert the data. Fill necessary fields ('x', 'y') and return both
        generated datapoints and updated properties. `x`, `y` values and labels
        are inferred and always provided.
        """
        inferred_properties = self._infer_x_y()

        datapoints = self._find_datapoints()
        properties = self.properties | inferred_properties

        properties["y_label"] = self.infer_y_label(properties)
        properties["x_label"] = self.infer_x_label(properties)

        return datapoints, properties


def _get_short_y_file(y_file, common_prefix_len):
    return y_file[common_prefix_len:].strip("/\\")


def _update_from_field(
    target_datapoints: list[dict],
    field: str,
    source_datapoints: Optional[list[dict]] = None,
    source_field: Optional[str] = None,
):
    if source_datapoints is None:
        source_datapoints = target_datapoints
    if source_field is None:
        source_field = field

    if len(source_datapoints) != len(target_datapoints):
        raise IndexError("Source and target datapoints must have the same length")

    for index, datapoint in enumerate(target_datapoints):
        source_datapoint = source_datapoints[index]
        if source_field in source_datapoint:
            datapoint[field] = source_datapoint[source_field]


def _update_from_index(datapoints: list[dict], new_field: str):
    for index, datapoint in enumerate(datapoints):
        datapoint[new_field] = index


def _update_all(datapoints: list[dict], update_dict: dict):
    for datapoint in datapoints:
        datapoint.update(update_dict)