stellargraph/stellargraph

View on GitHub
demos/link-prediction/homogeneous-comparison-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Comparison of link prediction with random walks based node embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/homogeneous-comparison-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/homogeneous-comparison-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This demo notebook compares the link prediction performance of the embeddings learned by Node2Vec [1], Attri2Vec [2], GraphSAGE [3] and GCN [4] on the Cora dataset, under the same edge train-test-split setting. Node2Vec and Attri2Vec are learned by capturing the random walk context node similarity. GraphSAGE and GCN are learned in an unsupervised way by making nodes co-occurring in short random walks represented closely in the embedding space.\n",
    "\n",
    "We're going to tackle link prediction as a supervised learning problem on top of node representations/embeddings. After obtaining embeddings, a binary classifier can be used to predict a link, or not, between any two nodes in the graph. Various hyperparameters could be relevant in obtaining the best link classifier - this demo demonstrates incorporating model selection into the pipeline for choosing the best binary operator to apply on a pair of node embeddings.\n",
    "\n",
    "There are four steps:\n",
    "\n",
    "1. Obtain embeddings for each node\n",
    "2. For each set of hyperparameters, train a classifier\n",
    "3. Select the classifier that performs the best\n",
    "4. Evaluate the selected classifier on unseen data to validate its ability to generalise\n",
    "\n",
    "\n",
    "**References:** \n",
    "\n",
    "[1] Node2Vec: Scalable Feature Learning for Networks. A. Grover, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2016.\n",
    "\n",
    "[2] Attributed Network Embedding via Subspace Discovery. D. Zhang, Y. Jie, X. Zhu and C. Zhang. Data Mining and Knowledge Discovery, 2019. \n",
    "\n",
    "[3] Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec. Neural Information Processing Systems (NIPS), 2017.\n",
    "\n",
    "[4] Graph Convolutional Networks (GCN): Semi-Supervised Classification with Graph Convolutional Networks. Thomas N. Kipf, Max Welling. International Conference on Learning Representations (ICLR), 2017"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from math import isclose\n",
    "from sklearn.decomposition import PCA\n",
    "import os\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from stellargraph import StellarGraph, datasets\n",
    "from stellargraph.data import EdgeSplitter\n",
    "from collections import Counter\n",
    "import multiprocessing\n",
    "from IPython.display import display, HTML\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Load the dataset\n",
    "\n",
    "The Cora dataset is a homogeneous network where all nodes are papers and edges between nodes are citation links, e.g. paper A cites paper B."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "graph, _ = dataset.load(largest_connected_component_only=True, str_node_ids=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 5209\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5209]\n",
      "        Weights: all 1 (default)\n",
      "        Features: none\n"
     ]
    }
   ],
   "source": [
    "print(graph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "## Construct splits of the input data\n",
    "\n",
    "We have to carefully split the data to avoid data leakage and evaluate the algorithms correctly:\n",
    "\n",
    "* For computing node embeddings, a **Train Graph** (`graph_train`)\n",
    "* For training classifiers, a classifier **Training Set** (`examples_train`) of positive and negative edges that weren't used for computing node embeddings\n",
    "* For choosing the best classifier, a **Model Selection Test Set** (`examples_model_selection`) of positive and negative edges that weren't used for computing node embeddings or training the classifier \n",
    "* For the final evaluation, with the learned node embeddings from the **Train Graph** (`graph_train`), the chosen best classifier is applied to a **Test Set** (`examples_test`) of positive and negative edges not used for neither computing the node embeddings or for classifier training or model selection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "###  Test Graph\n",
    "\n",
    "We begin with the full graph and use the `EdgeSplitter` class to produce:\n",
    "\n",
    "* Test Graph\n",
    "* Test set of positive/negative link examples\n",
    "\n",
    "The Test Graph is the reduced graph we obtain from removing the test set of links from the full graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 520 positive and 520 negative edges. **\n",
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 4689\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [4689]\n",
      "        Weights: all 1 (default)\n",
      "        Features: none\n"
     ]
    }
   ],
   "source": [
    "# Define an edge splitter on the original graph:\n",
    "edge_splitter_test = EdgeSplitter(graph)\n",
    "\n",
    "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from graph, and obtain the\n",
    "# reduced graph graph_test with the sampled links removed:\n",
    "graph_test, examples_test, labels_test = edge_splitter_test.train_test_split(\n",
    "    p=0.1, method=\"global\"\n",
    ")\n",
    "\n",
    "print(graph_test.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "### Train Graph\n",
    "\n",
    "This time, we use the `EdgeSplitter` on the Test Graph, and perform a train/test split on the examples to produce:\n",
    "\n",
    "* Train Graph\n",
    "* Training set of link examples\n",
    "* Set of link examples for model selection\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 468 positive and 468 negative edges. **\n",
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 4221\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [4221]\n",
      "        Weights: all 1 (default)\n",
      "        Features: none\n"
     ]
    }
   ],
   "source": [
    "# Do the same process to compute a training subset from within the test graph\n",
    "edge_splitter_train = EdgeSplitter(graph_test)\n",
    "graph_train, examples, labels = edge_splitter_train.train_test_split(\n",
    "    p=0.1, method=\"global\"\n",
    ")\n",
    "(\n",
    "    examples_train,\n",
    "    examples_model_selection,\n",
    "    labels_train,\n",
    "    labels_model_selection,\n",
    ") = train_test_split(examples, labels, train_size=0.75, test_size=0.25)\n",
    "\n",
    "print(graph_train.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Below is a summary of the different splits that have been created in this section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Number of Examples</th>\n",
       "      <th>Hidden from</th>\n",
       "      <th>Picked from</th>\n",
       "      <th>Use</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Split</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Training Set</th>\n",
       "      <td>702</td>\n",
       "      <td>Train Graph</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Train the Link Classifier</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Model Selection</th>\n",
       "      <td>234</td>\n",
       "      <td>Train Graph</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Select the best Link Classifier model</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Test set</th>\n",
       "      <td>1040</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Full Graph</td>\n",
       "      <td>Evaluate the best Link Classifier</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Number of Examples  Hidden from Picked from  \\\n",
       "Split                                                          \n",
       "Training Set                    702  Train Graph  Test Graph   \n",
       "Model Selection                 234  Train Graph  Test Graph   \n",
       "Test set                       1040   Test Graph  Full Graph   \n",
       "\n",
       "                                                   Use  \n",
       "Split                                                   \n",
       "Training Set                 Train the Link Classifier  \n",
       "Model Selection  Select the best Link Classifier model  \n",
       "Test set             Evaluate the best Link Classifier  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(\n",
    "    [\n",
    "        (\n",
    "            \"Training Set\",\n",
    "            len(examples_train),\n",
    "            \"Train Graph\",\n",
    "            \"Test Graph\",\n",
    "            \"Train the Link Classifier\",\n",
    "        ),\n",
    "        (\n",
    "            \"Model Selection\",\n",
    "            len(examples_model_selection),\n",
    "            \"Train Graph\",\n",
    "            \"Test Graph\",\n",
    "            \"Select the best Link Classifier model\",\n",
    "        ),\n",
    "        (\n",
    "            \"Test set\",\n",
    "            len(examples_test),\n",
    "            \"Test Graph\",\n",
    "            \"Full Graph\",\n",
    "            \"Evaluate the best Link Classifier\",\n",
    "        ),\n",
    "    ],\n",
    "    columns=(\"Split\", \"Number of Examples\", \"Hidden from\", \"Picked from\", \"Use\"),\n",
    ").set_index(\"Split\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "## Create random walker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "We define the helper function to generate biased random walks from the given graph with the fixed random walk parameters:\n",
    "\n",
    "* `p` - Random walk parameter \"p\" that defines probability, \"1/p\", of returning to source node\n",
    "* `q` - Random walk parameter \"q\" that defines probability, \"1/q\", for moving to a node away from the source node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.data import BiasedRandomWalk\n",
    "\n",
    "\n",
    "def create_biased_random_walker(graph, walk_num, walk_length):\n",
    "    # parameter settings for \"p\" and \"q\":\n",
    "    p = 1.0\n",
    "    q = 1.0\n",
    "    return BiasedRandomWalk(graph, n=walk_num, length=walk_length, p=p, q=q)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "## Parameter Settings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "We train Node2Vec, Attri2Vec, GraphSAGE, and GCN by following the same unsupervised learning procedure: we firstly generate a set of short random walks from the given graph and then learn node embeddings from batches of `target, context` pairs collected from random walks. For learning node embeddings, we need to specify the following parameters:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "* `dimension` - Dimensionality of node embeddings\n",
    "* `walk_number` - Number of walks from each node\n",
    "* `walk_length` - Length of each random walk\n",
    "* `epochs` - The number of epochs to train embedding learning model\n",
    "* `batch_size` - The batch size to train embedding learning model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "We consistently set the node embedding dimension to 128 for all algorithms. However, we use different hidden layers to learn node embeddings for different algorithms to exert their respective power. For the remaining parameters, we set them as:\n",
    "\n",
    "|               | Node2Vec | Attri2Vec | GraphSAGE | GCN |\n",
    "|---------------|----------|-----------|-----------|-----|\n",
    "| `walk_number` |    20    |     4     |     1     |  1  |\n",
    "| `walk_length` |     5    |     5     |     5     |  5  |\n",
    "| `epochs`      |    6     |     6     |     6     |  6  |\n",
    "| `batch_size`  |    50    |     50    |    50     |  50 |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "As all algorithms use the same `walk_length`, `batch_size` and `epochs` values, we uniformly set them here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "walk_length = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "26",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "For different algorithms, users can find the best parameter setting with the `Model Selection` edge set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "## Node2Vec\n",
    "\n",
    "We use Node2Vec [1], to calculate node embeddings. These embeddings are learned in such a way to ensure that nodes that are close in the graph remain close in the embedding space. We train Node2Vec with the Stellargraph Node2Vec components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.data import UnsupervisedSampler\n",
    "from stellargraph.mapper import Node2VecLinkGenerator, Node2VecNodeGenerator\n",
    "from stellargraph.layer import Node2Vec, link_classification\n",
    "from tensorflow import keras\n",
    "\n",
    "\n",
    "def node2vec_embedding(graph, name):\n",
    "\n",
    "    # Set the embedding dimension and walk number:\n",
    "    dimension = 128\n",
    "    walk_number = 20\n",
    "\n",
    "    print(f\"Training Node2Vec for '{name}':\")\n",
    "\n",
    "    graph_node_list = list(graph.nodes())\n",
    "\n",
    "    # Create the biased random walker to generate random walks\n",
    "    walker = create_biased_random_walker(graph, walk_number, walk_length)\n",
    "\n",
    "    # Create the unsupervised sampler to sample (target, context) pairs from random walks\n",
    "    unsupervised_samples = UnsupervisedSampler(\n",
    "        graph, nodes=graph_node_list, walker=walker\n",
    "    )\n",
    "\n",
    "    # Define a Node2Vec training generator, which generates batches of training pairs\n",
    "    generator = Node2VecLinkGenerator(graph, batch_size)\n",
    "\n",
    "    # Create the Node2Vec model\n",
    "    node2vec = Node2Vec(dimension, generator=generator)\n",
    "\n",
    "    # Build the model and expose input and output sockets of Node2Vec, for node pair inputs\n",
    "    x_inp, x_out = node2vec.in_out_tensors()\n",
    "\n",
    "    # Use the link_classification function to generate the output of the Node2Vec model\n",
    "    prediction = link_classification(\n",
    "        output_dim=1, output_act=\"sigmoid\", edge_embedding_method=\"dot\"\n",
    "    )(x_out)\n",
    "\n",
    "    # Stack the Node2Vec encoder and prediction layer into a Keras model, and specify the loss\n",
    "    model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(lr=1e-3),\n",
    "        loss=keras.losses.binary_crossentropy,\n",
    "        metrics=[keras.metrics.binary_accuracy],\n",
    "    )\n",
    "\n",
    "    # Train the model\n",
    "    model.fit(\n",
    "        generator.flow(unsupervised_samples),\n",
    "        epochs=epochs,\n",
    "        verbose=2,\n",
    "        use_multiprocessing=False,\n",
    "        workers=4,\n",
    "        shuffle=True,\n",
    "    )\n",
    "\n",
    "    # Build the model to predict node representations from node ids with the learned Node2Vec model parameters\n",
    "    x_inp_src = x_inp[0]\n",
    "    x_out_src = x_out[0]\n",
    "    embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)\n",
    "\n",
    "    # Get representations for all nodes in ``graph``\n",
    "    node_gen = Node2VecNodeGenerator(graph, batch_size).flow(graph_node_list)\n",
    "    node_embeddings = embedding_model.predict(node_gen, workers=1, verbose=0)\n",
    "\n",
    "    def get_embedding(u):\n",
    "        u_index = graph_node_list.index(u)\n",
    "        return node_embeddings[u_index]\n",
    "\n",
    "    return get_embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "## Attri2Vec\n",
    "\n",
    "We use Attri2Vec [2] to calculate node embeddings. Attri2Vec learns node representations through performing a linear/non-linear mapping on node content attributes and simultaneously making nodes sharing similar context nodes in random walks have similar representations. With the node content features are used to learn node embeddings, we wish that Attri2Vec can achieve better link prediction performance than the only structure preserving network embedding algorithm Node2Vec."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.mapper import Attri2VecLinkGenerator, Attri2VecNodeGenerator\n",
    "from stellargraph.layer import Attri2Vec\n",
    "\n",
    "\n",
    "def attri2vec_embedding(graph, name):\n",
    "\n",
    "    # Set the embedding dimension and walk number:\n",
    "    dimension = [128]\n",
    "    walk_number = 4\n",
    "\n",
    "    print(f\"Training Attri2Vec for '{name}':\")\n",
    "\n",
    "    graph_node_list = list(graph.nodes())\n",
    "\n",
    "    # Create the biased random walker to generate random walks\n",
    "    walker = create_biased_random_walker(graph, walk_number, walk_length)\n",
    "\n",
    "    # Create the unsupervised sampler to sample (target, context) pairs from random walks\n",
    "    unsupervised_samples = UnsupervisedSampler(\n",
    "        graph, nodes=graph_node_list, walker=walker\n",
    "    )\n",
    "\n",
    "    # Define an Attri2Vec training generator, which generates batches of training pairs\n",
    "    generator = Attri2VecLinkGenerator(graph, batch_size)\n",
    "\n",
    "    # Create the Attri2Vec model\n",
    "    attri2vec = Attri2Vec(\n",
    "        layer_sizes=dimension, generator=generator, bias=False, normalize=None\n",
    "    )\n",
    "\n",
    "    # Build the model and expose input and output sockets of Attri2Vec, for node pair inputs\n",
    "    x_inp, x_out = attri2vec.in_out_tensors()\n",
    "\n",
    "    # Use the link_classification function to generate the output of the Attri2Vec model\n",
    "    prediction = link_classification(\n",
    "        output_dim=1, output_act=\"sigmoid\", edge_embedding_method=\"ip\"\n",
    "    )(x_out)\n",
    "\n",
    "    # Stack the Attri2Vec encoder and prediction layer into a Keras model, and specify the loss\n",
    "    model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(lr=1e-3),\n",
    "        loss=keras.losses.binary_crossentropy,\n",
    "        metrics=[keras.metrics.binary_accuracy],\n",
    "    )\n",
    "\n",
    "    # Train the model\n",
    "    model.fit(\n",
    "        generator.flow(unsupervised_samples),\n",
    "        epochs=epochs,\n",
    "        verbose=2,\n",
    "        use_multiprocessing=False,\n",
    "        workers=1,\n",
    "        shuffle=True,\n",
    "    )\n",
    "\n",
    "    # Build the model to predict node representations from node features with the learned Attri2Vec model parameters\n",
    "    x_inp_src = x_inp[0]\n",
    "    x_out_src = x_out[0]\n",
    "    embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)\n",
    "\n",
    "    # Get representations for all nodes in ``graph``\n",
    "    node_gen = Attri2VecNodeGenerator(graph, batch_size).flow(graph_node_list)\n",
    "    node_embeddings = embedding_model.predict(node_gen, workers=1, verbose=0)\n",
    "\n",
    "    def get_embedding(u):\n",
    "        u_index = graph_node_list.index(u)\n",
    "        return node_embeddings[u_index]\n",
    "\n",
    "    return get_embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "## GraphSAGE\n",
    "\n",
    "GraphSAGE [3] learns node embeddings for attributed graphs through aggregating neighboring node attributes. The aggregation parameters are learned by encouraging node pairs co-occurring in short random walks to have similar representations. As node attributes are also leveraged, GraphSAGE is expected to perform better than Node2Vec in link prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "34",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.mapper import GraphSAGELinkGenerator, GraphSAGENodeGenerator\n",
    "from stellargraph.layer import GraphSAGE\n",
    "\n",
    "\n",
    "def graphsage_embedding(graph, name):\n",
    "\n",
    "    # Set the embedding dimensions, the numbers of sampled neighboring nodes and walk number:\n",
    "    dimensions = [128, 128]\n",
    "    num_samples = [10, 5]\n",
    "    walk_number = 1\n",
    "\n",
    "    print(f\"Training GraphSAGE for '{name}':\")\n",
    "\n",
    "    graph_node_list = list(graph.nodes())\n",
    "\n",
    "    # Create the biased random walker to generate random walks\n",
    "    walker = create_biased_random_walker(graph, walk_number, walk_length)\n",
    "\n",
    "    # Create the unsupervised sampler to sample (target, context) pairs from random walks\n",
    "    unsupervised_samples = UnsupervisedSampler(\n",
    "        graph, nodes=graph_node_list, walker=walker\n",
    "    )\n",
    "\n",
    "    # Define a GraphSAGE training generator, which generates batches of training pairs\n",
    "    generator = GraphSAGELinkGenerator(graph, batch_size, num_samples)\n",
    "\n",
    "    # Create the GraphSAGE model\n",
    "    graphsage = GraphSAGE(\n",
    "        layer_sizes=dimensions,\n",
    "        generator=generator,\n",
    "        bias=True,\n",
    "        dropout=0.0,\n",
    "        normalize=\"l2\",\n",
    "    )\n",
    "\n",
    "    # Build the model and expose input and output sockets of GraphSAGE, for node pair inputs\n",
    "    x_inp, x_out = graphsage.in_out_tensors()\n",
    "\n",
    "    # Use the link_classification function to generate the output of the GraphSAGE model\n",
    "    prediction = link_classification(\n",
    "        output_dim=1, output_act=\"sigmoid\", edge_embedding_method=\"ip\"\n",
    "    )(x_out)\n",
    "\n",
    "    # Stack the GraphSAGE encoder and prediction layer into a Keras model, and specify the loss\n",
    "    model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(lr=1e-3),\n",
    "        loss=keras.losses.binary_crossentropy,\n",
    "        metrics=[keras.metrics.binary_accuracy],\n",
    "    )\n",
    "\n",
    "    # Train the model\n",
    "    model.fit(\n",
    "        generator.flow(unsupervised_samples),\n",
    "        epochs=epochs,\n",
    "        verbose=2,\n",
    "        use_multiprocessing=False,\n",
    "        workers=4,\n",
    "        shuffle=True,\n",
    "    )\n",
    "\n",
    "    # Build the model to predict node representations from node features with the learned GraphSAGE model parameters\n",
    "    x_inp_src = x_inp[0::2]\n",
    "    x_out_src = x_out[0]\n",
    "    embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)\n",
    "\n",
    "    # Get representations for all nodes in ``graph``\n",
    "    node_gen = GraphSAGENodeGenerator(graph, batch_size, num_samples).flow(\n",
    "        graph_node_list\n",
    "    )\n",
    "    node_embeddings = embedding_model.predict(node_gen, workers=1, verbose=0)\n",
    "\n",
    "    def get_embedding(u):\n",
    "        u_index = graph_node_list.index(u)\n",
    "        return node_embeddings[u_index]\n",
    "\n",
    "    return get_embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "## GCN\n",
    "\n",
    "GCN [4] learns node embeddings through graph convolution. Traditional GCN relies on node labels as a supervision to perform training. Here, we consider the unsupervised link prediction setting and we try to learn informative GCN node embeddings by making nodes co-occurring in short random walks represented closely, as is performed in training GraphSAGE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.mapper import FullBatchLinkGenerator, FullBatchNodeGenerator\n",
    "from stellargraph.layer import GCN, LinkEmbedding\n",
    "\n",
    "\n",
    "def gcn_embedding(graph, name):\n",
    "\n",
    "    # Set the embedding dimensions and walk number:\n",
    "    dimensions = [128, 128]\n",
    "    walk_number = 1\n",
    "\n",
    "    print(f\"Training GCN for '{name}':\")\n",
    "\n",
    "    graph_node_list = list(graph.nodes())\n",
    "\n",
    "    # Create the biased random walker to generate random walks\n",
    "    walker = create_biased_random_walker(graph, walk_number, walk_length)\n",
    "\n",
    "    # Create the unsupervised sampler to sample (target, context) pairs from random walks\n",
    "    unsupervised_samples = UnsupervisedSampler(\n",
    "        graph, nodes=graph_node_list, walker=walker\n",
    "    )\n",
    "\n",
    "    # Define a GCN training generator, which generates the full batch of training pairs\n",
    "    generator = FullBatchLinkGenerator(graph, method=\"gcn\")\n",
    "\n",
    "    # Create the GCN model\n",
    "    gcn = GCN(\n",
    "        layer_sizes=dimensions,\n",
    "        activations=[\"relu\", \"relu\"],\n",
    "        generator=generator,\n",
    "        dropout=0.3,\n",
    "    )\n",
    "\n",
    "    # Build the model and expose input and output sockets of GCN, for node pair inputs\n",
    "    x_inp, x_out = gcn.in_out_tensors()\n",
    "\n",
    "    # Use the dot product of node embeddings to make node pairs co-occurring in short random walks represented closely\n",
    "    prediction = LinkEmbedding(activation=\"sigmoid\", method=\"ip\")(x_out)\n",
    "    prediction = keras.layers.Reshape((-1,))(prediction)\n",
    "\n",
    "    # Stack the GCN encoder and prediction layer into a Keras model, and specify the loss\n",
    "    model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(lr=1e-3),\n",
    "        loss=keras.losses.binary_crossentropy,\n",
    "        metrics=[keras.metrics.binary_accuracy],\n",
    "    )\n",
    "\n",
    "    # Train the model\n",
    "    batches = unsupervised_samples.run(batch_size)\n",
    "    for epoch in range(epochs):\n",
    "        print(f\"Epoch: {epoch+1}/{epochs}\")\n",
    "        batch_iter = 1\n",
    "        for batch in batches:\n",
    "            samples = generator.flow(batch[0], targets=batch[1], use_ilocs=True)[0]\n",
    "            [loss, accuracy] = model.train_on_batch(x=samples[0], y=samples[1])\n",
    "            output = (\n",
    "                f\"{batch_iter}/{len(batches)} - loss:\"\n",
    "                + \" {:6.4f}\".format(loss)\n",
    "                + \" - binary_accuracy:\"\n",
    "                + \" {:6.4f}\".format(accuracy)\n",
    "            )\n",
    "            if batch_iter == len(batches):\n",
    "                print(output)\n",
    "            else:\n",
    "                print(output, end=\"\\r\")\n",
    "            batch_iter = batch_iter + 1\n",
    "\n",
    "    # Get representations for all nodes in ``graph``\n",
    "    embedding_model = keras.Model(inputs=x_inp, outputs=x_out)\n",
    "    node_embeddings = embedding_model.predict(\n",
    "        generator.flow(list(zip(graph_node_list, graph_node_list)))\n",
    "    )\n",
    "    node_embeddings = node_embeddings[0][:, 0, :]\n",
    "\n",
    "    def get_embedding(u):\n",
    "        u_index = graph_node_list.index(u)\n",
    "        return node_embeddings[u_index]\n",
    "\n",
    "    return get_embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {},
   "source": [
    "## Train and evaluate the link prediction model\n",
    "\n",
    "There are a few steps involved in using the learned embeddings to perform link prediction:\n",
    "1. We calculate link/edge embeddings for the positive and negative edge samples by applying a binary operator on the embeddings of the source and target nodes of each sampled edge.\n",
    "2. Given the embeddings of the positive and negative examples, we train a logistic regression classifier to predict a binary value indicating whether an edge between two nodes should exist or not.\n",
    "3. We evaluate the performance of the link classifier for each of the 4 operators on the training data with node embeddings calculated on the **Train Graph** (`graph_train`), and select the best classifier.\n",
    "4. The best classifier is then used to calculate scores on the test data with node embeddings trained on the **Train Graph** (`graph_train`).\n",
    "\n",
    "Below are a set of helper functions that let us repeat these steps for each of the binary operators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.linear_model import LogisticRegressionCV\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "# 1. link embeddings\n",
    "def link_examples_to_features(link_examples, transform_node, binary_operator):\n",
    "    return [\n",
    "        binary_operator(transform_node(src), transform_node(dst))\n",
    "        for src, dst in link_examples\n",
    "    ]\n",
    "\n",
    "\n",
    "# 2. training classifier\n",
    "def train_link_prediction_model(\n",
    "    link_examples, link_labels, get_embedding, binary_operator\n",
    "):\n",
    "    clf = link_prediction_classifier()\n",
    "    link_features = link_examples_to_features(\n",
    "        link_examples, get_embedding, binary_operator\n",
    "    )\n",
    "    clf.fit(link_features, link_labels)\n",
    "    return clf\n",
    "\n",
    "\n",
    "def link_prediction_classifier(max_iter=5000):\n",
    "    lr_clf = LogisticRegressionCV(Cs=10, cv=10, scoring=\"roc_auc\", max_iter=max_iter)\n",
    "    return Pipeline(steps=[(\"sc\", StandardScaler()), (\"clf\", lr_clf)])\n",
    "\n",
    "\n",
    "# 3. and 4. evaluate classifier\n",
    "def evaluate_link_prediction_model(\n",
    "    clf, link_examples_test, link_labels_test, get_embedding, binary_operator\n",
    "):\n",
    "    link_features_test = link_examples_to_features(\n",
    "        link_examples_test, get_embedding, binary_operator\n",
    "    )\n",
    "    score = evaluate_roc_auc(clf, link_features_test, link_labels_test)\n",
    "    return score\n",
    "\n",
    "\n",
    "def evaluate_roc_auc(clf, link_features, link_labels):\n",
    "    predicted = clf.predict_proba(link_features)\n",
    "\n",
    "    # check which class corresponds to positive links\n",
    "    positive_column = list(clf.classes_).index(1)\n",
    "    return roc_auc_score(link_labels, predicted[:, positive_column])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39",
   "metadata": {},
   "source": [
    "We consider 4 different operators: \n",
    "\n",
    "* *Hadamard*\n",
    "* $L_1$\n",
    "* $L_2$\n",
    "* *average*\n",
    "\n",
    "The paper [1] provides a detailed description of these operators. All operators produce link embeddings that have equal dimensionality to the input node embeddings (128 dimensions for our example). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "40",
   "metadata": {},
   "outputs": [],
   "source": [
    "def operator_hadamard(u, v):\n",
    "    return u * v\n",
    "\n",
    "\n",
    "def operator_l1(u, v):\n",
    "    return np.abs(u - v)\n",
    "\n",
    "\n",
    "def operator_l2(u, v):\n",
    "    return (u - v) ** 2\n",
    "\n",
    "\n",
    "def operator_avg(u, v):\n",
    "    return (u + v) / 2.0\n",
    "\n",
    "\n",
    "def run_link_prediction(binary_operator, embedding_train):\n",
    "    clf = train_link_prediction_model(\n",
    "        examples_train, labels_train, embedding_train, binary_operator\n",
    "    )\n",
    "    score = evaluate_link_prediction_model(\n",
    "        clf,\n",
    "        examples_model_selection,\n",
    "        labels_model_selection,\n",
    "        embedding_train,\n",
    "        binary_operator,\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"classifier\": clf,\n",
    "        \"binary_operator\": binary_operator,\n",
    "        \"score\": score,\n",
    "    }\n",
    "\n",
    "\n",
    "binary_operators = [operator_hadamard, operator_l1, operator_l2, operator_avg]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41",
   "metadata": {},
   "source": [
    "### Train and evaluate the link model with the specified embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "42",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(embedding, name):\n",
    "\n",
    "    embedding_train = embedding(graph_train, \"Train Graph\")\n",
    "\n",
    "    # Train the link classification model with the learned embedding\n",
    "    results = [run_link_prediction(op, embedding_train) for op in binary_operators]\n",
    "    best_result = max(results, key=lambda result: result[\"score\"])\n",
    "    print(\n",
    "        f\"\\nBest result with '{name}' embeddings from '{best_result['binary_operator'].__name__}'\"\n",
    "    )\n",
    "    display(\n",
    "        pd.DataFrame(\n",
    "            [(result[\"binary_operator\"].__name__, result[\"score\"]) for result in results],\n",
    "            columns=(\"name\", \"ROC AUC\"),\n",
    "        ).set_index(\"name\")\n",
    "    )\n",
    "\n",
    "    # Evaluate the best model using the test set\n",
    "    test_score = evaluate_link_prediction_model(\n",
    "        best_result[\"classifier\"],\n",
    "        examples_test,\n",
    "        labels_test,\n",
    "        embedding_train,\n",
    "        best_result[\"binary_operator\"],\n",
    "    )\n",
    "\n",
    "    return test_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43",
   "metadata": {},
   "source": [
    "### Collect the link prediction results for Node2Vec, Attri2Vec, GraphSAGE and GCN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44",
   "metadata": {},
   "source": [
    "#### Get Node2Vec link prediction result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Node2Vec for 'Train Graph':\n",
      "link_classification: using 'dot' method to combine node embeddings into edge embeddings\n",
      "Train for 7674 steps\n",
      "Epoch 1/6\n",
      "7674/7674 - 27s - loss: 0.5544 - binary_accuracy: 0.6739\n",
      "Epoch 2/6\n",
      "7674/7674 - 38s - loss: 0.4469 - binary_accuracy: 0.7548\n",
      "Epoch 3/6\n",
      "7674/7674 - 38s - loss: 0.2881 - binary_accuracy: 0.8811\n",
      "Epoch 4/6\n",
      "7674/7674 - 38s - loss: 0.1686 - binary_accuracy: 0.9420\n",
      "Epoch 5/6\n",
      "7674/7674 - 39s - loss: 0.1369 - binary_accuracy: 0.9522\n",
      "Epoch 6/6\n",
      "7674/7674 - 41s - loss: 0.1277 - binary_accuracy: 0.9551\n",
      "\n",
      "Best result with 'Node2Vec' embeddings from 'operator_l2'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>operator_hadamard</th>\n",
       "      <td>0.810491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l1</th>\n",
       "      <td>0.835257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l2</th>\n",
       "      <td>0.845412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_avg</th>\n",
       "      <td>0.513223</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    ROC AUC\n",
       "name                       \n",
       "operator_hadamard  0.810491\n",
       "operator_l1        0.835257\n",
       "operator_l2        0.845412\n",
       "operator_avg       0.513223"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "node2vec_result = train_and_evaluate(node2vec_embedding, \"Node2Vec\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46",
   "metadata": {},
   "source": [
    "#### Get Attri2Vec link prediction result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "47",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Attri2Vec for 'Train Graph':\n",
      "link_classification: using 'ip' method to combine node embeddings into edge embeddings\n",
      "Train for 1535 steps\n",
      "Epoch 1/6\n",
      "1535/1535 - 4s - loss: 0.6997 - binary_accuracy: 0.5324\n",
      "Epoch 2/6\n",
      "1535/1535 - 4s - loss: 0.6275 - binary_accuracy: 0.6479\n",
      "Epoch 3/6\n",
      "1535/1535 - 4s - loss: 0.4614 - binary_accuracy: 0.8077\n",
      "Epoch 4/6\n",
      "1535/1535 - 4s - loss: 0.3255 - binary_accuracy: 0.8848\n",
      "Epoch 5/6\n",
      "1535/1535 - 5s - loss: 0.2451 - binary_accuracy: 0.9198\n",
      "Epoch 6/6\n",
      "1535/1535 - 4s - loss: 0.1901 - binary_accuracy: 0.9427\n",
      "\n",
      "Best result with 'Attri2Vec' embeddings from 'operator_l1'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>operator_hadamard</th>\n",
       "      <td>0.871274</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l1</th>\n",
       "      <td>0.893191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l2</th>\n",
       "      <td>0.883913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_avg</th>\n",
       "      <td>0.542300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    ROC AUC\n",
       "name                       \n",
       "operator_hadamard  0.871274\n",
       "operator_l1        0.893191\n",
       "operator_l2        0.883913\n",
       "operator_avg       0.542300"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "attri2vec_result = train_and_evaluate(attri2vec_embedding, \"Attri2Vec\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48",
   "metadata": {},
   "source": [
    "#### Get GraphSAGE link prediction result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "49",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training GraphSAGE for 'Train Graph':\n",
      "link_classification: using 'ip' method to combine node embeddings into edge embeddings\n",
      "Train for 384 steps\n",
      "Epoch 1/6\n",
      "384/384 - 31s - loss: 0.5610 - binary_accuracy: 0.7552\n",
      "Epoch 2/6\n",
      "384/384 - 26s - loss: 0.5380 - binary_accuracy: 0.7844\n",
      "Epoch 3/6\n",
      "384/384 - 29s - loss: 0.5376 - binary_accuracy: 0.7840\n",
      "Epoch 4/6\n",
      "384/384 - 26s - loss: 0.5339 - binary_accuracy: 0.7943\n",
      "Epoch 5/6\n",
      "384/384 - 26s - loss: 0.5321 - binary_accuracy: 0.7958\n",
      "Epoch 6/6\n",
      "384/384 - 26s - loss: 0.5313 - binary_accuracy: 0.7936\n",
      "\n",
      "Best result with 'GraphSAGE' embeddings from 'operator_l2'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>operator_hadamard</th>\n",
       "      <td>0.883986</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l1</th>\n",
       "      <td>0.882963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l2</th>\n",
       "      <td>0.888150</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_avg</th>\n",
       "      <td>0.507890</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    ROC AUC\n",
       "name                       \n",
       "operator_hadamard  0.883986\n",
       "operator_l1        0.882963\n",
       "operator_l2        0.888150\n",
       "operator_avg       0.507890"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "graphsage_result = train_and_evaluate(graphsage_embedding, \"GraphSAGE\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50",
   "metadata": {},
   "source": [
    "#### Get GCN link prediction result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training GCN for 'Train Graph':\n",
      "Using GCN (local pooling) filters...\n",
      "Epoch: 1/6\n",
      "384/384 - loss: 0.4742 - binary_accuracy: 0.5882\n",
      "Epoch: 2/6\n",
      "384/384 - loss: 0.3981 - binary_accuracy: 0.8235\n",
      "Epoch: 3/6\n",
      "384/384 - loss: 0.3889 - binary_accuracy: 0.9412\n",
      "Epoch: 4/6\n",
      "384/384 - loss: 0.3537 - binary_accuracy: 0.9118\n",
      "Epoch: 5/6\n",
      "384/384 - loss: 0.3784 - binary_accuracy: 0.8824\n",
      "Epoch: 6/6\n",
      "384/384 - loss: 0.3334 - binary_accuracy: 0.8824\n",
      "\n",
      "Best result with 'GCN' embeddings from 'operator_hadamard'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>operator_hadamard</th>\n",
       "      <td>0.881794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l1</th>\n",
       "      <td>0.824445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l2</th>\n",
       "      <td>0.760666</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_avg</th>\n",
       "      <td>0.595777</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    ROC AUC\n",
       "name                       \n",
       "operator_hadamard  0.881794\n",
       "operator_l1        0.824445\n",
       "operator_l2        0.760666\n",
       "operator_avg       0.595777"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gcn_result = train_and_evaluate(gcn_embedding, \"GCN\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52",
   "metadata": {},
   "source": [
    "#### Comparison between Node2Vec, Attri2Vec, GraphSAGE and GCN on the test set"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53",
   "metadata": {},
   "source": [
    "The ROC AUC scores on the test set of links of different embeddings with their corresponding best operators:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Node2Vec</th>\n",
       "      <td>0.843417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Attri2Vec</th>\n",
       "      <td>0.936742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GraphSAGE</th>\n",
       "      <td>0.925422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GCN</th>\n",
       "      <td>0.902997</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            ROC AUC\n",
       "name               \n",
       "Node2Vec   0.843417\n",
       "Attri2Vec  0.936742\n",
       "GraphSAGE  0.925422\n",
       "GCN        0.902997"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(\n",
    "    [\n",
    "        (\"Node2Vec\", node2vec_result),\n",
    "        (\"Attri2Vec\", attri2vec_result),\n",
    "        (\"GraphSAGE\", graphsage_result),\n",
    "        (\"GCN\", gcn_result),\n",
    "    ],\n",
    "    columns=(\"name\", \"ROC AUC\"),\n",
    ").set_index(\"name\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56",
   "metadata": {},
   "source": [
    "This example has demonstrated how to use the `stellargraph` library to build a link prediction algorithm for homogeneous graphs using the unsupervised embeddings learned by Node2Vec [1], Attri2Vec [2] and GraphSAGE [3] and GCN [4]. \n",
    "\n",
    "For more information about the link prediction process, all of these algorithms have specific demos with more details:\n",
    "\n",
    "- [Node2Vec](node2vec-link-prediction.ipynb)\n",
    "- [Attri2Vec](attri2vec-link-prediction.ipynb)\n",
    "- [GraphSAGE](graphsage-link-prediction.ipynb)\n",
    "- [GCN](gcn-link-prediction.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/homogeneous-comparison-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/homogeneous-comparison-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}