stellargraph/stellargraph

View on GitHub
demos/link-prediction/distmult-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Knowledge graph link prediction with DistMult"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/distmult-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/distmult-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This notebook reproduces the experiments done in the paper that introduced the DistMult algorithm: Embedding Entities and Relations for Learning and Inference in Knowledge Bases, Bishan Yang, Scott Wen-tau Yih, Xiaodong He, Jianfeng Gao and Li Deng, ICLR 2015. https://arxiv.org/pdf/1412.6575\n",
    "\n",
    "In table 2, the paper reports 2 metrics measured on the WN18 and FB15K datasets (and FB15k-401): MRR (mean reciprocal rank) and Hits at 10. These are computed as \"filtered\", where known edges (in the train, test or validation sets) are ignored when computing ranks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph import datasets, utils\n",
    "from tensorflow.keras import callbacks, optimizers, losses, metrics, regularizers, Model\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from stellargraph.mapper import KGTripleGenerator\n",
    "from stellargraph.layer import DistMult\n",
    "\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Initialisation\n",
    "\n",
    "We need to set up our model parameters, like the number of epochs to train for, and the dimension of the embedding vectors we compute for each node and for each edge type. \n",
    "\n",
    "The evaluation is performed in three steps:\n",
    "\n",
    "1. Load the data\n",
    "2. Train a model\n",
    "3. Evaluate the model\n",
    "\n",
    "On pages 4 and 5, the paper describes their implementation details. The paper says that it uses:\n",
    "\n",
    "- the AdaGrad optimiser for 100 (for FB15k) or 300 epochs (for WN18)\n",
    "- an embedding dimension of 100\n",
    "- samples 2 corrupted edges per true edge\n",
    "- unit normalization of each entity embedding vector after each epoch: this is not currently supported by TensorFlow ([#33755](https://github.com/tensorflow/tensorflow/issues/33755)), and so may explain the slightly poorer MRR metrics on WN18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 300\n",
    "embedding_dimension = 100\n",
    "negative_samples = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "## WN18\n",
    "\n",
    "The paper uses the WN18 and FB15k datasets for validation. These datasets are not good for evaluating algorithms because they contain \"inverse relations\", where `(s, r1, o)` implies `(o, r2, s)` for a pair of relation types `r1` and `r2` (for instance, `_hyponym` (\"is more specific than\") and `_hypernym` (\"is more general than\") in WN18), however, they work fine to demonstrate StellarGraph's functionality, and are appropriate to compare against the published results.\n",
    "\n",
    "### Load the data\n",
    "\n",
    "The dataset comes with a defined train, test and validation split, each consisting of subject, relation, object triples. We can load a `StellarGraph` object with all of the triples, as well as the individual splits as Pandas DataFrames, using the `load` method of the `WN18` dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The WN18 dataset consists of triplets from WordNet 3.0 (http://wordnet.princeton.edu). There are 40,943 synsets and 18 relation types among them. The training set contains 141442 triplets, the validation set 5000 and the test set 5000. Antoine Bordes, Xavier Glorot, Jason Weston and Yoshua Bengio “A Semantic Matching Energy Function for Learning with Multi-relational Data” (2014).\n",
       "\n",
       "Note: this dataset contains many inverse relations, and so should only be used to compare against published results. Prefer WN18RR. See: Kristina Toutanova and Danqi Chen “Observed versus latent features for knowledge base and text inference” (2015), and Dettmers, Tim, Pasquale Minervini, Pontus Stenetorp and Sebastian Riedel “Convolutional 2D Knowledge Graph Embeddings” (2017)."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wn18 = datasets.WN18()\n",
    "display(HTML(wn18.description))\n",
    "wn18_graph, wn18_train, wn18_test, wn18_valid = wn18.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarDiGraph: Directed multigraph\n",
      " Nodes: 40943, Edges: 151442\n",
      "\n",
      " Node types:\n",
      "  default: [40943]\n",
      "    Features: none\n",
      "    Edge types: default-_also_see->default, default-_derivationally_related_form->default, default-_has_part->default, default-_hypernym->default, default-_hyponym->default, ... (13 more)\n",
      "\n",
      " Edge types:\n",
      "    default-_hyponym->default: [37221]\n",
      "    default-_hypernym->default: [37221]\n",
      "    default-_derivationally_related_form->default: [31867]\n",
      "    default-_member_meronym->default: [7928]\n",
      "    default-_member_holonym->default: [7928]\n",
      "    default-_part_of->default: [5148]\n",
      "    default-_has_part->default: [5142]\n",
      "    default-_member_of_domain_topic->default: [3341]\n",
      "    default-_synset_domain_topic_of->default: [3335]\n",
      "    default-_instance_hyponym->default: [3150]\n",
      "    default-_instance_hypernym->default: [3150]\n",
      "    default-_also_see->default: [1396]\n",
      "    default-_verb_group->default: [1220]\n",
      "    default-_member_of_domain_region->default: [983]\n",
      "    default-_synset_domain_region_of->default: [982]\n",
      "    default-_member_of_domain_usage->default: [675]\n",
      "    default-_synset_domain_usage_of->default: [669]\n",
      "    default-_similar_to->default: [86]\n"
     ]
    }
   ],
   "source": [
    "print(wn18_graph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "### Train a model\n",
    "\n",
    "The DistMult algorithm consists of some embedding layers and a scoring layer, but the `DistMult` object means these details are invisible to us. The `DistMult` model consumes \"knowledge-graph triples\", which can be produced in the appropriate format using `KGTripleGenerator`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "wn18_gen = KGTripleGenerator(\n",
    "    wn18_graph, batch_size=len(wn18_train) // 10  # ~10 batches per epoch\n",
    ")\n",
    "\n",
    "wn18_distmult = DistMult(\n",
    "    wn18_gen,\n",
    "    embedding_dimension=embedding_dimension,\n",
    "    embeddings_regularizer=regularizers.l2(1e-7),\n",
    ")\n",
    "\n",
    "wn18_inp, wn18_out = wn18_distmult.in_out_tensors()\n",
    "\n",
    "wn18_model = Model(inputs=wn18_inp, outputs=wn18_out)\n",
    "\n",
    "wn18_model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.001),\n",
    "    loss=losses.BinaryCrossentropy(from_logits=True),\n",
    "    metrics=[metrics.BinaryAccuracy(threshold=0.0)],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "Inputs for training are produced by calling the `KGTripleGenerator.flow` method, this takes a dataframe with `source`, `label` and `target` columns, where each row is a true edge in the knowledge graph.  The `negative_samples` parameter controls how many random edges are created for each positive edge to use as negative examples for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "wn18_train_gen = wn18_gen.flow(\n",
    "    wn18_train, negative_samples=negative_samples, shuffle=True\n",
    ")\n",
    "wn18_valid_gen = wn18_gen.flow(wn18_valid, negative_samples=negative_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n"
     ]
    }
   ],
   "source": [
    "wn18_es = callbacks.EarlyStopping(monitor=\"val_loss\", patience=50)\n",
    "wn18_history = wn18_model.fit(\n",
    "    wn18_train_gen,\n",
    "    validation_data=wn18_valid_gen,\n",
    "    epochs=epochs,\n",
    "    callbacks=[wn18_es],\n",
    "    verbose=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "utils.plot_history(wn18_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "### Evaluate the model\n",
    "\n",
    "We've now trained a model, so we can apply the evaluation procedure from the paper to it. This is done by taking each test edge `E = (s, r, o)`, and scoring it against all mutations `(s, r, n)` and `(n, r, o)` for every node `n` in the graph, that is, doing a prediction for every one of these edges similar to `E`. The \"raw\" rank is the number of mutated edges that have a higher predicted score than the true `E`.\n",
    "\n",
    "The DistMult paper uses only 10 batches per epoch, which results in large batch sizes: ~15 thousand edges per batch for WN18, and ~60 thousand edges per batch for the FB15k dataset below. Evaluation with `rank_edges_against_all_nodes` uses bulk operations for efficient reasons, at the cost of memory usage proportional to `O(batch size * number of nodes)`; a more moderate batch size gives similar performance without using large amounts of memory. We can swap the batch size by creating a new generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "wn18_smaller_gen = KGTripleGenerator(wn18_graph, batch_size=5000)\n",
    "\n",
    "wn18_raw_ranks, wn18_filtered_ranks = wn18_distmult.rank_edges_against_all_nodes(\n",
    "    wn18_smaller_gen.flow(wn18_test), wn18_graph\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper function to compute metrics from an array of ranks\n",
    "def results_as_dataframe(mrr, hits_at_10):\n",
    "    return pd.DataFrame(\n",
    "        [(mrr, hits_at_10)], columns=[\"mrr\", \"hits at 10\"], index=[\"filtered\"],\n",
    "    )\n",
    "\n",
    "\n",
    "def summarise(ranks):\n",
    "    return results_as_dataframe(np.mean(1 / ranks), np.mean(ranks <= 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>hits at 10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>filtered</td>\n",
       "      <td>0.709954</td>\n",
       "      <td>0.9303</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               mrr  hits at 10\n",
       "filtered  0.709954      0.9303"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarise(wn18_filtered_ranks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "For comparison, Table 2 in the paper gives the following results for WN18. All of the numbers are similar:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>hits at 10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>filtered</td>\n",
       "      <td>0.83</td>\n",
       "      <td>0.942</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           mrr  hits at 10\n",
       "filtered  0.83       0.942"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_as_dataframe(0.83, 0.942)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "## FB15k\n",
    "\n",
    "Now that we know the process, we can apply the model on the FB15k dataset in the same way.\n",
    "\n",
    "### Loading the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "25",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "This FREEBASE FB15k DATA consists of a collection of triplets (synset, relation_type, triplet)extracted from Freebase (http://www.freebase.com). There are 14,951 mids and 1,345 relation types among them. The training set contains 483142 triplets, the validation set 50000 and the test set 59071. Antoine Bordes, Nicolas Usunier, Alberto Garcia-Durán, Jason Weston and Oksana Yakhnenko “Translating Embeddings for Modeling Multi-relational Data” (2013).\n",
       "\n",
       "Note: this dataset contains many inverse relations, and so should only be used to compare against published results. Prefer FB15k_237. See: Kristina Toutanova and Danqi Chen “Observed versus latent features for knowledge base and text inference” (2015), and Dettmers, Tim, Pasquale Minervini, Pontus Stenetorp and Sebastian Riedel “Convolutional 2D Knowledge Graph Embeddings” (2017)."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fb15k = datasets.FB15k()\n",
    "display(HTML(fb15k.description))\n",
    "fb15k_graph, fb15k_train, fb15k_test, fb15k_valid = fb15k.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarDiGraph: Directed multigraph\n",
      " Nodes: 14951, Edges: 592213\n",
      "\n",
      " Node types:\n",
      "  default: [14951]\n",
      "    Features: none\n",
      "    Edge types: default-/american_football/football_coach/coaching_history./american_football/football_historical_coach_position/position->default, default-/american_football/football_coach/coaching_history./american_football/football_historical_coach_position/team->default, default-/american_football/football_coach_position/coaches_holding_this_position./american_football/football_historical_coach_position/coach->default, default-/american_football/football_coach_position/coaches_holding_this_position./american_football/football_historical_coach_position/team->default, default-/american_football/football_player/current_team./american_football/football_roster_position/position->default, ... (1340 more)\n",
      "\n",
      " Edge types:\n",
      "    default-/award/award_nominee/award_nominations./award/award_nomination/award_nominee->default: [19764]\n",
      "    default-/film/film/release_date_s./film/film_regional_release_date/film_release_region->default: [15837]\n",
      "    default-/award/award_nominee/award_nominations./award/award_nomination/award->default: [14921]\n",
      "    default-/award/award_category/nominees./award/award_nomination/award_nominee->default: [14921]\n",
      "    default-/people/profession/people_with_this_profession->default: [14220]\n",
      "    default-/people/person/profession->default: [14220]\n",
      "    default-/film/film/starring./film/performance/actor->default: [11638]\n",
      "    default-/film/actor/film./film/performance/film->default: [11638]\n",
      "    default-/award/award_nominated_work/award_nominations./award/award_nomination/award->default: [11594]\n",
      "    default-/award/award_category/nominees./award/award_nomination/nominated_for->default: [11594]\n",
      "    default-/award/award_winner/awards_won./award/award_honor/award_winner->default: [10378]\n",
      "    default-/film/film_genre/films_in_this_genre->default: [8946]\n",
      "    default-/film/film/genre->default: [8946]\n",
      "    default-/award/award_nominee/award_nominations./award/award_nomination/nominated_for->default: [7632]\n",
      "    default-/award/award_nominated_work/award_nominations./award/award_nomination/award_nominee->default: [7632]\n",
      "    default-/film/film_job/films_with_this_crew_job./film/film_crew_gig/film->default: [7400]\n",
      "    default-/film/film/other_crew./film/film_crew_gig/film_crew_role->default: [7400]\n",
      "    default-/common/topic/webpage./common/webpage/category->default: [7232]\n",
      "    default-/common/annotation_category/annotations./common/webpage/topic->default: [7232]\n",
      "    default-/music/genre/artists->default: [7229]\n",
      "    ... (1325 more)\n"
     ]
    }
   ],
   "source": [
    "print(fb15k_graph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "### Train a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "fb15k_gen = KGTripleGenerator(\n",
    "    fb15k_graph, batch_size=len(fb15k_train) // 10  # ~100 batches per epoch\n",
    ")\n",
    "\n",
    "fb15k_distmult = DistMult(\n",
    "    fb15k_gen,\n",
    "    embedding_dimension=embedding_dimension,\n",
    "    embeddings_regularizer=regularizers.l2(1e-8),\n",
    ")\n",
    "\n",
    "fb15k_inp, fb15k_out = fb15k_distmult.in_out_tensors()\n",
    "\n",
    "fb15k_model = Model(inputs=fb15k_inp, outputs=fb15k_out)\n",
    "fb15k_model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.001),\n",
    "    loss=losses.BinaryCrossentropy(from_logits=True),\n",
    "    metrics=[metrics.BinaryAccuracy(threshold=0.0)],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "fb15k_train_gen = fb15k_gen.flow(\n",
    "    fb15k_train, negative_samples=negative_samples, shuffle=True\n",
    ")\n",
    "fb15k_valid_gen = fb15k_gen.flow(fb15k_valid, negative_samples=negative_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n"
     ]
    }
   ],
   "source": [
    "fb15k_es = callbacks.EarlyStopping(monitor=\"val_loss\", patience=50)\n",
    "fb15k_history = fb15k_model.fit(\n",
    "    fb15k_train_gen,\n",
    "    validation_data=fb15k_valid_gen,\n",
    "    epochs=epochs,\n",
    "    callbacks=[fb15k_es],\n",
    "    verbose=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "utils.plot_history(fb15k_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "### Evaluate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "33",
   "metadata": {},
   "outputs": [],
   "source": [
    "fb15k_smaller_gen = KGTripleGenerator(fb15k_graph, batch_size=5000)\n",
    "\n",
    "fb15k_raw_ranks, fb15k_filtered_ranks = fb15k_distmult.rank_edges_against_all_nodes(\n",
    "    fb15k_smaller_gen.flow(fb15k_test), fb15k_graph\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>hits at 10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>filtered</td>\n",
       "      <td>0.33919</td>\n",
       "      <td>0.579582</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              mrr  hits at 10\n",
       "filtered  0.33919    0.579582"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarise(fb15k_filtered_ranks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "For comparison, Table 2 in the paper gives the following results for FB15k:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "36",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>hits at 10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>filtered</td>\n",
       "      <td>0.35</td>\n",
       "      <td>0.577</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           mrr  hits at 10\n",
       "filtered  0.35       0.577"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_as_dataframe(0.35, 0.577)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/distmult-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/distmult-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}