stellargraph/stellargraph

View on GitHub
demos/link-prediction/node2vec-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Link prediction with Node2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/node2vec-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/node2vec-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This demo notebook demonstrates how to predict citation links/edges between papers using Node2Vec on the Cora dataset.\n",
    "\n",
    "We're going to tackle link prediction as a supervised learning problem on top of node representations/embeddings. The embeddings are computed with the unsupervised node2vec algorithm. After obtaining embeddings, a binary classifier can be used to predict a link, or not, between any two nodes in the graph. Various hyperparameters could be relevant in obtaining the best link classifier - this demo demonstrates incorporating model selection into the pipeline for choosing the best binary operator to apply on a pair of node embeddings.\n",
    "\n",
    "There are four steps:\n",
    "\n",
    "1. Obtain embeddings for each node\n",
    "2. For each set of hyperparameters, train a classifier\n",
    "3. Select the classifier that performs the best\n",
    "4. Evaluate the selected classifier on unseen data to validate its ability to generalise\n",
    "\n",
    "<a name=\"refs\"></a>\n",
    "**References:** \n",
    "\n",
    "[1] Node2Vec: Scalable Feature Learning for Networks. A. Grover, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2016. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from math import isclose\n",
    "from sklearn.decomposition import PCA\n",
    "import os\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from stellargraph import StellarGraph, datasets\n",
    "from stellargraph.data import EdgeSplitter\n",
    "from collections import Counter\n",
    "import multiprocessing\n",
    "from IPython.display import display, HTML\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Load the dataset\n",
    "\n",
    "The Cora dataset is a homogeneous network where all nodes are papers and edges between nodes are citation links, e.g. paper A cites paper B."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "graph, _ = dataset.load(largest_connected_component_only=True, str_node_ids=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 5209\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5209]\n"
     ]
    }
   ],
   "source": [
    "print(graph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "## Construct splits of the input data\n",
    "\n",
    "We have to carefully split the data to avoid data leakage and evaluate the algorithms correctly:\n",
    "\n",
    "* For computing node embeddings, a **Train Graph** (`graph_train`)\n",
    "* For training classifiers, a classifier **Training Set** (`examples_train`) of positive and negative edges that weren't used for computing node embeddings\n",
    "* For choosing the best classifier, an **Model Selection Test Set** (`examples_model_selection`) of positive and negative edges that weren't used for computing node embeddings or training the classifier \n",
    "* For the final evaluation, a **Test Graph** (`graph_test`) to compute test node embeddings with more edges than the Train Graph, and a **Test Set** (`examples_test`) of positive and negative edges not used for neither computing the test node embeddings or for classifier training or model selection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "###  Test Graph\n",
    "\n",
    "We begin with the full graph and use the `EdgeSplitter` class to produce:\n",
    "\n",
    "* Test Graph\n",
    "* Test set of positive/negative link examples\n",
    "\n",
    "The Test Graph is the reduced graph we obtain from removing the test set of links from the full graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 520 positive and 520 negative edges. **\n",
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 4689\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [4689]\n"
     ]
    }
   ],
   "source": [
    "# Define an edge splitter on the original graph:\n",
    "edge_splitter_test = EdgeSplitter(graph)\n",
    "\n",
    "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from graph, and obtain the\n",
    "# reduced graph graph_test with the sampled links removed:\n",
    "graph_test, examples_test, labels_test = edge_splitter_test.train_test_split(\n",
    "    p=0.1, method=\"global\"\n",
    ")\n",
    "\n",
    "print(graph_test.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "### Train Graph\n",
    "\n",
    "This time, we use the `EdgeSplitter` on the Test Graph, and perform a train/test split on the examples to produce:\n",
    "\n",
    "* Train Graph\n",
    "* Training set of link examples\n",
    "* Set of link examples for model selection\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** Sampled 468 positive and 468 negative edges. **\n",
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2485, Edges: 4221\n",
      "\n",
      " Node types:\n",
      "  paper: [2485]\n",
      "    Features: float32 vector, length 1433\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [4221]\n"
     ]
    }
   ],
   "source": [
    "# Do the same process to compute a training subset from within the test graph\n",
    "edge_splitter_train = EdgeSplitter(graph_test, graph)\n",
    "graph_train, examples, labels = edge_splitter_train.train_test_split(\n",
    "    p=0.1, method=\"global\"\n",
    ")\n",
    "(\n",
    "    examples_train,\n",
    "    examples_model_selection,\n",
    "    labels_train,\n",
    "    labels_model_selection,\n",
    ") = train_test_split(examples, labels, train_size=0.75, test_size=0.25)\n",
    "\n",
    "print(graph_train.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Below is a summary of the different splits that have been created in this section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Number of Examples</th>\n",
       "      <th>Hidden from</th>\n",
       "      <th>Picked from</th>\n",
       "      <th>Use</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Split</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Training Set</th>\n",
       "      <td>702</td>\n",
       "      <td>Train Graph</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Train the Link Classifier</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Model Selection</th>\n",
       "      <td>234</td>\n",
       "      <td>Train Graph</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Select the best Link Classifier model</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Test set</th>\n",
       "      <td>1040</td>\n",
       "      <td>Test Graph</td>\n",
       "      <td>Full Graph</td>\n",
       "      <td>Evaluate the best Link Classifier</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Number of Examples  Hidden from Picked from  \\\n",
       "Split                                                          \n",
       "Training Set                    702  Train Graph  Test Graph   \n",
       "Model Selection                 234  Train Graph  Test Graph   \n",
       "Test set                       1040   Test Graph  Full Graph   \n",
       "\n",
       "                                                   Use  \n",
       "Split                                                   \n",
       "Training Set                 Train the Link Classifier  \n",
       "Model Selection  Select the best Link Classifier model  \n",
       "Test set             Evaluate the best Link Classifier  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(\n",
    "    [\n",
    "        (\n",
    "            \"Training Set\",\n",
    "            len(examples_train),\n",
    "            \"Train Graph\",\n",
    "            \"Test Graph\",\n",
    "            \"Train the Link Classifier\",\n",
    "        ),\n",
    "        (\n",
    "            \"Model Selection\",\n",
    "            len(examples_model_selection),\n",
    "            \"Train Graph\",\n",
    "            \"Test Graph\",\n",
    "            \"Select the best Link Classifier model\",\n",
    "        ),\n",
    "        (\n",
    "            \"Test set\",\n",
    "            len(examples_test),\n",
    "            \"Test Graph\",\n",
    "            \"Full Graph\",\n",
    "            \"Evaluate the best Link Classifier\",\n",
    "        ),\n",
    "    ],\n",
    "    columns=(\"Split\", \"Number of Examples\", \"Hidden from\", \"Picked from\", \"Use\"),\n",
    ").set_index(\"Split\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "## Node2Vec\n",
    "\n",
    "We use Node2Vec [[1]](#refs), to calculate node embeddings. These embeddings are learned in such a way to ensure that nodes that are close in the graph remain close in the embedding space. Node2Vec first involves running random walks on the graph to obtain our context pairs, and using these to train a Word2Vec model.\n",
    "\n",
    "These are the set of parameters we can use:\n",
    "\n",
    "* `p` - Random walk parameter \"p\"\n",
    "* `q` - Random walk parameter \"q\"\n",
    "* `dimensions` - Dimensionality of node2vec embeddings\n",
    "* `num_walks` - Number of walks from each node\n",
    "* `walk_length` - Length of each random walk\n",
    "* `window_size` - Context window size for Word2Vec\n",
    "* `num_iter` - number of SGD iterations (epochs)\n",
    "* `workers` - Number of workers for Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = 1.0\n",
    "q = 1.0\n",
    "dimensions = 128\n",
    "num_walks = 10\n",
    "walk_length = 80\n",
    "window_size = 10\n",
    "num_iter = 1\n",
    "workers = multiprocessing.cpu_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.data import BiasedRandomWalk\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "\n",
    "def node2vec_embedding(graph, name):\n",
    "    rw = BiasedRandomWalk(graph)\n",
    "    walks = rw.run(graph.nodes(), n=num_walks, length=walk_length, p=p, q=q)\n",
    "    print(f\"Number of random walks for '{name}': {len(walks)}\")\n",
    "\n",
    "    model = Word2Vec(\n",
    "        walks,\n",
    "        vector_size=dimensions,\n",
    "        window=window_size,\n",
    "        min_count=0,\n",
    "        sg=1,\n",
    "        workers=workers,\n",
    "        epochs=num_iter,\n",
    "    )\n",
    "\n",
    "    def get_embedding(u):\n",
    "        return model.wv[u]\n",
    "\n",
    "    return get_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of random walks for 'Train Graph': 24850\n"
     ]
    }
   ],
   "source": [
    "embedding_train = node2vec_embedding(graph_train, \"Train Graph\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "## Train and evaluate the link prediction model\n",
    "\n",
    "There are a few steps involved in using the Word2Vec model to perform link prediction:\n",
    "1. We calculate link/edge embeddings for the positive and negative edge samples by applying a binary operator on the embeddings of the source and target nodes of each sampled edge.\n",
    "2. Given the embeddings of the positive and negative examples, we train a logistic regression classifier to predict a binary value indicating whether an edge between two nodes should exist or not.\n",
    "3. We evaluate the performance of the link classifier for each of the 4 operators on the training data with node embeddings calculated on the **Train Graph** (`graph_train`), and select the best classifier.\n",
    "4. The best classifier is then used to calculate scores on the test data with node embeddings calculated on the **Test Graph** (`graph_test`).\n",
    "\n",
    "Below are a set of helper functions that let us repeat these steps for each of the binary operators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.linear_model import LogisticRegressionCV\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "# 1. link embeddings\n",
    "def link_examples_to_features(link_examples, transform_node, binary_operator):\n",
    "    return [\n",
    "        binary_operator(transform_node(src), transform_node(dst))\n",
    "        for src, dst in link_examples\n",
    "    ]\n",
    "\n",
    "\n",
    "# 2. training classifier\n",
    "def train_link_prediction_model(\n",
    "    link_examples, link_labels, get_embedding, binary_operator\n",
    "):\n",
    "    clf = link_prediction_classifier()\n",
    "    link_features = link_examples_to_features(\n",
    "        link_examples, get_embedding, binary_operator\n",
    "    )\n",
    "    clf.fit(link_features, link_labels)\n",
    "    return clf\n",
    "\n",
    "\n",
    "def link_prediction_classifier(max_iter=2000):\n",
    "    lr_clf = LogisticRegressionCV(Cs=10, cv=10, scoring=\"roc_auc\", max_iter=max_iter)\n",
    "    return Pipeline(steps=[(\"sc\", StandardScaler()), (\"clf\", lr_clf)])\n",
    "\n",
    "\n",
    "# 3. and 4. evaluate classifier\n",
    "def evaluate_link_prediction_model(\n",
    "    clf, link_examples_test, link_labels_test, get_embedding, binary_operator\n",
    "):\n",
    "    link_features_test = link_examples_to_features(\n",
    "        link_examples_test, get_embedding, binary_operator\n",
    "    )\n",
    "    score = evaluate_roc_auc(clf, link_features_test, link_labels_test)\n",
    "    return score\n",
    "\n",
    "\n",
    "def evaluate_roc_auc(clf, link_features, link_labels):\n",
    "    predicted = clf.predict_proba(link_features)\n",
    "\n",
    "    # check which class corresponds to positive links\n",
    "    positive_column = list(clf.classes_).index(1)\n",
    "    return roc_auc_score(link_labels, predicted[:, positive_column])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "We consider 4 different operators: \n",
    "\n",
    "* *Hadamard*\n",
    "* $L_1$\n",
    "* $L_2$\n",
    "* *average*\n",
    "\n",
    "The paper [[1]](#refs) provides a detailed description of these operators. All operators produce link embeddings that have equal dimensionality to the input node embeddings (128 dimensions for our example). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "def operator_hadamard(u, v):\n",
    "    return u * v\n",
    "\n",
    "\n",
    "def operator_l1(u, v):\n",
    "    return np.abs(u - v)\n",
    "\n",
    "\n",
    "def operator_l2(u, v):\n",
    "    return (u - v) ** 2\n",
    "\n",
    "\n",
    "def operator_avg(u, v):\n",
    "    return (u + v) / 2.0\n",
    "\n",
    "\n",
    "def run_link_prediction(binary_operator):\n",
    "    clf = train_link_prediction_model(\n",
    "        examples_train, labels_train, embedding_train, binary_operator\n",
    "    )\n",
    "    score = evaluate_link_prediction_model(\n",
    "        clf,\n",
    "        examples_model_selection,\n",
    "        labels_model_selection,\n",
    "        embedding_train,\n",
    "        binary_operator,\n",
    "    )\n",
    "\n",
    "    return {\n",
    "        \"classifier\": clf,\n",
    "        \"binary_operator\": binary_operator,\n",
    "        \"score\": score,\n",
    "    }\n",
    "\n",
    "\n",
    "binary_operators = [operator_hadamard, operator_l1, operator_l2, operator_avg]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best result from 'operator_l1'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ROC AUC score</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>operator_hadamard</th>\n",
       "      <td>0.883362</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l1</th>\n",
       "      <td>0.891015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_l2</th>\n",
       "      <td>0.886894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>operator_avg</th>\n",
       "      <td>0.666127</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   ROC AUC score\n",
       "name                            \n",
       "operator_hadamard       0.883362\n",
       "operator_l1             0.891015\n",
       "operator_l2             0.886894\n",
       "operator_avg            0.666127"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = [run_link_prediction(op) for op in binary_operators]\n",
    "best_result = max(results, key=lambda result: result[\"score\"])\n",
    "\n",
    "print(f\"Best result from '{best_result['binary_operator'].__name__}'\")\n",
    "\n",
    "pd.DataFrame(\n",
    "    [(result[\"binary_operator\"].__name__, result[\"score\"]) for result in results],\n",
    "    columns=(\"name\", \"ROC AUC score\"),\n",
    ").set_index(\"name\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {},
   "source": [
    "### Evaluate the best model using the test set\n",
    "\n",
    "Now that we've trained and selected our best model, we use a test set of embeddings and calculate a final evaluation score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "27",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of random walks for 'Test Graph': 24850\n"
     ]
    }
   ],
   "source": [
    "embedding_test = node2vec_embedding(graph_test, \"Test Graph\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROC AUC score on test set using 'operator_l1': 0.9120266272189348\n"
     ]
    }
   ],
   "source": [
    "test_score = evaluate_link_prediction_model(\n",
    "    best_result[\"classifier\"],\n",
    "    examples_test,\n",
    "    labels_test,\n",
    "    embedding_test,\n",
    "    best_result[\"binary_operator\"],\n",
    ")\n",
    "print(\n",
    "    f\"ROC AUC score on test set using '{best_result['binary_operator'].__name__}': {test_score}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "### Visualise representations of link embeddings\n",
    "\n",
    "Learned link embeddings have 128 dimensions but for visualisation we project them down to 2 dimensions using the PCA algorithm ([link](https://en.wikipedia.org/wiki/Principal_component_analysis)). \n",
    "\n",
    "Blue points represent positive edges and red points represent negative (no edge should exist between the corresponding vertices) edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "30",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x1586ffa50>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x864 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Calculate edge features for test data\n",
    "link_features = link_examples_to_features(\n",
    "    examples_test, embedding_test, best_result[\"binary_operator\"]\n",
    ")\n",
    "\n",
    "# Learn a projection from 128 dimensions to 2\n",
    "pca = PCA(n_components=2)\n",
    "X_transformed = pca.fit_transform(link_features)\n",
    "\n",
    "# plot the 2-dimensional points\n",
    "plt.figure(figsize=(16, 12))\n",
    "plt.scatter(\n",
    "    X_transformed[:, 0],\n",
    "    X_transformed[:, 1],\n",
    "    c=np.where(labels_test == 1, \"b\", \"r\"),\n",
    "    alpha=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "This example has demonstrated how to use the `stellargraph` library to build a link prediction algorithm for homogeneous graphs using the Node2Vec, [1], representation learning algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/node2vec-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/node2vec-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}