zincware/ZnTrack

View on GitHub
examples/docs/08_custom_zntrackoptions.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "# Custom ZnTrackOptions\n",
    "\n",
    "ZnTrack allows you to create a custom ZnTrackOption similar to `zn.outs`.\n",
    "ZnTrack tries to handle some standard types automatically within the `zn.outs` option, but it can be useful to write custom ones.\n",
    "In the following example we use [Atomic Simulation Environment](https://wiki.fysik.dtu.dk/ase/index.html) to store / load objects to a custom datafile."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from zntrack import config\n",
    "\n",
    "# When using ZnTrack we can write our code inside a Jupyter notebook.\n",
    "# We can make use of this functionality by setting the `nb_name` config as follows:\n",
    "config.nb_name = \"08_custom_zntrackoptions.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from zntrack.utils import cwd_temp_dir\n",
    "\n",
    "temp_dir = cwd_temp_dir()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialized empty Git repository in /tmp/tmp5twr4kp_/.git/\n",
      "Initialized DVC repository.\n",
      "\n",
      "You can now commit the changes to git.\n",
      "\n",
      "+---------------------------------------------------------------------+\n",
      "|                                                                     |\n",
      "|        DVC has enabled anonymous aggregate usage analytics.         |\n",
      "|     Read the analytics documentation (and how to opt-out) here:     |\n",
      "|             <https://dvc.org/doc/user-guide/analytics>              |\n",
      "|                                                                     |\n",
      "+---------------------------------------------------------------------+\n",
      "\n",
      "What's next?\n",
      "------------\n",
      "- Check out the documentation: <https://dvc.org/doc>\n",
      "- Get help and share ideas: <https://dvc.org/chat>\n",
      "- Star us on GitHub: <https://github.com/iterative/dvc>\n"
     ]
    }
   ],
   "source": [
    "!git init\n",
    "\n",
    "!dvc init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the `ZnTrackOption` to build our new custom options."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import zntrack\n",
    "import ase.db\n",
    "import ase.io\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class Atoms(zntrack.Field):\n",
    "    # we will save the file as dvc run --outs\n",
    "    dvc_option = \"outs\"\n",
    "    group = zntrack.FieldGroup.RESULT  # you can choose from RESULT or PARAMETER\n",
    "\n",
    "    def get_files(self, instance) -> list:\n",
    "        \"\"\"Define the filename that is passed to dvc (used if tracked=True)\"\"\"\n",
    "        # self.name is the name of the class attribute we use for this database\n",
    "        return [instance.nwd / f\"{self.name}.db\"]\n",
    "\n",
    "    def save(self, instance):\n",
    "        \"\"\"Save the values to file\"\"\"\n",
    "        # we gather the actual values using __get__\n",
    "        atoms = getattr(instance, self.name)\n",
    "        # get the file name\n",
    "        file = self.get_files(instance)[0]\n",
    "        # save the data to the file\n",
    "        with ase.db.connect(file) as db:\n",
    "            for atom in tqdm.tqdm(atoms, ncols=70, desc=f\"Writing atoms to {file}\"):\n",
    "                db.write(atom, group=instance.name)\n",
    "\n",
    "    def get_data(self, instance):\n",
    "        \"\"\"Load data with ase.db.connect from file\"\"\"\n",
    "        # get the file name\n",
    "        file = self.get_files(instance)[0]\n",
    "        # load the data\n",
    "        atoms = []\n",
    "        with ase.db.connect(file) as db:\n",
    "            for row in tqdm.tqdm(\n",
    "                db.select(), ncols=70, desc=f\"Loading atoms from {file}\"\n",
    "            ):\n",
    "                atoms.append(row.toatoms())\n",
    "        # return the data so it can be saved in __dict__\n",
    "        return atoms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have defined our custom ZnTrackOption we can use it as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AtomsClass(zntrack.Node):\n",
    "    atoms = Atoms()\n",
    "\n",
    "    def run(self):\n",
    "        self.atoms = [ase.Atoms(\"N2\", positions=[[0, 0, -1], [0, 0, 1]])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running DVC command: 'stage add --name AtomsClass --force ...'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating 'dvc.yaml'\n",
      "Adding stage 'AtomsClass' in 'dvc.yaml'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.yaml nodes/AtomsClass/.gitignore\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
      "Submit issues to https://github.com/zincware/ZnTrack.\n",
      "[NbConvertApp] Converting notebook 08_custom_zntrackoptions.ipynb to script\n",
      "[NbConvertApp] Writing 2881 bytes to 08_custom_zntrackoptions.py\n"
     ]
    }
   ],
   "source": [
    "with zntrack.Project() as project:\n",
    "    node = AtomsClass()\n",
    "project.run(repro=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running stage 'AtomsClass':                                           \n",
      "> zntrack run src.AtomsClass.AtomsClass --name AtomsClass\n",
      "Loading atoms from nodes/AtomsClass/atoms.db: 0it [00:00, ?it/s]\n",
      "Writing atoms to nodes/AtomsClass/atoms.db: 100%|█| 1/1 [00:00<00:00, \n",
      "Generating lock file 'dvc.lock'                                                 \n",
      "Updating lock file 'dvc.lock'\n",
      "\n",
      "To track the changes with git, run:\n",
      "\n",
      "\tgit add dvc.lock\n",
      "\n",
      "To enable auto staging, run:\n",
      "\n",
      "\tdvc config core.autostage true\n",
      "Use `dvc push` to send your updates to remote storage.\n"
     ]
    }
   ],
   "source": [
    "!dvc repro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 1151.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Atoms(symbols='N2', pbc=False)]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 3231.36it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[Atoms(symbols='N2', pbc=False)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node.load()\n",
    "print(node.atoms)\n",
    "# or\n",
    "AtomsClass.from_rev().atoms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "temp_dir.cleanup()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}