stellargraph/stellargraph

View on GitHub
demos/interpretability/gcn-sparse-node-link-importance.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Interpreting nodes and edges with saliency maps in GCN (sparse)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/interpretability/gcn-sparse-node-link-importance.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/interpretability/gcn-sparse-node-link-importance.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "This demo shows how to use integrated gradients in graph convolutional networks to obtain accurate importance estimations for both the nodes and edges. The notebook consists of three parts:\n",
    "- setting up the node classification problem for Cora citation network\n",
    "- training and evaluating a GCN model for node classification\n",
    "- calculating node and edge importances for model's predictions of query (\"target\") nodes\n",
    "\n",
    "<a name=\"refs\"></a>\n",
    "**References**\n",
    "\n",
    "[1] Axiomatic Attribution for Deep Networks. M. Sundararajan, A. Taly, and Q. Yan.\n",
    "    Proceedings of the 34th International Conference on Machine Learning, Sydney, Australia, PMLR 70, 2017\n",
    "    ([link](https://arxiv.org/pdf/1703.01365.pdf)).\n",
    "    \n",
    "[2] Adversarial Examples on Graph Data: Deep Insights into Attack and Defense. H. Wu, C. Wang, Y. Tyshetskiy, A. Docherty, K. Lu, and L. Zhu. arXiv: 1903.01610 ([link](https://arxiv.org/abs/1903.01610)).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import os\n",
    "import time\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import FullBatchNodeGenerator\n",
    "from stellargraph.layer import GCN\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, optimizers, losses, metrics, Model, regularizers\n",
    "from sklearn import preprocessing, feature_extraction, model_selection\n",
    "from copy import deepcopy\n",
    "import matplotlib.pyplot as plt\n",
    "from stellargraph import datasets\n",
    "from IPython.display import display, HTML\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Loading the CORA network"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "tags": [
     "DataLoadingLinks"
    ]
   },
   "source": [
    "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8",
   "metadata": {
    "tags": [
     "DataLoading"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = datasets.Cora()\n",
    "display(HTML(dataset.description))\n",
    "G, subjects = dataset.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "### Splitting the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "For machine learning we want to take a subset of the nodes for training, and use the rest for validation and testing. We'll use scikit-learn again to do this.\n",
    "\n",
    "Here we're taking 140 node labels for training, 500 for validation, and the rest for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_subjects, test_subjects = model_selection.train_test_split(\n",
    "    subjects, train_size=140, test_size=None, stratify=subjects\n",
    ")\n",
    "val_subjects, test_subjects = model_selection.train_test_split(\n",
    "    test_subjects, train_size=500, test_size=None, stratify=test_subjects\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "### Converting to numeric arrays"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "For our categorical target, we will use one-hot vectors that will be fed into a soft-max Keras layer during training. To do this conversion ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_encoding = preprocessing.LabelBinarizer()\n",
    "\n",
    "train_targets = target_encoding.fit_transform(train_subjects)\n",
    "val_targets = target_encoding.transform(val_subjects)\n",
    "test_targets = target_encoding.transform(test_subjects)\n",
    "\n",
    "all_targets = target_encoding.transform(subjects)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "### Creating the GCN model in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "To feed data from the graph to the Keras model we need a generator. Since GCN is a full-batch model, we use the `FullBatchNodeGenerator` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "17",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GCN (local pooling) filters...\n"
     ]
    }
   ],
   "source": [
    "generator = FullBatchNodeGenerator(G, sparse=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "For training we map only the training nodes returned from our splitter and the target values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_gen = generator.flow(train_subjects.index, train_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Now we can specify our machine learning model: tn this example we use two GCN layers with 16-dimensional hidden node features at each layer with ELU activation functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_sizes = [16, 16]\n",
    "gcn = GCN(\n",
    "    layer_sizes=layer_sizes,\n",
    "    activations=[\"elu\", \"elu\"],\n",
    "    generator=generator,\n",
    "    dropout=0.3,\n",
    "    kernel_regularizer=regularizers.l2(5e-4),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expose the input and output tensors of the GCN model for node prediction, via GCN.in_out_tensors() method:\n",
    "x_inp, x_out = gcn.in_out_tensors()\n",
    "# Snap the final estimator layer to x_out\n",
    "x_out = layers.Dense(units=train_targets.shape[1], activation=\"softmax\")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23",
   "metadata": {},
   "source": [
    "### Training the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "Now let's create the actual Keras model with the input tensors `x_inp` and output tensors being the predictions `x_out` from the final dense layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Model(inputs=x_inp, outputs=x_out)\n",
    "\n",
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.01),  # decay=0.001),\n",
    "    loss=losses.categorical_crossentropy,\n",
    "    metrics=[metrics.categorical_accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {},
   "source": [
    "Train the model, keeping track of its loss and accuracy on the training set, and its generalisation performance on the validation set (we need to create another generator over the validation data for this)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_gen = generator.flow(val_subjects.index, val_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "1/1 - 0s - loss: 2.0886 - categorical_accuracy: 0.0786 - val_loss: 1.8735 - val_categorical_accuracy: 0.2860\n",
      "Epoch 2/20\n",
      "1/1 - 0s - loss: 1.8273 - categorical_accuracy: 0.3071 - val_loss: 1.7735 - val_categorical_accuracy: 0.3080\n",
      "Epoch 3/20\n",
      "1/1 - 0s - loss: 1.6816 - categorical_accuracy: 0.3429 - val_loss: 1.7049 - val_categorical_accuracy: 0.3280\n",
      "Epoch 4/20\n",
      "1/1 - 0s - loss: 1.5697 - categorical_accuracy: 0.3714 - val_loss: 1.6350 - val_categorical_accuracy: 0.4240\n",
      "Epoch 5/20\n",
      "1/1 - 0s - loss: 1.4508 - categorical_accuracy: 0.5000 - val_loss: 1.5633 - val_categorical_accuracy: 0.4860\n",
      "Epoch 6/20\n",
      "1/1 - 0s - loss: 1.3410 - categorical_accuracy: 0.5929 - val_loss: 1.4933 - val_categorical_accuracy: 0.5100\n",
      "Epoch 7/20\n",
      "1/1 - 0s - loss: 1.2154 - categorical_accuracy: 0.6714 - val_loss: 1.4239 - val_categorical_accuracy: 0.5420\n",
      "Epoch 8/20\n",
      "1/1 - 0s - loss: 1.1221 - categorical_accuracy: 0.6714 - val_loss: 1.3527 - val_categorical_accuracy: 0.5540\n",
      "Epoch 9/20\n",
      "1/1 - 0s - loss: 1.0248 - categorical_accuracy: 0.7286 - val_loss: 1.2816 - val_categorical_accuracy: 0.5820\n",
      "Epoch 10/20\n",
      "1/1 - 0s - loss: 0.9370 - categorical_accuracy: 0.7429 - val_loss: 1.2150 - val_categorical_accuracy: 0.6100\n",
      "Epoch 11/20\n",
      "1/1 - 0s - loss: 0.8205 - categorical_accuracy: 0.7929 - val_loss: 1.1561 - val_categorical_accuracy: 0.6420\n",
      "Epoch 12/20\n",
      "1/1 - 0s - loss: 0.7672 - categorical_accuracy: 0.8214 - val_loss: 1.1058 - val_categorical_accuracy: 0.6840\n",
      "Epoch 13/20\n",
      "1/1 - 0s - loss: 0.6830 - categorical_accuracy: 0.8500 - val_loss: 1.0636 - val_categorical_accuracy: 0.7120\n",
      "Epoch 14/20\n",
      "1/1 - 0s - loss: 0.6202 - categorical_accuracy: 0.8786 - val_loss: 1.0272 - val_categorical_accuracy: 0.7220\n",
      "Epoch 15/20\n",
      "1/1 - 0s - loss: 0.5606 - categorical_accuracy: 0.9143 - val_loss: 0.9955 - val_categorical_accuracy: 0.7380\n",
      "Epoch 16/20\n",
      "1/1 - 0s - loss: 0.5297 - categorical_accuracy: 0.9071 - val_loss: 0.9688 - val_categorical_accuracy: 0.7560\n",
      "Epoch 17/20\n",
      "1/1 - 0s - loss: 0.4936 - categorical_accuracy: 0.9429 - val_loss: 0.9467 - val_categorical_accuracy: 0.7600\n",
      "Epoch 18/20\n",
      "1/1 - 0s - loss: 0.4496 - categorical_accuracy: 0.9571 - val_loss: 0.9290 - val_categorical_accuracy: 0.7740\n",
      "Epoch 19/20\n",
      "1/1 - 0s - loss: 0.4013 - categorical_accuracy: 0.9643 - val_loss: 0.9158 - val_categorical_accuracy: 0.7780\n",
      "Epoch 20/20\n",
      "1/1 - 0s - loss: 0.3808 - categorical_accuracy: 0.9786 - val_loss: 0.9063 - val_categorical_accuracy: 0.7820\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_gen, shuffle=False, epochs=20, verbose=2, validation_data=val_gen\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "30",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sg.utils.plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "Evaluate the trained model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Metrics:\n",
      "\tloss: 0.9140\n",
      "\tcategorical_accuracy: 0.7843\n"
     ]
    }
   ],
   "source": [
    "test_gen = generator.flow(test_subjects.index, test_targets)\n",
    "test_metrics = model.evaluate(test_gen)\n",
    "print(\"\\nTest Set Metrics:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "## Node and link importance via saliency maps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34",
   "metadata": {},
   "source": [
    "In order to understand why a selected node is predicted as a certain class we want to find the node feature importance, total node importance, and link importance for nodes and edges in the selected node's neighbourhood (ego-net). These importances give information about the effect of changes in the node's features and its neighbourhood on the prediction of the node, specifically:\n",
    "\n",
    "- **Node feature importance**: Given the selected node $t$ and the model's prediction $s(c)$ for class $c$. The feature importance can be calculated for each node $v$ in the selected node's ego-net where the importance of feature $f$ for node $v$ is the change predicted score $s(c)$ for the selected node when the feature $f$ of node $v$ is perturbed.\n",
    "- **Total node importance**: This is defined as the sum of the feature importances for node $v$ for all features. Nodes with high importance (positive or negative) affect the prediction for the selected node more than links with low importance. \n",
    "- **Link importance**: This is defined as the change in the selected node's predicted score $s(c)$ if the link $e=(u, v)$ is removed from the graph. Links with high importance (positive or negative) affect the prediction for the selected node more than links with low importance. \n",
    "\n",
    "Node and link importances can be used to assess the role of nodes and links in model's predictions for the node(s) of interest (the selected node). For datasets like CORA-ML, the features and edges are binary, vanilla gradients may not perform well so we use integrated gradients [[1]](#refs) to compute them.\n",
    "\n",
    "Another interesting application of node and link importances is to identify model vulnerabilities to attacks via perturbing node features and graph structure (see [[2]](#refs))."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "To investigate these importances we use the StellarGraph `saliency_maps` routines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.interpretability.saliency_maps import IntegratedGradients"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {},
   "source": [
    "Select the target node whose prediction is to be interpreted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "38",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_nodes = list(G.nodes())\n",
    "target_nid = 1109199\n",
    "target_idx = graph_nodes.index(target_nid)\n",
    "y_true = all_targets[target_idx]  # true class of the target node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected node id: 1109199, \n",
      "True label: [0 1 0 0 0 0 0], \n",
      "Predicted scores: [0.02 0.77 0.   0.02 0.1  0.01 0.07]\n"
     ]
    }
   ],
   "source": [
    "all_gen = generator.flow(graph_nodes)\n",
    "y_pred = model.predict(all_gen)[0, target_idx]\n",
    "class_of_interest = np.argmax(y_pred)\n",
    "\n",
    "print(\n",
    "    \"Selected node id: {}, \\nTrue label: {}, \\nPredicted scores: {}\".format(\n",
    "        target_nid, y_true, y_pred.round(2)\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Get the node feature importance by using integrated gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "41",
   "metadata": {},
   "outputs": [],
   "source": [
    "int_grad_saliency = IntegratedGradients(model, train_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "For the parameters of `get_node_importance` method, `X` and `A` are the feature and adjacency matrices, respectively. If `sparse` option is enabled, `A` will be the non-zero values of the adjacency matrix with `A_index` being the indices. `target_idx` is the node of interest, and `class_of_interest` is set as the predicted label of the node. `steps` indicates the number of steps used to approximate the integration in integrated gradients calculation. A larger value of `steps` gives better approximation, at the cost of higher computational overhead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "integrated_node_importance = int_grad_saliency.get_node_importance(\n",
    "    target_idx, class_of_interest, steps=50\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "44",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2708,)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "integrated_node_importance.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "integrated_node_importance [0. 0. 0. ... 0. 0. 0.]\n",
      "integrate_node_importance.shape = (2708,)\n",
      "integrated self-importance of target node 1109199: 6.31\n"
     ]
    }
   ],
   "source": [
    "print(\"\\nintegrated_node_importance\", integrated_node_importance.round(2))\n",
    "print(\"integrate_node_importance.shape = {}\".format(integrated_node_importance.shape))\n",
    "print(\n",
    "    \"integrated self-importance of target node {}: {}\".format(\n",
    "        target_nid, integrated_node_importance[target_idx].round(2)\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46",
   "metadata": {},
   "source": [
    "Check that number of non-zero node importance values is less or equal the number of nodes in target node's K-hop ego net (where K is the number of GCN layers in the model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "47",
   "metadata": {},
   "outputs": [],
   "source": [
    "G_ego = nx.ego_graph(G.to_networkx(), target_nid, radius=len(gcn.activations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of nodes in the ego graph: 202\n",
      "Number of non-zero elements in integrated_node_importance: 202\n"
     ]
    }
   ],
   "source": [
    "print(\"Number of nodes in the ego graph: {}\".format(len(G_ego.nodes())))\n",
    "print(\n",
    "    \"Number of non-zero elements in integrated_node_importance: {}\".format(\n",
    "        np.count_nonzero(integrated_node_importance)\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49",
   "metadata": {},
   "source": [
    "We now compute the link importance using integrated gradients [[1]](#refs). Integrated gradients are obtained by accumulating the gradients along the path between the baseline (all-zero graph) and the state of the graph. They provide better sensitivity for the graphs with binary features and edges compared with the vanilla gradients. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "50",
   "metadata": {},
   "outputs": [],
   "source": [
    "integrate_link_importance = int_grad_saliency.get_integrated_link_masks(\n",
    "    target_idx, class_of_interest, steps=50\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "integrate_link_importance.shape = (2708, 2708)\n",
      "Number of non-zero elements in integrate_link_importance: 210\n"
     ]
    }
   ],
   "source": [
    "integrate_link_importance_dense = np.array(integrate_link_importance.todense())\n",
    "print(\"integrate_link_importance.shape = {}\".format(integrate_link_importance.shape))\n",
    "print(\n",
    "    \"Number of non-zero elements in integrate_link_importance: {}\".format(\n",
    "        np.count_nonzero(integrate_link_importance.todense())\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52",
   "metadata": {},
   "source": [
    "We can now find the nodes that have the highest importance to the prediction of the selected node:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 10 most important links by integrated gradients are:\n",
      " [(1544, 1206), (1544, 1544), (1544, 163), (1206, 1206), (1206, 789), (1206, 1544), (1544, 566), (1206, 163), (566, 733), (566, 1544)]\n"
     ]
    }
   ],
   "source": [
    "sorted_indices = np.argsort(integrate_link_importance_dense.flatten())\n",
    "N = len(graph_nodes)\n",
    "integrated_link_importance_rank = [(k // N, k % N) for k in sorted_indices[::-1]]\n",
    "topk = 10\n",
    "# integrate_link_importance = integrate_link_importance_dense\n",
    "print(\n",
    "    \"Top {} most important links by integrated gradients are:\\n {}\".format(\n",
    "        topk, integrated_link_importance_rank[:topk]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the labels as an attribute for the nodes in the graph. The labels are used to color the nodes in different classes.\n",
    "nx.set_node_attributes(G_ego, values={x[0]: {\"subject\": x[1]} for x in subjects.items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55",
   "metadata": {},
   "source": [
    "In the following, we plot the link and node importance (computed by integrated gradients) of the nodes within the ego graph of the target node. \n",
    "\n",
    "For nodes, the shape of the node indicates the positive/negative importance the node has. 'round' nodes have positive importance while 'diamond' nodes have negative importance. The size of the node indicates the value of the importance, e.g., a large diamond node has higher negative importance. \n",
    "\n",
    "For links, the color of the link indicates the positive/negative importance the link has. 'red' links have positive importance while 'blue' links have negative importance. The width of the link indicates the value of the importance, e.g., a thicker blue link has higher negative importance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "56",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.918084517035823"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "integrated_node_importance.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.208803627230227"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "integrate_link_importance.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1080x720 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "node_size_factor = 1e2\n",
    "link_width_factor = 2\n",
    "\n",
    "nodes = list(G_ego.nodes())\n",
    "colors = pd.DataFrame(\n",
    "    [v[1][\"subject\"] for v in G_ego.nodes(data=True)], index=nodes, columns=[\"subject\"]\n",
    ")\n",
    "colors = np.argmax(target_encoding.transform(colors), axis=1) + 1\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(15, 10))\n",
    "pos = nx.spring_layout(G_ego)\n",
    "\n",
    "# Draw ego as large and red\n",
    "node_sizes = [integrated_node_importance[graph_nodes.index(k)] for k in nodes]\n",
    "node_shapes = [\"o\" if w > 0 else \"d\" for w in node_sizes]\n",
    "\n",
    "positive_colors, negative_colors = [], []\n",
    "positive_node_sizes, negative_node_sizes = [], []\n",
    "positive_nodes, negative_nodes = [], []\n",
    "node_size_scale = node_size_factor / np.max(node_sizes)\n",
    "for k in range(len(nodes)):\n",
    "    if nodes[k] == target_idx:\n",
    "        continue\n",
    "    if node_shapes[k] == \"o\":\n",
    "        positive_colors.append(colors[k])\n",
    "        positive_nodes.append(nodes[k])\n",
    "        positive_node_sizes.append(node_size_scale * node_sizes[k])\n",
    "\n",
    "    else:\n",
    "        negative_colors.append(colors[k])\n",
    "        negative_nodes.append(nodes[k])\n",
    "        negative_node_sizes.append(node_size_scale * abs(node_sizes[k]))\n",
    "\n",
    "# Plot the ego network with the node importances\n",
    "cmap = plt.get_cmap(\"jet\", np.max(colors) - np.min(colors) + 1)\n",
    "nc = nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=positive_nodes,\n",
    "    node_color=positive_colors,\n",
    "    cmap=cmap,\n",
    "    node_size=positive_node_sizes,\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    node_shape=\"o\",\n",
    ")\n",
    "nc = nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=negative_nodes,\n",
    "    node_color=negative_colors,\n",
    "    cmap=cmap,\n",
    "    node_size=negative_node_sizes,\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    node_shape=\"d\",\n",
    ")\n",
    "# Draw the target node as a large star colored by its true subject\n",
    "nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=[target_nid],\n",
    "    node_size=50 * abs(node_sizes[nodes.index(target_nid)]),\n",
    "    node_shape=\"*\",\n",
    "    node_color=[colors[nodes.index(target_nid)]],\n",
    "    cmap=cmap,\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    label=\"Target\",\n",
    ")\n",
    "\n",
    "# Draw the edges with the edge importances\n",
    "edges = G_ego.edges()\n",
    "weights = [\n",
    "    integrate_link_importance[graph_nodes.index(u), graph_nodes.index(v)]\n",
    "    for u, v in edges\n",
    "]\n",
    "edge_colors = [\"red\" if w > 0 else \"blue\" for w in weights]\n",
    "weights = link_width_factor * np.abs(weights) / np.max(weights)\n",
    "\n",
    "ec = nx.draw_networkx_edges(G_ego, pos, edge_color=edge_colors, width=weights)\n",
    "plt.legend()\n",
    "plt.colorbar(nc, ticks=np.arange(np.min(colors), np.max(colors) + 1))\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59",
   "metadata": {},
   "source": [
    "We then remove the node or edge in the ego graph one by one and check how the prediction changes. By doing so, we can obtain the ground truth importance of the nodes and edges. Comparing the following figure and the above one can show the effectiveness of integrated gradients as the importance approximations are relatively consistent with the ground truth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "60",
   "metadata": {},
   "outputs": [],
   "source": [
    "(X, _, A_index, A), _ = train_gen[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1080x720 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "X_bk = deepcopy(X)\n",
    "A_bk = deepcopy(A)\n",
    "selected_nodes = np.array([[target_idx]], dtype=\"int32\")\n",
    "nodes = [graph_nodes.index(v) for v in G_ego.nodes()]\n",
    "edges = [(graph_nodes.index(u), graph_nodes.index(v)) for u, v in G_ego.edges()]\n",
    "clean_prediction = model.predict([X, selected_nodes, A_index, A]).squeeze()\n",
    "predict_label = np.argmax(clean_prediction)\n",
    "\n",
    "groud_truth_node_importance = np.zeros((N,))\n",
    "for node in nodes:\n",
    "    # we set all the features of the node to zero to check the ground truth node importance.\n",
    "    X_perturb = deepcopy(X_bk)\n",
    "    X_perturb[:, node, :] = 0\n",
    "    predict_after_perturb = model.predict(\n",
    "        [X_perturb, selected_nodes, A_index, A]\n",
    "    ).squeeze()\n",
    "    groud_truth_node_importance[node] = (\n",
    "        clean_prediction[predict_label] - predict_after_perturb[predict_label]\n",
    "    )\n",
    "\n",
    "node_shapes = [\n",
    "    \"o\" if groud_truth_node_importance[k] > 0 else \"d\" for k in range(len(nodes))\n",
    "]\n",
    "positive_colors, negative_colors = [], []\n",
    "positive_node_sizes, negative_node_sizes = [], []\n",
    "positive_nodes, negative_nodes = [], []\n",
    "# node_size_scale is used for better visulization of nodes\n",
    "node_size_scale = node_size_factor / max(groud_truth_node_importance)\n",
    "\n",
    "for k in range(len(node_shapes)):\n",
    "    if nodes[k] == target_idx:\n",
    "        continue\n",
    "    if node_shapes[k] == \"o\":\n",
    "        positive_colors.append(colors[k])\n",
    "        positive_nodes.append(graph_nodes[nodes[k]])\n",
    "        positive_node_sizes.append(\n",
    "            node_size_scale * groud_truth_node_importance[nodes[k]]\n",
    "        )\n",
    "    else:\n",
    "        negative_colors.append(colors[k])\n",
    "        negative_nodes.append(graph_nodes[nodes[k]])\n",
    "        negative_node_sizes.append(\n",
    "            node_size_scale * abs(groud_truth_node_importance[nodes[k]])\n",
    "        )\n",
    "X = deepcopy(X_bk)\n",
    "groud_truth_edge_importance = np.zeros((N, N))\n",
    "G_edge_indices = [(A_index[0, k, 0], A_index[0, k, 1]) for k in range(A.shape[1])]\n",
    "\n",
    "for edge in edges:\n",
    "    edge_index = G_edge_indices.index((edge[0], edge[1]))\n",
    "    origin_val = A[0, edge_index]\n",
    "\n",
    "    A[0, edge_index] = 0\n",
    "    # we set the weight of a given edge to zero to check the ground truth link importance\n",
    "    predict_after_perturb = model.predict([X, selected_nodes, A_index, A]).squeeze()\n",
    "    groud_truth_edge_importance[edge[0], edge[1]] = (\n",
    "        predict_after_perturb[predict_label] - clean_prediction[predict_label]\n",
    "    ) / (0 - 1)\n",
    "    A[0, edge_index] = origin_val\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(15, 10))\n",
    "cmap = plt.get_cmap(\"jet\", np.max(colors) - np.min(colors) + 1)\n",
    "# Draw the target node as a large star colored by its true subject\n",
    "nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=[target_nid],\n",
    "    node_size=50 * abs(node_sizes[nodes.index(target_idx)]),\n",
    "    node_color=[colors[nodes.index(target_idx)]],\n",
    "    cmap=cmap,\n",
    "    node_shape=\"*\",\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    label=\"Target\",\n",
    ")\n",
    "# Draw the ego net\n",
    "nc = nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=positive_nodes,\n",
    "    node_color=positive_colors,\n",
    "    cmap=cmap,\n",
    "    node_size=positive_node_sizes,\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    node_shape=\"o\",\n",
    ")\n",
    "nc = nx.draw_networkx_nodes(\n",
    "    G_ego,\n",
    "    pos,\n",
    "    nodelist=negative_nodes,\n",
    "    node_color=negative_colors,\n",
    "    cmap=cmap,\n",
    "    node_size=negative_node_sizes,\n",
    "    vmin=np.min(colors) - 0.5,\n",
    "    vmax=np.max(colors) + 0.5,\n",
    "    node_shape=\"d\",\n",
    ")\n",
    "edges = G_ego.edges()\n",
    "weights = [\n",
    "    groud_truth_edge_importance[graph_nodes.index(u), graph_nodes.index(v)]\n",
    "    for u, v in edges\n",
    "]\n",
    "edge_colors = [\"red\" if w > 0 else \"blue\" for w in weights]\n",
    "weights = link_width_factor * np.abs(weights) / np.max(weights)\n",
    "\n",
    "ec = nx.draw_networkx_edges(G_ego, pos, edge_color=edge_colors, width=weights)\n",
    "plt.legend()\n",
    "plt.colorbar(nc, ticks=np.arange(np.min(colors), np.max(colors) + 1))\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62",
   "metadata": {},
   "source": [
    "By comparing the above two figures, one can see that the integrated gradients are quite consistent with the brute-force approach. The main benefit of using integrated gradients is scalability. The gradient operations are very efficient to compute on deep learning frameworks with the parallelism provided by GPUs. Also, integrated gradients can give the importance of individual node features, for all nodes in the graph. Achieving this by brute-force approach is often non-trivial. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/interpretability/gcn-sparse-node-link-importance.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/interpretability/gcn-sparse-node-link-importance.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}