zincware/ZnTrack

View on GitHub
examples/docs/06_named_nodes.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ac2ff6f8-c61e-4dc7-9e9f-d0ba103540ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import zntrack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c7c9d484-81b6-4269-b506-c25fd5efe471",
   "metadata": {},
   "outputs": [],
   "source": [
    "zntrack.config.nb_name = \"06_named_nodes.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d883b41b-61f2-448f-b935-52dec6baaa87",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from zntrack.utils import cwd_temp_dir\n",
    "\n",
    "temp_dir = cwd_temp_dir()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "02b1bed1-bf12-47bb-8d60-2b40575cdb63",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialized empty Git repository in /tmp/tmpeeo_o0hl/.git/\n",
      "Initialized DVC repository.\n",
      "\n",
      "You can now commit the changes to git.\n",
      "\n",
      "+---------------------------------------------------------------------+\n",
      "|                                                                     |\n",
      "|        DVC has enabled anonymous aggregate usage analytics.         |\n",
      "|     Read the analytics documentation (and how to opt-out) here:     |\n",
      "|             <https://dvc.org/doc/user-guide/analytics>              |\n",
      "|                                                                     |\n",
      "+---------------------------------------------------------------------+\n",
      "\n",
      "What's next?\n",
      "------------\n",
      "- Check out the documentation: <https://dvc.org/doc>\n",
      "- Get help and share ideas: <https://dvc.org/chat>\n",
      "- Star us on GitHub: <https://github.com/iterative/dvc>\n"
     ]
    }
   ],
   "source": [
    "!git init\n",
    "\n",
    "!dvc init"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d20eda47-4757-4de5-b920-1db19ee0ef37",
   "metadata": {},
   "source": [
    "# Named Nodes\n",
    "Named Nodes allow us to use the same Node multiple times in a single graph at e.g. different steps. Therefore, we can pass a `name` argument to the `__init__` of our Node.\n",
    "\n",
    "<blockquote>Notice that this is one of only very few scenarios where we want to pass an argument directly to the `__init__`</blockquote>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6a704798-b67a-4863-b96c-60cdeb65d559",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HelloWorld(zntrack.Node):\n",
    "    inputs = zntrack.zn.params()\n",
    "    outputs = zntrack.zn.outs()\n",
    "\n",
    "    def run(self):\n",
    "        self.outputs = self.inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2a04b805-bbda-49ca-97ef-efb761f67165",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running DVC command: 'stage add --name HelloWorld --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating 'dvc.yaml'\n",
      "Adding stage 'HelloWorld' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml nodes/HelloWorld/.gitignore\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
      "Submit issues to https://github.com/zincware/ZnTrack.\n",
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n",
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n",
      "Running DVC command: 'stage add --name Test01 --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding stage 'Test01' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add nodes/Test01/.gitignore dvc.yaml\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n",
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n",
      "Running DVC command: 'stage add --name Test02 --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding stage 'Test02' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml nodes/Test02/.gitignore\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running stage 'Test02':\n",
      "> zntrack run src.HelloWorld.HelloWorld --name Test02\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating lock file 'dvc.lock'\n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "Running stage 'Test01':\n",
      "> zntrack run src.HelloWorld.HelloWorld --name Test01\n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "Running stage 'HelloWorld':\n",
      "> zntrack run src.HelloWorld.HelloWorld --name HelloWorld\n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.lock\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n",
      "Use `dvc push` to send your updates to remote storage.\n"
     ]
    }
   ],
   "source": [
    "with zntrack.Project() as project:\n",
    "    hw1 = HelloWorld(inputs=3)\n",
    "    hw2 = HelloWorld(inputs=17, name=\"Test01\")\n",
    "    hw3 = HelloWorld(inputs=42, name=\"Test02\")\n",
    "project.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2b99fa5a-bc72-4834-8f75-98d79be58e4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+ \n",
      "| HelloWorld | \n",
      "+------------+ \n",
      "+--------+ \n",
      "| Test01 | \n",
      "+--------+ \n",
      "+--------+ \n",
      "| Test02 | \n",
      "+--------+ \n"
     ]
    }
   ],
   "source": [
    "!dvc dag"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8cbb47f-8a5b-4f56-9408-4b365e487ba8",
   "metadata": {},
   "source": [
    "We can now also build a Node that depends on multiple of the same Nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dc84f265-bffb-4e32-abf3-de7483a5d51d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FindMaximum(zntrack.Node):\n",
    "    deps = zntrack.zn.deps()\n",
    "    maximum = zntrack.zn.outs()\n",
    "\n",
    "    def run(self):\n",
    "        self.maximum = 0\n",
    "        for node in self.deps:\n",
    "            if node.outputs > self.maximum:\n",
    "                self.maximum = node.outputs\n",
    "                print(f\"New maximum found {node.outputs}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7612adfd-b00b-4026-a586-def7bf37f91f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running DVC command: 'stage add --name HelloWorld --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Modifying stage 'HelloWorld' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n",
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n",
      "Running DVC command: 'stage add --name Test01 --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Modifying stage 'Test01' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n",
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n",
      "Running DVC command: 'stage add --name Test02 --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Modifying stage 'Test02' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n",
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n",
      "Running DVC command: 'stage add --name FindMaximum --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding stage 'FindMaximum' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml nodes/FindMaximum/.gitignore\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Converting notebook 06_named_nodes.ipynb to script\n",
      "/data/fzills/miniconda3/envs/zntrack/lib/python3.10/site-packages/nbformat/__init__.py:93: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
      "  validate(nb)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stage 'HelloWorld' didn't change, skipping\n",
      "Stage 'Test01' didn't change, skipping\n",
      "Stage 'Test02' didn't change, skipping\n",
      "Running stage 'FindMaximum':\n",
      "> zntrack run src.FindMaximum.FindMaximum --name FindMaximum\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NbConvertApp] Writing 1821 bytes to 06_named_nodes.py\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New maximum found 3.\n",
      "New maximum found 17.\n",
      "New maximum found 42.\n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.lock\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n",
      "Use `dvc push` to send your updates to remote storage.\n"
     ]
    }
   ],
   "source": [
    "with project:\n",
    "    max_node = FindMaximum(deps=[hw1, hw2, hw3])\n",
    "project.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "73562b4c-8240-4e2d-807f-b328aa315d6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+          +--------+          +--------+ \n",
      "| HelloWorld |          | Test01 |          | Test02 | \n",
      "+------------+**        +--------+       ***+--------+ \n",
      "                ***          *        ***              \n",
      "                   ****     *     ****                 \n",
      "                       **   *   **                     \n",
      "                    +-------------+                    \n",
      "                    | FindMaximum |                    \n",
      "                    +-------------+                    \n"
     ]
    }
   ],
   "source": [
    "!dvc dag"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97067bfd-a085-4125-8adf-29e544c7f18a",
   "metadata": {},
   "source": [
    "Using this combined Node we can e.g. find the maximum of the generated values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a1455d7b-d53f-4d36-a539-cb128e316e95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42\n"
     ]
    }
   ],
   "source": [
    "max_node.load()\n",
    "print(max_node.maximum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New maximum found 3.\n",
      "New maximum found 17.\n",
      "New maximum found 42.\n"
     ]
    }
   ],
   "source": [
    "# Running it manually to highlight the print statements\n",
    "FindMaximum.from_rev().run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7b59179c-0b15-45a6-97b1-b2d6cc1ebce3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_dir.cleanup()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}