stellargraph/stellargraph

View on GitHub
demos/interpretability/hateful-twitters-interpretability.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Intepretability on Hateful Twitter Datasets\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/interpretability/hateful-twitters-interpretability.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/interpretability/hateful-twitters-interpretability.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "In this demo, we apply saliency maps (with support of sparse tensors) on the task on the detection of Twitter users who use hateful lexicon using graph machine learning with Stellargraph.\n",
    "\n",
    "We consider the use-case of identifying hateful users on Twitter motivated by the work in [1] and using the dataset also published in [1]. Classification is based on a graph based on users' retweets and attributes as related to their account activity, and the content of tweets.\n",
    "\n",
    "We pose identifying hateful users as a binary classification problem. We demonstrate the advantage of connected vs unconnected data in a semi-supervised setting with few training examples.\n",
    "\n",
    "For connected data, we use Graph Convolutional Networks [2] as implemented in the `stellargraph` library. We pose the problem of identifying hateful tweeter users as node attribute inference in graphs.\n",
    "\n",
    "We then use the interpretability tool (i.e., saliency maps) implemented in our library to demonstrate how to obtain the importance of the node features and links to gain insights into the model.\n",
    "\n",
    "**References**\n",
    "\n",
    "1. \"Like Sheep Among Wolves\": Characterizing Hateful Users on Twitter. M. H. Ribeiro, P. H. Calais, Y. A. Santos, V. A. F. Almeida, and W. Meira Jr.  arXiv preprint arXiv:1801.00317 (2017).\n",
    "\n",
    "\n",
    "2. Semi-Supervised Classification with Graph Convolutional Networks. T. Kipf, M. Welling. ICLR 2017. arXiv:1609.02907 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.3.0b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.3.0b\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.3.0b, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import itertools\n",
    "import os\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.linear_model import LogisticRegressionCV\n",
    "\n",
    "import stellargraph as sg\n",
    "from stellargraph.mapper import GraphSAGENodeGenerator, FullBatchNodeGenerator\n",
    "from stellargraph.layer import GraphSAGE, GCN, GAT\n",
    "from stellargraph import globalvar\n",
    "\n",
    "from tensorflow.keras import layers, optimizers, losses, metrics, Model, models\n",
    "from sklearn import preprocessing, feature_extraction\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import metrics\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.sparse import csr_matrix, lil_matrix\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Train GCN model on the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = os.path.expanduser(\"~/data/hateful-twitter-users\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "### First load and prepare the node features\n",
    "\n",
    "Each node in the graph is associated with a large number of features (also referred to as attributes). \n",
    "\n",
    "The list of features is given [here](https://www.kaggle.com/manoelribeiro/hateful-users-on-twitter). We repeated here for convenience.\n",
    "\n",
    "hate :(\"hateful\"|\"normal\"|\"other\")\n",
    "  if user was annotated as hateful, normal, or not annotated.\n",
    "  \n",
    "  (is_50|is_50_2) :bool\n",
    "  whether user was deleted up to 12/12/17 or 14/01/18. \n",
    "  \n",
    "  (is_63|is_63_2) :bool\n",
    "  whether user was suspended up to 12/12/17 or 14/01/18. \n",
    "        \n",
    "  (hate|normal)_neigh :bool\n",
    "  is the user on the neighborhood of a (hateful|normal) user? \n",
    "  \n",
    "  [c_] (statuses|follower|followees|favorites)_count :int\n",
    "  number of (tweets|follower|followees|favorites) a user has.\n",
    "  \n",
    "  [c_] listed_count:int\n",
    "  number of lists a user is in.\n",
    "\n",
    "  [c_] (betweenness|eigenvector|in_degree|outdegree) :float\n",
    "  centrality measurements for each user in the retweet graph.\n",
    "  \n",
    "  [c_] *_empath :float\n",
    "  occurrences of empath categories in the users latest 200 tweets.\n",
    "\n",
    "  [c_] *_glove :float          \n",
    "  glove vector calculated for users latest 200 tweets.\n",
    "  \n",
    "  [c_] (sentiment|subjectivity) :float\n",
    "  average sentiment and subjectivity of users tweets.\n",
    "  \n",
    "  [c_] (time_diff|time_diff_median) :float\n",
    "  average and median time difference between tweets.\n",
    "  \n",
    "  [c_] (tweet|retweet|quote) number :float\n",
    "  percentage of direct tweets, retweets and quotes of an user.\n",
    "  \n",
    "  [c_] (number urls|number hashtags|baddies|mentions) :float\n",
    "  number of bad words|mentions|urls|hashtags per tweet in average.\n",
    "  \n",
    "  [c_] status length :float\n",
    "  average status length.\n",
    "  \n",
    "  hashtags :string\n",
    "  all hashtags employed by the user separated by spaces.\n",
    "  \n",
    "**Notice** that c_ are attributes calculated for the 1-neighborhood of a user in the retweet network (averaged out)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "First, we are going to load the user features and prepare them for machine learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "users_feat = pd.read_csv(os.path.join(data_dir, \"users_neighborhood_anon.csv\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "### Data cleaning and preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "The dataset as given includes a large number of graph related features that are manually extracted. \n",
    "\n",
    "Since we are going to employ modern graph neural networks methods for classification, we are going to drop these manually engineered features. \n",
    "\n",
    "The power of Graph Neural Networks stems from their ability to learn useful graph-related features eliminating the need for manual feature engineering."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_cleaning(feat):\n",
    "    feat = feat.drop(columns=[\"hate_neigh\", \"normal_neigh\"])\n",
    "\n",
    "    # Convert target values in hate column from strings to integers (0,1,2)\n",
    "    feat[\"hate\"] = np.where(\n",
    "        feat[\"hate\"] == \"hateful\", 1, np.where(feat[\"hate\"] == \"normal\", 0, 2)\n",
    "    )\n",
    "\n",
    "    # missing information\n",
    "    number_of_missing = feat.isnull().sum()\n",
    "    number_of_missing[number_of_missing != 0]\n",
    "\n",
    "    # Replace NA with 0\n",
    "    feat.fillna(0, inplace=True)\n",
    "\n",
    "    # droping info about suspension and deletion as it is should not be use din the predictive model\n",
    "    feat.drop(\n",
    "        feat.columns[feat.columns.str.contains(\"is_|_glove|c_|sentiment\")],\n",
    "        axis=1,\n",
    "        inplace=True,\n",
    "    )\n",
    "\n",
    "    # drop hashtag feature\n",
    "    feat.drop([\"hashtags\"], axis=1, inplace=True)\n",
    "\n",
    "    # Drop centrality based measures\n",
    "    feat.drop(\n",
    "        columns=[\"betweenness\", \"eigenvector\", \"in_degree\", \"out_degree\"], inplace=True\n",
    "    )\n",
    "\n",
    "    feat.drop(columns=[\"created_at\"], inplace=True)\n",
    "\n",
    "    return feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_data = data_cleaning(users_feat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "The continous features in our dataset have distributions with very long tails. We apply normalization to correct for this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ignore the first two columns because those are user_id and hate (the target variable)\n",
    "df_values = node_data.iloc[:, 2:].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "pt = preprocessing.PowerTransformer(method=\"yeo-johnson\", standardize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_values_log = pt.fit_transform(df_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_data.iloc[:, 2:] = df_values_log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the dataframe index to be the same as the user_id and drop the user_id columns\n",
    "node_data.index = node_data.index.map(str)\n",
    "node_data.drop(columns=[\"user_id\"], inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "### Next load the graph\n",
    "\n",
    "Now that we have the node features prepared for machine learning, let us load the retweet graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_nx = nx.read_edgelist(path=os.path.expanduser(os.path.join(data_dir, \"users.edges\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100386, 2194979)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g_nx.number_of_nodes(), g_nx.number_of_edges()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "The graph has just over 100k nodes and approximately 2.2m edges.\n",
    "\n",
    "We aim to train a graph neural network model that will predict the \"hate\"attribute on the nodes.\n",
    "\n",
    "For computation convenience, we have mapped the target labels **normal**, **hateful**, and **other** to the numeric values **0**, **1**, and **2** respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "25",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0, 1, 2}\n"
     ]
    }
   ],
   "source": [
    "print(set(node_data[\"hate\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "26",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>hate</th>\n",
       "      <th>statuses_count</th>\n",
       "      <th>followers_count</th>\n",
       "      <th>followees_count</th>\n",
       "      <th>favorites_count</th>\n",
       "      <th>listed_count</th>\n",
       "      <th>negotiate_empath</th>\n",
       "      <th>vehicle_empath</th>\n",
       "      <th>science_empath</th>\n",
       "      <th>timidity_empath</th>\n",
       "      <th>...</th>\n",
       "      <th>number hashtags</th>\n",
       "      <th>tweet number</th>\n",
       "      <th>retweet number</th>\n",
       "      <th>quote number</th>\n",
       "      <th>status length</th>\n",
       "      <th>number urls</th>\n",
       "      <th>baddies</th>\n",
       "      <th>mentions</th>\n",
       "      <th>time_diff</th>\n",
       "      <th>time_diff_median</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10999</td>\n",
       "      <td>2</td>\n",
       "      <td>0.651057</td>\n",
       "      <td>-0.228440</td>\n",
       "      <td>0.539018</td>\n",
       "      <td>1.468664</td>\n",
       "      <td>0.319936</td>\n",
       "      <td>0.060148</td>\n",
       "      <td>-1.573040</td>\n",
       "      <td>0.468232</td>\n",
       "      <td>-0.446347</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.347727</td>\n",
       "      <td>-0.087181</td>\n",
       "      <td>0.355153</td>\n",
       "      <td>1.193070</td>\n",
       "      <td>0.010627</td>\n",
       "      <td>0.314380</td>\n",
       "      <td>0.581937</td>\n",
       "      <td>0.017239</td>\n",
       "      <td>-0.772738</td>\n",
       "      <td>-0.713314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55317</td>\n",
       "      <td>2</td>\n",
       "      <td>0.527130</td>\n",
       "      <td>0.159289</td>\n",
       "      <td>0.603327</td>\n",
       "      <td>0.116831</td>\n",
       "      <td>0.400391</td>\n",
       "      <td>-0.170600</td>\n",
       "      <td>0.731748</td>\n",
       "      <td>-0.155481</td>\n",
       "      <td>0.487008</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.159648</td>\n",
       "      <td>0.863400</td>\n",
       "      <td>-0.628442</td>\n",
       "      <td>1.058797</td>\n",
       "      <td>-0.400813</td>\n",
       "      <td>-0.034034</td>\n",
       "      <td>-0.023220</td>\n",
       "      <td>0.088925</td>\n",
       "      <td>0.209697</td>\n",
       "      <td>0.501357</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44622</td>\n",
       "      <td>2</td>\n",
       "      <td>-0.972049</td>\n",
       "      <td>0.513316</td>\n",
       "      <td>0.003403</td>\n",
       "      <td>0.041867</td>\n",
       "      <td>0.682879</td>\n",
       "      <td>0.398669</td>\n",
       "      <td>-0.434141</td>\n",
       "      <td>-0.439622</td>\n",
       "      <td>0.134869</td>\n",
       "      <td>...</td>\n",
       "      <td>1.059839</td>\n",
       "      <td>-0.068104</td>\n",
       "      <td>0.338591</td>\n",
       "      <td>-0.254387</td>\n",
       "      <td>1.066497</td>\n",
       "      <td>1.200203</td>\n",
       "      <td>0.243681</td>\n",
       "      <td>0.661312</td>\n",
       "      <td>1.318291</td>\n",
       "      <td>1.403518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71821</td>\n",
       "      <td>2</td>\n",
       "      <td>1.003596</td>\n",
       "      <td>1.295017</td>\n",
       "      <td>0.219550</td>\n",
       "      <td>0.198376</td>\n",
       "      <td>1.810431</td>\n",
       "      <td>-0.601582</td>\n",
       "      <td>-1.187685</td>\n",
       "      <td>0.012743</td>\n",
       "      <td>0.684971</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.705789</td>\n",
       "      <td>0.335796</td>\n",
       "      <td>-0.035509</td>\n",
       "      <td>-1.125292</td>\n",
       "      <td>-0.736826</td>\n",
       "      <td>-0.555163</td>\n",
       "      <td>-0.429600</td>\n",
       "      <td>0.542465</td>\n",
       "      <td>-0.675596</td>\n",
       "      <td>-0.164192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57907</td>\n",
       "      <td>2</td>\n",
       "      <td>1.158887</td>\n",
       "      <td>1.763834</td>\n",
       "      <td>2.302950</td>\n",
       "      <td>-0.603070</td>\n",
       "      <td>1.965467</td>\n",
       "      <td>1.635436</td>\n",
       "      <td>-1.573040</td>\n",
       "      <td>-1.285986</td>\n",
       "      <td>-1.540435</td>\n",
       "      <td>...</td>\n",
       "      <td>0.994608</td>\n",
       "      <td>1.001552</td>\n",
       "      <td>-0.818391</td>\n",
       "      <td>0.511212</td>\n",
       "      <td>0.249450</td>\n",
       "      <td>-0.184754</td>\n",
       "      <td>0.682368</td>\n",
       "      <td>1.253365</td>\n",
       "      <td>-0.766926</td>\n",
       "      <td>-0.781316</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 205 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       hate  statuses_count  followers_count  followees_count  \\\n",
       "10999     2        0.651057        -0.228440         0.539018   \n",
       "55317     2        0.527130         0.159289         0.603327   \n",
       "44622     2       -0.972049         0.513316         0.003403   \n",
       "71821     2        1.003596         1.295017         0.219550   \n",
       "57907     2        1.158887         1.763834         2.302950   \n",
       "\n",
       "       favorites_count  listed_count  negotiate_empath  vehicle_empath  \\\n",
       "10999         1.468664      0.319936          0.060148       -1.573040   \n",
       "55317         0.116831      0.400391         -0.170600        0.731748   \n",
       "44622         0.041867      0.682879          0.398669       -0.434141   \n",
       "71821         0.198376      1.810431         -0.601582       -1.187685   \n",
       "57907        -0.603070      1.965467          1.635436       -1.573040   \n",
       "\n",
       "       science_empath  timidity_empath  ...  number hashtags  tweet number  \\\n",
       "10999        0.468232        -0.446347  ...        -0.347727     -0.087181   \n",
       "55317       -0.155481         0.487008  ...        -0.159648      0.863400   \n",
       "44622       -0.439622         0.134869  ...         1.059839     -0.068104   \n",
       "71821        0.012743         0.684971  ...        -1.705789      0.335796   \n",
       "57907       -1.285986        -1.540435  ...         0.994608      1.001552   \n",
       "\n",
       "       retweet number  quote number  status length  number urls   baddies  \\\n",
       "10999        0.355153      1.193070       0.010627     0.314380  0.581937   \n",
       "55317       -0.628442      1.058797      -0.400813    -0.034034 -0.023220   \n",
       "44622        0.338591     -0.254387       1.066497     1.200203  0.243681   \n",
       "71821       -0.035509     -1.125292      -0.736826    -0.555163 -0.429600   \n",
       "57907       -0.818391      0.511212       0.249450    -0.184754  0.682368   \n",
       "\n",
       "       mentions  time_diff  time_diff_median  \n",
       "10999  0.017239  -0.772738         -0.713314  \n",
       "55317  0.088925   0.209697          0.501357  \n",
       "44622  0.661312   1.318291          1.403518  \n",
       "71821  0.542465  -0.675596         -0.164192  \n",
       "57907  1.253365  -0.766926         -0.781316  \n",
       "\n",
       "[5 rows x 205 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node_data = node_data.loc[list(g_nx.nodes())]\n",
    "node_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "### Splitting the data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28",
   "metadata": {},
   "source": [
    "For machine learning we want to take a subset of the nodes for training, and use the rest for validation and testing. We'll use scikit-learn again to split our data into training and test sets.\n",
    "\n",
    "The total number of annotated nodes is very small when compared to the total number of nodes in the graph. We are only going to use 15% of the annotated nodes for training and the remaining 85% of nodes for testing.\n",
    "\n",
    "First, we are going to select the subset of nodes that are annotated as hateful or normal. These will be the nodes that have 'hate' values that are either 0 or 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose the nodes annotated with normal or hateful classes\n",
    "annotated_users = node_data[node_data[\"hate\"] != 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "annotated_user_features = annotated_users.drop(columns=[\"hate\"])\n",
    "annotated_user_targets = annotated_users[[\"hate\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "There are 4971 annoted nodes out of a possible, approximately, 100k nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    4427\n",
      "1     544\n",
      "Name: hate, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(annotated_user_targets.hate.value_counts())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sizes and class distributions for train/test data\n",
      "Shape train_data (745, 204)\n",
      "Shape test_data (4226, 204)\n",
      "Train data number of 0s 667 and 1s 78\n",
      "Test data number of 0s 3760 and 1s 466\n"
     ]
    }
   ],
   "source": [
    "# split the data\n",
    "train_data, test_data, train_targets, test_targets = train_test_split(\n",
    "    annotated_user_features, annotated_user_targets, test_size=0.85, random_state=101\n",
    ")\n",
    "train_targets = train_targets.values\n",
    "test_targets = test_targets.values\n",
    "print(\"Sizes and class distributions for train/test data\")\n",
    "print(\"Shape train_data {}\".format(train_data.shape))\n",
    "print(\"Shape test_data {}\".format(test_data.shape))\n",
    "print(\n",
    "    \"Train data number of 0s {} and 1s {}\".format(\n",
    "        np.sum(train_targets == 0), np.sum(train_targets == 1)\n",
    "    )\n",
    ")\n",
    "print(\n",
    "    \"Test data number of 0s {} and 1s {}\".format(\n",
    "        np.sum(test_targets == 0), np.sum(test_targets == 1)\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((745, 1), (4226, 1))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_targets.shape, test_targets.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((745, 204), (4226, 204))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.shape, test_data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36",
   "metadata": {},
   "source": [
    "We are going to use 745 nodes for training and 4226 nodes for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# choosing features to assign to a graph, excluding target variable\n",
    "node_features = node_data.drop(columns=[\"hate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38",
   "metadata": {},
   "source": [
    "### Dealing with imbalanced data\n",
    "\n",
    "Because the training data exhibit high imbalance, we introduce class weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 0.5584707646176912, 1: 4.7756410256410255}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "\n",
    "class_weights = compute_class_weight(\n",
    "    \"balanced\", np.unique(train_targets), train_targets[:, 0]\n",
    ")\n",
    "train_class_weights = dict(zip(np.unique(train_targets), class_weights))\n",
    "train_class_weights"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Our data is now ready for machine learning.\n",
    "\n",
    "Node features are stored in the Pandas DataFrame `node_features`.\n",
    "\n",
    "The graph in networkx format is stored in the variable `g_nx`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41",
   "metadata": {},
   "source": [
    "### Specify global parameters\n",
    "\n",
    "Here we specify some parameters that control the type of model we are going to use. For example, we specify the base model type, e.g., GCN, GraphSAGE, etc, as well as model-specific parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "42",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 20"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43",
   "metadata": {},
   "source": [
    "## Creating the base graph machine learning model in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44",
   "metadata": {},
   "source": [
    "Now create a `StellarGraph` object from the `NetworkX` graph and the node features and targets. It is `StellarGraph` objects that we use in this library to perform machine learning tasks on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "45",
   "metadata": {},
   "outputs": [],
   "source": [
    "G = sg.StellarGraph.from_networkx(g_nx, node_features=node_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46",
   "metadata": {},
   "source": [
    "To feed data from the graph to the Keras model we need a generator. The generators are specialized to the model and the learning task. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47",
   "metadata": {},
   "source": [
    "For training we map only the training nodes returned from our splitter and the target values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GCN (local pooling) filters...\n"
     ]
    }
   ],
   "source": [
    "generator = FullBatchNodeGenerator(G, method=\"gcn\", sparse=True)\n",
    "train_gen = generator.flow(train_data.index, train_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "49",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = GCN(\n",
    "    layer_sizes=[32, 16],\n",
    "    generator=generator,\n",
    "    bias=True,\n",
    "    dropout=0.5,\n",
    "    activations=[\"elu\", \"elu\"],\n",
    ")\n",
    "x_inp, x_out = base_model.in_out_tensors()\n",
    "prediction = layers.Dense(units=1, activation=\"sigmoid\")(x_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50",
   "metadata": {},
   "source": [
    "### Create a Keras model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51",
   "metadata": {},
   "source": [
    "Now let's create the actual Keras model with the graph inputs `x_inp` provided by the `base_model` and outputs being the predictions from the softmax layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "52",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(inputs=x_inp, outputs=prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53",
   "metadata": {},
   "source": [
    "We compile our Keras model to use the `Adam` optimiser and the binary cross entropy loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "54",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=optimizers.Adam(lr=0.005), loss=losses.binary_crossentropy, metrics=[\"acc\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "55",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(1, 100386, 204)]   0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_3 (InputLayer)            [(1, None, 2)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_4 (InputLayer)            [(1, None)]          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dropout (Dropout)               (1, 100386, 204)     0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "input_2 (InputLayer)            [(1, None)]          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "squeezed_sparse_conversion (Squ (100386, 100386)     0           input_3[0][0]                    \n",
      "                                                                 input_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "graph_convolution (GraphConvolu (1, 100386, 32)      6560        dropout[0][0]                    \n",
      "                                                                 input_2[0][0]                    \n",
      "                                                                 squeezed_sparse_conversion[0][0] \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (1, 100386, 32)      0           graph_convolution[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "graph_convolution_1 (GraphConvo (1, None, 16)        528         dropout_1[0][0]                  \n",
      "                                                                 input_2[0][0]                    \n",
      "                                                                 squeezed_sparse_conversion[0][0] \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (1, None, 1)         17          graph_convolution_1[0][0]        \n",
      "==================================================================================================\n",
      "Total params: 7,105\n",
      "Trainable params: 7,105\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56",
   "metadata": {},
   "source": [
    "Train the model, keeping track of its loss and accuracy on the training set, and its performance on the test set during the training. We don't use the test set during training but only for measuring the trained model's generalization performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "57",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "1/1 - 2s - loss: 0.7628 - acc: 0.4430 - val_loss: 0.6149 - val_acc: 0.7274\n",
      "Epoch 2/20\n",
      "1/1 - 2s - loss: 0.6236 - acc: 0.7074 - val_loss: 0.5340 - val_acc: 0.8173\n",
      "Epoch 3/20\n",
      "1/1 - 2s - loss: 0.5439 - acc: 0.7866 - val_loss: 0.4680 - val_acc: 0.8504\n",
      "Epoch 4/20\n",
      "1/1 - 2s - loss: 0.4800 - acc: 0.8309 - val_loss: 0.4154 - val_acc: 0.8623\n",
      "Epoch 5/20\n",
      "1/1 - 2s - loss: 0.4137 - acc: 0.8510 - val_loss: 0.3731 - val_acc: 0.8857\n",
      "Epoch 6/20\n",
      "1/1 - 2s - loss: 0.3704 - acc: 0.8738 - val_loss: 0.3377 - val_acc: 0.9035\n",
      "Epoch 7/20\n",
      "1/1 - 2s - loss: 0.3347 - acc: 0.9020 - val_loss: 0.3077 - val_acc: 0.9113\n",
      "Epoch 8/20\n",
      "1/1 - 2s - loss: 0.2946 - acc: 0.9195 - val_loss: 0.2834 - val_acc: 0.9139\n",
      "Epoch 9/20\n",
      "1/1 - 2s - loss: 0.2637 - acc: 0.9221 - val_loss: 0.2654 - val_acc: 0.9169\n",
      "Epoch 10/20\n",
      "1/1 - 2s - loss: 0.2394 - acc: 0.9221 - val_loss: 0.2523 - val_acc: 0.9188\n",
      "Epoch 11/20\n",
      "1/1 - 2s - loss: 0.2333 - acc: 0.9154 - val_loss: 0.2424 - val_acc: 0.9200\n",
      "Epoch 12/20\n",
      "1/1 - 2s - loss: 0.2169 - acc: 0.9262 - val_loss: 0.2352 - val_acc: 0.9214\n",
      "Epoch 13/20\n",
      "1/1 - 2s - loss: 0.2034 - acc: 0.9329 - val_loss: 0.2311 - val_acc: 0.9210\n",
      "Epoch 14/20\n",
      "1/1 - 2s - loss: 0.1907 - acc: 0.9315 - val_loss: 0.2299 - val_acc: 0.9210\n",
      "Epoch 15/20\n",
      "1/1 - 2s - loss: 0.1882 - acc: 0.9329 - val_loss: 0.2305 - val_acc: 0.9205\n",
      "Epoch 16/20\n",
      "1/1 - 2s - loss: 0.1867 - acc: 0.9289 - val_loss: 0.2315 - val_acc: 0.9198\n",
      "Epoch 17/20\n",
      "1/1 - 2s - loss: 0.1746 - acc: 0.9450 - val_loss: 0.2321 - val_acc: 0.9212\n",
      "Epoch 18/20\n",
      "1/1 - 2s - loss: 0.1785 - acc: 0.9289 - val_loss: 0.2333 - val_acc: 0.9224\n",
      "Epoch 19/20\n",
      "1/1 - 2s - loss: 0.1716 - acc: 0.9329 - val_loss: 0.2347 - val_acc: 0.9238\n",
      "Epoch 20/20\n",
      "1/1 - 2s - loss: 0.1691 - acc: 0.9383 - val_loss: 0.2366 - val_acc: 0.9245\n"
     ]
    }
   ],
   "source": [
    "test_gen = generator.flow(test_data.index, test_targets)\n",
    "history = model.fit(\n",
    "    train_gen,\n",
    "    epochs=epochs,\n",
    "    validation_data=test_gen,\n",
    "    verbose=2,\n",
    "    shuffle=False,\n",
    "    class_weight=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58",
   "metadata": {},
   "source": [
    "### Model Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59",
   "metadata": {},
   "source": [
    "Now we have trained the model, let's evaluate it on the test set.\n",
    "\n",
    "We are going to consider 4 evaluation metrics calculated on the test set: Accuracy, Area Under the ROC curve (AU-ROC), the ROC curve, and the confusion table."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60",
   "metadata": {},
   "source": [
    "#### Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Metrics:\n",
      "\tloss: 0.2366\n",
      "\tacc: 0.9245\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_gen)\n",
    "print(\"\\nTest Set Metrics:\")\n",
    "for name, val in zip(model.metrics_names, test_metrics):\n",
    "    print(\"\\t{}: {:0.4f}\".format(name, val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "62",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes = node_data.index\n",
    "all_gen = generator.flow(all_nodes)\n",
    "all_predictions = model.predict(all_gen).squeeze()[..., np.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100386, 1)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_predictions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "64",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_predictions_df = pd.DataFrame(all_predictions, index=node_data.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65",
   "metadata": {},
   "source": [
    "Let's extract the predictions for the test data only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "66",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_preds = all_predictions_df.loc[test_data.index, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4226, 1)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_preds.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68",
   "metadata": {},
   "source": [
    "The predictions are the probability of the true class that in this case is the probability of a user being hateful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The AUC on test set:\n",
      "\n",
      "0.8730678134416948\n"
     ]
    }
   ],
   "source": [
    "test_predictions = test_preds.values\n",
    "test_predictions_class = ((test_predictions > 0.5) * 1).flatten()\n",
    "test_df = pd.DataFrame(\n",
    "    {\n",
    "        \"Predicted_score\": test_predictions.flatten(),\n",
    "        \"Predicted_class\": test_predictions_class,\n",
    "        \"True\": test_targets[:, 0],\n",
    "    }\n",
    ")\n",
    "roc_auc = metrics.roc_auc_score(test_df[\"True\"].values, test_df[\"Predicted_score\"].values)\n",
    "print(\"The AUC on test set:\\n\")\n",
    "print(roc_auc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70",
   "metadata": {},
   "source": [
    "## Interpretability by Saliency Maps\n",
    "\n",
    "To understand which features and edges the model is looking at while making the predictions, we use the interpretability tool in the StellarGraph library (i.e., saliency maps) to demonstrate the importance of node features and edges given a target user."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "71",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph.interpretability.saliency_maps import IntegratedGradients\n",
    "\n",
    "int_saliency = IntegratedGradients(model, all_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "target_idx = 36367, target_nid = 77692\n",
      "prediction score for node 36367 is [0.9576855]\n",
      "ground truth score for node 36367 is [1]\n"
     ]
    }
   ],
   "source": [
    "# we first select a list of nodes which are confidently classified as hateful.\n",
    "predicted_hateful_index = set(np.where(all_predictions > 0.9)[0].tolist())\n",
    "test_indices_set = set([int(k) for k in test_data.index.tolist()])\n",
    "hateful_in_test = list(predicted_hateful_index.intersection(test_indices_set))\n",
    "\n",
    "# let's pick one node from the predicted hateful users as an example.\n",
    "idx = 2\n",
    "target_idx = hateful_in_test[idx]\n",
    "target_nid = list(G.nodes())[target_idx]\n",
    "print(\"target_idx = {}, target_nid = {}\".format(target_idx, target_nid))\n",
    "print(\n",
    "    \"prediction score for node {} is {}\".format(target_idx, all_predictions[target_idx])\n",
    ")\n",
    "print(\n",
    "    \"ground truth score for node {} is {}\".format(\n",
    "        target_idx, test_targets[test_data.index.tolist().index(str(target_nid))]\n",
    "    )\n",
    ")\n",
    "[X, all_targets, A_index, A], y_true_all = all_gen[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73",
   "metadata": {},
   "source": [
    "For the prediction of the target node, we then calculate the importance of the features for each node in the graph. Our support for sparse saliency maps makes it efficient to fit the scale like this dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# We set the target_idx which is our target node.\n",
    "node_feature_importance = int_saliency.get_integrated_node_masks(target_idx, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75",
   "metadata": {},
   "source": [
    "As `node_feature_importance` is a matrix where `node_feature_importance[i][j]` indicates the importance of the j-th feature of node i to the prediction of the target node, we sum up the feature importance of each node to measure its node importance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.41207344  0.21552302  0.20715151 ... -0.0263247  -0.03053751\n",
      " -0.03189572]\n",
      "node_importance has 12721 non-zero values\n"
     ]
    }
   ],
   "source": [
    "node_importance = np.sum(node_feature_importance, axis=-1)\n",
    "node_importance_rank = np.argsort(node_importance)[::-1]\n",
    "print(node_importance[node_importance_rank])\n",
    "print(\n",
    "    \"node_importance has {} non-zero values\".format(\n",
    "        np.where(node_importance != 0)[0].shape[0]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77",
   "metadata": {},
   "source": [
    "We expect the number of non-zero values of `node_importance` to match the number of nodes in the ego graph. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The ego graph of the target node has 12721 neighbors\n"
     ]
    }
   ],
   "source": [
    "G_ego = nx.ego_graph(g_nx, target_nid, radius=2)\n",
    "print(\"The ego graph of the target node has {} neighbors\".format(len(G_ego.nodes())))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79",
   "metadata": {},
   "source": [
    "We then analyze the feature importance of the top-250 important nodes. See the output for the top-5 importance nodes. For each row, the features are sorted according to their importance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>195</th>\n",
       "      <th>196</th>\n",
       "      <th>197</th>\n",
       "      <th>198</th>\n",
       "      <th>199</th>\n",
       "      <th>200</th>\n",
       "      <th>201</th>\n",
       "      <th>202</th>\n",
       "      <th>203</th>\n",
       "      <th>204</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>36367</td>\n",
       "      <td>fear_empath</td>\n",
       "      <td>torment_empath</td>\n",
       "      <td>favorites_count</td>\n",
       "      <td>legend_empath</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>computer_empath</td>\n",
       "      <td>furniture_empath</td>\n",
       "      <td>royalty_empath</td>\n",
       "      <td>office_empath</td>\n",
       "      <td>...</td>\n",
       "      <td>listen_empath</td>\n",
       "      <td>medieval_empath</td>\n",
       "      <td>divine_empath</td>\n",
       "      <td>confusion_empath</td>\n",
       "      <td>statuses_count</td>\n",
       "      <td>gain_empath</td>\n",
       "      <td>wealthy_empath</td>\n",
       "      <td>family_empath</td>\n",
       "      <td>tweet number</td>\n",
       "      <td>retweet number</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>34534</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>fear_empath</td>\n",
       "      <td>favorites_count</td>\n",
       "      <td>ridicule_empath</td>\n",
       "      <td>pet_empath</td>\n",
       "      <td>computer_empath</td>\n",
       "      <td>children_empath</td>\n",
       "      <td>shame_empath</td>\n",
       "      <td>royalty_empath</td>\n",
       "      <td>...</td>\n",
       "      <td>hearing_empath</td>\n",
       "      <td>weakness_empath</td>\n",
       "      <td>masculine_empath</td>\n",
       "      <td>home_empath</td>\n",
       "      <td>gain_empath</td>\n",
       "      <td>positive_emotion_empath</td>\n",
       "      <td>competing_empath</td>\n",
       "      <td>statuses_count</td>\n",
       "      <td>tweet number</td>\n",
       "      <td>retweet number</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>69814</td>\n",
       "      <td>ridicule_empath</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>joy_empath</td>\n",
       "      <td>anger_empath</td>\n",
       "      <td>money_empath</td>\n",
       "      <td>retweet number</td>\n",
       "      <td>listen_empath</td>\n",
       "      <td>tweet number</td>\n",
       "      <td>negotiate_empath</td>\n",
       "      <td>...</td>\n",
       "      <td>contentment_empath</td>\n",
       "      <td>divine_empath</td>\n",
       "      <td>deception_empath</td>\n",
       "      <td>fear_empath</td>\n",
       "      <td>dominant_heirarchical_empath</td>\n",
       "      <td>furniture_empath</td>\n",
       "      <td>subjectivity</td>\n",
       "      <td>white_collar_job_empath</td>\n",
       "      <td>health_empath</td>\n",
       "      <td>giving_empath</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>69254</td>\n",
       "      <td>ridicule_empath</td>\n",
       "      <td>legend_empath</td>\n",
       "      <td>hipster_empath</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>science_empath</td>\n",
       "      <td>fun_empath</td>\n",
       "      <td>farming_empath</td>\n",
       "      <td>furniture_empath</td>\n",
       "      <td>favorites_count</td>\n",
       "      <td>...</td>\n",
       "      <td>status length</td>\n",
       "      <td>kill_empath</td>\n",
       "      <td>clothing_empath</td>\n",
       "      <td>weakness_empath</td>\n",
       "      <td>breaking_empath</td>\n",
       "      <td>retweet number</td>\n",
       "      <td>giving_empath</td>\n",
       "      <td>pet_empath</td>\n",
       "      <td>gain_empath</td>\n",
       "      <td>quote number</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>71135</td>\n",
       "      <td>royalty_empath</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>statuses_count</td>\n",
       "      <td>leader_empath</td>\n",
       "      <td>weakness_empath</td>\n",
       "      <td>legend_empath</td>\n",
       "      <td>joy_empath</td>\n",
       "      <td>children_empath</td>\n",
       "      <td>ridicule_empath</td>\n",
       "      <td>...</td>\n",
       "      <td>divine_empath</td>\n",
       "      <td>religion_empath</td>\n",
       "      <td>healing_empath</td>\n",
       "      <td>poor_empath</td>\n",
       "      <td>sexual_empath</td>\n",
       "      <td>deception_empath</td>\n",
       "      <td>help_empath</td>\n",
       "      <td>farming_empath</td>\n",
       "      <td>tweet number</td>\n",
       "      <td>favorites_count</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 205 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     0                1               2                3                4    \\\n",
       "0  36367      fear_empath  torment_empath  favorites_count    legend_empath   \n",
       "1  34534      rage_empath     fear_empath  favorites_count  ridicule_empath   \n",
       "2  69814  ridicule_empath     rage_empath       joy_empath     anger_empath   \n",
       "3  69254  ridicule_empath   legend_empath   hipster_empath      rage_empath   \n",
       "4  71135   royalty_empath     rage_empath   statuses_count    leader_empath   \n",
       "\n",
       "               5                6                 7                 8    \\\n",
       "0      rage_empath  computer_empath  furniture_empath    royalty_empath   \n",
       "1       pet_empath  computer_empath   children_empath      shame_empath   \n",
       "2     money_empath   retweet number     listen_empath      tweet number   \n",
       "3   science_empath       fun_empath    farming_empath  furniture_empath   \n",
       "4  weakness_empath    legend_empath        joy_empath   children_empath   \n",
       "\n",
       "                9    ...                 195              196  \\\n",
       "0     office_empath  ...       listen_empath  medieval_empath   \n",
       "1    royalty_empath  ...      hearing_empath  weakness_empath   \n",
       "2  negotiate_empath  ...  contentment_empath    divine_empath   \n",
       "3   favorites_count  ...       status length      kill_empath   \n",
       "4   ridicule_empath  ...       divine_empath  religion_empath   \n",
       "\n",
       "                197               198                           199  \\\n",
       "0     divine_empath  confusion_empath                statuses_count   \n",
       "1  masculine_empath       home_empath                   gain_empath   \n",
       "2  deception_empath       fear_empath  dominant_heirarchical_empath   \n",
       "3   clothing_empath   weakness_empath               breaking_empath   \n",
       "4    healing_empath       poor_empath                 sexual_empath   \n",
       "\n",
       "                       200               201                      202  \\\n",
       "0              gain_empath    wealthy_empath            family_empath   \n",
       "1  positive_emotion_empath  competing_empath           statuses_count   \n",
       "2         furniture_empath      subjectivity  white_collar_job_empath   \n",
       "3           retweet number     giving_empath               pet_empath   \n",
       "4         deception_empath       help_empath           farming_empath   \n",
       "\n",
       "             203              204  \n",
       "0   tweet number   retweet number  \n",
       "1   tweet number   retweet number  \n",
       "2  health_empath    giving_empath  \n",
       "3    gain_empath     quote number  \n",
       "4   tweet number  favorites_count  \n",
       "\n",
       "[5 rows x 205 columns]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_names = annotated_users.keys()[1:].values\n",
    "feature_importance_rank = np.argsort(node_feature_importance[target_idx])[::-1]\n",
    "df = pd.DataFrame(\n",
    "    [\n",
    "        ([k] + list(feature_names[np.argsort(node_feature_importance[k])[::-1]]))\n",
    "        for k in node_importance_rank[:250]\n",
    "    ],\n",
    "    columns=range(205),\n",
    ")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81",
   "metadata": {},
   "source": [
    "As a sanity check, we expect the target node itself to have a relatively high importance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4120734350877242\n",
      "The node itself is the 1-th important node\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>194</th>\n",
       "      <th>195</th>\n",
       "      <th>196</th>\n",
       "      <th>197</th>\n",
       "      <th>198</th>\n",
       "      <th>199</th>\n",
       "      <th>200</th>\n",
       "      <th>201</th>\n",
       "      <th>202</th>\n",
       "      <th>203</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>fear_empath</td>\n",
       "      <td>torment_empath</td>\n",
       "      <td>favorites_count</td>\n",
       "      <td>legend_empath</td>\n",
       "      <td>rage_empath</td>\n",
       "      <td>computer_empath</td>\n",
       "      <td>furniture_empath</td>\n",
       "      <td>royalty_empath</td>\n",
       "      <td>office_empath</td>\n",
       "      <td>science_empath</td>\n",
       "      <td>...</td>\n",
       "      <td>listen_empath</td>\n",
       "      <td>medieval_empath</td>\n",
       "      <td>divine_empath</td>\n",
       "      <td>confusion_empath</td>\n",
       "      <td>statuses_count</td>\n",
       "      <td>gain_empath</td>\n",
       "      <td>wealthy_empath</td>\n",
       "      <td>family_empath</td>\n",
       "      <td>tweet number</td>\n",
       "      <td>retweet number</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1 rows × 204 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           0               1                2              3            4    \\\n",
       "0  fear_empath  torment_empath  favorites_count  legend_empath  rage_empath   \n",
       "\n",
       "               5                 6               7              8    \\\n",
       "0  computer_empath  furniture_empath  royalty_empath  office_empath   \n",
       "\n",
       "              9    ...            194              195            196  \\\n",
       "0  science_empath  ...  listen_empath  medieval_empath  divine_empath   \n",
       "\n",
       "                197             198          199             200  \\\n",
       "0  confusion_empath  statuses_count  gain_empath  wealthy_empath   \n",
       "\n",
       "             201           202             203  \n",
       "0  family_empath  tweet number  retweet number  \n",
       "\n",
       "[1 rows x 204 columns]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "self_feature_importance_rank = np.argsort(node_feature_importance[target_idx])\n",
    "print(np.sum(node_feature_importance[target_idx]))\n",
    "print(\n",
    "    \"The node itself is the {}-th important node\".format(\n",
    "        1 + node_importance_rank.tolist().index(target_idx)\n",
    "    )\n",
    ")\n",
    "df = pd.DataFrame([feature_names[self_feature_importance_rank][::-1]], columns=range(204))\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83",
   "metadata": {},
   "source": [
    "For different nodes, the same features may have different ranks. To understand the overall importance of the features, we now analyze the average feature importance rank for the above selected nodes. Specifically, we obtain the average rank of each specific feature among the top-250 important nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ridicule_empath 11.096\n",
      "hipster_empath 23.612\n",
      "furniture_empath 27.564\n",
      "science_empath 28.28\n",
      "cleaning_empath 33.94\n",
      "legend_empath 34.536\n",
      "joy_empath 35.688\n",
      "anger_empath 36.344\n",
      "farming_empath 37.444\n",
      "children_empath 37.54\n",
      "listen_empath 42.396\n",
      "negotiate_empath 51.996\n",
      "listed_count 52.784\n",
      "vacation_empath 52.952\n",
      "rural_empath 53.42\n",
      "anonymity_empath 53.688\n",
      "traveling_empath 53.988\n",
      "time_diff 55.836\n",
      "speaking_empath 56.932\n",
      "money_empath 56.968\n",
      "noise_empath 57.368\n",
      "death_empath 59.676\n",
      "exasperation_empath 60.744\n",
      "hiking_empath 61.0\n",
      "violence_empath 61.692\n",
      "positive_emotion_empath 62.132\n",
      "social_media_empath 63.252\n",
      "leisure_empath 64.448\n",
      "royalty_empath 64.732\n",
      "pride_empath 65.788\n",
      "weakness_empath 66.6\n",
      "office_empath 68.584\n",
      "wedding_empath 69.616\n",
      "sympathy_empath 70.0\n",
      "subjectivity 70.424\n",
      "terrorism_empath 70.74\n",
      "kill_empath 71.808\n",
      "computer_empath 73.048\n",
      "number hashtags 73.092\n",
      "fight_empath 74.428\n",
      "favorites_count 74.444\n",
      "emotional_empath 75.48\n",
      "followers_count 75.584\n",
      "dispute_empath 75.74\n",
      "baddies 75.756\n",
      "shame_empath 76.06\n",
      "zest_empath 78.936\n",
      "rage_empath 78.996\n",
      "statuses_count 81.088\n",
      "swearing_terms_empath 81.52\n",
      "attractive_empath 83.192\n",
      "dominant_personality_empath 84.864\n",
      "fire_empath 84.964\n",
      "business_empath 85.288\n",
      "government_empath 85.732\n",
      "fear_empath 85.992\n",
      "love_empath 86.236\n",
      "musical_empath 86.564\n",
      "car_empath 86.772\n",
      "number urls 87.404\n",
      "play_empath 87.632\n",
      "white_collar_job_empath 88.052\n",
      "disgust_empath 88.116\n",
      "fun_empath 88.44\n",
      "payment_empath 88.768\n",
      "vehicle_empath 89.568\n",
      "childish_empath 90.0\n",
      "prison_empath 90.096\n",
      "writing_empath 91.224\n",
      "healing_empath 91.296\n",
      "torment_empath 91.848\n",
      "school_empath 92.644\n",
      "hygiene_empath 92.852\n",
      "achievement_empath 94.408\n",
      "ancient_empath 96.328\n",
      "medical_emergency_empath 96.336\n",
      "philosophy_empath 97.288\n",
      "crime_empath 97.416\n",
      "driving_empath 97.96\n",
      "pet_empath 97.964\n",
      "timidity_empath 98.2\n",
      "cold_empath 98.296\n",
      "warmth_empath 98.408\n",
      "power_empath 98.756\n",
      "leader_empath 98.832\n",
      "dance_empath 99.204\n",
      "liquid_empath 99.288\n",
      "communication_empath 99.496\n",
      "movement_empath 99.916\n",
      "monster_empath 102.284\n",
      "ocean_empath 102.676\n",
      "tweet number 103.136\n",
      "confusion_empath 103.38\n",
      "surprise_empath 103.496\n",
      "order_empath 103.508\n",
      "followees_count 103.968\n",
      "alcohol_empath 104.008\n",
      "worship_empath 104.636\n",
      "hearing_empath 104.8\n",
      "trust_empath 104.8\n",
      "affection_empath 105.088\n",
      "shape_and_size_empath 105.32\n",
      "urban_empath 105.448\n",
      "toy_empath 105.556\n",
      "suffering_empath 105.896\n",
      "college_empath 106.852\n",
      "beauty_empath 106.896\n",
      "real_estate_empath 108.068\n",
      "tool_empath 108.428\n",
      "health_empath 108.624\n",
      "appearance_empath 108.812\n",
      "stealing_empath 108.904\n",
      "technology_empath 109.016\n",
      "occupation_empath 109.512\n",
      "youth_empath 109.76\n",
      "time_diff_median 110.224\n",
      "eating_empath 110.26\n",
      "ugliness_empath 110.3\n",
      "banking_empath 110.56\n",
      "lust_empath 110.98\n",
      "competing_empath 111.256\n",
      "cooking_empath 111.484\n",
      "fashion_empath 111.792\n",
      "ship_empath 111.9\n",
      "journalism_empath 112.076\n",
      "negative_emotion_empath 112.676\n",
      "economics_empath 113.6\n",
      "water_empath 113.956\n",
      "retweet number 114.104\n",
      "sexual_empath 114.336\n",
      "night_empath 114.776\n",
      "smell_empath 115.676\n",
      "air_travel_empath 115.736\n",
      "party_empath 115.884\n",
      "war_empath 116.0\n",
      "law_empath 116.676\n",
      "tourism_empath 116.708\n",
      "body_empath 116.88\n",
      "status length 116.904\n",
      "reading_empath 117.008\n",
      "politeness_empath 117.156\n",
      "family_empath 117.42\n",
      "meeting_empath 117.504\n",
      "shopping_empath 117.708\n",
      "art_empath 117.904\n",
      "medieval_empath 118.684\n",
      "pain_empath 119.612\n",
      "feminine_empath 120.672\n",
      "wealthy_empath 121.316\n",
      "optimism_empath 121.552\n",
      "irritability_empath 121.664\n",
      "plant_empath 121.912\n",
      "independence_empath 122.204\n",
      "animal_empath 122.304\n",
      "horror_empath 122.38\n",
      "phone_empath 122.416\n",
      "nervousness_empath 122.712\n",
      "quote number 122.86\n",
      "anticipation_empath 124.308\n",
      "cheerfulness_empath 125.0\n",
      "politics_empath 125.672\n",
      "internet_empath 126.012\n",
      "military_empath 126.056\n",
      "sailing_empath 126.088\n",
      "celebration_empath 126.304\n",
      "giving_empath 127.656\n",
      "religion_empath 127.848\n",
      "weapon_empath 128.356\n",
      "sound_empath 129.372\n",
      "strength_empath 131.172\n",
      "restaurant_empath 133.104\n",
      "breaking_empath 133.148\n",
      "injury_empath 133.204\n",
      "sports_empath 134.016\n",
      "swimming_empath 134.496\n",
      "sadness_empath 134.884\n",
      "mentions 135.048\n",
      "gain_empath 135.116\n",
      "valuable_empath 137.04\n",
      "exercise_empath 137.804\n",
      "help_empath 137.828\n",
      "programming_empath 137.892\n",
      "hate_empath 138.292\n",
      "weather_empath 138.476\n",
      "friends_empath 138.752\n",
      "clothing_empath 139.02\n",
      "contentment_empath 139.568\n",
      "neglect_empath 139.652\n",
      "blue_collar_job_empath 141.772\n",
      "work_empath 143.168\n",
      "beach_empath 144.776\n",
      "disappointment_empath 144.944\n",
      "deception_empath 145.72\n",
      "messaging_empath 147.116\n",
      "sleep_empath 148.896\n",
      "poor_empath 148.932\n",
      "envy_empath 150.56\n",
      "dominant_heirarchical_empath 154.408\n",
      "morning_empath 156.448\n",
      "superhero_empath 158.392\n",
      "home_empath 159.552\n",
      "aggression_empath 160.376\n",
      "divine_empath 160.444\n",
      "masculine_empath 161.288\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "average_feature_rank = defaultdict(int)\n",
    "for i in node_importance_rank[:250]:\n",
    "    feature_rank = list(feature_names[np.argsort(node_feature_importance[i])[::-1]])\n",
    "    for j in range(len(feature_rank)):\n",
    "        average_feature_rank[feature_rank[j]] += feature_rank.index(feature_rank[j])\n",
    "for k in average_feature_rank.keys():\n",
    "    average_feature_rank[k] /= 250.0\n",
    "sorted_avg_feature_rank = sorted(average_feature_rank.items(), key=lambda a: a[1])\n",
    "for feat, avg_rank in sorted_avg_feature_rank:\n",
    "    print(feat, avg_rank)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85",
   "metadata": {},
   "source": [
    "It seems for our target node, topics relevant to cleaning, hipster, etc. are important while those such as leaisure, ship, goverment, etc. are not important."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86",
   "metadata": {},
   "source": [
    "We then calculate the link importance for the edges that are connected to the target node within k hops (k = 2 for our GCN model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "87",
   "metadata": {},
   "outputs": [],
   "source": [
    "link_importance = int_saliency.get_integrated_link_masks(target_idx, 0, steps=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 4289572, 2) (1, 4289572)\n"
     ]
    }
   ],
   "source": [
    "(x, y) = link_importance.nonzero()\n",
    "[X, all_targets, A_index, A], y_true_all = all_gen[0]\n",
    "print(A_index.shape, A.shape)\n",
    "G_edge_indices = [(A_index[0, k, 0], A_index[0, k, 1]) for k in range(A_index.shape[1])]\n",
    "link_dict = {(A_index[0, k, 0], A_index[0, k, 1]): k for k in range(A_index.shape[1])}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89",
   "metadata": {},
   "source": [
    "As a sanity check, we expect the most important edge to connect important nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 26145 edges within the ego graph of the target node\n",
      "The most important edge connects 0-th important node and 132-th important node\n"
     ]
    }
   ],
   "source": [
    "nonzero_importance_val = link_importance[(x, y)].flatten().tolist()[0]\n",
    "link_importance_rank = np.argsort(nonzero_importance_val)[::-1]\n",
    "edge_number_in_ego_graph = link_importance_rank.shape[0]\n",
    "print(\n",
    "    \"There are {} edges within the ego graph of the target node\".format(\n",
    "        edge_number_in_ego_graph\n",
    "    )\n",
    ")\n",
    "x_rank, y_rank = x[link_importance_rank], y[link_importance_rank]\n",
    "print(\n",
    "    \"The most important edge connects {}-th important node and {}-th important node\".format(\n",
    "        node_importance_rank.tolist().index(x_rank[0]),\n",
    "        (node_importance_rank.tolist().index(y_rank[0])),\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91",
   "metadata": {},
   "source": [
    "To ensure that we are getting the correct importance for edges, we then check what happens if we perturb the top-10 most important edges. Specifically, if we remove the top important edges according to the calculated edge importance scores, we should expect to see the prediction of the target node change. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "92",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A_perturb.shape = (1, 4289572)\n"
     ]
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "selected_nodes = np.array([[target_idx]], dtype=\"int32\")\n",
    "prediction_clean = model.predict([X, selected_nodes, A_index, A]).squeeze()\n",
    "A_perturb = deepcopy(A)\n",
    "print(\"A_perturb.shape = {}\".format(A_perturb.shape))\n",
    "# we remove top 1% important edges in the graph and see how the prediction changes\n",
    "topk = int(edge_number_in_ego_graph * 0.01)\n",
    "\n",
    "for i in range(topk):\n",
    "    edge_x, edge_y = x_rank[i], y_rank[i]\n",
    "    edge_index = link_dict[(edge_x, edge_y)]\n",
    "    A_perturb[0, edge_index] = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93",
   "metadata": {},
   "source": [
    "As expected, the prediction score drops after the perturbation. The target node is predicted as non-hateful now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The prediction score changes from 0.9576854109764099 to 0.14111952483654022 after the perturbation\n"
     ]
    }
   ],
   "source": [
    "prediction = model.predict([X, selected_nodes, A_index, A_perturb]).squeeze()\n",
    "print(\n",
    "    \"The prediction score changes from {} to {} after the perturbation\".format(\n",
    "        prediction_clean, prediction\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95",
   "metadata": {},
   "source": [
    "NOTES: For UX team, the above notebook shows how we are able to compute the importance of nodes and edges. However, it seems the ego graph of the target node in twitter dataset is often very big so that we may draw only top important nodes/edges on the visualization. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "96",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "97",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/interpretability/hateful-twitters-interpretability.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/interpretability/hateful-twitters-interpretability.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}