stellargraph/stellargraph

View on GitHub
demos/link-prediction/gcn-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Link prediction with GCN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/gcn-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/gcn-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "In this example, we use our implementation of the [GCN](https://arxiv.org/abs/1609.02907) algorithm to build a model that predicts citation links in the Cora dataset (see below). The problem is treated as a supervised link prediction problem on a homogeneous citation network with nodes representing papers (with attributes such as binary keyword indicators and categorical subject) and links corresponding to paper-paper citations. \n",
    "\n",
    "To address this problem, we build a model with the following architecture. First we build a two-layer GCN model that takes labeled node pairs (`citing-paper` -> `cited-paper`)  corresponding to possible citation links, and outputs a pair of node embeddings for the `citing-paper` and `cited-paper` nodes of the pair. These embeddings are then fed into a link classification layer, which first applies a binary operator to those node embeddings (e.g., concatenating them) to construct the embedding of the potential link. Thus obtained link embeddings are passed through the dense link classification layer to obtain link predictions - probability for these candidate links to actually exist in the network. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy between predicted link probabilities and true link labels, with true/false citation links having labels 1/0) using stochastic gradient descent (SGD) updates of the model parameters, with minibatches of 'training' links fed into the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import stellargraph as sg\n",
    "from stellargraph.data import EdgeSplitter\n",
    "from stellargraph.mapper import FullBatchLinkGenerator\n",
    "from stellargraph.layer import GCN, LinkEmbedding\n",
    "\n",
    "\n",
    "from tensorflow import keras\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "\n",
    "from stellargraph import globalvar\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading the CORA network data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "G, _ = dataset.load(subject_as_feature=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2708, Edges: 5429\n",
      "\n",
      " Node types:\n",
      "  paper: [2708]\n",
      "    Features: float32 vector, length 1440\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5429]\n",
      "        Weights: all 1 (default)\n",
      "        Features: none\n"
     ]
    }
   ],
   "source": [
    "print(G.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "We aim to train a link prediction model, hence we need to prepare the train and test sets of links and the corresponding graphs with those links removed.\n",
    "\n",
    "We are going to split our input graph into a train and test graphs using the EdgeSplitter class in `stellargraph.data`. We will use the train graph for training the model (a binary classifier that, given two nodes, predicts whether a link between these two nodes should exist or not) and the test graph for evaluating the model's performance on hold out data.\n",
    "Each of these graphs will have the same number of nodes as the input graph, but the number of links will differ (be reduced) as some of the links will be removed during each split and used as the positive samples for training/testing the link prediction classifier."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "From the original graph G, extract a randomly sampled subset of test edges (true and false citation links) and the reduced graph G_test with the positive test edges removed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 542 positive and 542 negative edges. **\n"
     ]
    }
   ],
   "source": [
    "# Define an edge splitter on the original graph G:\n",
    "edge_splitter_test = EdgeSplitter(G)\n",
    "\n",
    "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the\n",
    "# reduced graph G_test with the sampled links removed:\n",
    "G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(\n",
    "    p=0.1, method=\"global\", keep_connected=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "The reduced graph G_test, together with the test ground truth set of links (edge_ids_test, edge_labels_test), will be used for testing the model.\n",
    "\n",
    "Now repeat this procedure to obtain the training data for the model. From the reduced graph G_test, extract a randomly sampled subset of train edges (true and false citation links) and the reduced graph G_train with the positive train edges removed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 488 positive and 488 negative edges. **\n"
     ]
    }
   ],
   "source": [
    "# Define an edge splitter on the reduced graph G_test:\n",
    "edge_splitter_train = EdgeSplitter(G_test)\n",
    "\n",
    "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the\n",
    "# reduced graph G_train with the sampled links removed:\n",
    "G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(\n",
    "    p=0.1, method=\"global\", keep_connected=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "G_train, together with the train ground truth set of links (edge_ids_train, edge_labels_train), will be used for training the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "## Creating the GCN link model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "Next, we create the link generators for the train and test link examples to the model. The link generators take the pairs of nodes (`citing-paper`, `cited-paper`) that are given in the `.flow` method to the Keras model, together with the corresponding binary labels indicating whether those pairs represent true or false links.\n",
    "\n",
    "The number of epochs for training the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "For training we create a generator on the `G_train` graph, and make an iterator over the training links using the generator's `flow()` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GCN (local pooling) filters...\n"
     ]
    }
   ],
   "source": [
    "train_gen = FullBatchLinkGenerator(G_train, method=\"gcn\")\n",
    "train_flow = train_gen.flow(edge_ids_train, edge_labels_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GCN (local pooling) filters...\n"
     ]
    }
   ],
   "source": [
    "test_gen = FullBatchLinkGenerator(G_test, method=\"gcn\")\n",
    "test_flow = test_gen.flow(edge_ids_test, edge_labels_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "Now we can specify our machine learning model, we need a few more parameters for this:\n",
    "\n",
    " * the `layer_sizes` is a list of hidden feature sizes of each layer in the model. In this example we use two GCN layers with 16-dimensional hidden node features at each layer.\n",
    " * `activations` is a list of activations applied to each layer's output\n",
    " * `dropout=0.3` specifies a 30% dropout at each layer. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "We create a GCN model as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "gcn = GCN(\n",
    "    layer_sizes=[16, 16], activations=[\"relu\", \"relu\"], generator=train_gen, dropout=0.3\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "To create a Keras model we now expose the input and output tensors of the GCN model for link prediction, via the `GCN.in_out_tensors` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_inp, x_out = gcn.in_out_tensors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "Final link classification layer that takes a pair of node embeddings produced by the GCN model, applies a binary operator to them to produce the corresponding link embedding (`ip` for inner product; other options for the binary operator can be seen by running a cell with `?LinkEmbedding` in it), and passes it through a dense layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = LinkEmbedding(activation=\"relu\", method=\"ip\")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "The predictions need to be reshaped from `(X, 1)` to `(X,)` to match the shape of the targets we have supplied above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = keras.layers.Reshape((-1,))(prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "Stack the GCN and prediction layers into a Keras model, and specify the loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "\n",
    "model.compile(\n",
    "    optimizer=keras.optimizers.Adam(lr=0.01),\n",
    "    loss=keras.losses.binary_crossentropy,\n",
    "    # not just \"acc\" due to https://github.com/tensorflow/tensorflow/issues/41361\n",
    "    metrics=[\"binary_accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "Evaluate the initial (untrained) model on the train and test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "34",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "1/1 [==============================] - 0s 109ms/step - loss: 2.0672 - binary_accuracy: 0.5000\n",
      "  ['...']\n",
      "1/1 [==============================] - 0s 11ms/step - loss: 2.0854 - binary_accuracy: 0.5000\n",
      "\n",
      "Train Set Metrics of the initial (untrained) model:\n",
      "\tloss: 2.0672\n",
      "\tbinary_accuracy: 0.5000\n",
      "\n",
      "Test Set Metrics of the initial (untrained) model:\n",
      "\tloss: 2.0854\n",
      "\tbinary_accuracy: 0.5000\n"
     ]
    }
   ],
   "source": [
    "init_train_metrics = model.evaluate(train_flow)\n",
    "init_test_metrics = model.evaluate(test_flow)\n",
    "\n",
    "print(\"\\nTrain Set Metrics of the initial (untrained) model:\")\n",
    "for name, val in zip(model.metrics_names, init_train_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))\n",
    "\n",
    "print(\"\\nTest Set Metrics of the initial (untrained) model:\")\n",
    "for name, val in zip(model.metrics_names, init_test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "Train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n",
      "Train for 1 steps, validate for 1 steps\n",
      "Epoch 1/50\n",
      "1/1 - 1s - loss: 2.0177 - binary_accuracy: 0.5000 - val_loss: 0.6800 - val_binary_accuracy: 0.6282\n",
      "Epoch 2/50\n",
      "1/1 - 0s - loss: 0.7811 - binary_accuracy: 0.6148 - val_loss: 2.7203 - val_binary_accuracy: 0.5304\n",
      "Epoch 3/50\n",
      "1/1 - 0s - loss: 3.0683 - binary_accuracy: 0.5512 - val_loss: 0.8094 - val_binary_accuracy: 0.6070\n",
      "Epoch 4/50\n",
      "1/1 - 0s - loss: 1.0147 - binary_accuracy: 0.6281 - val_loss: 0.6638 - val_binary_accuracy: 0.6144\n",
      "Epoch 5/50\n",
      "1/1 - 0s - loss: 0.6450 - binary_accuracy: 0.6383 - val_loss: 0.7782 - val_binary_accuracy: 0.5452\n",
      "Epoch 6/50\n",
      "1/1 - 0s - loss: 0.7345 - binary_accuracy: 0.5594 - val_loss: 0.8198 - val_binary_accuracy: 0.5360\n",
      "Epoch 7/50\n",
      "1/1 - 0s - loss: 0.7581 - binary_accuracy: 0.5420 - val_loss: 0.7800 - val_binary_accuracy: 0.5424\n",
      "Epoch 8/50\n",
      "1/1 - 0s - loss: 0.7302 - binary_accuracy: 0.5635 - val_loss: 0.6993 - val_binary_accuracy: 0.5793\n",
      "Epoch 9/50\n",
      "1/1 - 0s - loss: 0.6579 - binary_accuracy: 0.6178 - val_loss: 0.6471 - val_binary_accuracy: 0.6448\n",
      "Epoch 10/50\n",
      "1/1 - 0s - loss: 0.5960 - binary_accuracy: 0.6527 - val_loss: 0.6303 - val_binary_accuracy: 0.6651\n",
      "Epoch 11/50\n",
      "1/1 - 0s - loss: 0.6916 - binary_accuracy: 0.7049 - val_loss: 0.6082 - val_binary_accuracy: 0.6753\n",
      "Epoch 12/50\n",
      "1/1 - 0s - loss: 0.6069 - binary_accuracy: 0.7182 - val_loss: 0.6083 - val_binary_accuracy: 0.6633\n",
      "Epoch 13/50\n",
      "1/1 - 0s - loss: 0.5257 - binary_accuracy: 0.7131 - val_loss: 0.5991 - val_binary_accuracy: 0.6550\n",
      "Epoch 14/50\n",
      "1/1 - 0s - loss: 0.5381 - binary_accuracy: 0.7111 - val_loss: 0.5900 - val_binary_accuracy: 0.6688\n",
      "Epoch 15/50\n",
      "1/1 - 0s - loss: 0.5440 - binary_accuracy: 0.7305 - val_loss: 0.5756 - val_binary_accuracy: 0.6873\n",
      "Epoch 16/50\n",
      "1/1 - 0s - loss: 0.5004 - binary_accuracy: 0.7480 - val_loss: 0.5669 - val_binary_accuracy: 0.7011\n",
      "Epoch 17/50\n",
      "1/1 - 0s - loss: 0.5103 - binary_accuracy: 0.7572 - val_loss: 0.5710 - val_binary_accuracy: 0.7168\n",
      "Epoch 18/50\n",
      "1/1 - 0s - loss: 0.5410 - binary_accuracy: 0.7510 - val_loss: 0.5528 - val_binary_accuracy: 0.7389\n",
      "Epoch 19/50\n",
      "1/1 - 0s - loss: 0.5042 - binary_accuracy: 0.7602 - val_loss: 0.5363 - val_binary_accuracy: 0.7555\n",
      "Epoch 20/50\n",
      "1/1 - 0s - loss: 0.5035 - binary_accuracy: 0.7818 - val_loss: 0.5337 - val_binary_accuracy: 0.7565\n",
      "Epoch 21/50\n",
      "1/1 - 0s - loss: 0.4343 - binary_accuracy: 0.7900 - val_loss: 0.5315 - val_binary_accuracy: 0.7518\n",
      "Epoch 22/50\n",
      "1/1 - 0s - loss: 0.4395 - binary_accuracy: 0.7920 - val_loss: 0.5290 - val_binary_accuracy: 0.7509\n",
      "Epoch 23/50\n",
      "1/1 - 0s - loss: 0.4513 - binary_accuracy: 0.7838 - val_loss: 0.5253 - val_binary_accuracy: 0.7528\n",
      "Epoch 24/50\n",
      "1/1 - 0s - loss: 0.4329 - binary_accuracy: 0.8064 - val_loss: 0.5292 - val_binary_accuracy: 0.7528\n",
      "Epoch 25/50\n",
      "1/1 - 0s - loss: 0.3979 - binary_accuracy: 0.8289 - val_loss: 0.5225 - val_binary_accuracy: 0.7620\n",
      "Epoch 26/50\n",
      "1/1 - 0s - loss: 0.4230 - binary_accuracy: 0.8084 - val_loss: 0.5259 - val_binary_accuracy: 0.7685\n",
      "Epoch 27/50\n",
      "1/1 - 0s - loss: 0.4280 - binary_accuracy: 0.8340 - val_loss: 0.5319 - val_binary_accuracy: 0.7703\n",
      "Epoch 28/50\n",
      "1/1 - 0s - loss: 0.3886 - binary_accuracy: 0.8320 - val_loss: 0.5297 - val_binary_accuracy: 0.7786\n",
      "Epoch 29/50\n",
      "1/1 - 0s - loss: 0.3921 - binary_accuracy: 0.8525 - val_loss: 0.5542 - val_binary_accuracy: 0.7860\n",
      "Epoch 30/50\n",
      "1/1 - 0s - loss: 0.3724 - binary_accuracy: 0.8576 - val_loss: 0.5854 - val_binary_accuracy: 0.7878\n",
      "Epoch 31/50\n",
      "1/1 - 0s - loss: 0.3583 - binary_accuracy: 0.8525 - val_loss: 0.5993 - val_binary_accuracy: 0.7851\n",
      "Epoch 32/50\n",
      "1/1 - 0s - loss: 0.3930 - binary_accuracy: 0.8504 - val_loss: 0.5984 - val_binary_accuracy: 0.7934\n",
      "Epoch 33/50\n",
      "1/1 - 0s - loss: 0.4009 - binary_accuracy: 0.8627 - val_loss: 0.5854 - val_binary_accuracy: 0.7952\n",
      "Epoch 34/50\n",
      "1/1 - 0s - loss: 0.3854 - binary_accuracy: 0.8617 - val_loss: 0.5798 - val_binary_accuracy: 0.7989\n",
      "Epoch 35/50\n",
      "1/1 - 0s - loss: 0.3616 - binary_accuracy: 0.8873 - val_loss: 0.5744 - val_binary_accuracy: 0.7980\n",
      "Epoch 36/50\n",
      "1/1 - 0s - loss: 0.3418 - binary_accuracy: 0.8637 - val_loss: 0.5535 - val_binary_accuracy: 0.8044\n",
      "Epoch 37/50\n",
      "1/1 - 0s - loss: 0.3682 - binary_accuracy: 0.8730 - val_loss: 0.5457 - val_binary_accuracy: 0.8035\n",
      "Epoch 38/50\n",
      "1/1 - 0s - loss: 0.3270 - binary_accuracy: 0.8842 - val_loss: 0.5584 - val_binary_accuracy: 0.8044\n",
      "Epoch 39/50\n",
      "1/1 - 0s - loss: 0.2986 - binary_accuracy: 0.8955 - val_loss: 0.5798 - val_binary_accuracy: 0.8054\n",
      "Epoch 40/50\n",
      "1/1 - 0s - loss: 0.3134 - binary_accuracy: 0.8740 - val_loss: 0.6044 - val_binary_accuracy: 0.8026\n",
      "Epoch 41/50\n",
      "1/1 - 0s - loss: 0.3217 - binary_accuracy: 0.8832 - val_loss: 0.5901 - val_binary_accuracy: 0.7961\n",
      "Epoch 42/50\n",
      "1/1 - 0s - loss: 0.3200 - binary_accuracy: 0.8791 - val_loss: 0.6040 - val_binary_accuracy: 0.7943\n",
      "Epoch 43/50\n",
      "1/1 - 0s - loss: 0.3124 - binary_accuracy: 0.8740 - val_loss: 0.6031 - val_binary_accuracy: 0.7961\n",
      "Epoch 44/50\n",
      "1/1 - 0s - loss: 0.3157 - binary_accuracy: 0.8822 - val_loss: 0.6116 - val_binary_accuracy: 0.7998\n",
      "Epoch 45/50\n",
      "1/1 - 0s - loss: 0.3034 - binary_accuracy: 0.8852 - val_loss: 0.6333 - val_binary_accuracy: 0.7998\n",
      "Epoch 46/50\n",
      "1/1 - 0s - loss: 0.2818 - binary_accuracy: 0.8914 - val_loss: 0.6404 - val_binary_accuracy: 0.7998\n",
      "Epoch 47/50\n",
      "1/1 - 0s - loss: 0.2646 - binary_accuracy: 0.8914 - val_loss: 0.6476 - val_binary_accuracy: 0.8007\n",
      "Epoch 48/50\n",
      "1/1 - 0s - loss: 0.2637 - binary_accuracy: 0.8975 - val_loss: 0.6712 - val_binary_accuracy: 0.8063\n",
      "Epoch 49/50\n",
      "1/1 - 0s - loss: 0.2802 - binary_accuracy: 0.9047 - val_loss: 0.6873 - val_binary_accuracy: 0.8054\n",
      "Epoch 50/50\n",
      "1/1 - 0s - loss: 0.2481 - binary_accuracy: 0.9109 - val_loss: 0.7384 - val_binary_accuracy: 0.8035\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_flow, epochs=epochs, validation_data=test_flow, verbose=2, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {},
   "source": [
    "Plot the training history:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39",
   "metadata": {},
   "source": [
    "Evaluate the trained model on test citation links:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "1/1 [==============================] - 0s 11ms/step - loss: 0.1937 - binary_accuracy: 0.9426\n",
      "  ['...']\n",
      "1/1 [==============================] - 0s 11ms/step - loss: 0.7384 - binary_accuracy: 0.8035\n",
      "\n",
      "Train Set Metrics of the trained model:\n",
      "\tloss: 0.1937\n",
      "\tbinary_accuracy: 0.9426\n",
      "\n",
      "Test Set Metrics of the trained model:\n",
      "\tloss: 0.7384\n",
      "\tbinary_accuracy: 0.8035\n"
     ]
    }
   ],
   "source": [
    "train_metrics = model.evaluate(train_flow)\n",
    "test_metrics = model.evaluate(test_flow)\n",
    "\n",
    "print(\"\\nTrain Set Metrics of the trained model:\")\n",
    "for name, val in zip(model.metrics_names, train_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))\n",
    "\n",
    "print(\"\\nTest Set Metrics of the trained model:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/gcn-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/gcn-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}