examples/0. Embeddings Generation/Pipelines/ML20M/2. NLP.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# NLP with RoBERTa.\n",
"\n",
"Yeah, I am somewhat of an NLP engineer myself"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import json\n",
"import torch.nn.functional as F\n",
"import pandas as pd\n",
"from fairseq.data.data_utils import collate_tokens\n",
"from tqdm.auto import tqdm\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"omdb = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )\n",
"tmdb = json.load(open(\"../../../../data/parsed/tmdb.json\", \"r\") )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
"cuda = torch.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"plots = []\n",
"for i in tmdb.keys():\n",
" omdb_plot = omdb[i]['omdb'].get('Plot', '')\n",
" tmdb_plot = tmdb[i]['tmdb'].get('overview', '')\n",
" plot = tmdb_plot + ' ' + omdb_plot\n",
" plots.append((i, plot, len(plot)))\n",
" \n",
"plots = list(sorted(plots, key=lambda x: x[2]))\n",
"plots = list(filter(lambda x: x[2] > 4, plots))\n",
"\n",
"def chunks(l, n):\n",
" for i in range(0, len(l), n):\n",
" yield l[i:i + n]\n",
"\n",
"ids = [i[0] for i in plots]\n",
"plots = [i[1] for i in plots]\n",
"plots = list(chunks(plots, batch_size))\n",
"ids = list(chunks(ids, batch_size))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /home/dev/.cache/torch/hub/pytorch_fairseq_master\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from cache at /home/dev/.cache/torch/pytorch_fairseq/37d2bc14cf6332d61ed5abeb579948e6054e46cc724c7d23426382d11a31b2d6.ae5852b4abc6bf762e0b6b30f19e741aa05562471e9eb8f4a6ae261f04f9b350\n",
"| dictionary: 50264 types\n",
"\n"
]
}
],
"source": [
"roberta = torch.hub.load('pytorch/fairseq', 'roberta.base').to(cuda)\n",
"roberta.eval()\n",
"print()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"fs = {}\n",
"\n",
"def extract_features(batch, ids):\n",
" batch = collate_tokens([roberta.encode(sent) for sent in batch], pad_idx=1).to(cuda)\n",
" batch = batch[:, :512]\n",
" features = roberta.extract_features(batch)\n",
" pooled_features = F.avg_pool2d(features, (features.size(1), 1)).squeeze()\n",
" for i in range(pooled_features.size(0)):\n",
" fs[ids[i]] = pooled_features[i].detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "51a46152153e49e6ab7c6744feb97427",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=6779), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for batch, ids in tqdm(zip(plots[::-1], ids[::-1]), total=len(plots)):\n",
" extract_features(batch, ids)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"transformed = pd.DataFrame(fs).T\n",
"transformed.index = transformed.index.astype(int)\n",
"transformed = transformed.sort_index()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>758</th>\n",
" <th>759</th>\n",
" <th>760</th>\n",
" <th>761</th>\n",
" <th>762</th>\n",
" <th>763</th>\n",
" <th>764</th>\n",
" <th>765</th>\n",
" <th>766</th>\n",
" <th>767</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-0.005599</td>\n",
" <td>0.138494</td>\n",
" <td>0.047051</td>\n",
" <td>-0.099981</td>\n",
" <td>0.208267</td>\n",
" <td>0.163597</td>\n",
" <td>-0.050247</td>\n",
" <td>0.035369</td>\n",
" <td>0.021860</td>\n",
" <td>-0.001333</td>\n",
" <td>...</td>\n",
" <td>-0.053477</td>\n",
" <td>0.014401</td>\n",
" <td>-0.035731</td>\n",
" <td>-0.068612</td>\n",
" <td>0.146932</td>\n",
" <td>0.106177</td>\n",
" <td>-0.128289</td>\n",
" <td>-0.231606</td>\n",
" <td>0.047912</td>\n",
" <td>-0.046285</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-0.028936</td>\n",
" <td>0.053734</td>\n",
" <td>0.066000</td>\n",
" <td>-0.130739</td>\n",
" <td>0.197591</td>\n",
" <td>0.014505</td>\n",
" <td>-0.001784</td>\n",
" <td>0.091164</td>\n",
" <td>0.036338</td>\n",
" <td>-0.002871</td>\n",
" <td>...</td>\n",
" <td>-0.058495</td>\n",
" <td>0.049999</td>\n",
" <td>-0.049668</td>\n",
" <td>-0.037801</td>\n",
" <td>0.088053</td>\n",
" <td>0.142559</td>\n",
" <td>-0.166629</td>\n",
" <td>-0.081439</td>\n",
" <td>0.034168</td>\n",
" <td>-0.023142</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.023951</td>\n",
" <td>0.082014</td>\n",
" <td>0.041002</td>\n",
" <td>-0.058334</td>\n",
" <td>0.188524</td>\n",
" <td>0.099200</td>\n",
" <td>0.009292</td>\n",
" <td>0.044268</td>\n",
" <td>0.051445</td>\n",
" <td>0.032975</td>\n",
" <td>...</td>\n",
" <td>-0.031117</td>\n",
" <td>-0.017112</td>\n",
" <td>-0.016568</td>\n",
" <td>-0.009261</td>\n",
" <td>0.070678</td>\n",
" <td>0.122078</td>\n",
" <td>-0.029504</td>\n",
" <td>-0.045054</td>\n",
" <td>0.114256</td>\n",
" <td>0.064617</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.028417</td>\n",
" <td>0.169414</td>\n",
" <td>0.063841</td>\n",
" <td>-0.036933</td>\n",
" <td>0.114328</td>\n",
" <td>0.082039</td>\n",
" <td>0.017422</td>\n",
" <td>0.084967</td>\n",
" <td>-0.001609</td>\n",
" <td>0.048082</td>\n",
" <td>...</td>\n",
" <td>-0.081082</td>\n",
" <td>-0.044695</td>\n",
" <td>0.164680</td>\n",
" <td>0.029210</td>\n",
" <td>0.015597</td>\n",
" <td>0.080508</td>\n",
" <td>0.006273</td>\n",
" <td>-0.155380</td>\n",
" <td>0.039771</td>\n",
" <td>0.049289</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.011459</td>\n",
" <td>0.131149</td>\n",
" <td>0.039703</td>\n",
" <td>-0.037407</td>\n",
" <td>0.289072</td>\n",
" <td>0.121404</td>\n",
" <td>-0.046844</td>\n",
" <td>-0.013482</td>\n",
" <td>-0.103010</td>\n",
" <td>0.039538</td>\n",
" <td>...</td>\n",
" <td>-0.075606</td>\n",
" <td>0.007551</td>\n",
" <td>0.031218</td>\n",
" <td>-0.000565</td>\n",
" <td>0.113364</td>\n",
" <td>0.092764</td>\n",
" <td>0.033090</td>\n",
" <td>-0.285467</td>\n",
" <td>0.050361</td>\n",
" <td>0.061391</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 768 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"1 -0.005599 0.138494 0.047051 -0.099981 0.208267 0.163597 -0.050247 \n",
"2 -0.028936 0.053734 0.066000 -0.130739 0.197591 0.014505 -0.001784 \n",
"3 0.023951 0.082014 0.041002 -0.058334 0.188524 0.099200 0.009292 \n",
"4 0.028417 0.169414 0.063841 -0.036933 0.114328 0.082039 0.017422 \n",
"5 0.011459 0.131149 0.039703 -0.037407 0.289072 0.121404 -0.046844 \n",
"\n",
" 7 8 9 ... 758 759 760 761 \\\n",
"1 0.035369 0.021860 -0.001333 ... -0.053477 0.014401 -0.035731 -0.068612 \n",
"2 0.091164 0.036338 -0.002871 ... -0.058495 0.049999 -0.049668 -0.037801 \n",
"3 0.044268 0.051445 0.032975 ... -0.031117 -0.017112 -0.016568 -0.009261 \n",
"4 0.084967 -0.001609 0.048082 ... -0.081082 -0.044695 0.164680 0.029210 \n",
"5 -0.013482 -0.103010 0.039538 ... -0.075606 0.007551 0.031218 -0.000565 \n",
"\n",
" 762 763 764 765 766 767 \n",
"1 0.146932 0.106177 -0.128289 -0.231606 0.047912 -0.046285 \n",
"2 0.088053 0.142559 -0.166629 -0.081439 0.034168 -0.023142 \n",
"3 0.070678 0.122078 -0.029504 -0.045054 0.114256 0.064617 \n",
"4 0.015597 0.080508 0.006273 -0.155380 0.039771 0.049289 \n",
"5 0.113364 0.092764 0.033090 -0.285467 0.050361 0.061391 \n",
"\n",
"[5 rows x 768 columns]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformed.head()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"transformed.to_csv('../../../../data/engineering/roberta.csv', index=True, index_label='idx')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}