stellargraph/stellargraph

View on GitHub
demos/ensembles/ensemble-node-classification-example.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Ensemble models for node classification\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/ensembles/ensemble-node-classification-example.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/ensembles/ensemble-node-classification-example.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of `stellargraph`'s `Ensemble` class for node attribute inference using the Cora and Pubmed-Diabetes citation datasets.\n",
    "\n",
    "The `Ensemble` class brings ensemble learning to `stellargraph`'s graph neural network models, e.g., `GraphSAGE` and `GCN`, quantifying prediction variance and potentially improving prediction accuracy. \n",
    "\n",
    "**References**\n",
    "\n",
    "1. Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec arXiv:1706.02216 \n",
    "[cs.SI], 2017.\n",
    "\n",
    "\n",
    "2. Semi-Supervised Classification with Graph Convolutional Networks. T. Kipf, M. Welling. ICLR 2017. arXiv:1609.02907 \n",
    "\n",
    "\n",
    "3. Graph Attention Networks. P. Veličković et al. ICLR 2018"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import itertools\n",
    "import os\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import GraphSAGENodeGenerator, FullBatchNodeGenerator\n",
    "from stellargraph.layer import GraphSAGE, GCN, GAT\n",
    "from stellargraph import globalvar\n",
    "\n",
    "from stellargraph.ensemble import Ensemble, BaggingEnsemble\n",
    "\n",
    "from tensorflow.keras import layers, optimizers, losses, metrics, Model, models\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading the network data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "The PubMed Diabetes dataset consists of 19717 scientific publications from PubMed database pertaining to diabetes classified into one of three classes. The citation network consists of 44338 links. Each publication in the dataset is described by a TF/IDF weighted word vector from a dictionary which consists of 500 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(HTML(datasets.Cora().description))\n",
    "display(HTML(datasets.PubMedDiabetes().description))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "First, we select the dataset to use, either Cora or Pubmed-Diabetes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cora = True  # Select the dataset; if False, then Pubmed-Diabetes dataset is used.\n",
    "if use_cora:\n",
    "    dataset = datasets.Cora()\n",
    "else:\n",
    "    dataset = datasets.PubMedDiabetes()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "Load the graph data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [],
   "source": [
    "G, labels = dataset.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "We aim to train a graph-ML model that will predict the \"subject\" or \"label\" attribute on the nodes depending on the selected dataset. These subjects are one of 7 or 3 categories for Cora and PubMed-Diabetes respectively:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Genetic_Algorithms', 'Neural_Networks', 'Case_Based', 'Rule_Learning', 'Probabilistic_Methods', 'Reinforcement_Learning', 'Theory'}\n"
     ]
    }
   ],
   "source": [
    "# Print the class names for the selected dataset\n",
    "print(set(labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "### Splitting the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "For machine learning we want to take a subset of the nodes for training, and use the rest for validation and testing. We'll use scikit-learn again to do this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels, test_labels = model_selection.train_test_split(\n",
    "    labels, train_size=0.2, test_size=None, stratify=labels, random_state=42,  # 140\n",
    ")\n",
    "val_labels, test_labels = model_selection.train_test_split(\n",
    "    test_labels,\n",
    "    train_size=0.2,  # 500,\n",
    "    test_size=None,\n",
    "    stratify=test_labels,\n",
    "    random_state=100,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "### Converting to numeric arrays"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "For our categorical target, we will use one-hot vectors that will be fed into a soft-max Keras layer during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_encoding = preprocessing.LabelBinarizer()\n",
    "\n",
    "train_targets = target_encoding.fit_transform(train_labels)\n",
    "val_targets = target_encoding.transform(val_labels)\n",
    "test_targets = target_encoding.transform(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "### Specify global parameters\n",
    "\n",
    "Here we specify some parameters that control the type of model we are going to use. For example, we specify the base model type, e.g., GCN, GraphSAGE, etc, and the number of estimators in the ensemble as well as model-specific parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "22",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "model_type = \"graphsage\"  # Can be either gcn, gat, or graphsage\n",
    "use_bagging = (\n",
    "    True  # If True, each model in the ensemble is trained on a bootstrapped sample\n",
    ")\n",
    "# of the given training data; otherwise, the same training data are used\n",
    "# for training each model.\n",
    "\n",
    "if model_type == \"graphsage\":\n",
    "    # For GraphSAGE model\n",
    "    batch_size = 50\n",
    "    num_samples = [10, 10]\n",
    "    n_estimators = 5  # The number of estimators in the ensemble\n",
    "    n_predictions = 10  # The number of predictions per estimator per query point\n",
    "    epochs = 50  # The number of training epochs\n",
    "elif model_type == \"gcn\":\n",
    "    # For GCN model\n",
    "    n_estimators = 5  # The number of estimators in the ensemble\n",
    "    n_predictions = 10  # The number of predictions per estimator per query point\n",
    "    epochs = 50  # The number of training epochs\n",
    "elif model_type == \"gat\":\n",
    "    # For GAT model\n",
    "    layer_sizes = [8, train_targets.shape[1]]\n",
    "    attention_heads = 8\n",
    "    n_estimators = 5  # The number of estimators in the ensemble\n",
    "    n_predictions = 10  # The number of predictions per estimator per query point\n",
    "    epochs = 200  # The number of training epochs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "## Creating the base graph machine learning model in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "To feed data from the graph to the Keras model we need a generator that feeds data from the graph into the model. The generators are specialized to the model and the learning task. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "For training we use only the training nodes returned from our splitter and the target values. The `shuffle=True` argument is given to the `flow` method to improve training for those generators that support shuffling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_type == \"graphsage\":\n",
    "    generator = GraphSAGENodeGenerator(G, batch_size, num_samples)\n",
    "    train_gen = generator.flow(train_labels.index, train_targets, shuffle=True)\n",
    "elif model_type == \"gcn\":\n",
    "    generator = FullBatchNodeGenerator(G, method=\"gcn\")\n",
    "    train_gen = generator.flow(\n",
    "        train_labels.index, train_targets\n",
    "    )  # does not support shuffle\n",
    "elif model_type == \"gat\":\n",
    "    generator = FullBatchNodeGenerator(G, method=\"gat\")\n",
    "    train_gen = generator.flow(\n",
    "        train_labels.index, train_targets\n",
    "    )  # does not support shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "541"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_labels.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "Now we can specify our machine learning model, we need a few more parameters for this but the parameters are model-specific."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_type == \"graphsage\":\n",
    "    base_model = GraphSAGE(\n",
    "        layer_sizes=[16, 16], generator=generator, bias=True, dropout=0.5, normalize=\"l2\"\n",
    "    )\n",
    "    x_inp, x_out = base_model.in_out_tensors()\n",
    "    prediction = layers.Dense(units=train_targets.shape[1], activation=\"softmax\")(x_out)\n",
    "elif model_type == \"gcn\":\n",
    "    base_model = GCN(\n",
    "        layer_sizes=[32, train_targets.shape[1]],\n",
    "        generator=generator,\n",
    "        bias=True,\n",
    "        dropout=0.5,\n",
    "        activations=[\"elu\", \"softmax\"],\n",
    "    )\n",
    "    x_inp, x_out = base_model.in_out_tensors()\n",
    "    prediction = x_out\n",
    "elif model_type == \"gat\":\n",
    "    base_model = GAT(\n",
    "        layer_sizes=layer_sizes,\n",
    "        attn_heads=attention_heads,\n",
    "        generator=generator,\n",
    "        bias=True,\n",
    "        in_dropout=0.5,\n",
    "        attn_dropout=0.5,\n",
    "        activations=[\"elu\", \"softmax\"],\n",
    "    )\n",
    "    x_inp, prediction = base_model.in_out_tensors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30",
   "metadata": {},
   "source": [
    "Let's have a look at the shape of the output tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([None, 7])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "### Create a Keras model and then an Ensemble"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "Now let's create the actual Keras model with the graph inputs `x_inp` provided by the `base_model` and outputs being the predictions from the softmax layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "34",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=x_inp, outputs=prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "Next, we create the ensemble model consisting of `n_estimators` models.\n",
    "\n",
    "We are also going to specify that we want to make `n_predictions` per query point per model. These predictions will differ because of the application of `dropout` and, in the case of ensembling GraphSAGE models, the sampling of node neighbourhoods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_bagging:\n",
    "    model = BaggingEnsemble(model, n_estimators=n_estimators, n_predictions=n_predictions)\n",
    "else:\n",
    "    model = Ensemble(model, n_estimators=n_estimators, n_predictions=n_predictions)\n",
    "\n",
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.005),\n",
    "    loss=losses.categorical_crossentropy,\n",
    "    metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<stellargraph.ensemble.BaggingEnsemble at 0x1419399d0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The model is of type stellargraph.ensemble.Ensemble but has\n",
    "# a very similar interface to a Keras model\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38",
   "metadata": {},
   "source": [
    "The ensemble has `n_estimators` models. Let's have a look at the first model's layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x141725490>,\n",
       " <tensorflow.python.keras.engine.input_layer.InputLayer at 0x141725590>,\n",
       " <tensorflow.python.keras.engine.input_layer.InputLayer at 0x1416fbf10>,\n",
       " <tensorflow.python.keras.layers.core.Reshape at 0x141725a50>,\n",
       " <tensorflow.python.keras.layers.core.Reshape at 0x141928090>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x1427b83d0>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x141725a90>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x1419c8910>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x142528c50>,\n",
       " <stellargraph.layer.graphsage.MeanAggregator at 0x1093d8f90>,\n",
       " <tensorflow.python.keras.layers.core.Reshape at 0x141928ad0>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x141976610>,\n",
       " <tensorflow.python.keras.layers.core.Dropout at 0x1418e4910>,\n",
       " <stellargraph.layer.graphsage.MeanAggregator at 0x10945b690>,\n",
       " <tensorflow.python.keras.layers.core.Reshape at 0x141970dd0>,\n",
       " <tensorflow.python.keras.layers.core.Lambda at 0x14172ac90>,\n",
       " <tensorflow.python.keras.layers.core.Dense at 0x14172add0>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.layers(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Train the model, keeping track of its loss and accuracy on the training set, and its performance on the validation set during the training (e.g., for early stopping), and generalization performance of the final model on a held-out test set (we need to create another generator over the test data for this)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "41",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_gen = generator.flow(val_labels.index, val_targets)\n",
    "test_gen = generator.flow(test_labels.index, test_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "Note that the amount of time to train the ensemble is linear to `n_estimators`.\n",
    "\n",
    "Also, we are going to use early stopping by monitoring the accuracy on the validation data and stopping if the accuracy does not increase after 10 training epochs (this is the default grace value specified by the `Ensemble` class but we can set it to a different value by using `model.early_stopping_patience=20` for example.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n"
     ]
    }
   ],
   "source": [
    "if use_bagging:\n",
    "    # When using bootstrap samples to train each model in the ensemble, we must specify\n",
    "    # the IDs of the training nodes (train_data) and their corresponding target values\n",
    "    # (train_targets)\n",
    "    history = model.fit(\n",
    "        generator,\n",
    "        train_data=train_labels.index,\n",
    "        train_targets=train_targets,\n",
    "        epochs=epochs,\n",
    "        validation_data=val_gen,\n",
    "        verbose=0,\n",
    "        shuffle=False,\n",
    "        bag_size=None,\n",
    "        use_early_stopping=True,  # Enable early stopping\n",
    "        early_stopping_monitor=\"val_acc\",\n",
    "    )\n",
    "else:\n",
    "    history = model.fit(\n",
    "        train_gen,\n",
    "        epochs=epochs,\n",
    "        validation_data=val_gen,\n",
    "        verbose=0,\n",
    "        shuffle=False,\n",
    "        use_early_stopping=True,  # Enable early stopping\n",
    "        early_stopping_monitor=\"val_acc\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "44",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45",
   "metadata": {},
   "source": [
    "Now we have trained the model, let's evaluate it on the test set. Note that the `.evaluate()` method of the `Ensemble` class returns mean and standard deviation of each evaluation metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "  ['...']\n",
      "\n",
      "Test Set Metrics of the trained models:\n",
      "\tloss: 0.6872±0.0270\n",
      "\tacc: 0.8085±0.0109\n"
     ]
    }
   ],
   "source": [
    "test_metrics_mean, test_metrics_std = model.evaluate(test_gen)\n",
    "\n",
    "print(\"\\nTest Set Metrics of the trained models:\")\n",
    "for name, m, s in zip(model.metrics_names, test_metrics_mean, test_metrics_std):\n",
    "    print(\"\\t{}: {:0.4f}±{:0.4f}\".format(name, m, s))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47",
   "metadata": {},
   "source": [
    "### Making predictions with the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48",
   "metadata": {},
   "source": [
    "Now let's get the predictions for all nodes, using a new generator for all nodes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "49",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes = labels.index\n",
    "all_gen = generator.flow(all_nodes)\n",
    "all_predictions = model.predict(generator=all_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "50",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10, 2708, 7)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_predictions.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51",
   "metadata": {},
   "source": [
    "For full-batch methods, the batch dimension is 1 so we will remove any singleton dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "52",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_predictions = np.squeeze(all_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10, 2708, 7)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_predictions.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54",
   "metadata": {},
   "source": [
    "These predictions will be the output of the softmax layer, so to get final categories we'll use the `inverse_transform` method of our target attribute specification to turn these values back to the original categories"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55",
   "metadata": {},
   "source": [
    "For demonstration, we are going to select one of the nodes in the graph, and plot the ensemble's predictions for that node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "56",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_query_point = -1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57",
   "metadata": {},
   "source": [
    "The array `all_predictions` has dimensionality $MxKxNxF$ where $M$ is the number of estimators in the ensemble (`n_estimators`); $K$ is the number of predictions per query point per estimator (`n_predictions`); $N$ is the number of query points (`len(all_predictions)`); and $F$ is the output dimensionality of the specified layer determined by the shape of the output layer.\n",
    "\n",
    "Since we are only interested in the predictions for a single query node, e.g., `selected_query_point`, we are going to slice the array to extract them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10, 7)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Select the predictions for the point specified by selected_query_point\n",
    "qp_predictions = all_predictions[:, :, selected_query_point, :]\n",
    "# The shape should be n_estimators x n_predictions x size_output_layer\n",
    "qp_predictions.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59",
   "metadata": {},
   "source": [
    "Next, to facilitate plotting the predictions using either a density plot or a box plot, we are going to reshape `qp_predictions` to $R\\times F$ where $R$ is equal to $M\\times K$ as above and $F$ is the output dimensionality of the output layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 7)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qp_predictions = qp_predictions.reshape(\n",
    "    np.product(qp_predictions.shape[0:-1]), qp_predictions.shape[-1]\n",
    ")\n",
    "qp_predictions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 'Case_Based',\n",
       " 1: 'Genetic_Algorithms',\n",
       " 2: 'Neural_Networks',\n",
       " 3: 'Probabilistic_Methods',\n",
       " 4: 'Reinforcement_Learning',\n",
       " 5: 'Rule_Learning',\n",
       " 6: 'Theory'}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inv_subject_mapper = {k: v for k, v in enumerate(target_encoding.classes_)}\n",
    "inv_subject_mapper"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62",
   "metadata": {},
   "source": [
    "We'd like to assess the ensemble's confidence in its predictions in order to decide if we can trust them or not. Utilising density plots, we can visually inspect the ensemble's distribution of prediction probabilities for a node's label.\n",
    "\n",
    "This is better demonstrated if the ensemble's base mode is `GraphSAGE` because the predictions of the base model vary most (when compared to GCN and GAT) due to the random sampling of node neighbours during prediction in addition to the inherent stocasticity of the ensemble itself.\n",
    "\n",
    "If the density plot for the predicted node label is well separated from those of the other labels with little overlap then we can be confident trusting the model's prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "if model_type not in [\"gcn\", \"gat\"]:\n",
    "    fig, ax = plt.subplots(figsize=(12, 6))\n",
    "    for i in range(qp_predictions.shape[1]):\n",
    "        sns.kdeplot(data=qp_predictions[:, i].reshape((-1,)), label=inv_subject_mapper[i])\n",
    "    plt.xlabel(\"Predicted Probability\")\n",
    "    plt.title(\"Density plots of predicted probabilities for each subject\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64",
   "metadata": {},
   "source": [
    "An alternative and possibly more informative view of the distribution of node predictions is a box plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'Subject')"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "ax.boxplot(x=qp_predictions)\n",
    "ax.set_xticklabels(target_encoding.classes_)\n",
    "ax.tick_params(axis=\"x\", rotation=45)\n",
    "if model_type == \"graphsage\":\n",
    "    y = np.argmax(target_encoding.transform(labels), axis=1)\n",
    "elif model_type == \"gcn\" or model_type == \"gat\":\n",
    "    y = np.argmax(target_encoding.transform(labels.reindex(G.nodes())), axis=1)\n",
    "plt.title(f\"Correct {target_encoding.classes_[y[selected_query_point]]}\")\n",
    "plt.ylabel(\"Predicted Probability\")\n",
    "plt.xlabel(\"Subject\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66",
   "metadata": {},
   "source": [
    "The above example shows that the ensemble predicts the correct node label with high confidence so we can trust its prediction.\n",
    "\n",
    "(Note that due to the stochastic nature of training neural network algorithms, the above conclusion may not be valid if you re-run the notebook; however, the general conclusion that the use of ensemble learning can be used to quantify the model's uncertainty about its predictions still holds.)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67",
   "metadata": {},
   "source": [
    "## Node embeddings\n",
    "\n",
    "Evaluate node embeddings as activations of the output of one of the graph convolutional or aggregation layers in the ensemble model, and visualise them, coloring nodes by their subject label.\n",
    "\n",
    "You can find the index of the layer of interest by calling the `Ensemble` class's method `layers`, e.g., `model.layers()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "68",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_type == \"graphsage\":\n",
    "    # For GraphSAGE, we are going to use the output activations of the second GraphSAGE layer\n",
    "    # as the node embeddings\n",
    "    emb = model.predict(\n",
    "        generator=generator, predict_data=labels.index, output_layer=-4\n",
    "    )  # this selects the output layer\n",
    "elif model_type == \"gcn\" or model_type == \"gat\":\n",
    "    # For GCN and GAT, we are going to use the output activations of the first GCN or Graph\n",
    "    # Attention layer as the node embeddings\n",
    "    emb = model.predict(\n",
    "        generator=generator, predict_data=labels.index, output_layer=6\n",
    "    )  # this selects the output layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69",
   "metadata": {},
   "source": [
    "The array `emb` has dimensionality $MxKxNxF$ (or $MxKx1xNxF$for full batch methods) where $M$ is the number of estimators in the ensemble (`n_estimators`); $K$ is the number of predictions per query point per estimator (`n_predictions`); $N$ is the number of query points (`len(node_data.index)`); and $F$ is the output dimensionality of the specified layer determined by the shape of the readout layer as specified above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10, 2708, 1, 16)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10, 2708, 16)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb = np.squeeze(emb)\n",
    "emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72",
   "metadata": {},
   "source": [
    "Next we are going to average the predictions over the number of models and the number of predictions per query point. \n",
    "\n",
    "The dimensionality of the array will then be **NxF** where N is the number of points to predict (equal to the number of nodes in the graph for this example) and F is the dimensionality of the embeddings that depends on the output shape of the readout layer as specified above.\n",
    "\n",
    "Note that we could have achieved the same by specifying `summarise=True` in the call to the method `predict` above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "73",
   "metadata": {},
   "outputs": [],
   "source": [
    "emb = np.mean(emb, axis=(0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2708, 16)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75",
   "metadata": {},
   "source": [
    "Project the embeddings to 2d using either TSNE or PCA transform, and visualise, coloring nodes by their subject label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "76",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = emb\n",
    "y = np.argmax(target_encoding.transform(labels), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "77",
   "metadata": {},
   "outputs": [],
   "source": [
    "if X.shape[1] > 2:\n",
    "    transform = TSNE  # PCA\n",
    "\n",
    "    trans = transform(n_components=2)\n",
    "    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=labels.index)\n",
    "    emb_transformed[\"label\"] = y\n",
    "else:\n",
    "    emb_transformed = pd.DataFrame(X, index=labels.index)\n",
    "    emb_transformed = emb_transformed.rename(columns={\"0\": 0, \"1\": 1})\n",
    "    emb_transformed[\"label\"] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x504 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alpha = 0.7\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.scatter(\n",
    "    emb_transformed[0],\n",
    "    emb_transformed[1],\n",
    "    c=emb_transformed[\"label\"].astype(\"category\"),\n",
    "    cmap=\"jet\",\n",
    "    alpha=alpha,\n",
    ")\n",
    "ax.set(aspect=\"equal\", xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
    "plt.title(\n",
    "    \"{} visualization of {} embeddings for cora dataset\".format(\n",
    "        model_type, transform.__name__\n",
    "    )\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "79",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "80",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/ensembles/ensemble-node-classification-example.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/ensembles/ensemble-node-classification-example.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}