zincware/ZnTrack

View on GitHub
examples/dataclasses_for_parameters.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# ZnTrack Parameters with dataclasses\n",
    "\n",
    "To structure the parameters used in a Node it can be useful to pass them as a dataclass. The following Notebook will illustrate a small Example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "\n",
    "from zntrack import config\n",
    "import znjson\n",
    "import dataclasses\n",
    "\n",
    "config.nb_name = \"dataclasses_for_parameters.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from zntrack.utils import cwd_temp_dir\n",
    "\n",
    "temp_dir = cwd_temp_dir()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialized empty Git repository in /tmp/tmpv_5nwynt/.git/\n",
      "Initialized DVC repository.\n",
      "\n",
      "You can now commit the changes to git.\n",
      "\n",
      "+---------------------------------------------------------------------+\n",
      "|                                                                     |\n",
      "|        DVC has enabled anonymous aggregate usage analytics.         |\n",
      "|     Read the analytics documentation (and how to opt-out) here:     |\n",
      "|             <https://dvc.org/doc/user-guide/analytics>              |\n",
      "|                                                                     |\n",
      "+---------------------------------------------------------------------+\n",
      "\n",
      "What's next?\n",
      "------------\n",
      "- Check out the documentation: <https://dvc.org/doc>\n",
      "- Get help and share ideas: <https://dvc.org/chat>\n",
      "- Star us on GitHub: <https://github.com/iterative/dvc>\n"
     ]
    }
   ],
   "source": [
    "!git init\n",
    "\n",
    "!dvc init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from zntrack import Node, zn, Project\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class Parameter:\n",
    "    start: int\n",
    "    stop: int\n",
    "    step: int = 1\n",
    "\n",
    "\n",
    "class ParameterConverter(znjson.ConverterBase):\n",
    "    instance = Parameter\n",
    "    representation = \"Parameter\"\n",
    "\n",
    "    def encode(self, obj) -> str:\n",
    "        return dataclasses.asdict(obj)\n",
    "\n",
    "    def decode(self, value: str):\n",
    "        return Parameter(**value)\n",
    "\n",
    "\n",
    "class ComputeRandomNumber(Node):\n",
    "    param: Parameter = zn.params()\n",
    "    number = zn.outs()\n",
    "\n",
    "    # register the Converter\n",
    "    # ----------------------\n",
    "    # you can register the converter anywhere,\n",
    "    # as long as it is executed before\n",
    "    # the first time the node is instantiated\n",
    "    _ = znjson.config.register(ParameterConverter)\n",
    "\n",
    "    def __init__(self, param: Parameter = None, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.param = param\n",
    "\n",
    "    def run(self):\n",
    "        self.number = random.randrange(self.param.start, self.param.stop, self.param.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running DVC command: 'stage add --name ComputeRandomNumber --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating 'dvc.yaml'\n",
      "Adding stage 'ComputeRandomNumber' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add nodes/ComputeRandomNumber/.gitignore dvc.yaml\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
      "Submit issues to https://github.com/zincware/ZnTrack.\n",
      "[NbConvertApp] Converting notebook dataclasses_for_parameters.ipynb to script\n",
      "[NbConvertApp] Writing 1907 bytes to dataclasses_for_parameters.py\n",
      "Running DVC command: 'repro'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running stage 'ComputeRandomNumber':\n",
      "> zntrack run src.ComputeRandomNumber.ComputeRandomNumber --name ComputeRandomNumber\n",
      "Generating lock file 'dvc.lock'\n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.lock\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n",
      "Use `dvc push` to send your updates to remote storage.\n"
     ]
    }
   ],
   "source": [
    "parameter = Parameter(start=100, stop=200)\n",
    "with Project() as proj:\n",
    "    ComputeRandomNumber(param=parameter)\n",
    "proj.run(repro=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "134\n",
      "Parameter(start=100, stop=200, step=1)\n"
     ]
    }
   ],
   "source": [
    "print(ComputeRandomNumber.from_rev().number)\n",
    "print(ComputeRandomNumber.from_rev().param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The arguments of the dataclass are saved in the `params.yaml` file and can also be modified there."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ComputeRandomNumber:\n",
      "    param:\n",
      "        _type: Parameter\n",
      "        value:\n",
      "            start: 100\n",
      "            step: 1\n",
      "            stop: 200\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(pathlib.Path(\"params.yaml\").read_text())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_dir.cleanup()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}