zincware/ZnTrack

View on GitHub
zntrack/fields/meta/__init__.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Additional fields that are neither dvc/zn i/o fields."""

import json
import pathlib
import typing

import yaml
import znjson

from zntrack.fields.field import Field, FieldGroup
from zntrack.utils import config, file_io

if typing.TYPE_CHECKING:
    from zntrack import Node


class Text(Field):
    """A metadata field."""

    dvc_option: str = None
    group = FieldGroup.PARAMETER
    use_dvc_yaml: bool = False

    def get_files(self, instance) -> list:
        """Get the params.yaml file."""
        return []

    def save(self, instance):
        """Save the field to disk."""
        value = getattr(instance, self.name)
        if config.files.dvc.exists() and self.use_dvc_yaml:
            file_io.update_meta(
                file=config.files.dvc,
                node_name=instance.name,
                data={self.name: value},
            )
        else:
            # load from zntrack.json
            self._write_value_to_config(value, instance, encoder=znjson.ZnEncoder)

    def get_data(self, instance: "Node") -> any:
        """Get the value of the field from the file."""
        if config.files.dvc.exists() and self.use_dvc_yaml:
            dvc_dict = yaml.safe_load(instance.state.fs.read_text(config.files.dvc))
            return dvc_dict["stages"][instance.name]["meta"].get(self.name, None)
        else:
            # load from zntrack.json
            zntrack_dict = json.loads(
                instance.state.fs.read_text(config.files.zntrack),
            )
            return json.loads(
                json.dumps(zntrack_dict[instance.name][self.name]), cls=znjson.ZnDecoder
            )

    def get_stage_add_argument(self, instance) -> typing.List[tuple]:
        """Get the dvc command for this field."""
        return []


class Environment(Field):
    """Environment variables to export."""

    dvc_option: str = None
    group = FieldGroup.PARAMETER

    def __init__(self, *args, is_parameter: bool = False, **kwargs):
        """Initialize the field."""
        self.is_parameter = is_parameter
        super().__init__(*args, **kwargs)

    def get_files(self, instance) -> list:
        """There are no affect files."""
        return []

    def save(self, instance):
        """Save the field to disk."""
        file = pathlib.Path("env.yaml")
        try:
            context = yaml.safe_load(file.read_text())
        except FileNotFoundError:
            context = {}

        stages = context.get("stages", {})
        # TODO: when to reset the environment variables?

        node_context = stages.get(instance.name, {})
        value = getattr(instance, self.name)
        if isinstance(value, (str, dict)):
            node_context[self.name] = value
            stages[instance.name] = node_context
        elif value is None:
            return
        else:
            raise ValueError(
                f"Environment value must be a string or dict, not {type(value)}"
            )

        context["stages"] = stages
        file.write_text(yaml.safe_dump(context))

    def get_data(self, instance: "Node") -> any:
        """Get the value of the field from the file."""
        env_dict = yaml.safe_load(instance.state.fs.read_text("env.yaml"))
        return env_dict.get("stages", {}).get(instance.name, {}).get(self.name, None)

    def get_stage_add_argument(self, instance) -> typing.List[tuple]:
        """Get the dvc command for this field."""
        if self.is_parameter:
            return [("--params", f"env.yaml:stages.{instance.name}.{self.name}")]
        return []