stellargraph/stellargraph

View on GitHub
demos/link-prediction/attri2vec-link-prediction.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Link prediction via inductive node representations with attri2vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/attri2vec-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/attri2vec-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This demo notebook demonstrates how to perform link prediction for out-of-sample nodes through learning node representations inductively with attri2vec [1]. The implementation uses the stellargraph components.\n",
    "\n",
    "<a name=\"refs\"></a>\n",
    "**References:** \n",
    "\n",
    "[1] [Attributed Network Embedding via Subspace Discovery](https://link.springer.com/article/10.1007/s10618-019-00650-2). D. Zhang, Y. Jie, X. Zhu and C. Zhang. Data Mining and Knowledge Discovery, 2019. \n",
    "\n",
    "## attri2vec\n",
    "\n",
    "attri2vec learns node representations by performing a linear/non-linear mapping on node content attributes. To make the learned node representations respect structural similarity, [DeepWalk](https://dl.acm.org/citation.cfm?id=2623732)/[Node2Vec](https://snap.stanford.edu/node2vec) learning mechanism is used to make nodes sharing similar random walk context nodes represented closely in the subspace, which is achieved by maximizing the occurrence probability of context nodes conditioned on the representation of the target nodes. \n",
    "\n",
    "In this demo, we first train the attri2vec model on the in-sample subgraph and obtain a mapping function from node attributes to node representations, then apply the mapping function to the content attributes of out-of-sample nodes and obtain the representations of out-of-sample nodes. We evaluate the quality of inferred out-of-sample node representations by using it to predict the links of out-of-sample nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.data import UnsupervisedSampler\n",
    "from stellargraph.mapper import Attri2VecLinkGenerator, Attri2VecNodeGenerator\n",
    "from stellargraph.layer import Attri2Vec, link_classification\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading DBLP network data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "This demo uses a DBLP citation network, a subgraph extracted from [DBLP-Citation-network V3](https://aminer.org/citation). To form this subgraph, papers from four subjects are extracted according to their venue information: *Database, Data Mining, Artificial Intelligence and Computer Vision*, and papers with no citations are removed. The DBLP network contains 18,448 papers and 45,661 citation relations. From paper titles, we construct 2,476-dimensional binary node feature vectors, with each element indicating the presence/absence of the corresponding word. By ignoring the citation direction, we take the DBLP subgraph as an undirected network.\n",
    "\n",
    "As papers in DBLP are attached with publication year, the DBLP network with the dynamic property can be used to study the problem of out-of-sample node representation learning. From the DBLP network, we construct four in-sample subgraphs using papers published before 2006, 2007, 2008 and 2009, and denote the four subgraphs as DBLP2006, DBLP2007, DBLP2008, and DBLP2009. For each subgraph, the remaining papers are taken as out-of-sample nodes. We consider the case where new coming nodes have no links. We predict the links of out-of-sample nodes using the learned out-of-sample node representations and compare its performance with the node content feature baseline.\n",
    "\n",
    "The dataset used in this demo can be downloaded from https://www.kaggle.com/daozhang/dblp-subgraph.\n",
    "The following is the description of the dataset:\n",
    "\n",
    "> The `content.txt` file contains descriptions of the papers in the following format:\n",
    ">\n",
    ">     <paper_id> <word_attributes> <class_label> <publication_year>\n",
    ">\n",
    "> The first entry in each line contains the unique integer ID (ranging from 0 to 18,447) of the paper followed by binary values indicating whether each word in the vocabulary is present (indicated by 1) or absent (indicated by 0) in the paper. Finally, the last two entries in the line are the class label and the publication year of the paper.\n",
    "> The `edgeList.txt` file contains the citation relations. Each line describes a link in the following format:\n",
    ">\n",
    ">     <ID of paper1> <ID of paper2>\n",
    ">\n",
    "> Each line contains two paper IDs, with paper2 citing paper1 or paper1 citing paper2.\n",
    "\n",
    "\n",
    "Download and unzip the `dblp-subgraph.zip` file to a location on your computer and set the `data_dir` variable to\n",
    "point to the location of the dataset (the `DBLP` directory containing `content.txt` and `edgeList.txt`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"~/data/DBLP\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "Load the graph from the edge list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "edgelist = pd.read_csv(\n",
    "    os.path.join(data_dir, \"edgeList.txt\"),\n",
    "    sep=\"\\t\",\n",
    "    header=None,\n",
    "    names=[\"source\", \"target\"],\n",
    ")\n",
    "edgelist[\"label\"] = \"cites\"  # set the edge type"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "Load paper content features, subjects and publishing years."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_names = [\"w_{}\".format(ii) for ii in range(2476)]\n",
    "node_column_names = feature_names + [\"subject\", \"year\"]\n",
    "node_data = pd.read_csv(\n",
    "    os.path.join(data_dir, \"content.txt\"), sep=\"\\t\", header=None, names=node_column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "Construct the whole graph from edge list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "G_all_nx = nx.from_pandas_edgelist(edgelist, edge_attr=\"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Specify node types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.set_node_attributes(G_all_nx, \"paper\", \"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17",
   "metadata": {},
   "source": [
    "Get node features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_node_features = node_data[feature_names]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "Create the Stellargraph with node features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "G_all = sg.StellarGraph.from_networkx(G_all_nx, node_features=all_node_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NetworkXStellarGraph: Undirected multigraph\n",
      " Nodes: 18448, Edges: 45611\n",
      "\n",
      " Node types:\n",
      "  paper: [18448]\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [45611]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(G_all.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "## Get DBLP Subgraph \n",
    "### with papers published before a threshold year"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "Get the edge list connecting in-sample nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "year_thresh = 2006  # the threshold year for in-sample and out-of-sample set split, which can be 2007, 2008 and 2009\n",
    "subgraph_edgelist = []\n",
    "for ii in range(len(edgelist)):\n",
    "    source_index = edgelist[\"source\"][ii]\n",
    "    target_index = edgelist[\"target\"][ii]\n",
    "    source_year = int(node_data[\"year\"][source_index])\n",
    "    target_year = int(node_data[\"year\"][target_index])\n",
    "    if source_year < year_thresh and target_year < year_thresh:\n",
    "        subgraph_edgelist.append([source_index, target_index])\n",
    "subgraph_edgelist = pd.DataFrame(\n",
    "    np.array(subgraph_edgelist), columns=[\"source\", \"target\"]\n",
    ")\n",
    "subgraph_edgelist[\"label\"] = \"cites\"  # set the edge type"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "Construct the network from the selected edge list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "G_sub_nx = nx.from_pandas_edgelist(subgraph_edgelist, edge_attr=\"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "Specify node types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.set_node_attributes(G_sub_nx, \"paper\", \"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "Get the ids of the nodes in the selected subgraph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "subgraph_node_ids = sorted(list(G_sub_nx.nodes))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "Get the node features of the selected subgraph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "subgraph_node_features = node_data[feature_names].reindex(subgraph_node_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "Create the Stellargraph with node features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "34",
   "metadata": {},
   "outputs": [],
   "source": [
    "G_sub = sg.StellarGraph.from_networkx(G_sub_nx, node_features=subgraph_node_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NetworkXStellarGraph: Undirected multigraph\n",
      " Nodes: 11776, Edges: 28937\n",
      "\n",
      " Node types:\n",
      "  paper: [11776]\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [28937]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(G_sub.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36",
   "metadata": {},
   "source": [
    "## Train attri2vec on the DBLP Subgraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {},
   "source": [
    "Specify the other optional parameter values: root nodes, the number of walks to take per node, the length of each walk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "38",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = list(G_sub.nodes())\n",
    "number_of_walks = 2\n",
    "length = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39",
   "metadata": {},
   "source": [
    "Create the UnsupervisedSampler instance with the relevant parameters passed to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "40",
   "metadata": {},
   "outputs": [],
   "source": [
    "unsupervised_samples = UnsupervisedSampler(\n",
    "    G_sub, nodes=nodes, length=length, number_of_walks=number_of_walks\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41",
   "metadata": {},
   "source": [
    "Set the batch size and the number of epochs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "42",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "epochs = 6"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43",
   "metadata": {},
   "source": [
    "Define an attri2vec training generator, which generates a batch of (feature of target node, index of context node, label of node pair) pairs per iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "44",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Attri2VecLinkGenerator(G_sub, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45",
   "metadata": {},
   "source": [
    "Building the model: a 1-hidden-layer node representation ('input embedding') of the `target` node and the parameter vector ('output embedding') for predicting the existence of `context node` for each `(target context)` pair, with a link classification layer performed on the dot product of the 'input embedding' of the `target` node and the 'output embedding' of the `context` node.\n",
    "\n",
    "attri2vec part of the model, with a 128-dimension hidden layer, no bias term and no normalization. (Normalization can be set to 'l2'). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "46",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_sizes = [128]\n",
    "attri2vec = Attri2Vec(\n",
    "    layer_sizes=layer_sizes, generator=generator, bias=False, normalize=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the model and expose input and output sockets of attri2vec, for node pair inputs:\n",
    "x_inp, x_out = attri2vec.in_out_tensors()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48",
   "metadata": {},
   "source": [
    "Use the link_classification function to generate the prediction, with the `ip` edge embedding generation method and the `sigmoid` activation, which actually performs the dot product of the 'input embedding' of the target node and the 'output embedding' of the context node followed by a sigmoid activation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "49",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "link_classification: using 'ip' method to combine node embeddings into edge embeddings\n"
     ]
    }
   ],
   "source": [
    "prediction = link_classification(\n",
    "    output_dim=1, output_act=\"sigmoid\", edge_embedding_method=\"ip\"\n",
    ")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50",
   "metadata": {},
   "source": [
    "Stack the attri2vec encoder and prediction layer into a Keras model, and specify the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "51",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Model(inputs=x_inp, outputs=prediction)\n",
    "\n",
    "model.compile(\n",
    "    optimizer=keras.optimizers.Adam(lr=1e-2),\n",
    "    loss=keras.losses.binary_crossentropy,\n",
    "    metrics=[keras.metrics.binary_accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52",
   "metadata": {},
   "source": [
    "Train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/6\n",
      "4711/4711 - 194s - loss: 0.7453 - binary_accuracy: 0.5219\n",
      "Epoch 2/6\n",
      "4711/4711 - 194s - loss: 0.6641 - binary_accuracy: 0.5677\n",
      "Epoch 3/6\n",
      "4711/4711 - 201s - loss: 0.5845 - binary_accuracy: 0.6375\n",
      "Epoch 4/6\n",
      "4711/4711 - 195s - loss: 0.5185 - binary_accuracy: 0.7198\n",
      "Epoch 5/6\n",
      "4711/4711 - 194s - loss: 0.4662 - binary_accuracy: 0.7767\n",
      "Epoch 6/6\n",
      "4711/4711 - 194s - loss: 0.4220 - binary_accuracy: 0.8110\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    generator.flow(unsupervised_samples),\n",
    "    epochs=epochs,\n",
    "    verbose=2,\n",
    "    use_multiprocessing=False,\n",
    "    workers=1,\n",
    "    shuffle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54",
   "metadata": {},
   "source": [
    "## Predicting links of out-of-sample nodes with the learned attri2vec model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55",
   "metadata": {},
   "source": [
    "Build the node based model for predicting node representations from node content attributes with the learned parameters. Below a Keras model is constructed, with `x_inp[0]` as input and `x_out[0]` as output. Note that this model's weights are the same as those of the corresponding node encoder in the previously trained node pair classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "56",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_inp_src = x_inp[0]\n",
    "x_out_src = x_out[0]\n",
    "embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57",
   "metadata": {},
   "source": [
    "Get the node embeddings, for both in-sample and out-of-sample nodes, by applying the learned mapping function to node content features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "369/369 [==============================] - 1s 2ms/step\n"
     ]
    }
   ],
   "source": [
    "node_ids = node_data.index\n",
    "node_gen = Attri2VecNodeGenerator(G_all, batch_size).flow(node_ids)\n",
    "node_embeddings = embedding_model.predict(node_gen, workers=4, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59",
   "metadata": {},
   "source": [
    "Get the positive and negative edges for in-sample nodes and out-of-sample nodes. The edges of the in-sample nodes only include the edges between in-sample nodes, and the edges of out-of-sample nodes are referred to all the edges linked to out-of-sample nodes, including the edges connecting in-sample and out-of-sample edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "60",
   "metadata": {},
   "outputs": [],
   "source": [
    "year_thresh = 2006\n",
    "in_sample_edges = []\n",
    "out_of_sample_edges = []\n",
    "for ii in range(len(edgelist)):\n",
    "    source_index = edgelist[\"source\"][ii]\n",
    "    target_index = edgelist[\"target\"][ii]\n",
    "    if source_index > target_index:  # neglect edge direction for the undirected graph\n",
    "        continue\n",
    "    source_year = int(node_data[\"year\"][source_index])\n",
    "    target_year = int(node_data[\"year\"][target_index])\n",
    "    if source_year < year_thresh and target_year < year_thresh:\n",
    "        in_sample_edges.append([source_index, target_index, 1])  # get the positive edge\n",
    "        negative_target_index = unsupervised_samples.random.choices(\n",
    "            node_data.index.tolist(), k=1\n",
    "        )  # generate negative node\n",
    "        in_sample_edges.append(\n",
    "            [source_index, negative_target_index[0], 0]\n",
    "        )  # get the negative edge\n",
    "    else:\n",
    "        out_of_sample_edges.append(\n",
    "            [source_index, target_index, 1]\n",
    "        )  # get the positive edge\n",
    "        negative_target_index = unsupervised_samples.random.choices(\n",
    "            node_data.index.tolist(), k=1\n",
    "        )  # generate negative node\n",
    "        out_of_sample_edges.append(\n",
    "            [source_index, negative_target_index[0], 0]\n",
    "        )  # get the negative edge\n",
    "in_sample_edges = np.array(in_sample_edges)\n",
    "out_of_sample_edges = np.array(out_of_sample_edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61",
   "metadata": {},
   "source": [
    "Construct the edge features from the learned node representations with l2 normed difference, where edge features are the element-wise square of the difference between the embeddings of two head nodes. Other strategy like element-wise product can also be used to construct edge features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "62",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_sample_edge_feat_from_emb = (\n",
    "    node_embeddings[in_sample_edges[:, 0]] - node_embeddings[in_sample_edges[:, 1]]\n",
    ") ** 2\n",
    "out_of_sample_edge_feat_from_emb = (\n",
    "    node_embeddings[out_of_sample_edges[:, 0]]\n",
    "    - node_embeddings[out_of_sample_edges[:, 1]]\n",
    ") ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63",
   "metadata": {},
   "source": [
    "Train the Logistic Regression classifier from in-sample edges with the edge features constructed from attri2vec embeddings. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
       "                   intercept_scaling=1, l1_ratio=None, max_iter=500,\n",
       "                   multi_class='auto', n_jobs=None, penalty='l2',\n",
       "                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,\n",
       "                   warm_start=False)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_edge_pred_from_emb = LogisticRegression(\n",
    "    verbose=0, solver=\"lbfgs\", multi_class=\"auto\", max_iter=500\n",
    ")\n",
    "clf_edge_pred_from_emb.fit(in_sample_edge_feat_from_emb, in_sample_edges[:, 2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65",
   "metadata": {},
   "source": [
    "Predict the edge existence probability with the trained Logistic Regression classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "66",
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_pred_from_emb = clf_edge_pred_from_emb.predict_proba(\n",
    "    out_of_sample_edge_feat_from_emb\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67",
   "metadata": {},
   "source": [
    "Get the positive class index of `edge_pred_from_emb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "68",
   "metadata": {},
   "outputs": [],
   "source": [
    "if clf_edge_pred_from_emb.classes_[0] == 1:\n",
    "    positive_class_index = 0\n",
    "else:\n",
    "    positive_class_index = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69",
   "metadata": {},
   "source": [
    "Evaluate the AUC score for the prediction with attri2vec embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7880958790510728"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(out_of_sample_edges[:, 2], edge_pred_from_emb[:, positive_class_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71",
   "metadata": {},
   "source": [
    "As the baseline, we also investigate the performance of node content features in predicting the edges of out-of-sample nodes. Firstly, we construct edge features from node content features with the same strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "72",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_sample_edge_rep_from_feat = (\n",
    "    node_data[feature_names].values[in_sample_edges[:, 0]]\n",
    "    - node_data[feature_names].values[in_sample_edges[:, 1]]\n",
    ") ** 2\n",
    "out_of_sample_edge_rep_from_feat = (\n",
    "    node_data[feature_names].values[out_of_sample_edges[:, 0]]\n",
    "    - node_data[feature_names].values[out_of_sample_edges[:, 1]]\n",
    ") ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73",
   "metadata": {},
   "source": [
    "Then we train the Logistic Regression classifier from in-sample edges with the edge features constructed from node content features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
       "                   intercept_scaling=1, l1_ratio=None, max_iter=500,\n",
       "                   multi_class='auto', n_jobs=None, penalty='l2',\n",
       "                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,\n",
       "                   warm_start=False)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_edge_pred_from_feat = LogisticRegression(\n",
    "    verbose=0, solver=\"lbfgs\", multi_class=\"auto\", max_iter=500\n",
    ")\n",
    "clf_edge_pred_from_feat.fit(in_sample_edge_rep_from_feat, in_sample_edges[:, 2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75",
   "metadata": {},
   "source": [
    "Predict the edge existence probability with the trained Logistic Regression classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "76",
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_pred_from_feat = clf_edge_pred_from_feat.predict_proba(\n",
    "    out_of_sample_edge_rep_from_feat\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77",
   "metadata": {},
   "source": [
    "Get positive class index of `clf_edge_pred_from_feat`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "78",
   "metadata": {},
   "outputs": [],
   "source": [
    "if clf_edge_pred_from_feat.classes_[0] == 1:\n",
    "    positive_class_index = 0\n",
    "else:\n",
    "    positive_class_index = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79",
   "metadata": {},
   "source": [
    "Evaluate the AUC score for the prediction with node content features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6601881857840772"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(out_of_sample_edges[:, 2], edge_pred_from_feat[:, positive_class_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81",
   "metadata": {},
   "source": [
    "attri2vec can inductively infer the representations of out-of-sample nodes from their content attributes. As the inferred node representations well capture both structure and node content information, they perform much better than node content features in predicting the links of out-of-sample nodes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/link-prediction/attri2vec-link-prediction.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/attri2vec-link-prediction.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}