zntrack/utils/config.py
"""Description: Configuration File for ZnTrack."""
import contextlib
import dataclasses
import logging
import sys
import typing
from pathlib import Path
@dataclasses.dataclass(frozen=True)
class Files:
"""Important File paths for ZnTrack to work.
Notes
-----
Currently frozen because changing the value is not tested.
"""
zntrack: Path = Path("zntrack.json")
params: Path = Path("params.yaml")
dvc: Path = Path("dvc.yaml")
@dataclasses.dataclass
class Config:
"""Collection of Node configurations.
Attributes
----------
nb_name: str
Name of the JupyterNotebook, if the Nodes are defined in a Notebook
nb_class_path: Path
The path where jupyter notebooks should write the *.py
lazy: bool, default = True
Use lazy loading for Node.load(). This means that all ZnTrackOptions are only
loaded from files when they are first accessed.
allow_empty_loading: bool
Allow "Node.load(lazy=False)" even if nothing can be loaded, e.g.
zntrack.json / params.yaml does not exist or does not contain data
for the respective Node.
log_level: int, default = logging.WARNING
The log level to be used in the ZnTrack stdout logger.
The default log level (WARNING) will provide sufficient information for most
runs. If you encounter any issues you can set it to logging.INFO for more in-depth
information. DEBUG level can produce a lot of useful information for more complex
issues.
interpreter: str|Path, default = None
Set the Python interpreter to be used for the 'dvc cmd'.
If None, ZnTrack will try to automatically determine the interpreter.
Use e.g. `config.interpreter=sys.executable` to use a specific version.
Note, that changing the command will also affect your graph, and you might
not be able to use the existing cache.
dvc_api: bool, default = True
Use the `dvc.cli.main` function instead of subprocess
disable_operating_directory: bool, default = False
Global config to disable operating directory context manager.
"""
nb_name: str = None
nb_class_path: Path = Path("src")
lazy: bool = True
allow_empty_loading: bool = False
interpreter: typing.Union[str, Path] = Path(sys.executable).name
dvc_api: bool = True
disable_operating_directory: bool = False
files: Files = dataclasses.field(default_factory=Files)
_log_level: int = dataclasses.field(default=logging.INFO, init=False, repr=True)
@property
def log_level(self):
"""Get the log level."""
return self._log_level
@log_level.setter
def log_level(self, value):
"""Update the loglevel of the ZnTrack logger."""
self._log_level = value
logger = logging.getLogger("zntrack")
logger.setLevel(self._log_level)
@contextlib.contextmanager
def updated_config(self, **kwargs) -> None:
"""Temporarily update the config.
Yields
------
Environment with temporarily changed config.
"""
state = {}
for key, value in kwargs.items():
state[key] = getattr(self, key)
setattr(self, key, value)
try:
yield
finally:
for key, value in state.items():
setattr(self, key, value)
config = Config()
class DISABLE_TMP_PATH:
"""Identifier for disabling loading data into a temporary directory."""
def __init__(self) -> None:
"""Prohibit instantiation."""
raise NotImplementedError("This class can not be instantiated.")
def __repr__(self) -> str:
"""Provide better representation."""
return "DISABLE_TMP_PATH"
def __str__(self) -> str:
"""Provide better representation."""
return "DISABLE_TMP_PATH"