zntrack/fields/dvc/options.py
"""Base classes for 'zntrack.<field>_path' fields."""
import json
import pathlib
import typing
import znflow.utils
import znjson
from zntrack.fields.field import Field, FieldGroup, PlotsMixin
from zntrack.utils import DISABLE_TMP_PATH, get_nwd, node_wd
if typing.TYPE_CHECKING:
from zntrack import Node
class _LoadIntoTmpPath(znflow.utils.IterableHandler):
def default(self, value, **kwargs):
instance = kwargs["instance"]
path = value
if instance.state.fs.isdir(pathlib.Path(path).as_posix()):
instance.state.fs.get(
pathlib.Path(path).as_posix(),
instance.state.tmp_path.as_posix(),
recursive=True,
)
_path = instance.state.tmp_path / pathlib.Path(path).name
else:
temp_file = instance.state.tmp_path / pathlib.Path(path).name
instance.state.fs.get(pathlib.Path(path).as_posix(), temp_file.as_posix())
_path = temp_file
if isinstance(path, pathlib.PurePath):
return _path
else:
return _path.as_posix()
class DVCOption(Field):
"""A field that is used as a dvc option.
The DVCOption field is designed for paths only.
"""
group = FieldGroup.PARAMETER
def __init__(self, *args, **kwargs):
"""Create a DVCOption field."""
if node_wd.nwd in args or node_wd.nwd in kwargs.values():
raise ValueError(
"Can not set `zntrack.nwd` as value for {self}. Please use"
" `zntrack.nwd/...` to create a path relative to the node working"
" directory."
)
self.dvc_option = kwargs.pop("dvc_option")
super().__init__(*args, **kwargs)
def get_files(self, instance: "Node") -> list:
"""Get the files affected by this field.
Parameters
----------
instance : Node
The node instance to get the files for.
Returns
-------
list of str
A list of file paths affected by this field.
"""
value = getattr(instance, self.name)
if not isinstance(value, list):
value = [value]
return [pathlib.Path(file).as_posix() for file in value if file is not None]
def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
"""Get the dvc command for this field.
Parameters
----------
instance : Node
The node instance to get the command for.
Returns
-------
list of tuple of str
A list of command-line arguments to use when adding
this field to the DVC stage.
"""
if self.dvc_option == "params":
return [
(f"--{self.dvc_option}", f"{file}:") for file in self.get_files(instance)
]
else:
return [(f"--{self.dvc_option}", file) for file in self.get_files(instance)]
def get_data(self, instance: "Node") -> any:
"""Get the value of the field from the configuration file.
Parameters
----------
instance : Node
The Node instance to get the field value for.
decoder : Any, optional
The decoder to use when parsing the configuration file, by default None.
Returns
-------
any
The value of the field from the configuration file.
"""
zntrack_dict = json.loads(
instance.state.fs.read_text("zntrack.json"),
)
return json.loads(
json.dumps(zntrack_dict[instance.name][self.name]), cls=znjson.ZnDecoder
)
def save(self, instance: "Node"):
"""Save the field to config file.
Parameters
----------
instance : Node
The node instance to save the field for.
"""
try:
value = instance.__dict__[self.name]
except KeyError:
try:
# default value is not stored in __dict__
# TODO: not sure if I like this
value = getattr(instance, self.name)
except AttributeError:
return
self._write_value_to_config(value, instance, encoder=znjson.ZnEncoder)
def __get__(self, instance: "Node", owner=None):
"""Add replacement of the nwd to the get method.
Parameters
----------
instance : Node
The node instance to get the value for.
owner : type, optional
The owner class of the descriptor, by default None
Returns
-------
Any
The value of the attribute.
"""
if instance is None:
return self
value = super().__get__(instance, owner)
path = node_wd.ReplaceNWD()(value, nwd=get_nwd(instance))
if instance.state.tmp_path not in [None, DISABLE_TMP_PATH]:
loader = _LoadIntoTmpPath()
return loader(path, instance=instance)
else:
return path
class PlotsOption(PlotsMixin, DVCOption):
"""Field with DVC plots kwargs."""