awarebayes/RecNN

View on GitHub
examples/0. Embeddings Generation/Pipelines/ML20M/2. NLP.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NLP with RoBERTa.\n",
    "\n",
    "Yeah, I am somewhat of an NLP engineer myself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import json\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "from fairseq.data.data_utils import collate_tokens\n",
    "from tqdm.auto import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "omdb = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )\n",
    "tmdb = json.load(open(\"../../../../data/parsed/tmdb.json\", \"r\") )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "cuda = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "plots = []\n",
    "for i in tmdb.keys():\n",
    "    omdb_plot = omdb[i]['omdb'].get('Plot', '')\n",
    "    tmdb_plot = tmdb[i]['tmdb'].get('overview', '')\n",
    "    plot = tmdb_plot + ' ' + omdb_plot\n",
    "    plots.append((i, plot, len(plot)))\n",
    "    \n",
    "plots = list(sorted(plots, key=lambda x: x[2]))\n",
    "plots = list(filter(lambda x: x[2] > 4, plots))\n",
    "\n",
    "def chunks(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield l[i:i + n]\n",
    "\n",
    "ids = [i[0] for i in plots]\n",
    "plots = [i[1] for i in plots]\n",
    "plots = list(chunks(plots, batch_size))\n",
    "ids = list(chunks(ids, batch_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/dev/.cache/torch/hub/pytorch_fairseq_master\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from cache at /home/dev/.cache/torch/pytorch_fairseq/37d2bc14cf6332d61ed5abeb579948e6054e46cc724c7d23426382d11a31b2d6.ae5852b4abc6bf762e0b6b30f19e741aa05562471e9eb8f4a6ae261f04f9b350\n",
      "| dictionary: 50264 types\n",
      "\n"
     ]
    }
   ],
   "source": [
    "roberta = torch.hub.load('pytorch/fairseq', 'roberta.base').to(cuda)\n",
    "roberta.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = {}\n",
    "\n",
    "def extract_features(batch, ids):\n",
    "    batch = collate_tokens([roberta.encode(sent) for sent in batch], pad_idx=1).to(cuda)\n",
    "    batch = batch[:, :512]\n",
    "    features = roberta.extract_features(batch)\n",
    "    pooled_features = F.avg_pool2d(features, (features.size(1), 1)).squeeze()\n",
    "    for i in range(pooled_features.size(0)):\n",
    "        fs[ids[i]] = pooled_features[i].detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51a46152153e49e6ab7c6744feb97427",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=6779), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for batch, ids in tqdm(zip(plots[::-1], ids[::-1]), total=len(plots)):\n",
    "    extract_features(batch, ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed = pd.DataFrame(fs).T\n",
    "transformed.index = transformed.index.astype(int)\n",
    "transformed = transformed.sort_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>758</th>\n",
       "      <th>759</th>\n",
       "      <th>760</th>\n",
       "      <th>761</th>\n",
       "      <th>762</th>\n",
       "      <th>763</th>\n",
       "      <th>764</th>\n",
       "      <th>765</th>\n",
       "      <th>766</th>\n",
       "      <th>767</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.005599</td>\n",
       "      <td>0.138494</td>\n",
       "      <td>0.047051</td>\n",
       "      <td>-0.099981</td>\n",
       "      <td>0.208267</td>\n",
       "      <td>0.163597</td>\n",
       "      <td>-0.050247</td>\n",
       "      <td>0.035369</td>\n",
       "      <td>0.021860</td>\n",
       "      <td>-0.001333</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.053477</td>\n",
       "      <td>0.014401</td>\n",
       "      <td>-0.035731</td>\n",
       "      <td>-0.068612</td>\n",
       "      <td>0.146932</td>\n",
       "      <td>0.106177</td>\n",
       "      <td>-0.128289</td>\n",
       "      <td>-0.231606</td>\n",
       "      <td>0.047912</td>\n",
       "      <td>-0.046285</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.028936</td>\n",
       "      <td>0.053734</td>\n",
       "      <td>0.066000</td>\n",
       "      <td>-0.130739</td>\n",
       "      <td>0.197591</td>\n",
       "      <td>0.014505</td>\n",
       "      <td>-0.001784</td>\n",
       "      <td>0.091164</td>\n",
       "      <td>0.036338</td>\n",
       "      <td>-0.002871</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.058495</td>\n",
       "      <td>0.049999</td>\n",
       "      <td>-0.049668</td>\n",
       "      <td>-0.037801</td>\n",
       "      <td>0.088053</td>\n",
       "      <td>0.142559</td>\n",
       "      <td>-0.166629</td>\n",
       "      <td>-0.081439</td>\n",
       "      <td>0.034168</td>\n",
       "      <td>-0.023142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.023951</td>\n",
       "      <td>0.082014</td>\n",
       "      <td>0.041002</td>\n",
       "      <td>-0.058334</td>\n",
       "      <td>0.188524</td>\n",
       "      <td>0.099200</td>\n",
       "      <td>0.009292</td>\n",
       "      <td>0.044268</td>\n",
       "      <td>0.051445</td>\n",
       "      <td>0.032975</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.031117</td>\n",
       "      <td>-0.017112</td>\n",
       "      <td>-0.016568</td>\n",
       "      <td>-0.009261</td>\n",
       "      <td>0.070678</td>\n",
       "      <td>0.122078</td>\n",
       "      <td>-0.029504</td>\n",
       "      <td>-0.045054</td>\n",
       "      <td>0.114256</td>\n",
       "      <td>0.064617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.028417</td>\n",
       "      <td>0.169414</td>\n",
       "      <td>0.063841</td>\n",
       "      <td>-0.036933</td>\n",
       "      <td>0.114328</td>\n",
       "      <td>0.082039</td>\n",
       "      <td>0.017422</td>\n",
       "      <td>0.084967</td>\n",
       "      <td>-0.001609</td>\n",
       "      <td>0.048082</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.081082</td>\n",
       "      <td>-0.044695</td>\n",
       "      <td>0.164680</td>\n",
       "      <td>0.029210</td>\n",
       "      <td>0.015597</td>\n",
       "      <td>0.080508</td>\n",
       "      <td>0.006273</td>\n",
       "      <td>-0.155380</td>\n",
       "      <td>0.039771</td>\n",
       "      <td>0.049289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.011459</td>\n",
       "      <td>0.131149</td>\n",
       "      <td>0.039703</td>\n",
       "      <td>-0.037407</td>\n",
       "      <td>0.289072</td>\n",
       "      <td>0.121404</td>\n",
       "      <td>-0.046844</td>\n",
       "      <td>-0.013482</td>\n",
       "      <td>-0.103010</td>\n",
       "      <td>0.039538</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.075606</td>\n",
       "      <td>0.007551</td>\n",
       "      <td>0.031218</td>\n",
       "      <td>-0.000565</td>\n",
       "      <td>0.113364</td>\n",
       "      <td>0.092764</td>\n",
       "      <td>0.033090</td>\n",
       "      <td>-0.285467</td>\n",
       "      <td>0.050361</td>\n",
       "      <td>0.061391</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 768 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        0         1         2         3         4         5         6    \\\n",
       "1 -0.005599  0.138494  0.047051 -0.099981  0.208267  0.163597 -0.050247   \n",
       "2 -0.028936  0.053734  0.066000 -0.130739  0.197591  0.014505 -0.001784   \n",
       "3  0.023951  0.082014  0.041002 -0.058334  0.188524  0.099200  0.009292   \n",
       "4  0.028417  0.169414  0.063841 -0.036933  0.114328  0.082039  0.017422   \n",
       "5  0.011459  0.131149  0.039703 -0.037407  0.289072  0.121404 -0.046844   \n",
       "\n",
       "        7         8         9    ...       758       759       760       761  \\\n",
       "1  0.035369  0.021860 -0.001333  ... -0.053477  0.014401 -0.035731 -0.068612   \n",
       "2  0.091164  0.036338 -0.002871  ... -0.058495  0.049999 -0.049668 -0.037801   \n",
       "3  0.044268  0.051445  0.032975  ... -0.031117 -0.017112 -0.016568 -0.009261   \n",
       "4  0.084967 -0.001609  0.048082  ... -0.081082 -0.044695  0.164680  0.029210   \n",
       "5 -0.013482 -0.103010  0.039538  ... -0.075606  0.007551  0.031218 -0.000565   \n",
       "\n",
       "        762       763       764       765       766       767  \n",
       "1  0.146932  0.106177 -0.128289 -0.231606  0.047912 -0.046285  \n",
       "2  0.088053  0.142559 -0.166629 -0.081439  0.034168 -0.023142  \n",
       "3  0.070678  0.122078 -0.029504 -0.045054  0.114256  0.064617  \n",
       "4  0.015597  0.080508  0.006273 -0.155380  0.039771  0.049289  \n",
       "5  0.113364  0.092764  0.033090 -0.285467  0.050361  0.061391  \n",
       "\n",
       "[5 rows x 768 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformed.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed.to_csv('../../../../data/engineering/roberta.csv', index=True, index_label='idx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}