Lambda-School-Labs/Labs26-StorySquad-DS-TeamB

View on GitHub
notebooks/clustering.ipynb

Summary

Maintainability
Test Coverage
{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "Python 3.7.6 64-bit ('anaconda3': virtualenv)",
   "display_name": "Python 3.7.6 64-bit ('anaconda3': virtualenv)",
   "metadata": {
    "interpreter": {
     "hash": "b8b1a584ca6d5769a8ab9f1c518f3f9bdf9236a52e581b5d219269464762273d"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Clustering For Match Making"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### Remainder Problem\n",
    "Instead of being left with 1 user who gets grouped with 3 bots, make the last three groups all have 3 users and 1 bot"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### Clustering algorithms\n",
    "Split based on ranked squad_score: Simpliest mvp\n",
    "- ended up implementing this method due to time constrants. The other methods involve utilizing all of the metrics which invloved setting up a DS database which we dodn't have implemented at this time.\n",
    "\n",
    "K Means: Doesn't split into even groups like we need, cluster into 5 similar tiers and then group within each cluster based off of Squad Score.\n",
    "\n",
    "KNN: Select a random user, grab the 3 closest neighbors, repeat until everyone is matched\n",
    "\n",
    "Agglomarative clustering: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html\n",
    "\n",
    "linear sum assignment: https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.optimize.linear_sum_assignment.html"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import pandas as pd\n",
    "from sklearn.cluster import KMeans, AgglomerativeClustering\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neighbors import NearestNeighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "   story_id  story_length  avg_word_len  quotes_num  unique_words_num  \\\n",
       "0      3132          1375      5.092593           6               138   \n",
       "1      3104           903      4.961538           0               110   \n",
       "2      3103           750      5.000000           1                93   \n",
       "3      3117           439      4.877778           1                56   \n",
       "4      3102          1812      4.897297           0               193   \n",
       "\n",
       "   squad_score  \n",
       "0    39.177001  \n",
       "1    26.173076  \n",
       "2    24.497113  \n",
       "3    16.454330  \n",
       "4    41.083353  "
      ],
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>story_id</th>\n      <th>story_length</th>\n      <th>avg_word_len</th>\n      <th>quotes_num</th>\n      <th>unique_words_num</th>\n      <th>squad_score</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>3132</td>\n      <td>1375</td>\n      <td>5.092593</td>\n      <td>6</td>\n      <td>138</td>\n      <td>39.177001</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>3104</td>\n      <td>903</td>\n      <td>4.961538</td>\n      <td>0</td>\n      <td>110</td>\n      <td>26.173076</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>3103</td>\n      <td>750</td>\n      <td>5.000000</td>\n      <td>1</td>\n      <td>93</td>\n      <td>24.497113</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>3117</td>\n      <td>439</td>\n      <td>4.877778</td>\n      <td>1</td>\n      <td>56</td>\n      <td>16.454330</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>3102</td>\n      <td>1812</td>\n      <td>4.897297</td>\n      <td>0</td>\n      <td>193</td>\n      <td>41.083353</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "source": [
    "# Read in metrics csv\n",
    "df = pd.read_csv('story_squad_metrics.csv', index_col='Unnamed: 0')\n",
    "df.head()"
   ]
  },
  {
   "source": [
    "### Using K Means Clustering\n",
    "- Segment the users into different clusters based on all of the metrics - 5 'tiers'\n",
    "- Within their clusters, group into groups of 4 based on their Squad Score"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "   story_id  story_length  avg_word_len  quotes_num  unique_words_num  \\\n",
       "0      3132          1375      5.092593           6               138   \n",
       "1      3104           903      4.961538           0               110   \n",
       "2      3103           750      5.000000           1                93   \n",
       "3      3117           439      4.877778           1                56   \n",
       "4      3102          1812      4.897297           0               193   \n",
       "\n",
       "   squad_score  cluster  \n",
       "0    39.177001        3  \n",
       "1    26.173076        0  \n",
       "2    24.497113        0  \n",
       "3    16.454330        0  \n",
       "4    41.083353        3  "
      ],
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>story_id</th>\n      <th>story_length</th>\n      <th>avg_word_len</th>\n      <th>quotes_num</th>\n      <th>unique_words_num</th>\n      <th>squad_score</th>\n      <th>cluster</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>3132</td>\n      <td>1375</td>\n      <td>5.092593</td>\n      <td>6</td>\n      <td>138</td>\n      <td>39.177001</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>3104</td>\n      <td>903</td>\n      <td>4.961538</td>\n      <td>0</td>\n      <td>110</td>\n      <td>26.173076</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>3103</td>\n      <td>750</td>\n      <td>5.000000</td>\n      <td>1</td>\n      <td>93</td>\n      <td>24.497113</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>3117</td>\n      <td>439</td>\n      <td>4.877778</td>\n      <td>1</td>\n      <td>56</td>\n      <td>16.454330</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>3102</td>\n      <td>1812</td>\n      <td>4.897297</td>\n      <td>0</td>\n      <td>193</td>\n      <td>41.083353</td>\n      <td>3</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "# Get clusters\n",
    "\n",
    "# Instantiate scaler\n",
    "scaler = StandardScaler()\n",
    "\n",
    "# Pull out features\n",
    "features = df.drop('story_id', axis=1)\n",
    "\n",
    "# Scale data\n",
    "norm_x = scaler.fit_transform(features)\n",
    "\n",
    "# Instantiate model - 5 clusters\n",
    "model= KMeans(n_clusters = 5)\n",
    "\n",
    "# Predict clusters\n",
    "df['cluster'] = model.fit_predict(norm_x)\n",
    "\n",
    "# View df\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "3    49\n",
       "0    45\n",
       "2    39\n",
       "4    19\n",
       "1    15\n",
       "Name: cluster, dtype: int64"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "source": [
    "# Cluster distribution\n",
    "df['cluster'].value_counts()"
   ]
  },
  {
   "source": [
    "The associated number of the clusters change each time you run the model, but the distribution remains the exact same"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "# MVP Clustering from Squad Scores\n",
    "\n",
    "Grouping function:\n",
    "- uses the remainder from len(df) % 4 to make decisions of how to group the last users\n",
    "- edge case conditions handles if there are less than 6 users, otherwise it would return some keys with blank values"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_4(df):\n",
    "    '''\n",
    "    Function to split given dataframe into groups of 4 based on their ranked squad_scores\n",
    "    When there is a remainder of users not evenly divisable by 4, it will split the remainder so there is never more than 1 computer user in a group, unless there are less that 3 users.\n",
    "\n",
    "        Input: df to be grouped containing the column 'squad_score' and 'story_id'\n",
    "        Output: Dictionary of groupings. {group #: list of story_id's}\n",
    "\n",
    "    '''\n",
    "    # Rank by squad_score\n",
    "    df = df.sort_values(by= ['squad_score'], ascending= False)\n",
    "\n",
    "    # Initial variables\n",
    "    split = len(df) // 4\n",
    "    remainder = len(df) % 4\n",
    "    group_dict = {}\n",
    "\n",
    "    # Edge Cases: \n",
    "    # - less than 4, they are all in one group\n",
    "    # - 5, one group of 3 one group of 2\n",
    "    if len(df) == 5:\n",
    "        group_dict[1] = list(df['story_id'][:3])\n",
    "        group_dict[2] = list(df['story_id'][3:])\n",
    "        return group_dict\n",
    "    \n",
    "    if len(df) < 4:\n",
    "        group_dict[1] = list(df['story_id'][:])\n",
    "        return group_dict\n",
    "\n",
    "    # If the remainder is 3 -> last group will be a group of 3 users\n",
    "    if remainder == 3:\n",
    "        for i in range(split):\n",
    "            # Group by top 4 squad scores\n",
    "            group_dict[i+1] = list(df['story_id'][:4])\n",
    "            # Drop stories you have grouped already \n",
    "            df = df[4:]\n",
    "        \n",
    "        # Final group is the last 3 remainders\n",
    "        group_dict[split +1] = list(df['story_id'][:])\n",
    "        return group_dict\n",
    "\n",
    "    # If the remainder is 2 -> last 2 groups will be groups of 3\n",
    "    elif remainder == 2:\n",
    "        # Leave the last 2 groups to split into 2 groups of 3\n",
    "        for i in range(split -1):\n",
    "            # Group by top 4 squad scores\n",
    "            group_dict[i+1] = list(df['story_id'][:4])\n",
    "            # Drop stories you have grouped already\n",
    "            df = df[4:]\n",
    "\n",
    "        # The last two groups will be groups of 3\n",
    "        group_dict[split] = list(df['story_id'][:3])\n",
    "        group_dict[split + 1] = list(df['story_id'][3:])\n",
    "        return group_dict\n",
    "\n",
    "    # If the remainder is 1 -> last 3 groups will be groups of 3\n",
    "    elif remainder == 1:\n",
    "        # Leave the last 3 groups to be split into 3 groups of 3\n",
    "        for i in range(split -2):\n",
    "            # Group by top 4 squad scores\n",
    "            group_dict[i+1] = list(df['story_id'][:4])\n",
    "            # Drop stories you have already grouped\n",
    "            df = df[4:]\n",
    "\n",
    "        # The last three groups as groups of 3\n",
    "        group_dict[split -1] = list(df['story_id'][:3])\n",
    "        group_dict[split] = list(df['story_id'][3:6])\n",
    "        group_dict[split + 1] = list(df['story_id'][6:])\n",
    "        return group_dict\n",
    "    \n",
    "    # If the remainder is 0, split evenly by 4\n",
    "    elif remainder == 0:\n",
    "        for i in range(split):\n",
    "            # Group by top 4 squad scores\n",
    "            group_dict[i+1] = list(df['story_id'][:4])\n",
    "            # Drop stories you have already grouped\n",
    "            df = df[4:]\n",
    "        return group_dict\n",
    "\n",
    "    else:\n",
    "        return 'Invalid number of remaining users'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract each cluster 'tier'\n",
    "first = df[df['cluster']== 0]\n",
    "second = df[df['cluster']== 1]\n",
    "third = df[df['cluster']== 2]\n",
    "fourth = df[df['cluster']== 3]\n",
    "fifth = df[df['cluster']== 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create all the cluster dictionaries\n",
    "first_cluster = group_4(first)\n",
    "second_cluster = group_4(second)\n",
    "third_cluster = group_4(third)\n",
    "fourth_cluster = group_4(fourth)\n",
    "fifth_cluster = group_4(fifth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "First Cluster: {1: [5209, 3216, 5123, 3215], 2: [5229, 5120, 5102, 3247], 3: [3128, 5232, 5243, 3210], 4: [3123, 3116, 5258, 3111], 5: [3223, 5207, 3227, 3104], 6: [5110, 3225, 3125, 3103], 7: [3207, 3238, 3124, 3113], 8: [3222, 3206, 3131, 3107], 9: [3214, 3226, 3236, 5233], 10: [3119, 3235, 3121], 11: [3127, 3110, 3117], 12: [3202, 3240, 3229]}\nSecond Cluster: {1: [5234, 5235, 5213, 5219], 2: [5107, 5215, 5117, 5118], 3: [5129, 5227, 5119, 5222], 4: [5262, 5247, 5206]}\nThird Cluster: {1: [5244, 5122, 3129, 5202], 2: [3234, 3246, 3221, 5256], 3: [3248, 3203, 3243, 5103], 4: [3217, 5238, 5126, 5112], 5: [3237, 3218, 5216, 5264], 6: [3208, 5115, 5101, 3205], 7: [3241, 3105, 3244, 5104], 8: [3108, 5116, 3126, 5214], 9: [5109, 3112, 3106, 3120], 10: [3231, 3201, 3228]}\nFourth Cluster: {1: [5248, 5257, 5111, 5261], 2: [5241, 5217, 5208, 5230], 3: [5105, 3115, 5121, 3219], 4: [5113, 5204, 5251, 5259], 5: [3204, 5218, 5255, 5203], 6: [3122, 5237, 3109, 5210], 7: [5125, 5246, 3102, 5108], 8: [3211, 5132, 5228, 3224], 9: [3101, 3239, 5263, 3132], 10: [5225, 3213, 5114, 3232], 11: [5221, 3220, 5249], 12: [3114, 5239, 3118], 13: [3245, 5260, 5240]}\nFifth Cluster: {1: [5254, 5124, 5130, 5223], 2: [5245, 3209, 5106, 5236], 3: [5224, 3230, 3130, 5131], 4: [5253, 5220, 5252, 5205], 5: [3212, 5250, 5242]}\n"
     ]
    }
   ],
   "source": [
    "# View each cluster and their groupings\n",
    "print(f'First Cluster: {first_cluster}')\n",
    "print(f'Second Cluster: {second_cluster}')\n",
    "print(f'Third Cluster: {third_cluster}')\n",
    "print(f'Fourth Cluster: {fourth_cluster}')\n",
    "print(f'Fifth Cluster: {fifth_cluster}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Testing edge cases\n",
    "remainder_0 = first[3:]\n",
    "remainder_1 = first[2:]\n",
    "remainder_2 = first[1:]\n",
    "remainder_3 = first\n",
    "small_num_9 = first[-9:]\n",
    "small_num_5 = first[-5:]\n",
    "small_num_4 = first[-4:]\n",
    "small_num_3 = first[-3:]\n",
    "small_num_2 = first[-2:]\n",
    "small_num_1 = first[-1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{0: [5104]}"
      ]
     },
     "metadata": {},
     "execution_count": 59
    }
   ],
   "source": [
    "clust_dict = group_4(small_num_1)\n",
    "clust_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "metadata": {},
     "execution_count": 15
    }
   ],
   "source": [
    "# Sanity check - each story from the cluster was grouped\n",
    "unique = set()\n",
    "\n",
    "for inner_list in clust_dict.values():\n",
    "    for item in inner_list:\n",
    "        unique.add(item)\n",
    "    \n",
    "len(unique) == len(first)"
   ]
  },
  {
   "source": [
    "## KNN\n",
    "- Pull a random story, find the three nearest stories to create a group\n",
    "- Drop the stories that you have already grouped\n",
    "- TODO: Deal with the remainder problem"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap in function that continues until the remainder problem\n",
    "# Refit the data after you drop the grouped stories \n",
    "    # - otherwise it could suggest a story that we have already grouped and dropped\n",
    "\n",
    "def group_nn(df):\n",
    "    '''\n",
    "    Function creates groups of four from the input df by using Nearest Neighbors\n",
    "    Pulls one user, finds the three most similar users in the cohort to form the group\n",
    "    When there is a remainder of users not evenly divisable by 4, it will split the remainder so there is never more than 1 computer user in a group, unless there are less that 3 users.\n",
    "\n",
    "        Input: df to be grouped containing the column 'story_id'\n",
    "        Output: dictionary of groupings {group ID #: list of story_id's}\n",
    "    '''\n",
    "    # Empty dictionary to store the groupings\n",
    "    groups = {}\n",
    "\n",
    "    # Instantiate scaler\n",
    "    scaler = StandardScaler()\n",
    "\n",
    "    # Pull out features\n",
    "    features = df.drop('story_id', axis=1)\n",
    "\n",
    "    # Scale data\n",
    "    norm_x = scaler.fit_transform(features)\n",
    "\n",
    "    # Turn into df\n",
    "    df_norm_x = pd.DataFrame(norm_x)\n",
    "\n",
    "    # Instantiate model - groups of 4\n",
    "    nn = NearestNeighbors(n_neighbors=4, algorithm='kd_tree')\n",
    "\n",
    "    # Counter to use as key for groups in dictionary\n",
    "    counter = 1\n",
    "\n",
    "    # While loop that takes the top user and creates a group with the its three closest users\n",
    "    # Drops grouped users and continues until there are less than 12 users left to group\n",
    "    # Remainder problem will be dealt with after the while loop runs\n",
    "    while len(df_norm_x) >11:\n",
    "        # Fit the nearest neighbors model\n",
    "        nn.fit(df_norm_x)\n",
    "\n",
    "        # Find nearest neighbors\n",
    "        array_1, array_2 = nn.kneighbors([df_norm_x.iloc[0].values])\n",
    "\n",
    "        # Put story_id list into groups dictionary\n",
    "        groups[counter] = [df['story_id'][item] for item in array_2[0]]\n",
    "\n",
    "        # Increment the counter\n",
    "        counter += 1\n",
    "\n",
    "        # Drop the users you have already grouped\n",
    "        # From both df's that you are using\n",
    "        df_norm_x = df_norm_x.drop(array_2[0])\n",
    "        df = df.drop(array_2[0])\n",
    "\n",
    "        # Reset the index\n",
    "        # For both datasets that you are using\n",
    "        df_norm_x.reset_index(inplace= True, drop= True)\n",
    "        df.reset_index(inplace= True, drop= True)\n",
    "\n",
    "    # TODO: Deal with remainders\n",
    "    print(f\"Remaining users: {len(df_norm_x)}\")\n",
    "\n",
    "\n",
    "    return groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Remaining users: 11\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{1: [3132, 3220, 3245, 5132],\n",
       " 2: [3104, 3111, 5110, 3124],\n",
       " 3: [3103, 3222, 3238, 3207],\n",
       " 4: [3117, 3127, 3110, 3119],\n",
       " 5: [3102, 3211, 5113, 3101],\n",
       " 6: [3105, 5101, 5104, 3208],\n",
       " 7: [3129, 3234, 3221, 3246],\n",
       " 8: [3116, 3223, 3247, 3128],\n",
       " 9: [3118, 3232, 5249, 5225],\n",
       " 10: [3120, 3112, 3231, 3201],\n",
       " 11: [3121, 3202, 3235, 5233],\n",
       " 12: [3126, 5109, 3108, 3241],\n",
       " 13: [3131, 3214, 3206, 3227],\n",
       " 14: [3109, 5218, 5203, 5210],\n",
       " 15: [3107, 3226, 3113, 5258],\n",
       " 16: [3106, 3228, 3205, 3237],\n",
       " 17: [3130, 5106, 3209, 5224],\n",
       " 18: [3115, 5241, 5251, 5237],\n",
       " 19: [3123, 5232, 5207, 3210],\n",
       " 20: [3125, 5120, 5123, 3215],\n",
       " 21: [3122, 5255, 3204, 5208],\n",
       " 22: [3114, 5260, 5240, 3213],\n",
       " 23: [3216, 5102, 5243, 5229],\n",
       " 24: [3229, 3225, 3240, 3236],\n",
       " 25: [3218, 3217, 5264, 3203],\n",
       " 26: [3243, 5126, 5112, 5256],\n",
       " 27: [3244, 5115, 5116, 5216],\n",
       " 28: [3219, 5105, 5217, 3239],\n",
       " 29: [3212, 5242, 5245, 5220],\n",
       " 30: [3224, 5263, 5114, 5125],\n",
       " 31: [3248, 5103, 5244, 5122],\n",
       " 32: [3230, 5257, 5111, 5205],\n",
       " 33: [5254, 5130, 5236, 5223],\n",
       " 34: [5253, 5252, 5248, 5131],\n",
       " 35: [5262, 5117, 5261, 5118],\n",
       " 36: [5209, 5221, 5239, 5214],\n",
       " 37: [5238, 5246, 5202, 5228],\n",
       " 38: [5206, 5222, 5247, 5119],\n",
       " 39: [5230, 5204, 5121, 5259]}"
      ]
     },
     "metadata": {},
     "execution_count": 117
    }
   ],
   "source": [
    "# Group 167 stories from db\n",
    "nn_groups = group_nn(df)\n",
    "nn_groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "metadata": {},
     "execution_count": 107
    }
   ],
   "source": [
    "# Sanity check - each story from the cluster was grouped\n",
    "unique = set()\n",
    "\n",
    "for inner_list in nn_groups.values():\n",
    "    for item in inner_list:\n",
    "        unique.add(item)\n",
    "    \n",
    "len(unique) == len(df) - 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}