stellargraph/stellargraph

View on GitHub
demos/node-classification/rgcn-node-classification.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Node classification with Relational Graph Convolutional Network (RGCN)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/rgcn-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/rgcn-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This example demonstrates how use an RGCN [1] on the AIFB dataset with stellargraph. \n",
    "\n",
    "[1] Modeling Relational Data with Graph Convolutional Networks. Thomas N. Kipf, Michael Schlichtkrull (2017). https://arxiv.org/pdf/1703.06103.pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {},
   "source": [
    "First we load the required libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdflib.extras.external_graph_libs import *\n",
    "from rdflib import Graph, URIRef, Literal\n",
    "\n",
    "import networkx as nx\n",
    "from networkx.classes.function import info\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import RelationalFullBatchNodeGenerator\n",
    "from stellargraph.layer import RGCN\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.models import Model\n",
    "\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from collections import Counter\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "## Loading the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The AIFB dataset describes the AIFB research institute in terms of its staff, research group, and publications. First used for machine learning with RDF in Bloehdorn, Stephan and Sure, York, \"Kernel Methods for Mining Instance Data in Ontologies\", The Semantic Web (2008), http://dx.doi.org/10.1007/978-3-540-76298-0_5. It contains ~8k entities, ~29k edges, and 45 different relationships or edge types. In (Bloehdorn et al 2007) the dataset was first used to predict the affiliation (i.e., research group) for people in the dataset. The dataset contains 178 members of a research group with 5 different research groups. The goal is to predict which research group a researcher belongs to."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.AIFB()\n",
    "display(HTML(dataset.description))\n",
    "G, affiliation = dataset.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarDiGraph: Directed multigraph\n",
      " Nodes: 8285, Edges: 29043\n",
      "\n",
      " Node types:\n",
      "  default: [8285]\n",
      "    Features: float32 vector, length 8285\n",
      "    Edge types: default-http://swrc.ontoware.org/ontology#abstract->default, default-http://swrc.ontoware.org/ontology#address->default, default-http://swrc.ontoware.org/ontology#author->default, default-http://swrc.ontoware.org/ontology#booktitle->default, default-http://swrc.ontoware.org/ontology#carriedOutBy->default, ... (40 more)\n",
      "\n",
      " Edge types:\n",
      "    default-http://swrc.ontoware.org/ontology#publication->default: [4163]\n",
      "    default-http://www.w3.org/1999/02/22-rdf-syntax-ns#type->default: [4124]\n",
      "    default-http://swrc.ontoware.org/ontology#author->default: [3986]\n",
      "    default-http://swrc.ontoware.org/ontology#isAbout->default: [2477]\n",
      "    default-http://swrc.ontoware.org/ontology#name->default: [1302]\n",
      "    default-http://swrc.ontoware.org/ontology#year->default: [1227]\n",
      "    default-http://swrc.ontoware.org/ontology#title->default: [1227]\n",
      "    default-http://swrc.ontoware.org/ontology#publishes->default: [1217]\n",
      "    default-http://swrc.ontoware.org/ontology#projectInfo->default: [952]\n",
      "    default-http://swrc.ontoware.org/ontology#hasProject->default: [952]\n",
      "    default-http://swrc.ontoware.org/ontology#booktitle->default: [765]\n",
      "    default-http://swrc.ontoware.org/ontology#month->default: [759]\n",
      "    default-http://swrc.ontoware.org/ontology#isWorkedOnBy->default: [571]\n",
      "    default-http://swrc.ontoware.org/ontology#pages->default: [548]\n",
      "    default-http://swrc.ontoware.org/ontology#abstract->default: [534]\n",
      "    default-http://swrc.ontoware.org/ontology#dealtWithIn->default: [357]\n",
      "    default-http://swrc.ontoware.org/ontology#member->default: [339]\n",
      "    default-http://swrc.ontoware.org/ontology#volume->default: [311]\n",
      "    default-http://swrc.ontoware.org/ontology#series->default: [298]\n",
      "    default-http://swrc.ontoware.org/ontology#homepage->default: [239]\n",
      "    ... (25 more)\n"
     ]
    }
   ],
   "source": [
    "print(G.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "The relationship 'affiliation' indicates whether a researcher is affiliated with a research group e.g. (researcher, research group, affiliation). This is used to create the one-hot labels in the `affiliation` DataFrame. These relationships are not included in `G` (nor is its inverse relationship 'employs'). The idea here is to test whether we can recover a 'missing' relationship. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "## Input preparation\n",
    "\n",
    "The nodes don't natively have features, so they've been replaced with one-hot indicators to allow the model to learn from the graph structure. We're only training on the people with affiliations, so we split that into train and test splits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_targets, test_targets = model_selection.train_test_split(\n",
    "    affiliation, train_size=0.8, test_size=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = RelationalFullBatchNodeGenerator(G, sparse=True)\n",
    "\n",
    "train_gen = generator.flow(train_targets.index, targets=train_targets)\n",
    "test_gen = generator.flow(test_targets.index, targets=test_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "## RGCN model creation and training\n",
    "\n",
    "We use stellargraph to create an RGCN object. This creates a stack of relational graph convolutional layers. We add a softmax layer to transform the features created by RGCN into class predictions and create a Keras model.  Then we train the model on the stellargraph generators.\n",
    "\n",
    "Each RGCN layer creates a weight matrix for each relationship in the graph. If `num_bases==0` these weight matrices are completely independent. If `num_bases!=0` each weight matrix is a different linear combination of the same basis matrices. This introduces parameter sharing and reduces the number of the parameters in the model.  See the paper for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "rgcn = RGCN(\n",
    "    layer_sizes=[32, 32],\n",
    "    activations=[\"relu\", \"relu\"],\n",
    "    generator=generator,\n",
    "    bias=True,\n",
    "    num_bases=20,\n",
    "    dropout=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_in, x_out = rgcn.in_out_tensors()\n",
    "predictions = Dense(train_targets.shape[-1], activation=\"softmax\")(x_out)\n",
    "model = Model(inputs=x_in, outputs=predictions)\n",
    "model.compile(\n",
    "    loss=\"categorical_crossentropy\",\n",
    "    optimizer=keras.optimizers.Adam(0.01),\n",
    "    metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "1/1 [==============================] - 27s 27s/step - loss: 1.6109 - acc: 0.2746 - val_loss: 1.5623 - val_acc: 0.3611\n",
      "Epoch 2/20\n",
      "1/1 [==============================] - 23s 23s/step - loss: 1.5564 - acc: 0.5000 - val_loss: 1.4438 - val_acc: 0.4167\n",
      "Epoch 3/20\n",
      "1/1 [==============================] - 22s 22s/step - loss: 1.4328 - acc: 0.5070 - val_loss: 1.2094 - val_acc: 0.5000\n",
      "Epoch 4/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 1.2018 - acc: 0.5141 - val_loss: 0.9568 - val_acc: 0.6389\n",
      "Epoch 5/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.8872 - acc: 0.7606 - val_loss: 0.7373 - val_acc: 0.6944\n",
      "Epoch 6/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.7686 - acc: 0.8099 - val_loss: 0.5692 - val_acc: 0.7778\n",
      "Epoch 7/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.6025 - acc: 0.8662 - val_loss: 0.4802 - val_acc: 0.8889\n",
      "Epoch 8/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.4335 - acc: 0.8944 - val_loss: 0.4364 - val_acc: 0.9444\n",
      "Epoch 9/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.3616 - acc: 0.9437 - val_loss: 0.4061 - val_acc: 0.9444\n",
      "Epoch 10/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.3286 - acc: 0.9437 - val_loss: 0.3821 - val_acc: 0.9444\n",
      "Epoch 11/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.3106 - acc: 0.9507 - val_loss: 0.3619 - val_acc: 0.9444\n",
      "Epoch 12/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.2678 - acc: 0.9437 - val_loss: 0.3498 - val_acc: 0.9167\n",
      "Epoch 13/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.2236 - acc: 0.9507 - val_loss: 0.3463 - val_acc: 0.9167\n",
      "Epoch 14/20\n",
      "1/1 [==============================] - 21s 21s/step - loss: 0.2434 - acc: 0.9296 - val_loss: 0.3552 - val_acc: 0.9167\n",
      "Epoch 15/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.2236 - acc: 0.9296 - val_loss: 0.3680 - val_acc: 0.9167\n",
      "Epoch 16/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.1783 - acc: 0.9437 - val_loss: 0.3912 - val_acc: 0.9167\n",
      "Epoch 17/20\n",
      "1/1 [==============================] - 20s 20s/step - loss: 0.1887 - acc: 0.9437 - val_loss: 0.4214 - val_acc: 0.9167\n",
      "Epoch 18/20\n",
      "1/1 [==============================] - 19s 19s/step - loss: 0.1636 - acc: 0.9437 - val_loss: 0.4550 - val_acc: 0.9167\n",
      "Epoch 19/20\n",
      "1/1 [==============================] - 18s 18s/step - loss: 0.1699 - acc: 0.9437 - val_loss: 0.4450 - val_acc: 0.9167\n",
      "Epoch 20/20\n",
      "1/1 [==============================] - 18s 18s/step - loss: 0.1848 - acc: 0.9437 - val_loss: 0.4342 - val_acc: 0.9167\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_gen, validation_data=test_gen, epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Now we assess the accuracy of our trained model on the test set - it does pretty well on this example dataset!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Metrics:\n",
      "\tloss: 0.4342\n",
      "\tacc: 0.9167\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_gen)\n",
    "print(\"\\nTest Set Metrics:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "## Node embeddings\n",
    "\n",
    "We evaluate node embeddings as the activations of the output of the last graph convolution layer in the GCN layer stack and visualise them, coloring nodes by their true subject label. We expect to see nice clusters of researchers in the node embedding space, with researchers from the same group belonging to the same cluster.\n",
    "\n",
    "To calculate the node embeddings rather than the class predictions, we create a new model with the same inputs as we used previously `x_inp` but now the output is the embeddings `x_out` rather than the predicted class. Additionally note that the weights trained previously are kept in the new model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# get embeddings for all people nodes\n",
    "all_gen = generator.flow(affiliation.index, targets=affiliation)\n",
    "embedding_model = Model(inputs=x_in, outputs=x_out)\n",
    "emb = embedding_model.predict(all_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = emb.squeeze(0)\n",
    "y = affiliation.idxmax(axis=\"columns\").astype(\"category\")\n",
    "\n",
    "if X.shape[1] > 2:\n",
    "    transform = TSNE\n",
    "\n",
    "    trans = transform(n_components=2)\n",
    "    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=affiliation.index)\n",
    "    emb_transformed[\"label\"] = y\n",
    "else:\n",
    "    emb_transformed = pd.DataFrame(X, index=affiliation.index)\n",
    "    emb_transformed = emb_transformed.rename(columns={\"0\": 0, \"1\": 1})\n",
    "    emb_transformed[\"label\"] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x504 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alpha = 0.7\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.scatter(\n",
    "    emb_transformed[0],\n",
    "    emb_transformed[1],\n",
    "    c=emb_transformed[\"label\"].cat.codes,\n",
    "    cmap=\"jet\",\n",
    "    alpha=alpha,\n",
    ")\n",
    "ax.set(aspect=\"equal\", xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
    "plt.title(\n",
    "    \"{} visualization of RGCN embeddings for AIFB dataset\".format(transform.__name__)\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {},
   "source": [
    "Aside from a slight overlap the classes are well separated despite only using 2-dimensions. This indicates that our model is performing well at clustering the researchers into the right groups."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/rgcn-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/rgcn-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}