stellargraph/stellargraph

View on GitHub
demos/node-classification/graphsage-node-classification.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Node classification with GraphSAGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/graphsage-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/graphsage-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Import stellargraph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import GraphSAGENodeGenerator\n",
    "from stellargraph.layer import GraphSAGE\n",
    "\n",
    "from tensorflow.keras import layers, optimizers, losses, metrics, Model\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading the CORA network"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "G, node_subjects = dataset.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2708, Edges: 5429\n",
      "\n",
      " Node types:\n",
      "  paper: [2708]\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5429]\n"
     ]
    }
   ],
   "source": [
    "print(G.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "We aim to train a graph-ML model that will predict the \"subject\" attribute on the nodes. These subjects are one of 7 categories:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Case_Based',\n",
       " 'Genetic_Algorithms',\n",
       " 'Neural_Networks',\n",
       " 'Probabilistic_Methods',\n",
       " 'Reinforcement_Learning',\n",
       " 'Rule_Learning',\n",
       " 'Theory'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(node_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "### Splitting the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "For machine learning we want to take a subset of the nodes for training, and use the rest for testing. We'll use scikit-learn again to do this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_subjects, test_subjects = model_selection.train_test_split(\n",
    "    node_subjects, train_size=0.1, test_size=None, stratify=node_subjects\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Note using stratified sampling gives the following counts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Probabilistic_Methods': 42,\n",
       "         'Genetic_Algorithms': 42,\n",
       "         'Reinforcement_Learning': 22,\n",
       "         'Rule_Learning': 18,\n",
       "         'Neural_Networks': 81,\n",
       "         'Case_Based': 30,\n",
       "         'Theory': 35})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "\n",
    "Counter(train_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "The training set has class imbalance that might need to be compensated, e.g., via using a weighted cross-entropy loss in model training, with class weights inversely proportional to class support. However, we will ignore the class imbalance in this example, for simplicity."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "### Converting to numeric arrays"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "For our categorical target, we will use one-hot vectors that will be fed into a soft-max Keras layer during training. To do this conversion ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_encoding = preprocessing.LabelBinarizer()\n",
    "\n",
    "train_targets = target_encoding.fit_transform(train_subjects)\n",
    "test_targets = target_encoding.transform(test_subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "We now do the same for the node attributes we want to use to predict the subject. These are the feature vectors that the Keras model will use as input. The CORA dataset contains attributes 'w_x' that correspond to words found in that publication. If a word occurs more than once in a publication the relevant attribute will be set to one, otherwise it will be zero."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "## Creating the GraphSAGE model in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "To feed data from the graph to the Keras model we need a data generator that feeds data from the graph to the model. The generators are specialized to the model and the learning task so we choose the `GraphSAGENodeGenerator` as we are predicting node attributes with a GraphSAGE model.\n",
    "\n",
    "We need two other parameters, the `batch_size` to use for training and the number of nodes to sample at each level of the model. Here we choose a two-level model with 10 nodes sampled in the first layer, and 5 in the second."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "num_samples = [10, 5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "A `GraphSAGENodeGenerator` object is required to send the node features in sampled subgraphs to Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = GraphSAGENodeGenerator(G, batch_size, num_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "Using the `generator.flow()` method, we can create iterators over nodes that should be used to train, validate, or evaluate the model. For training we use only the training nodes returned from our splitter and the target values. The `shuffle=True` argument is given to the `flow` method to improve training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "Now we can specify our machine learning model, we need a few more parameters for this:\n",
    "\n",
    " * the `layer_sizes` is a list of hidden feature sizes of each layer in the model. In this example we use 32-dimensional hidden node features at each layer.\n",
    " * The `bias` and `dropout` are internal parameters of the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "graphsage_model = GraphSAGE(\n",
    "    layer_sizes=[32, 32], generator=generator, bias=True, dropout=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "Now we create a model to predict the 7 categories using Keras softmax layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_inp, x_out = graphsage_model.in_out_tensors()\n",
    "prediction = layers.Dense(units=train_targets.shape[1], activation=\"softmax\")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "### Training the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34",
   "metadata": {},
   "source": [
    "Now let's create the actual Keras model with the graph inputs `x_inp` provided by the `graph_model` and outputs being the predictions from the softmax layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "35",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=x_inp, outputs=prediction)\n",
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.005),\n",
    "    loss=losses.categorical_crossentropy,\n",
    "    metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36",
   "metadata": {},
   "source": [
    "Train the model, keeping track of its loss and accuracy on the training set, and its generalisation performance on the test set (we need to create another generator over the test data for this)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "37",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_gen = generator.flow(test_subjects.index, test_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "6/6 - 2s - loss: 1.8488 - acc: 0.3037 - val_loss: 1.6904 - val_acc: 0.3794\n",
      "Epoch 2/20\n",
      "6/6 - 2s - loss: 1.6272 - acc: 0.4852 - val_loss: 1.5230 - val_acc: 0.5349\n",
      "Epoch 3/20\n",
      "6/6 - 2s - loss: 1.4474 - acc: 0.6333 - val_loss: 1.3641 - val_acc: 0.6829\n",
      "Epoch 4/20\n",
      "6/6 - 2s - loss: 1.2771 - acc: 0.7630 - val_loss: 1.2483 - val_acc: 0.7186\n",
      "Epoch 5/20\n",
      "6/6 - 2s - loss: 1.1698 - acc: 0.8444 - val_loss: 1.1501 - val_acc: 0.7498\n",
      "Epoch 6/20\n",
      "6/6 - 2s - loss: 1.0364 - acc: 0.9000 - val_loss: 1.0619 - val_acc: 0.7756\n",
      "Epoch 7/20\n",
      "6/6 - 2s - loss: 0.9260 - acc: 0.8963 - val_loss: 0.9960 - val_acc: 0.7896\n",
      "Epoch 8/20\n",
      "6/6 - 2s - loss: 0.8232 - acc: 0.9000 - val_loss: 0.9372 - val_acc: 0.7986\n",
      "Epoch 9/20\n",
      "6/6 - 2s - loss: 0.7396 - acc: 0.9481 - val_loss: 0.8897 - val_acc: 0.8056\n",
      "Epoch 10/20\n",
      "6/6 - 2s - loss: 0.6708 - acc: 0.9630 - val_loss: 0.8496 - val_acc: 0.8056\n",
      "Epoch 11/20\n",
      "6/6 - 2s - loss: 0.5816 - acc: 0.9667 - val_loss: 0.8084 - val_acc: 0.8162\n",
      "Epoch 12/20\n",
      "6/6 - 2s - loss: 0.5232 - acc: 0.9852 - val_loss: 0.7748 - val_acc: 0.8175\n",
      "Epoch 13/20\n",
      "6/6 - 2s - loss: 0.4801 - acc: 0.9778 - val_loss: 0.7515 - val_acc: 0.8154\n",
      "Epoch 14/20\n",
      "6/6 - 2s - loss: 0.4383 - acc: 0.9852 - val_loss: 0.7452 - val_acc: 0.8097\n",
      "Epoch 15/20\n",
      "6/6 - 2s - loss: 0.4116 - acc: 0.9778 - val_loss: 0.7161 - val_acc: 0.8187\n",
      "Epoch 16/20\n",
      "6/6 - 2s - loss: 0.3584 - acc: 0.9889 - val_loss: 0.7039 - val_acc: 0.8187\n",
      "Epoch 17/20\n",
      "6/6 - 2s - loss: 0.3559 - acc: 0.9815 - val_loss: 0.6767 - val_acc: 0.8240\n",
      "Epoch 18/20\n",
      "6/6 - 2s - loss: 0.3104 - acc: 0.9889 - val_loss: 0.6849 - val_acc: 0.8146\n",
      "Epoch 19/20\n",
      "6/6 - 2s - loss: 0.2925 - acc: 0.9815 - val_loss: 0.6698 - val_acc: 0.8162\n",
      "Epoch 20/20\n",
      "6/6 - 2s - loss: 0.2690 - acc: 0.9852 - val_loss: 0.6559 - val_acc: 0.8199\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_gen, epochs=20, validation_data=test_gen, verbose=2, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Now we have trained the model we can evaluate on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Metrics:\n",
      "\tloss: 0.6601\n",
      "\tacc: 0.8228\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_gen)\n",
    "print(\"\\nTest Set Metrics:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "### Making predictions with the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43",
   "metadata": {},
   "source": [
    "Now let's get the predictions themselves for all nodes using another node iterator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "44",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes = node_subjects.index\n",
    "all_mapper = generator.flow(all_nodes)\n",
    "all_predictions = model.predict(all_mapper)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45",
   "metadata": {},
   "source": [
    "These predictions will be the output of the softmax layer, so to get final categories we'll use the `inverse_transform` method of our target attribute specification to turn these values back to the original categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "46",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_predictions = target_encoding.inverse_transform(all_predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47",
   "metadata": {},
   "source": [
    "Let's have a look at a few:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "48",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Predicted</th>\n",
       "      <th>True</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>31336</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1061127</td>\n",
       "      <td>Rule_Learning</td>\n",
       "      <td>Rule_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1106406</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13195</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37879</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1126012</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Probabilistic_Methods</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1107140</td>\n",
       "      <td>Reinforcement_Learning</td>\n",
       "      <td>Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1102850</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31349</td>\n",
       "      <td>Neural_Networks</td>\n",
       "      <td>Neural_Networks</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1106418</td>\n",
       "      <td>Theory</td>\n",
       "      <td>Theory</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      Predicted                    True\n",
       "31336           Neural_Networks         Neural_Networks\n",
       "1061127           Rule_Learning           Rule_Learning\n",
       "1106406  Reinforcement_Learning  Reinforcement_Learning\n",
       "13195    Reinforcement_Learning  Reinforcement_Learning\n",
       "37879     Probabilistic_Methods   Probabilistic_Methods\n",
       "1126012  Reinforcement_Learning   Probabilistic_Methods\n",
       "1107140  Reinforcement_Learning                  Theory\n",
       "1102850         Neural_Networks         Neural_Networks\n",
       "31349           Neural_Networks         Neural_Networks\n",
       "1106418                  Theory                  Theory"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame({\"Predicted\": node_predictions, \"True\": node_subjects})\n",
    "df.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49",
   "metadata": {},
   "source": [
    "Create a NetworkX graph to save it as GraphML, e.g. for visualisation in [Gephi](https://gephi.org). This adds the predictions to the graph before saving too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "50",
   "metadata": {},
   "outputs": [],
   "source": [
    "Gnx = G.to_networkx(feature_attr=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "51",
   "metadata": {},
   "outputs": [],
   "source": [
    "for nid, pred, true in zip(df.index, df[\"Predicted\"], df[\"True\"]):\n",
    "    Gnx.nodes[nid][\"subject\"] = true\n",
    "    Gnx.nodes[nid][\"PREDICTED_subject\"] = pred.split(\"=\")[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52",
   "metadata": {},
   "source": [
    "Also add `isTrain` and `isCorrect` node attributes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "53",
   "metadata": {},
   "outputs": [],
   "source": [
    "for nid in train_subjects.index:\n",
    "    Gnx.nodes[nid][\"isTrain\"] = True\n",
    "\n",
    "for nid in test_subjects.index:\n",
    "    Gnx.nodes[nid][\"isTrain\"] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "54",
   "metadata": {},
   "outputs": [],
   "source": [
    "for nid in Gnx.nodes():\n",
    "    Gnx.nodes[nid][\"isCorrect\"] = (\n",
    "        Gnx.nodes[nid][\"subject\"] == Gnx.nodes[nid][\"PREDICTED_subject\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55",
   "metadata": {},
   "source": [
    "Save in GraphML format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "56",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_fname = \"pred_n={}.graphml\".format(num_samples)\n",
    "nx.write_graphml(Gnx, os.path.join(dataset.data_directory, pred_fname))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57",
   "metadata": {},
   "source": [
    "## Node embeddings\n",
    "Evaluate node embeddings as activations of the output of GraphSAGE layer stack, and visualise them, coloring nodes by their subject label.\n",
    "\n",
    "The GraphSAGE embeddings are the output of the GraphSAGE layers, namely the `x_out` variable. Let's create a new model with the same inputs as we used previously `x_inp` but now the output is the embeddings rather than the predicted class. Additionally note that the weights trained previously are kept in the new model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "58",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_model = Model(inputs=x_inp, outputs=x_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2708, 32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb = embedding_model.predict(all_mapper)\n",
    "emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60",
   "metadata": {},
   "source": [
    "Project the embeddings to 2d using either TSNE or PCA transform, and visualise, coloring nodes by their subject label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "61",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "62",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = emb\n",
    "y = np.argmax(target_encoding.transform(node_subjects), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "63",
   "metadata": {},
   "outputs": [],
   "source": [
    "if X.shape[1] > 2:\n",
    "    transform = TSNE  # PCA\n",
    "\n",
    "    trans = transform(n_components=2)\n",
    "    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=node_subjects.index)\n",
    "    emb_transformed[\"label\"] = y\n",
    "else:\n",
    "    emb_transformed = pd.DataFrame(X, index=node_subjects.index)\n",
    "    emb_transformed = emb_transformed.rename(columns={\"0\": 0, \"1\": 1})\n",
    "    emb_transformed[\"label\"] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x504 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "alpha = 0.7\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.scatter(\n",
    "    emb_transformed[0],\n",
    "    emb_transformed[1],\n",
    "    c=emb_transformed[\"label\"].astype(\"category\"),\n",
    "    cmap=\"jet\",\n",
    "    alpha=alpha,\n",
    ")\n",
    "ax.set(aspect=\"equal\", xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
    "plt.title(\n",
    "    \"{} visualization of GraphSAGE embeddings for cora dataset\".format(transform.__name__)\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/node-classification/graphsage-node-classification.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/node-classification/graphsage-node-classification.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}