stellargraph/stellargraph

View on GitHub
demos/link-prediction/hinsage-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Link prediction with Heterogeneous GraphSAGE (HinSAGE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/hinsage-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/hinsage-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "In this example, we use our generalisation of the [GraphSAGE](http://snap.stanford.edu/graphsage/) algorithm to heterogeneous graphs (which we call HinSAGE) to build a model that predicts user-movie ratings in the MovieLens dataset (see below). The problem is treated as a supervised link attribute inference problem on a user-movie network with nodes of two types (users and movies, both attributed) and links corresponding to user-movie ratings, with integer `rating` attributes from 1 to 5 (note that if a user hasn't rated a movie, the corresponding user-movie link does not exist in the network).\n",
    "\n",
    "To address this problem, we build a model with the following architecture: a two-layer HinSAGE model that takes labeled `(user, movie)` node pairs corresponding to user-movie ratings, and outputs a pair of node embeddings for the `user` and `movie` nodes of the pair. These embeddings are then fed into a link regression layer, which applies a binary operator to those node embeddings (e.g., concatenating them) to construct the link embedding. Thus obtained link embeddings are passed through the link regression layer to obtain predicted user-movie ratings. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., root mean square error between predicted and true ratings) using stochastic gradient descent (SGD) updates of the model parameters, with minibatches of user-movie training links fed into the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import HinSAGELinkGenerator\n",
    "from stellargraph.layer import HinSAGE, link_regression\n",
    "from tensorflow.keras import Model, optimizers, losses, metrics\n",
    "\n",
    "import multiprocessing\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "Specify the minibatch size (number of user-movie links per minibatch) and the number of epochs for training the ML model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "batch_size = 200\n",
    "epochs = 20\n",
    "# Use 70% of edges for training, the rest for testing:\n",
    "train_size = 0.7\n",
    "test_size = 0.3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "## Load the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The MovieLens 100K dataset contains 100,000 ratings from 943 users on 1682 movies."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.MovieLens()\n",
    "display(HTML(dataset.description))\n",
    "G, edges_with_ratings = dataset.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 2625, Edges: 100000\n",
      "\n",
      " Node types:\n",
      "  movie: [1682]\n",
      "    Features: float32 vector, length 19\n",
      "    Edge types: movie-rating->user\n",
      "  user: [943]\n",
      "    Features: float32 vector, length 24\n",
      "    Edge types: user-rating->movie\n",
      "\n",
      " Edge types:\n",
      "    user-rating->movie: [100000]\n"
     ]
    }
   ],
   "source": [
    "print(G.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "Split the edges into train and test sets for model training/evaluation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges_train, edges_test = model_selection.train_test_split(\n",
    "    edges_with_ratings, train_size=train_size, test_size=test_size\n",
    ")\n",
    "\n",
    "edgelist_train = list(edges_train[[\"user_id\", \"movie_id\"]].itertuples(index=False))\n",
    "edgelist_test = list(edges_test[[\"user_id\", \"movie_id\"]].itertuples(index=False))\n",
    "\n",
    "labels_train = edges_train[\"rating\"]\n",
    "labels_test = edges_test[\"rating\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "Our machine learning task of learning user-movie ratings can be framed as a supervised Link Attribute Inference: given a graph of user-movie ratings, we train a model for rating prediction using the ratings edges_train, and evaluate it using the test ratings edges_test. The model also requires the user-movie graph structure, to do the neighbour sampling required by the HinSAGE algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "We create the link mappers for sampling and streaming training and testing data to the model. The link mappers essentially \"map\" user-movie links to the input of HinSAGE: they take minibatches of user-movie links, sample 2-hop subgraphs of G with `(user, movie)` head nodes extracted from those user-movie links, and feed them, together with the corresponding user-movie ratings, to the input layer of the HinSAGE model, for SGD updates of the model parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "Specify the sizes of 1- and 2-hop neighbour samples for HinSAGE:\n",
    "\n",
    "Note that the length of `num_samples` list defines the number of layers/iterations in the HinSAGE model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = [8, 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "Create the generators to feed data from the graph to the Keras model. We need to specify the nodes types for the user-movie pairs that we will feed to the model. The `shuffle=True` argument is given to the `flow` method to improve training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = HinSAGELinkGenerator(\n",
    "    G, batch_size, num_samples, head_node_types=[\"user\", \"movie\"]\n",
    ")\n",
    "train_gen = generator.flow(edgelist_train, labels_train, shuffle=True)\n",
    "test_gen = generator.flow(edgelist_test, labels_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Build the model by stacking a two-layer HinSAGE model and a link regression layer on top.\n",
    "\n",
    "First, we define the HinSAGE part of the model, with hidden layer sizes of 32 for both HinSAGE layers, a bias term, and no dropout. (Dropout can be switched on by specifying a positive `dropout` rate, `0 < dropout < 1`)\n",
    "\n",
    "Note that the length of `layer_sizes` list must be equal to the length of `num_samples`, as `len(num_samples)` defines the number of hops (layers) in the HinSAGE model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('user', [2]),\n",
       " ('movie', [3]),\n",
       " ('movie', [4]),\n",
       " ('user', [5]),\n",
       " ('user', []),\n",
       " ('movie', [])]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator.schema.type_adjacency_list(generator.head_node_types, len(num_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'user': [EdgeType(n1='user', rel='rating', n2='movie')],\n",
       " 'movie': [EdgeType(n1='movie', rel='rating', n2='user')]}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator.schema.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "hinsage_layer_sizes = [32, 32]\n",
    "assert len(hinsage_layer_sizes) == len(num_samples)\n",
    "\n",
    "hinsage = HinSAGE(\n",
    "    layer_sizes=hinsage_layer_sizes, generator=generator, bias=True, dropout=0.0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expose input and output sockets of hinsage:\n",
    "x_inp, x_out = hinsage.in_out_tensors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "Add the final estimator layer for predicting the ratings. The edge_embedding_method argument specifies the way in which node representations (node embeddings) are combined into link representations (recall that links represent user-movie ratings, and are thus pairs of (user, movie) nodes). In this example, we will use `concat`, i.e., node embeddings are concatenated to get link embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "link_regression: using 'concat' method to combine node embeddings into edge embeddings\n"
     ]
    }
   ],
   "source": [
    "# Final estimator layer\n",
    "score_prediction = link_regression(edge_embedding_method=\"concat\")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "Create the Keras model, and compile it by specifying the optimizer, loss function to optimise, and metrics for diagnostics:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow.keras.backend as K\n",
    "\n",
    "\n",
    "def root_mean_square_error(s_true, s_pred):\n",
    "    return K.sqrt(K.mean(K.pow(s_true - s_pred, 2)))\n",
    "\n",
    "\n",
    "model = Model(inputs=x_inp, outputs=score_prediction)\n",
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=1e-2),\n",
    "    loss=losses.mean_squared_error,\n",
    "    metrics=[root_mean_square_error, metrics.mae],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "Summary of the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_3 (InputLayer)            [(None, 8, 19)]      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_5 (InputLayer)            [(None, 32, 24)]     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_6 (InputLayer)            [(None, 32, 19)]     0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_1 (InputLayer)            [(None, 1, 24)]      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "reshape (Reshape)               (None, 1, 8, 19)     0           input_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "reshape_2 (Reshape)             (None, 8, 4, 24)     0           input_5[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "input_4 (InputLayer)            [(None, 8, 24)]      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "reshape_3 (Reshape)             (None, 8, 4, 19)     0           input_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 1, 24)        0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout (Dropout)               (None, 1, 8, 19)     0           reshape[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_5 (Dropout)             (None, 8, 19)        0           input_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_4 (Dropout)             (None, 8, 4, 24)     0           reshape_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "input_2 (InputLayer)            [(None, 1, 19)]      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "reshape_1 (Reshape)             (None, 1, 8, 24)     0           input_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_7 (Dropout)             (None, 8, 24)        0           input_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_6 (Dropout)             (None, 8, 4, 19)     0           reshape_3[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "mean_hin_aggregator (MeanHinAgg multiple             720         dropout_1[0][0]                  \n",
      "                                                                 dropout[0][0]                    \n",
      "                                                                 dropout_7[0][0]                  \n",
      "                                                                 dropout_6[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "mean_hin_aggregator_1 (MeanHinA multiple             720         dropout_3[0][0]                  \n",
      "                                                                 dropout_2[0][0]                  \n",
      "                                                                 dropout_5[0][0]                  \n",
      "                                                                 dropout_4[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_3 (Dropout)             (None, 1, 19)        0           input_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_2 (Dropout)             (None, 1, 8, 24)     0           reshape_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "reshape_4 (Reshape)             (None, 1, 8, 32)     0           mean_hin_aggregator_1[1][0]      \n",
      "__________________________________________________________________________________________________\n",
      "reshape_5 (Reshape)             (None, 1, 8, 32)     0           mean_hin_aggregator[1][0]        \n",
      "__________________________________________________________________________________________________\n",
      "dropout_9 (Dropout)             (None, 1, 32)        0           mean_hin_aggregator[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "dropout_8 (Dropout)             (None, 1, 8, 32)     0           reshape_4[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_11 (Dropout)            (None, 1, 32)        0           mean_hin_aggregator_1[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "dropout_10 (Dropout)            (None, 1, 8, 32)     0           reshape_5[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "mean_hin_aggregator_2 (MeanHinA (None, 1, 32)        1056        dropout_9[0][0]                  \n",
      "                                                                 dropout_8[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "mean_hin_aggregator_3 (MeanHinA (None, 1, 32)        1056        dropout_11[0][0]                 \n",
      "                                                                 dropout_10[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "reshape_6 (Reshape)             (None, 32)           0           mean_hin_aggregator_2[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "reshape_7 (Reshape)             (None, 32)           0           mean_hin_aggregator_3[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "lambda (Lambda)                 (None, 32)           0           reshape_6[0][0]                  \n",
      "                                                                 reshape_7[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "link_embedding (LinkEmbedding)  (None, 64)           0           lambda[0][0]                     \n",
      "                                                                 lambda[1][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, 1)            65          link_embedding[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "reshape_8 (Reshape)             (None, 1)            0           dense[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 3,617\n",
      "Trainable params: 3,617\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the number of workers to use for model training\n",
    "num_workers = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "Evaluate the fresh (untrained) model on the test set (for reference):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "150/150 [==============================] - 25s 165ms/step - loss: 11.8834 - root_mean_square_error: 3.4467 - mean_absolute_error: 3.256210s - loss: 11.8462 - root_mean_square_erro\n",
      "Untrained model's Test Evaluation:\n",
      "\tloss: 11.8834\n",
      "\troot_mean_square_error: 3.4467\n",
      "\tmean_absolute_error: 3.2562\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(\n",
    "    test_gen, verbose=1, use_multiprocessing=False, workers=num_workers\n",
    ")\n",
    "\n",
    "print(\"Untrained model's Test Evaluation:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34",
   "metadata": {},
   "source": [
    "Train the model by feeding the data from the graph in minibatches, using mapper_train, and get validation metrics after each epoch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "  ['...']\n",
      "Train for 350 steps, validate for 150 steps\n",
      "Epoch 1/20\n",
      "350/350 [==============================] - 87s 249ms/step - loss: 1.3311 - root_mean_square_error: 1.1321 - mean_absolute_error: 0.9363 - val_loss: 1.1607 - val_root_mean_square_error: 1.0759 - val_mean_absolute_error: 0.8687error: 1.13\n",
      "Epoch 2/20\n",
      "350/350 [==============================] - 85s 244ms/step - loss: 1.1525 - root_mean_square_error: 1.0724 - mean_absolute_error: 0.8711 - val_loss: 1.1315 - val_root_mean_square_error: 1.0624 - val_mean_absolute_error: 0.8693s -  - ETA: 2\n",
      "Epoch 3/20\n",
      "350/350 [==============================] - 90s 256ms/step - loss: 1.1357 - root_mean_square_error: 1.0645 - mean_absolute_error: 0.8626 - val_loss: 1.1141 - val_root_mean_square_error: 1.0542 - val_mean_absolute_error: 0.8612 loss: 1.1321 - root_mean_square_error: 1.0628 - mean_absolute_error: - ETA: 10s - loss: \n",
      "Epoch 4/20\n",
      "350/350 [==============================] - 83s 236ms/step - loss: 1.1235 - root_mean_square_error: 1.0588 - mean_absolute_error: 0.8566 - val_loss: 1.1035 - val_root_mean_square_error: 1.0489 - val_mean_absolute_error: 0.8432\n",
      "Epoch 5/20\n",
      "350/350 [==============================] - 89s 255ms/step - loss: 1.1148 - root_mean_square_error: 1.0546 - mean_absolute_error: 0.8526 - val_loss: 1.1057 - val_root_mean_square_error: 1.0500 - val_mean_absolute_error: 0.8372are_error: 1.05 - ETA: 21s - loss: 1.1116 - root_mean_square_error: 1.0531 - mean_abso - E - ETA: 4s - loss: 1.1130 - root_mean_square_error: 1.0537 - mean_abso\n",
      "Epoch 6/20\n",
      "350/350 [==============================] - 85s 243ms/step - loss: 1.1123 - root_mean_square_error: 1.0533 - mean_absolute_error: 0.8517 - val_loss: 1.0983 - val_root_mean_square_error: 1.0466 - val_mean_absolute_error: 0.8461\n",
      "Epoch 7/20\n",
      "350/350 [==============================] - 100s 286ms/step - loss: 1.1096 - root_mean_square_error: 1.0523 - mean_absolute_error: 0.8498 - val_loss: 1.0966 - val_root_mean_square_error: 1.0458 - val_mean_absolute_error: 0.8538ot_me - ETA: 39s - loss: 1.1068 - root_mean_square_error: 1.0510 - mean_absolu - E - ETA: 18s - loss: 1.1056 - root_mean_square_error: 1.0504 - mean_absolu - ETA: 14s - loss: 1.1078 - ETA: 6s - loss: 1.1085 - root_mean_square_e\n",
      "Epoch 8/20\n",
      "350/350 [==============================] - 91s 261ms/step - loss: 1.1049 - root_mean_square_error: 1.0502 - mean_absolute_error: 0.8478 - val_loss: 1.0947 - val_root_mean_square_error: 1.0449 - val_mean_absolute_error: 0.8489 1:44 - loss: 1.0948 - root_mean_square_error: 1.0454 - mean_absolute_error: 0. - ETA: 1:42 - loss: 1.0966 - root_mean_square_error: 1.04 - ETA: 1:32 - loss: 1.0876 - root_mean_square_error: 1.0419 - mean_absolute_error - ETA: 1:30 - loss: 1.0922 - root_mean_square_error: 1.04 - ETA: 1:19 - loss: 1.0931 - root_mean_squa - ETA: 1:05 - loss: 1.0975 - root_mean_square_error: 1. - ETA: 54s - loss: 1.0982 - root_mean_square_error - ETA: 40s - loss: 1.1085 - root_mean_square_e - ETA: 27s - loss: 1.1068 - root_mean_square_error: 1.0511 - mean_absolute_error:  - ETA: 25s - loss: 1.1083 - root - ETA: 12s - loss: \n",
      "Epoch 9/20\n",
      "350/350 [==============================] - 79s 225ms/step - loss: 1.1031 - root_mean_square_error: 1.0490 - mean_absolute_error: 0.8473 - val_loss: 1.0926 - val_root_mean_square_error: 1.0439 - val_mean_absolute_error: 0.8503\n",
      "Epoch 10/20\n",
      "350/350 [==============================] - 79s 225ms/step - loss: 1.0967 - root_mean_square_error: 1.0461 - mean_absolute_error: 0.8439 - val_loss: 1.0842 - val_root_mean_square_error: 1.0399 - val_mean_absolute_error: 0.8402root_mean_squa - ETA: 28s - loss: 1.0982 - root_mean_square_error: 1.0467 - mean_absolute_err - ETA: 26s - loss: 1  - ETA: 3s - loss: 1.0969 - root_mean_square_error: 1.0461 - mean_absolu\n",
      "Epoch 11/20\n",
      "350/350 [==============================] - 79s 226ms/step - loss: 1.0955 - root_mean_square_error: 1.0455 - mean_absolute_error: 0.8437 - val_loss: 1.0832 - val_root_mean_square_error: 1.0394 - val_mean_absolute_error: 0.8402\n",
      "Epoch 12/20\n",
      "350/350 [==============================] - 104s 298ms/step - loss: 1.0948 - root_mean_square_error: 1.0453 - mean_absolute_error: 0.8426 - val_loss: 1.0866 - val_root_mean_square_error: 1.0411 - val_mean_absolute_error: 0.8448square_error: 1.0466 - m - ETA: 6s - loss: 1.0970 - root_mean_square_error: 1.0463 - mean_ - ETA: 4s - loss: 1.0953 - root_mean_square_error: 1.0455 - mean_absolu\n",
      "Epoch 13/20\n",
      "350/350 [==============================] - 132s 377ms/step - loss: 1.0941 - root_mean_square_error: 1.0448 - mean_absolute_error: 0.8431 - val_loss: 1.0867 - val_root_mean_square_error: 1.0410 - val_mean_absolute_error: 0.8492 loss: 1.0591 - root_mean_square_error: 1.0280 - mean_absolute_error:  - ETA: 3:08 - loss: 1.0806 - root_mean_square_error: 1.0383 - mean_absolu - ETA: 2:31 - loss: 1.0934 - root_mean_square_error: 1.0448 - mean_absolute_error:  - ETA: 2:22 - loss: 1.093 - ETA: 1:38 - loss: 1.0824 - root_mean_square_error: 1.03 - ETA: 1:27 - loss: 1.0839 - root_mean_square_error: 1.0401 - ETA: 1:20 - loss: 1.0836 - root_mean_square_error - ETA: 51s - loss: 1.0905 - root_mean_square_error: 1.0432 - me - ETA: 44s - loss: 1.0910 - root_mean_square_error: 1.0433 - mean_absolute_error: 0. - ETA: 42s - loss: 1.0926 - root_mean_squa - ETA: 29s - loss: 1.0936 - root_mean_square_error: 1.0445 - mean - ETA: 22s - loss: 1.0910 - root_mean_square_error: 1.0433 - mean_absolute_error: 0.842 - ETA: 22s - loss: 1.0905 - root_mean_square_error: 1.0430 - mean_absolute_error:  - ETA: 20s - loss: 1.0901 - root_mean_square_error: 1.0429 - mean_absolute_error: 0. - ETA: 19s - loss: 1.0899 - root_mean_square_error: 1 - ETA: 9s - loss: 1.0915 - root_mean_square_\n",
      "Epoch 14/20\n",
      "350/350 [==============================] - 109s 311ms/step - loss: 1.0895 - root_mean_square_error: 1.0426 - mean_absolute_error: 0.8408 - val_loss: 1.0872 - val_root_mean_square_error: 1.0411 - val_mean_absolute_error: 0.8332square_error:  - ETA: 1:02 - loss: 1.0780 - root_mean_square_error: 1.0371 - mean_absolute_error: 0. - ETA: 1:02 - loss: 1.0764 - root_mean_square_error: 1.0364 - mean_a - ETA: 58s - loss: 1.0809 - root_mean_square_error: 1.0386 - mean_abso - ETA: 53s - loss: 1.0807 - root_mean_square_error: 1.0384 - m - ETA: 46s - loss: 1.0810 - r - ETA: 26s - loss: 1.0835 - root_mean_square_ - ETA: 14s - loss: 1.0863 - root_mean_square_error: 1.0410 - mean_absolute_error: 0 - ETA: 13s -\n",
      "Epoch 15/20\n",
      "350/350 [==============================] - 100s 286ms/step - loss: 1.0880 - root_mean_square_error: 1.0419 - mean_absolute_error: 0.8412 - val_loss: 1.0884 - val_root_mean_square_error: 1.0418 - val_mean_absolute_error: 0.8358absolute_error: 0.84 - ETA: 28s - loss: 1.0879 - root_mean_square_error: 1.0419 - mean_absolute_error: 0.84 - ETA: 28s - loss: 1.0876 -  - ETA: 11s - loss: 1.0885 - root_mean_square_error: 1.0421 - mean_absolute_error: 0.8 - ETA: 11s - loss: 1.0877 - root_mean_square_error: 1.0417 - mean_absol - ETA: 8s - loss: 1.0888 - root_mean_square_err\n",
      "Epoch 16/20\n",
      "350/350 [==============================] - 82s 234ms/step - loss: 1.0857 - root_mean_square_error: 1.0409 - mean_absolute_error: 0.8401 - val_loss: 1.0833 - val_root_mean_square_error: 1.0392 - val_mean_absolute_error: 0.8332\n",
      "Epoch 17/20\n",
      "350/350 [==============================] - 87s 247ms/step - loss: 1.0837 - root_mean_square_error: 1.0398 - mean_absolute_error: 0.8387 - val_loss: 1.0744 - val_root_mean_square_error: 1.0351 - val_mean_absolute_error: 0.8348\n",
      "Epoch 18/20\n",
      "350/350 [==============================] - 88s 251ms/step - loss: 1.0840 - root_mean_square_error: 1.0400 - mean_absolute_error: 0.8389 - val_loss: 1.0878 - val_root_mean_square_error: 1.0416 - val_mean_absolute_error: 0.8527_square_error: 1.0385 - mean_absolute_erro - ETA: 17s - - ETA: 5s - loss: 1.0813 - root_mean_square_error: 1.0387 -\n",
      "Epoch 19/20\n",
      "350/350 [==============================] - 83s 236ms/step - loss: 1.0798 - root_mean_square_error: 1.0380 - mean_absolute_error: 0.8371 - val_loss: 1.0728 - val_root_mean_square_error: 1.0344 - val_mean_absolute_error: 0.8424\n",
      "Epoch 20/20\n",
      "350/350 [==============================] - 79s 227ms/step - loss: 1.0815 - root_mean_square_error: 1.0386 - mean_absolute_error: 0.8387 - val_loss: 1.0680 - val_root_mean_square_error: 1.0319 - val_mean_absolute_error: 0.83024s - loss: 1.0764 - root_mean_square_error: 1.0362 - me - ETA: 38s - loss: 1.0777 - root_\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_gen,\n",
    "    validation_data=test_gen,\n",
    "    epochs=epochs,\n",
    "    verbose=1,\n",
    "    shuffle=False,\n",
    "    use_multiprocessing=False,\n",
    "    workers=num_workers,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36",
   "metadata": {},
   "source": [
    "Plot the training history:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x864 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38",
   "metadata": {},
   "source": [
    "Evaluate the trained model on test user-movie rankings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  ['...']\n",
      "150/150 [==============================] - 23s 152ms/step - loss: 1.0738 - root_mean_square_error: 1.0347 - mean_absolute_error: 0.8325 44s - loss: 1.0911 - root_mean_square\n",
      "Test Evaluation:\n",
      "\tloss: 1.0738\n",
      "\troot_mean_square_error: 1.0347\n",
      "\tmean_absolute_error: 0.8325\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(\n",
    "    test_gen, use_multiprocessing=False, workers=num_workers, verbose=1\n",
    ")\n",
    "\n",
    "print(\"Test Evaluation:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Compare the predicted test rankings with \"mean baseline\" rankings, to see how much better our model does compared to this (very simplistic) baseline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Baseline Test set metrics:\n",
      "\troot_mean_square_error =  1.1250835578845824\n",
      "\tmean_absolute_error =  0.9442556604067485\n",
      "\n",
      "Model Test set metrics:\n",
      "\troot_mean_square_error =  1.0363648722884464\n",
      "\tmean_absolute_error =  0.8332940764943758\n"
     ]
    }
   ],
   "source": [
    "y_true = labels_test\n",
    "# Predict the rankings using the model:\n",
    "y_pred = model.predict(test_gen)\n",
    "# Mean baseline rankings = mean movie ranking:\n",
    "y_pred_baseline = np.full_like(y_pred, np.mean(y_true))\n",
    "\n",
    "rmse = np.sqrt(mean_squared_error(y_true, y_pred_baseline))\n",
    "mae = mean_absolute_error(y_true, y_pred_baseline)\n",
    "print(\"Mean Baseline Test set metrics:\")\n",
    "print(\"\\troot_mean_square_error = \", rmse)\n",
    "print(\"\\tmean_absolute_error = \", mae)\n",
    "\n",
    "rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
    "mae = mean_absolute_error(y_true, y_pred)\n",
    "print(\"\\nModel Test set metrics:\")\n",
    "print(\"\\troot_mean_square_error = \", rmse)\n",
    "print(\"\\tmean_absolute_error = \", mae)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "Compare the distributions of predicted and true rankings for the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "h_true = plt.hist(y_true, bins=30, facecolor=\"green\", alpha=0.5)\n",
    "h_pred = plt.hist(y_pred, bins=30, facecolor=\"blue\", alpha=0.5)\n",
    "plt.xlabel(\"ranking\")\n",
    "plt.ylabel(\"count\")\n",
    "plt.legend((\"True\", \"Predicted\"))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44",
   "metadata": {},
   "source": [
    "We see that our model beats the \"mean baseline\" by a significant margin. To further improve the model, you can try increasing the number of training epochs, change the dropout rate, change the sample sizes for subgraph sampling `num_samples`, hidden layer sizes `layer_sizes` of the HinSAGE part of the model, or try increasing the number of HinSAGE layers.\n",
    "\n",
    "However, note that the distribution of predicted scores is still very narrow, and rarely gives 1, 2 or 5 as a score."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45",
   "metadata": {},
   "source": [
    "This model uses a bipartite user-movie graph to learn to predict movie ratings. It can be further enhanced by using additional relations, e.g., friendships between users, if they become available. And the best part is: the underlying algorithm of the model does not need to change at all to take these extra relations into account - all that changes is the graph that it learns from!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/hinsage-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/hinsage-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}