stellargraph/stellargraph

View on GitHub
demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Cluster GCN on Neo4j"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This example demonstrates how to run Cluster GCN on a dataset stored entirely on disk with Neo4j. Our Neo4j Cluster GCN implementation iterates through user specified graph clusters and only ever stores the edges and features of one cluster in memory at any given time. This enables Cluster GCN to be used on extremely large datasets that don't fit into memory. \n",
    "\n",
    "\n",
    "We use Cora here as an example, see [this notebook](./load-cora-into-neo4j.ipynb) for instructions on how to load the Cora dataset into Neo4j. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import stellargraph as sg\n",
    "from stellargraph.connector.neo4j import Neo4jStellarGraph\n",
    "from stellargraph.layer import GCN\n",
    "from stellargraph.mapper import ClusterNodeGenerator\n",
    "import tensorflow as tf\n",
    "import py2neo\n",
    "import os\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "\n",
    "import numpy as np\n",
    "import scipy.sparse as sps\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Connect to Neo4j\n",
    "\n",
    "First we connect to the Neo4j with `py2neo`, we then create a `Neo4jStellarGraph` object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_host = os.environ.get(\"STELLARGRAPH_NEO4J_HOST\")\n",
    "\n",
    "# Create the Neo4j Graph database object; port, user, password parameters can be add to specify location and authentication\n",
    "graph = py2neo.Graph(host=default_host)\n",
    "neo4j_sg = Neo4jStellarGraph(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "## Data labels\n",
    "\n",
    "Here we get the node labels. Cluster GCN is semi-supervised and requires labels to be specified for some nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read the node labels from a seperate file\n",
    "# note this function also returns a StellarGraph\n",
    "# which we won't be using for this demo - we only need Neo4jStellarGraph!\n",
    "_, labels = sg.datasets.Cora().load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split the node labels into train/test/val\n",
    "\n",
    "train_labels, test_labels = model_selection.train_test_split(\n",
    "    labels, train_size=140, test_size=None, stratify=labels\n",
    ")\n",
    "val_labels, test_labels = model_selection.train_test_split(\n",
    "    test_labels, train_size=500, test_size=None, stratify=test_labels\n",
    ")\n",
    "\n",
    "target_encoding = preprocessing.LabelBinarizer()\n",
    "\n",
    "train_targets = target_encoding.fit_transform(train_labels)\n",
    "val_targets = target_encoding.transform(val_labels)\n",
    "test_targets = target_encoding.transform(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "## Neo4j Clustering\n",
    "\n",
    "We use one of the Neo4j Data Science Library's community detection algorithms to split our graph into clusters for ClusterGCN.\n",
    "\n",
    "We can use:\n",
    "\n",
    "- `louvain`: https://neo4j.com/docs/graph-data-science/current/algorithms/louvain/\n",
    "- `labelPropagation`: https://neo4j.com/docs/graph-data-science/current/algorithms/label-propagation/="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters = neo4j_sg.clusters(method=\"louvain\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "## Keras sequences\n",
    "\n",
    "We now use `StellarGraph` to create Keras sequences for training, testing, and validation. Under the hood, these sequences connect to your Neo4j database and lazily load data for each cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of clusters 173\n",
      "0 cluster has size 118\n",
      "1 cluster has size 413\n",
      "2 cluster has size 167\n",
      "3 cluster has size 38\n",
      "4 cluster has size 67\n",
      "5 cluster has size 226\n",
      "6 cluster has size 18\n",
      "7 cluster has size 68\n",
      "8 cluster has size 26\n",
      "9 cluster has size 60\n",
      "10 cluster has size 83\n",
      "11 cluster has size 118\n",
      "12 cluster has size 73\n",
      "13 cluster has size 3\n",
      "14 cluster has size 22\n",
      "15 cluster has size 36\n",
      "16 cluster has size 237\n",
      "17 cluster has size 12\n",
      "18 cluster has size 1\n",
      "19 cluster has size 128\n",
      "20 cluster has size 41\n",
      "21 cluster has size 11\n",
      "22 cluster has size 60\n",
      "23 cluster has size 5\n",
      "24 cluster has size 3\n",
      "25 cluster has size 100\n",
      "26 cluster has size 27\n",
      "27 cluster has size 1\n",
      "28 cluster has size 39\n",
      "29 cluster has size 2\n",
      "30 cluster has size 18\n",
      "31 cluster has size 11\n",
      "32 cluster has size 3\n",
      "33 cluster has size 5\n",
      "34 cluster has size 8\n",
      "35 cluster has size 14\n",
      "36 cluster has size 3\n",
      "37 cluster has size 2\n",
      "38 cluster has size 51\n",
      "39 cluster has size 2\n",
      "40 cluster has size 5\n",
      "41 cluster has size 4\n",
      "42 cluster has size 10\n",
      "43 cluster has size 5\n",
      "44 cluster has size 15\n",
      "45 cluster has size 5\n",
      "46 cluster has size 1\n",
      "47 cluster has size 4\n",
      "48 cluster has size 3\n",
      "49 cluster has size 6\n",
      "50 cluster has size 1\n",
      "51 cluster has size 3\n",
      "52 cluster has size 6\n",
      "53 cluster has size 8\n",
      "54 cluster has size 2\n",
      "55 cluster has size 2\n",
      "56 cluster has size 1\n",
      "57 cluster has size 13\n",
      "58 cluster has size 1\n",
      "59 cluster has size 6\n",
      "60 cluster has size 1\n",
      "61 cluster has size 10\n",
      "62 cluster has size 10\n",
      "63 cluster has size 5\n",
      "64 cluster has size 2\n",
      "65 cluster has size 4\n",
      "66 cluster has size 12\n",
      "67 cluster has size 3\n",
      "68 cluster has size 2\n",
      "69 cluster has size 5\n",
      "70 cluster has size 1\n",
      "71 cluster has size 2\n",
      "72 cluster has size 1\n",
      "73 cluster has size 2\n",
      "74 cluster has size 2\n",
      "75 cluster has size 2\n",
      "76 cluster has size 2\n",
      "77 cluster has size 1\n",
      "78 cluster has size 2\n",
      "79 cluster has size 2\n",
      "80 cluster has size 9\n",
      "81 cluster has size 1\n",
      "82 cluster has size 2\n",
      "83 cluster has size 2\n",
      "84 cluster has size 1\n",
      "85 cluster has size 2\n",
      "86 cluster has size 2\n",
      "87 cluster has size 2\n",
      "88 cluster has size 1\n",
      "89 cluster has size 1\n",
      "90 cluster has size 2\n",
      "91 cluster has size 2\n",
      "92 cluster has size 1\n",
      "93 cluster has size 1\n",
      "94 cluster has size 2\n",
      "95 cluster has size 1\n",
      "96 cluster has size 2\n",
      "97 cluster has size 2\n",
      "98 cluster has size 2\n",
      "99 cluster has size 2\n",
      "100 cluster has size 3\n",
      "101 cluster has size 5\n",
      "102 cluster has size 5\n",
      "103 cluster has size 4\n",
      "104 cluster has size 1\n",
      "105 cluster has size 13\n",
      "106 cluster has size 1\n",
      "107 cluster has size 3\n",
      "108 cluster has size 3\n",
      "109 cluster has size 2\n",
      "110 cluster has size 2\n",
      "111 cluster has size 3\n",
      "112 cluster has size 2\n",
      "113 cluster has size 1\n",
      "114 cluster has size 1\n",
      "115 cluster has size 1\n",
      "116 cluster has size 2\n",
      "117 cluster has size 3\n",
      "118 cluster has size 2\n",
      "119 cluster has size 2\n",
      "120 cluster has size 2\n",
      "121 cluster has size 2\n",
      "122 cluster has size 12\n",
      "123 cluster has size 1\n",
      "124 cluster has size 4\n",
      "125 cluster has size 1\n",
      "126 cluster has size 2\n",
      "127 cluster has size 2\n",
      "128 cluster has size 1\n",
      "129 cluster has size 2\n",
      "130 cluster has size 1\n",
      "131 cluster has size 2\n",
      "132 cluster has size 8\n",
      "133 cluster has size 2\n",
      "134 cluster has size 2\n",
      "135 cluster has size 2\n",
      "136 cluster has size 1\n",
      "137 cluster has size 2\n",
      "138 cluster has size 1\n",
      "139 cluster has size 2\n",
      "140 cluster has size 3\n",
      "141 cluster has size 1\n",
      "142 cluster has size 2\n",
      "143 cluster has size 4\n",
      "144 cluster has size 2\n",
      "145 cluster has size 2\n",
      "146 cluster has size 2\n",
      "147 cluster has size 2\n",
      "148 cluster has size 2\n",
      "149 cluster has size 2\n",
      "150 cluster has size 2\n",
      "151 cluster has size 2\n",
      "152 cluster has size 1\n",
      "153 cluster has size 4\n",
      "154 cluster has size 1\n",
      "155 cluster has size 3\n",
      "156 cluster has size 2\n",
      "157 cluster has size 2\n",
      "158 cluster has size 1\n",
      "159 cluster has size 1\n",
      "160 cluster has size 3\n",
      "161 cluster has size 2\n",
      "162 cluster has size 1\n",
      "163 cluster has size 2\n",
      "164 cluster has size 2\n",
      "165 cluster has size 2\n",
      "166 cluster has size 2\n",
      "167 cluster has size 1\n",
      "168 cluster has size 2\n",
      "169 cluster has size 2\n",
      "170 cluster has size 1\n",
      "171 cluster has size 2\n",
      "172 cluster has size 2\n"
     ]
    }
   ],
   "source": [
    "# create the Neo4jClusterNodeGenerator\n",
    "# and the keras sequence objects\n",
    "\n",
    "generator = ClusterNodeGenerator(neo4j_sg, clusters=clusters)\n",
    "\n",
    "train_gen = generator.flow(train_labels.index, targets=train_targets)\n",
    "val_gen = generator.flow(val_labels.index, targets=val_targets)\n",
    "test_gen = generator.flow(test_labels.index, targets=test_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "## Create and train your model!\n",
    "\n",
    "Now we create and train the Cluster GCN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the model\n",
    "cluster_gcn = GCN(\n",
    "    layer_sizes=[32, 32], generator=generator, activations=[\"relu\", \"relu\"], dropout=0.5,\n",
    ")\n",
    "\n",
    "x_in, x_out = cluster_gcn.in_out_tensors()\n",
    "predictions = tf.keras.layers.Dense(units=val_targets.shape[1], activation=\"softmax\")(\n",
    "    x_out\n",
    ")\n",
    "model = tf.keras.Model(x_in, predictions)\n",
    "model.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(lr=0.01),\n",
    "    loss=tf.keras.losses.categorical_crossentropy,\n",
    "    metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n",
      "Train for 173 steps, validate for 173 steps\n",
      "Epoch 1/10\n",
      "173/173 [==============================] - 9s 50ms/step - loss: 0.4474 - acc: 0.2357 - val_loss: 0.8697 - val_acc: 0.3040\n",
      "Epoch 2/10\n",
      "173/173 [==============================] - 7s 41ms/step - loss: 0.3968 - acc: 0.3786 - val_loss: 0.6322 - val_acc: 0.6740\n",
      "Epoch 3/10\n",
      "173/173 [==============================] - 7s 42ms/step - loss: 0.2307 - acc: 0.5500 - val_loss: 0.7893 - val_acc: 0.4980\n",
      "Epoch 4/10\n",
      "173/173 [==============================] - 7s 41ms/step - loss: 0.1732 - acc: 0.6214 - val_loss: 0.7399 - val_acc: 0.5860\n",
      "Epoch 5/10\n",
      "173/173 [==============================] - 7s 42ms/step - loss: 0.1245 - acc: 0.8000 - val_loss: 0.8939 - val_acc: 0.6680\n",
      "Epoch 6/10\n",
      "173/173 [==============================] - 7s 41ms/step - loss: 0.1719 - acc: 0.8500 - val_loss: 1.3818 - val_acc: 0.5820\n",
      "Epoch 7/10\n",
      "173/173 [==============================] - 7s 41ms/step - loss: 0.1494 - acc: 0.8357 - val_loss: 0.9496 - val_acc: 0.7220\n",
      "Epoch 8/10\n",
      "173/173 [==============================] - 7s 41ms/step - loss: 0.1257 - acc: 0.8929 - val_loss: 0.7883 - val_acc: 0.7200\n",
      "Epoch 9/10\n",
      "173/173 [==============================] - 7s 40ms/step - loss: 0.1114 - acc: 0.9286 - val_loss: 0.6892 - val_acc: 0.7140\n",
      "Epoch 10/10\n",
      "173/173 [==============================] - 7s 42ms/step - loss: 0.1298 - acc: 0.8929 - val_loss: 0.7618 - val_acc: 0.7320\n"
     ]
    }
   ],
   "source": [
    "# train the model!\n",
    "history = model.fit(train_gen, validation_data=val_gen, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "173/173 [==============================] - 4s 21ms/step - loss: 1.4353 - acc: 0.7689\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[1.4352797645499302, 0.7688588]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate the model\n",
    "model.evaluate(test_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "And that's it! We've trained a graph neural network without ever loading the whole graph into memory.\n",
    "\n",
    "Please refer to [cluster-gcn-node-classification](./../../node-classification/cluster-gcn-node-classification.ipynb) for **node embedding visualization**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}