examples/docs/08_custom_zntrackoptions.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
},
"tags": []
},
"source": [
"# Custom ZnTrackOptions\n",
"\n",
"ZnTrack allows you to create a custom ZnTrackOption similar to `zn.outs`.\n",
"ZnTrack tries to handle some standard types automatically within the `zn.outs` option, but it can be useful to write custom ones.\n",
"In the following example we use [Atomic Simulation Environment](https://wiki.fysik.dtu.dk/ase/index.html) to store / load objects to a custom datafile."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from zntrack import config\n",
"\n",
"# When using ZnTrack we can write our code inside a Jupyter notebook.\n",
"# We can make use of this functionality by setting the `nb_name` config as follows:\n",
"config.nb_name = \"08_custom_zntrackoptions.ipynb\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"nbsphinx": "hidden",
"tags": []
},
"outputs": [],
"source": [
"from zntrack.utils import cwd_temp_dir\n",
"\n",
"temp_dir = cwd_temp_dir()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"nbsphinx": "hidden",
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized empty Git repository in /tmp/tmp5twr4kp_/.git/\n",
"Initialized DVC repository.\n",
"\n",
"You can now commit the changes to git.\n",
"\n",
"+---------------------------------------------------------------------+\n",
"| |\n",
"| DVC has enabled anonymous aggregate usage analytics. |\n",
"| Read the analytics documentation (and how to opt-out) here: |\n",
"| <https://dvc.org/doc/user-guide/analytics> |\n",
"| |\n",
"+---------------------------------------------------------------------+\n",
"\n",
"What's next?\n",
"------------\n",
"- Check out the documentation: <https://dvc.org/doc>\n",
"- Get help and share ideas: <https://dvc.org/chat>\n",
"- Star us on GitHub: <https://github.com/iterative/dvc>\n"
]
}
],
"source": [
"!git init\n",
"\n",
"!dvc init"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the `ZnTrackOption` to build our new custom options."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import zntrack\n",
"import ase.db\n",
"import ase.io\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"class Atoms(zntrack.Field):\n",
" # we will save the file as dvc run --outs\n",
" dvc_option = \"outs\"\n",
" group = zntrack.FieldGroup.RESULT # you can choose from RESULT or PARAMETER\n",
"\n",
" def get_files(self, instance) -> list:\n",
" \"\"\"Define the filename that is passed to dvc (used if tracked=True)\"\"\"\n",
" # self.name is the name of the class attribute we use for this database\n",
" return [instance.nwd / f\"{self.name}.db\"]\n",
"\n",
" def save(self, instance):\n",
" \"\"\"Save the values to file\"\"\"\n",
" # we gather the actual values using __get__\n",
" atoms = getattr(instance, self.name)\n",
" # get the file name\n",
" file = self.get_files(instance)[0]\n",
" # save the data to the file\n",
" with ase.db.connect(file) as db:\n",
" for atom in tqdm.tqdm(atoms, ncols=70, desc=f\"Writing atoms to {file}\"):\n",
" db.write(atom, group=instance.name)\n",
"\n",
" def get_data(self, instance):\n",
" \"\"\"Load data with ase.db.connect from file\"\"\"\n",
" # get the file name\n",
" file = self.get_files(instance)[0]\n",
" # load the data\n",
" atoms = []\n",
" with ase.db.connect(file) as db:\n",
" for row in tqdm.tqdm(\n",
" db.select(), ncols=70, desc=f\"Loading atoms from {file}\"\n",
" ):\n",
" atoms.append(row.toatoms())\n",
" # return the data so it can be saved in __dict__\n",
" return atoms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have defined our custom ZnTrackOption we can use it as follows."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class AtomsClass(zntrack.Node):\n",
" atoms = Atoms()\n",
"\n",
" def run(self):\n",
" self.atoms = [ase.Atoms(\"N2\", positions=[[0, 0, -1], [0, 0, 1]])]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Running DVC command: 'stage add --name AtomsClass --force ...'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating 'dvc.yaml'\n",
"Adding stage 'AtomsClass' in 'dvc.yaml'\n",
"\n",
"To track the changes with git, run:\n",
"\n",
"\tgit add dvc.yaml nodes/AtomsClass/.gitignore\n",
"\n",
"To enable auto staging, run:\n",
"\n",
"\tdvc config core.autostage true\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"Submit issues to https://github.com/zincware/ZnTrack.\n",
"[NbConvertApp] Converting notebook 08_custom_zntrackoptions.ipynb to script\n",
"[NbConvertApp] Writing 2881 bytes to 08_custom_zntrackoptions.py\n"
]
}
],
"source": [
"with zntrack.Project() as project:\n",
" node = AtomsClass()\n",
"project.run(repro=False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running stage 'AtomsClass': \n",
"> zntrack run src.AtomsClass.AtomsClass --name AtomsClass\n",
"Loading atoms from nodes/AtomsClass/atoms.db: 0it [00:00, ?it/s]\n",
"Writing atoms to nodes/AtomsClass/atoms.db: 100%|█| 1/1 [00:00<00:00, \n",
"Generating lock file 'dvc.lock' \n",
"Updating lock file 'dvc.lock'\n",
"\n",
"To track the changes with git, run:\n",
"\n",
"\tgit add dvc.lock\n",
"\n",
"To enable auto staging, run:\n",
"\n",
"\tdvc config core.autostage true\n",
"Use `dvc push` to send your updates to remote storage.\n"
]
}
],
"source": [
"!dvc repro"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 1151.02it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Atoms(symbols='N2', pbc=False)]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 3231.36it/s]\n"
]
},
{
"data": {
"text/plain": [
"[Atoms(symbols='N2', pbc=False)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"node.load()\n",
"print(node.atoms)\n",
"# or\n",
"AtomsClass.from_rev().atoms"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"temp_dir.cleanup()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}