stellargraph/stellargraph

View on GitHub
demos/node-classification/sgc-node-classification.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Node classification with Simplified Graph Convolutions (SGC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/sgc-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/sgc-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of `StellarGraph`'s GCN [[2]](#refs), class for training the simplified graph convolution (SGC) model in introduced in [[1]](#refs).\n",
    "\n",
    "We show how to use `StellarGraph` to perform node attribute inference on the Cora citation network using SGC by creating a single layer GCN model with softmax activation.\n",
    "\n",
    "SGC simplifies GCN in the following ways,\n",
    "\n",
    " - It removes the non-linearities in the graph convolutional layers.\n",
    " - It smooths the node input features using powers of the normalized adjacency matrix with self loops (see [[2]](#refs)).\n",
    " - It uses a single softmax layer such that GCN is simplified to logistic regression on smoothed node features.\n",
    "\n",
    "For a graph with $N$ nodes, $F$-dimensional node features and $C$ number of classes, SGC simplifies GCN using the following logistic regression classifier with smoothed features,\n",
    "\n",
    "$\\hat{\\boldsymbol{Y}}_{SGC} = \\mathtt{softmax}(\\boldsymbol{S}^K \\boldsymbol{X}\\; \\boldsymbol{\\Theta})$\n",
    "\n",
    "where $\\hat{\\boldsymbol{Y}}_{SGC} \\in \\mathbb{R}^{N\\times C}$ are the class predictions; $\\boldsymbol{S}^K \\in \\mathbb{R}^{N\\times N}$ is the normalised graph adjacency matrix with self loops raised to the K-th power; $\\boldsymbol{X}\\in \\mathbb{R}^{N\\times F}$ are the node input features; and $\\boldsymbol{\\Theta} \\in \\mathbb{R}^{F\\times C}$ are the classifier's parameters to be learned.\n",
    "\n",
    "<a name=\"refs\"></a>\n",
    "**References**\n",
    "\n",
    "[1] Simplifying Graph Convolutional Networks. F. Wu, T. Zhang, A. H. de Souza Jr., C. Fifty, T. Yu, and K. Q. Weinberger, arXiv: 1902.07153. [link](https://arxiv.org/abs/1902.07153)\n",
    "\n",
    "[2] Semi-Supervised Classification with Graph Convolutional Networks. T. N. Kipf and M. Welling, ICLR 2016. [link](https://arxiv.org/abs/1609.02907)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import FullBatchNodeGenerator\n",
    "from stellargraph.layer import GCN\n",
    "\n",
    "from tensorflow.keras import layers, optimizers, losses, metrics, Model, regularizers\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading the CORA network\n",
    "\n",
    "Create graph using the `Cora` loader from `datasets`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "G, node_subjects = dataset.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2708, Edges: 5429\n",
      "\n",
      " Node types:\n",
      "  paper: [2708]\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5429]\n"
     ]
    }
   ],
   "source": [
    "print(G.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "We aim to train a graph-ML model that will predict the \"subject\" attribute on the nodes. These subjects are one of 7 categories:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Case_Based',\n",
       " 'Genetic_Algorithms',\n",
       " 'Neural_Networks',\n",
       " 'Probabilistic_Methods',\n",
       " 'Reinforcement_Learning',\n",
       " 'Rule_Learning',\n",
       " 'Theory'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(node_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "### Splitting the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "For machine learning we want to take a subset of the nodes for training, and use the rest for validation and testing. We'll use scikit-learn again to do this.\n",
    "\n",
    "Here we're taking 140 node labels for training, 500 for validation, and the rest for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_subjects, test_subjects = model_selection.train_test_split(\n",
    "    node_subjects, train_size=140, test_size=None, stratify=node_subjects\n",
    ")\n",
    "val_subjects, test_subjects = model_selection.train_test_split(\n",
    "    test_subjects, train_size=500, test_size=None, stratify=test_subjects\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Note using stratified sampling gives the following counts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Theory': 18,\n",
       "         'Neural_Networks': 42,\n",
       "         'Genetic_Algorithms': 22,\n",
       "         'Probabilistic_Methods': 22,\n",
       "         'Case_Based': 16,\n",
       "         'Rule_Learning': 9,\n",
       "         'Reinforcement_Learning': 11})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "\n",
    "Counter(train_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "The training set has class imbalance that might need to be compensated, e.g., via using a weighted cross-entropy loss in model training, with class weights inversely proportional to class support. However, we will ignore the class imbalance in this example, for simplicity."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "### Converting to numeric arrays"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "For our categorical target, we will use one-hot vectors that will be fed into a soft-max Keras layer during training. To do this conversion ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_encoding = preprocessing.LabelBinarizer()\n",
    "\n",
    "train_targets = target_encoding.fit_transform(train_subjects)\n",
    "val_targets = target_encoding.transform(val_subjects)\n",
    "test_targets = target_encoding.transform(test_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "We now do the same for the node attributes we want to use to predict the subject. These are the feature vectors that the Keras model will use as input. The CORA dataset contains attributes 'w_x' that correspond to words found in that publication. If a word occurs more than once in a publication the relevant attribute will be set to one, otherwise it will be zero."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "## Prepare node generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "To feed data from the graph to the Keras model we need a generator. Since SGC is a full-batch model, we use the `FullBatchNodeGenerator` class to feed node features and graph adjacency matrix to the model. \n",
    "\n",
    "For SGC, we need to tell the generator to smooth the node features by some power of the normalised adjacency matrix with self loops before multiplying by the model parameters.\n",
    "\n",
    "We achieve this by specifying `model='sgc'` and `k=2`, in this example, to use the SGC method and take the square of the adjacency matrix. For the setting `k=2` we are considering a 2-hop neighbourhood that is equivalent to a 2-layer GCN. We can set `k` larger to consider larger node neighbourhoods but this carries an associated computational penalty."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating 2-th power of normalized A...\n"
     ]
    }
   ],
   "source": [
    "generator = FullBatchNodeGenerator(G, method=\"sgc\", k=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "For training we map only the training nodes returned from our splitter and the target values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_gen = generator.flow(train_subjects.index, train_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "## Creating the SGC model in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "Now we can specify our machine learning model, we need a few more parameters for this:\n",
    "\n",
    " * the `layer_sizes` is a list of hidden feature sizes of each layer in the model. For SGC, we use a single hidden layer with output dimensionality equal to the number of classes.\n",
    " * `activations` is the activation function for the output layer. For SGC the output layer is the classification layer and for multi-class classification it should be a `softmax` activation. \n",
    " * Arguments such as `bias` and `dropout` are internal parameters of the model, execute `?GCN` for details. \n",
    " \n",
    "**Note:** The SGC model is a single layer GCN model with `softmax` activation and the full batch generator we created above that smooths the node features based on the graph structure. So, our SGC model is declared as a `StellarGraph.layer.GCN` model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgc = GCN(\n",
    "    layer_sizes=[train_targets.shape[1]],\n",
    "    generator=generator,\n",
    "    bias=True,\n",
    "    dropout=0.5,\n",
    "    activations=[\"softmax\"],\n",
    "    kernel_regularizer=regularizers.l2(5e-4),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expose the input and output tensors of the SGC model for node prediction,\n",
    "# via GCN.in_out_tensors() method:\n",
    "x_inp, predictions = sgc.in_out_tensors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "### Training the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "Now let's create the actual Keras model with the input tensors `x_inp` and output tensors being the predictions `predictions` from the final dense layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "33",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=x_inp, outputs=predictions)\n",
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.2),\n",
    "    loss=losses.categorical_crossentropy,\n",
    "    metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34",
   "metadata": {},
   "source": [
    "Train the model, keeping track of its loss and accuracy on the training set, and its generalisation performance on the validation set (we need to create another generator over the validation data for this)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "35",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_gen = generator.flow(val_subjects.index, val_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36",
   "metadata": {},
   "source": [
    "Create callbacks for early stopping (if validation accuracy stops improving) and best model checkpoint saving:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "37",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n",
    "\n",
    "if not os.path.isdir(\"logs\"):\n",
    "    os.makedirs(\"logs\")\n",
    "es_callback = EarlyStopping(\n",
    "    monitor=\"val_acc\", patience=50\n",
    ")  # patience is the number of epochs to wait before early stopping in case of no further improvement\n",
    "mc_callback = ModelCheckpoint(\n",
    "    \"logs/best_model.h5\", monitor=\"val_acc\", save_best_only=True, save_weights_only=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38",
   "metadata": {},
   "source": [
    "Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_gen,\n",
    "    epochs=50,\n",
    "    validation_data=val_gen,\n",
    "    verbose=0,\n",
    "    shuffle=False,  # this should be False, since shuffling data means shuffling the whole graph\n",
    "    callbacks=[es_callback, mc_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Plot the training history:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "Reload the saved weights of the best model found during the training (according to validation accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "43",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights(\"logs/best_model.h5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44",
   "metadata": {},
   "source": [
    "Evaluate the best model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "45",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_gen = generator.flow(test_subjects.index, test_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Metrics:\n",
      "\tloss: 0.8700\n",
      "\tacc: 0.7863\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_gen)\n",
    "print(\"\\nTest Set Metrics:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47",
   "metadata": {},
   "source": [
    "### Making predictions with the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48",
   "metadata": {},
   "source": [
    "Now let's get the predictions for all nodes:\n",
    "\n",
    "Note that the `predict` function now operates differently to the `GraphSAGE` or `HinSAGE` models\n",
    "in that if you give it a set of nodes, it will still return predictions for **all** nodes in the graph, and in a fixed order defined by the order of nodes in `X` and `A` (which is defined by the order of `G.nodes()`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "49",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes = node_subjects.index\n",
    "all_gen = generator.flow(all_nodes)\n",
    "all_predictions = model.predict(all_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50",
   "metadata": {},
   "source": [
    "Note that for full-batch methods the batch size is 1 and the predictions have shape $(1, N_{nodes}, N_{classes})$ so we we remove the batch dimension to obtain predictions of shape $(N_{nodes}, N_{classes})$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "51",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_predictions = all_predictions.squeeze()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52",
   "metadata": {},
   "source": [
    "These predictions will be the output of the softmax layer, so to get final categories we'll use the `inverse_transform` method of our target attribute specification to turn these values back to the original categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "53",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_predictions = target_encoding.inverse_transform(all_predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54",
   "metadata": {},
   "source": [
    "Let's have a look at a few:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "55",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Predicted</th>\n",
       "      <th>True</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>31336</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1061127</td>\n",
       "      <td>Rule_Learning</td>\n",
       "      <td>Rule_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1106406</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13195</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37879</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1126012</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1107140</td>\n",
       "      <td>Theory</td>\n",
       "      <td>Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1102850</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31349</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1106418</td>\n",
       "      <td>Theory</td>\n",
       "      <td>Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1123188</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1128990</td>\n",
       "      <td>Genetic_Algorithms</td>\n",
       "      <td>Genetic_Algorithms</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>109323</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>217139</td>\n",
       "      <td>Theory</td>\n",
       "      <td>Case_Based</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31353</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32083</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1126029</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1118017</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49482</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>753265</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      Predicted                    True\n",
       "31336     Probabilistic_Methods         Neural_Networks\n",
       "1061127           Rule_Learning           Rule_Learning\n",
       "1106406  Reinforcement_Learning  Reinforcement_Learning\n",
       "13195    Reinforcement_Learning  Reinforcement_Learning\n",
       "37879     Probabilistic_Methods   Probabilistic_Methods\n",
       "1126012   Probabilistic_Methods   Probabilistic_Methods\n",
       "1107140                  Theory                  Theory\n",
       "1102850         Neural_Networks         Neural_Networks\n",
       "31349     Probabilistic_Methods         Neural_Networks\n",
       "1106418                  Theory                  Theory\n",
       "1123188         Neural_Networks         Neural_Networks\n",
       "1128990      Genetic_Algorithms      Genetic_Algorithms\n",
       "109323    Probabilistic_Methods   Probabilistic_Methods\n",
       "217139                   Theory              Case_Based\n",
       "31353     Probabilistic_Methods         Neural_Networks\n",
       "32083           Neural_Networks         Neural_Networks\n",
       "1126029  Reinforcement_Learning  Reinforcement_Learning\n",
       "1118017         Neural_Networks         Neural_Networks\n",
       "49482           Neural_Networks         Neural_Networks\n",
       "753265          Neural_Networks         Neural_Networks"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame({\"Predicted\": node_predictions, \"True\": node_subjects})\n",
    "df.head(20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56",
   "metadata": {},
   "source": [
    "## Node representations\n",
    "\n",
    "Evaluate node representations as activations of the output layer and visualise them, coloring nodes by their true subject label. We expect to see nice clusters of papers in the node representation space, with papers of the same subject belonging to the same cluster.\n",
    "\n",
    "We are going to project the node representations to 2d using either TSNE or PCA transform, and visualise them, coloring nodes by their true subject label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "57",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = all_predictions\n",
    "y = np.argmax(target_encoding.transform(node_subjects), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "58",
   "metadata": {},
   "outputs": [],
   "source": [
    "if X.shape[1] > 2:\n",
    "    transform = TSNE  # PCA\n",
    "\n",
    "    trans = transform(n_components=2)\n",
    "    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=list(G.nodes()))\n",
    "    emb_transformed[\"label\"] = y\n",
    "else:\n",
    "    emb_transformed = pd.DataFrame(X, index=list(G.nodes()))\n",
    "    emb_transformed = emb_transformed.rename(columns={\"0\": 0, \"1\": 1})\n",
    "    emb_transformed[\"label\"] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x504 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alpha = 0.7\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.scatter(\n",
    "    emb_transformed[0],\n",
    "    emb_transformed[1],\n",
    "    c=emb_transformed[\"label\"].astype(\"category\"),\n",
    "    cmap=\"jet\",\n",
    "    alpha=alpha,\n",
    ")\n",
    "ax.set(aspect=\"equal\", xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
    "plt.title(\n",
    "    \"{} visualization of SGC activations for cora dataset\".format(transform.__name__)\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/sgc-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/sgc-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}