zincware/ZnTrack

View on GitHub
zntrack/fields/field.py

Summary

Maintainability
C
7 hrs
Test Coverage
"""The base class for all fields."""

import abc
import contextlib
import enum
import json
import logging
import pathlib
import shutil
import typing

import yaml
import zninit

from zntrack.utils import LazyOption, config

if typing.TYPE_CHECKING:
    from zntrack.core.node import Node

log = logging.getLogger(__name__)


class FieldGroup(enum.Enum):
    """Characterizes the group of a field."""

    PARAMETER = enum.auto()
    RESULT = enum.auto()


class Field(zninit.Descriptor, abc.ABC):
    """Base class for fields.

    Handles all the file I/O for the given field.

    Attributes
    ----------
    dvc_option : str
        The dvc command option for this field.

    """

    dvc_option: str = None
    group: FieldGroup = None

    @abc.abstractmethod
    def save(self, instance: "Node"):
        """Save the field to disk.

        Parameters
        ----------
        instance : Node
            The Node instance to save the field for.

        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_data(self, instance: "Node") -> any:
        """Get the value of the field from the file."""
        raise NotImplementedError

    @abc.abstractmethod
    def get_files(self, instance: "Node") -> list:
        """Get the files affected by this field.

        Parameters
        ----------
        instance : Node
            The Node instance to get the affected files for.

        Returns
        -------
        list
            The affected files.

        """
        raise NotImplementedError

    def load(self, instance: "Node", lazy: bool = None):
        """Load the field from disk.

        Parameters
        ----------
        instance : Node
            The Node instance to load the field for.
        lazy : bool, optional
            Whether to load the field lazily.
            This only applies to 'LazyField' classes.

        """
        try:
            instance.__dict__[self.name] = self.get_data(instance)
        except FileNotFoundError:
            # if something was not loaded, we set the loaded state to False
            log.warning(f"Could not load field {self.name} for node {instance.name}.")
            instance.state.loaded = False

    def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
        """Get the dvc stage add argument for this field.

        Parameters
        ----------
        instance : Node
            The Node instance to get the stage add argument for.

        Returns
        -------
        typing.List[tuple]
            The stage add argument for this field.

        """
        return [
            (f"--{self.dvc_option}", pathlib.Path(x).as_posix())
            for x in self.get_files(instance)
        ]

    def get_optional_dvc_cmd(
        self, instance: "Node", git_only_repo: bool
    ) -> typing.List[typing.List[str]]:
        """Get optional dvc commands that will be executed beside the main dvc command.

        This could be 'plots modify ...' or 'stage add --name node_helper'

        Parameters
        ----------
        instance : Node
            The Node instance to get the optional dvc commands for.
        git_only_repo : bool
            Whether the repo is a git only repo or has a dvc remote.

        Returns
        -------
        typing.List[str]
            The optional dvc commands.

        """
        return []

    def _write_value_to_config(self, value, instance: "Node", encoder=None):
        """Write the value of this field to the zntrack config file.

        Parameters
        ----------
        value: any
            The value to write to the config file.
        instance : Node
            The node instance to which this field belongs.
        encoder : json.JSONEncoder, optional
            The JSON encoder to use, by default None.

        """
        try:
            with open(config.files.zntrack, "r") as f:
                zntrack_dict = json.load(f)
        except FileNotFoundError:
            zntrack_dict = {}

        if instance.name not in zntrack_dict:
            zntrack_dict[instance.name] = {}
        zntrack_dict[instance.name][self.name] = value
        # use the __dict__ to avoid the nwd replacement
        with open(config.files.zntrack, "w") as f:
            json.dump(zntrack_dict, f, indent=4, cls=encoder)


class DataIsLazyError(Exception):
    """Exception to raise when a field is accessed that contains lazy data."""


class LazyField(Field):
    """Base class for fields that are loaded lazily."""

    def get_value_except_lazy(self, instance):
        """Get the value of the field.

        If the value is lazy, raise an Error.

        Raises
        ------
        DataIsLazyError
            If the value is lazy.

        """
        with contextlib.suppress(KeyError):
            if instance.__dict__[self.name] is LazyOption:
                raise DataIsLazyError()

        return getattr(instance, self.name, None)

    def __get__(self, instance, owner=None):
        """Load the field from disk if it is not already loaded."""
        if instance is None:
            return self
        if instance.__dict__.get(self.name) is LazyOption:
            self.load(instance, lazy=False)

        return super().__get__(instance, owner)

    def load(self, instance: "Node", lazy: bool = None):
        """Load the field from disk.

        Parameters
        ----------
        instance : Node
            The Node instance to load the field for.
        lazy : bool, optional
            Whether to load the field lazily, by default 'zntrack.config.lazy'.

        """
        if lazy in {None, True} and config.lazy:
            instance.__dict__[self.name] = LazyOption
        else:
            super().load(instance)


class PlotsMixin(Field):
    """DVC Plots Option including 'dvc plots modify' command."""

    def __init__(
        self,
        *args,
        template=None,
        x=None,
        y=None,
        x_label=None,
        y_label=None,
        title=None,
        use_global_plots: bool = True,
        **kwargs,
    ):
        """Create a DVCOption field.

        Attributes
        ----------
        use_global_plots : bool
            Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file.

        """
        super().__init__(*args, **kwargs)
        self.plots_options = {}
        self.use_global_plots = use_global_plots
        if self.use_global_plots:
            if self.dvc_option == "plots":
                self.dvc_option = "outs"
            elif self.dvc_option == "plots-no-cache":
                self.dvc_option = "outs-no-cache"
        if template is not None:
            self.plots_options["--template"] = pathlib.Path(template).as_posix()
        if x is not None:
            self.plots_options["-x"] = x
        if y is not None:
            self.plots_options["-y"] = y
        if x_label is not None:
            self.plots_options["--x-label"] = x_label
        if y_label is not None:
            self.plots_options["--y-label"] = y_label
        if title is not None:
            self.plots_options["--title"] = title

    def save(self, instance: "Node"):
        """Save plots options to dvc.yaml, if use_global_plots is True."""
        if self.plots_options.get("--template") is not None:
            template = pathlib.Path(self.plots_options["--template"]).resolve()
            if pathlib.Path.cwd() not in template.parents:
                # copy template to dvc_plots/templates if it is not in the cwd
                template_dir = pathlib.Path.cwd() / "dvc_plots" / "templates"
                template_dir.mkdir(parents=True, exist_ok=True)
                shutil.copy(template, template_dir)
                self.plots_options["--template"] = (
                    (template_dir / template.name)
                    .relative_to(pathlib.Path.cwd())
                    .as_posix()
                )

        with contextlib.suppress(NotImplementedError):
            super().save(instance=instance)
        if not self.use_global_plots:
            return

        dvc_file = config.files.dvc
        if not dvc_file.exists():
            dvc_file.write_text(yaml.safe_dump({}))
        dvc_config = yaml.safe_load(dvc_file.read_text())
        plots = dvc_config.get("plots", [])

        # remove leading "-/--"
        for key in list(self.plots_options):
            if key.startswith("--"):
                self.plots_options[key[2:]] = self.plots_options[key]
                del self.plots_options[key]
            elif key.startswith("-"):
                self.plots_options[key[1:]] = self.plots_options[key]
                del self.plots_options[key]
        # replace "-" with "_"
        for key in list(self.plots_options):
            if key.replace("-", "_") != key:
                self.plots_options[key.replace("-", "_")] = self.plots_options[key]
                del self.plots_options[key]

        for file in self.get_files(instance):
            replaced = False
            for entry in plots:  # entry: dict{filename: {x:, y:, ...}}
                if pathlib.Path(file) == pathlib.Path(next(iter(entry))):
                    entry = {pathlib.Path(file).as_posix(): self.plots_options}
                    replaced = True
            if not replaced:
                plots.append({pathlib.Path(file).as_posix(): self.plots_options})

        dvc_config["plots"] = plots
        dvc_file.write_text(yaml.dump(dvc_config))

    def get_optional_dvc_cmd(
        self, instance: "Node", git_only_repo: bool
    ) -> typing.List[typing.List[str]]:
        """Add 'dvc plots modify' to this option."""
        if not self.use_global_plots:
            cmds = []
            for file in self.get_files(instance):
                cmd = ["plots", "modify", pathlib.Path(file).as_posix()]
                for key, value in self.plots_options.items():
                    cmd.append(f"{key}")
                    cmd.append(pathlib.Path(value).as_posix())
                cmds.append(cmd)
            return cmds
        else:
            return []