awarebayes/RecNN

View on GitHub
examples/0. Embeddings Generation/Pipelines/ML20M/5. The Big Merge.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Big Merge\n",
    "\n",
    "Other methods will be added soon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "roberta = pd.read_csv('../../../../data/engineering/roberta.csv')\n",
    "cat = pd.read_csv('../../../../data/engineering/mca.csv')\n",
    "num = pd.read_csv('../../../../data/engineering/pca.csv')\n",
    "num = num.set_index('idx')\n",
    "cat = cat.set_index(cat.columns[0])\n",
    "roberta = roberta.set_index('idx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "movies = pd.read_csv('../../../../data/ml-20m/links.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([roberta, cat, num], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dev/.local/lib/python3.7/site-packages/ppca/_ppca.py:82: RuntimeWarning: divide by zero encountered in log\n",
      "  det = np.log(np.linalg.det(Sx))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "0.3197493112573442\n",
      "0.08642294395895767\n",
      "0.011613807065083748\n",
      "0.006629248893893269\n",
      "0.011409862478935606\n",
      "0.8079228072819928\n",
      "0.06312485654055533\n",
      "0.05411146255289068\n",
      "0.04663160316817749\n",
      "0.04029385825484155\n",
      "0.03489849054666294\n",
      "0.03031603746989564\n",
      "0.02643809621940929\n",
      "0.023162995698678968\n",
      "0.0203959192899823\n",
      "0.018052295018768483\n",
      "0.016059788417409626\n",
      "0.014358180687624955\n",
      "0.012897967368266539\n",
      "0.011638581618058197\n",
      "0.010546746675690777\n",
      "0.009595126661316566\n",
      "0.008761272289783184\n",
      "0.008026793373915764\n",
      "0.007376680908001365\n",
      "0.006798720521605572\n",
      "0.0062829681706675355\n",
      "0.005821283576526337\n",
      "0.005406929219866408\n",
      "0.005034242882371531\n",
      "0.0046983850262487525\n",
      "0.004395154395472112\n",
      "0.0041208596504158646\n",
      "0.003872232710471879\n",
      "0.003646370274946076\n",
      "0.0034406925665739774\n",
      "0.003252911535067904\n",
      "0.003081003730827092\n",
      "0.0029231853657245566\n",
      "0.0027778886027165495\n",
      "0.0026437389509543774\n",
      "0.002519533975680055\n",
      "0.0024042235610552964\n",
      "0.002296891856018224\n",
      "0.002196740901195815\n",
      "0.002103075838896906\n",
      "0.0020152915670275107\n",
      "0.0019328607029456268\n",
      "0.0018553227544089168\n",
      "0.001782274431118891\n",
      "0.0017133610573070168\n",
      "0.0016482690561274715\n",
      "0.0015867194714918043\n",
      "0.0015284624761753296\n",
      "0.0014732727944593016\n",
      "0.0014209459480452047\n",
      "0.0013712952201179185\n",
      "0.001324149226606064\n",
      "0.0012793499847434386\n",
      "0.0012367513771092131\n",
      "0.0011962179212254842\n",
      "0.0011576237692405567\n",
      "0.001120851876855733\n",
      "0.0010857932949968063\n",
      "0.0010523465494305384\n",
      "0.001020417084160119\n",
      "0.0009899167520883712\n",
      "0.0009607633424089101\n",
      "0.0009328801386421226\n",
      "0.0009061955037479308\n",
      "0.0008806424908571753\n",
      "0.000856158479084268\n",
      "0.0008326848342841142\n",
      "0.0008101665949125092\n",
      "0.000788552183035085\n",
      "0.0007677931401675053\n",
      "0.0007478438876611371\n",
      "0.0007286615106072425\n",
      "0.0007102055644578886\n",
      "0.0006924379028039329\n",
      "0.0006753225248765649\n",
      "0.0006588254409225502\n",
      "0.000642914553711682\n",
      "0.0006275595539855239\n",
      "0.000612731828080193\n",
      "0.0005984043756386281\n",
      "0.0005845517355467234\n",
      "0.0005711499183786994\n",
      "0.0005581763436834919\n",
      "0.0005456097808138605\n",
      "0.0005334302919983713\n",
      "0.0005216191767176692\n",
      "0.0005101589167735288\n",
      "0.0004990331212120225\n",
      "0.0004882264711523199\n",
      "0.00047772466406392766\n",
      "0.00046751435769309957\n",
      "0.0004575831136903741\n",
      "0.00044791934120325116\n",
      "0.0004385122407848385\n",
      "0.0004293517489004639\n",
      "0.0004204284835234162\n",
      "0.00041173369106006774\n",
      "0.0004032591951146358\n",
      "0.0003949973473198476\n",
      "0.00038694098056102355\n",
      "0.00037908336481762284\n",
      "0.00037141816578634135\n",
      "0.0003639394064953727\n",
      "0.00035664143180103025\n",
      "0.0003495188759925494\n",
      "0.00034256663333187554\n",
      "0.00033577983148291857\n",
      "0.0003291538078078471\n",
      "0.00032268408822133665\n",
      "0.000316366368663612\n",
      "0.0003101964988172501\n",
      "0.0003041704680177837\n",
      "0.0002982843930814383\n",
      "0.000292534507940978\n",
      "0.00028691715479323143\n",
      "0.0002814287767278767\n",
      "0.0002760659114651176\n",
      "0.000270825186255097\n",
      "0.000265703313633292\n",
      "0.0002606970879552861\n",
      "0.00025580338255681845\n",
      "0.00025101914752978516\n",
      "0.00024634140783730274\n",
      "0.0002417672618386657\n",
      "0.00023729388011872743\n",
      "0.0002329185044387394\n",
      "0.00022863844695875102\n",
      "0.0002244510894975349\n",
      "0.00022035388291419267\n",
      "0.0002163443465119652\n",
      "0.00021242006746824416\n",
      "0.000208578700262807\n",
      "0.000204817966149351\n",
      "0.00020113565245827303\n",
      "0.00019752961203156616\n",
      "0.0001939977624916267\n",
      "0.0001905380855302674\n",
      "0.00018714862607027705\n",
      "0.00018382749150203104\n",
      "0.00018057285068118212\n",
      "0.00017738293307356656\n",
      "0.00017425602765497317\n",
      "0.00017119048191438502\n",
      "0.00016818470067314628\n",
      "0.00016523714497562736\n",
      "0.00016234633080802752\n",
      "0.00015951082789844584\n",
      "0.0001567292583348756\n",
      "0.0001540002952917785\n",
      "0.00015132266158923713\n",
      "0.0001486951282998472\n",
      "0.00014611651331408737\n",
      "0.0001435856798353008\n",
      "0.00014110153492197242\n",
      "0.00013866302796849972\n",
      "0.00013626914917508337\n",
      "0.00013391892803160665\n",
      "0.00013161143180129287\n",
      "0.00012934576395484676\n",
      "0.0001271210626927477\n",
      "0.00012493649938116747\n",
      "0.00012279127712133686\n",
      "0.00012068462918946032\n",
      "0.000118615817648271\n",
      "0.00011658413186843575\n",
      "0.00011458888713344884\n",
      "0.00011262942328538195\n",
      "0.00011070510334643124\n",
      "0.00010881531226547558\n",
      "0.00010695945561289832\n",
      "0.00010513695836045223\n",
      "0.00010334726374328085\n",
      "0.00010158983205243999\n",
      "9.986413961660112e-05\n"
     ]
    }
   ],
   "source": [
    "from ppca import PPCA\n",
    "ppca = PPCA()\n",
    "ppca.fit(data=df.values.astype(float), d=128, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.10027779, 0.15507462, 0.19134567, 0.21945611, 0.24460353,\n",
       "       0.26845611, 0.29102592, 0.31095034, 0.32996193, 0.34643802,\n",
       "       0.3620572 , 0.37633324, 0.38997795, 0.40261482, 0.41483703,\n",
       "       0.42643099, 0.4368385 , 0.44636099, 0.45568047, 0.46451587,\n",
       "       0.47289856, 0.480961  , 0.48863717, 0.49602041, 0.50310459,\n",
       "       0.51003156, 0.51656425, 0.52289624, 0.52908073, 0.53498392,\n",
       "       0.54078847, 0.54642035, 0.55186093, 0.55712029, 0.56207776,\n",
       "       0.56691434, 0.57167964, 0.57630872, 0.58075463, 0.58516137,\n",
       "       0.58940848, 0.59361722, 0.597649  , 0.60165605, 0.60551434,\n",
       "       0.60932604, 0.61301356, 0.61662949, 0.62021721, 0.62374785,\n",
       "       0.62722111, 0.63066072, 0.63403963, 0.63737077, 0.64063838,\n",
       "       0.64385491, 0.64703044, 0.65019842, 0.65329135, 0.65630261,\n",
       "       0.65926582, 0.66220377, 0.66510903, 0.66796339, 0.67077644,\n",
       "       0.67351453, 0.67620537, 0.67883389, 0.68142576, 0.68400129,\n",
       "       0.6865571 , 0.68904796, 0.69152211, 0.69396168, 0.69638304,\n",
       "       0.69873471, 0.70108001, 0.70338456, 0.70567939, 0.70792041,\n",
       "       0.71015812, 0.71239188, 0.71457543, 0.71672394, 0.71886385,\n",
       "       0.72098915, 0.72306577, 0.7251115 , 0.72713712, 0.72914898,\n",
       "       0.73113489, 0.73309689, 0.73505694, 0.73698951, 0.73889228,\n",
       "       0.74079208, 0.74265585, 0.74450515, 0.74634531, 0.7481664 ,\n",
       "       0.74995547, 0.75172597, 0.75347283, 0.75519582, 0.75690472,\n",
       "       0.75859582, 0.76027799, 0.7619423 , 0.7635889 , 0.76521979,\n",
       "       0.76682305, 0.76841543, 0.77000291, 0.77157517, 0.77313363,\n",
       "       0.77467194, 0.77619465, 0.77770459, 0.77920782, 0.78069985,\n",
       "       0.78217445, 0.78363201, 0.78507767, 0.78651602, 0.78793485,\n",
       "       0.78934381, 0.79074488, 0.79214315])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ppca.var_exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Int64Index([     1,      2,      3,      4,      5,      6,      7,      8,\n",
       "                 9,     10,\n",
       "            ...\n",
       "            131241, 131243, 131248, 131250, 131252, 131254, 131256, 131258,\n",
       "            131260, 131262],\n",
       "           dtype='int64', length=27278)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import torch\n",
    "transformed = ppca.transform()\n",
    "films_dict = dict([(k, torch.tensor(transformed[i]).float()) for k, i in zip(df.index, range(transformed.shape[0]))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(films_dict, open('../../../../data/embeddings/ml20_pca128.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}