examples/dataclasses_for_parameters.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# ZnTrack Parameters with dataclasses\n",
"\n",
"To structure the parameters used in a Node it can be useful to pass them as a dataclass. The following Notebook will illustrate a small Example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"from zntrack import config\n",
"import znjson\n",
"import dataclasses\n",
"\n",
"config.nb_name = \"dataclasses_for_parameters.ipynb\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from zntrack.utils import cwd_temp_dir\n",
"\n",
"temp_dir = cwd_temp_dir()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized empty Git repository in /tmp/tmpv_5nwynt/.git/\n",
"Initialized DVC repository.\n",
"\n",
"You can now commit the changes to git.\n",
"\n",
"+---------------------------------------------------------------------+\n",
"| |\n",
"| DVC has enabled anonymous aggregate usage analytics. |\n",
"| Read the analytics documentation (and how to opt-out) here: |\n",
"| <https://dvc.org/doc/user-guide/analytics> |\n",
"| |\n",
"+---------------------------------------------------------------------+\n",
"\n",
"What's next?\n",
"------------\n",
"- Check out the documentation: <https://dvc.org/doc>\n",
"- Get help and share ideas: <https://dvc.org/chat>\n",
"- Star us on GitHub: <https://github.com/iterative/dvc>\n"
]
}
],
"source": [
"!git init\n",
"\n",
"!dvc init"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from zntrack import Node, zn, Project\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"@dataclasses.dataclass\n",
"class Parameter:\n",
" start: int\n",
" stop: int\n",
" step: int = 1\n",
"\n",
"\n",
"class ParameterConverter(znjson.ConverterBase):\n",
" instance = Parameter\n",
" representation = \"Parameter\"\n",
"\n",
" def encode(self, obj) -> str:\n",
" return dataclasses.asdict(obj)\n",
"\n",
" def decode(self, value: str):\n",
" return Parameter(**value)\n",
"\n",
"\n",
"class ComputeRandomNumber(Node):\n",
" param: Parameter = zn.params()\n",
" number = zn.outs()\n",
"\n",
" # register the Converter\n",
" # ----------------------\n",
" # you can register the converter anywhere,\n",
" # as long as it is executed before\n",
" # the first time the node is instantiated\n",
" _ = znjson.config.register(ParameterConverter)\n",
"\n",
" def __init__(self, param: Parameter = None, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.param = param\n",
"\n",
" def run(self):\n",
" self.number = random.randrange(self.param.start, self.param.stop, self.param.step)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Running DVC command: 'stage add --name ComputeRandomNumber --force ...'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating 'dvc.yaml'\n",
"Adding stage 'ComputeRandomNumber' in 'dvc.yaml'\n",
"\n",
"To track the changes with git, run:\n",
"\n",
"\tgit add nodes/ComputeRandomNumber/.gitignore dvc.yaml\n",
"\n",
"To enable auto staging, run:\n",
"\n",
"\tdvc config core.autostage true\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"Submit issues to https://github.com/zincware/ZnTrack.\n",
"[NbConvertApp] Converting notebook dataclasses_for_parameters.ipynb to script\n",
"[NbConvertApp] Writing 1907 bytes to dataclasses_for_parameters.py\n",
"Running DVC command: 'repro'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running stage 'ComputeRandomNumber':\n",
"> zntrack run src.ComputeRandomNumber.ComputeRandomNumber --name ComputeRandomNumber\n",
"Generating lock file 'dvc.lock'\n",
"Updating lock file 'dvc.lock'\n",
"\n",
"To track the changes with git, run:\n",
"\n",
"\tgit add dvc.lock\n",
"\n",
"To enable auto staging, run:\n",
"\n",
"\tdvc config core.autostage true\n",
"Use `dvc push` to send your updates to remote storage.\n"
]
}
],
"source": [
"parameter = Parameter(start=100, stop=200)\n",
"with Project() as proj:\n",
" ComputeRandomNumber(param=parameter)\n",
"proj.run(repro=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"134\n",
"Parameter(start=100, stop=200, step=1)\n"
]
}
],
"source": [
"print(ComputeRandomNumber.from_rev().number)\n",
"print(ComputeRandomNumber.from_rev().param)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The arguments of the dataclass are saved in the `params.yaml` file and can also be modified there."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ComputeRandomNumber:\n",
" param:\n",
" _type: Parameter\n",
" value:\n",
" start: 100\n",
" step: 1\n",
" stop: 200\n",
"\n"
]
}
],
"source": [
"print(pathlib.Path(\"params.yaml\").read_text())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"temp_dir.cleanup()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}