examples/0. Embeddings Generation/Pipelines/ML20M/5. The Big Merge.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Big Merge\n",
"\n",
"Other methods will be added soon"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"roberta = pd.read_csv('../../../../data/engineering/roberta.csv')\n",
"cat = pd.read_csv('../../../../data/engineering/mca.csv')\n",
"num = pd.read_csv('../../../../data/engineering/pca.csv')\n",
"num = num.set_index('idx')\n",
"cat = cat.set_index(cat.columns[0])\n",
"roberta = roberta.set_index('idx')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"movies = pd.read_csv('../../../../data/ml-20m/links.csv')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"df = pd.concat([roberta, cat, num], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dev/.local/lib/python3.7/site-packages/ppca/_ppca.py:82: RuntimeWarning: divide by zero encountered in log\n",
" det = np.log(np.linalg.det(Sx))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"0.3197493112573442\n",
"0.08642294395895767\n",
"0.011613807065083748\n",
"0.006629248893893269\n",
"0.011409862478935606\n",
"0.8079228072819928\n",
"0.06312485654055533\n",
"0.05411146255289068\n",
"0.04663160316817749\n",
"0.04029385825484155\n",
"0.03489849054666294\n",
"0.03031603746989564\n",
"0.02643809621940929\n",
"0.023162995698678968\n",
"0.0203959192899823\n",
"0.018052295018768483\n",
"0.016059788417409626\n",
"0.014358180687624955\n",
"0.012897967368266539\n",
"0.011638581618058197\n",
"0.010546746675690777\n",
"0.009595126661316566\n",
"0.008761272289783184\n",
"0.008026793373915764\n",
"0.007376680908001365\n",
"0.006798720521605572\n",
"0.0062829681706675355\n",
"0.005821283576526337\n",
"0.005406929219866408\n",
"0.005034242882371531\n",
"0.0046983850262487525\n",
"0.004395154395472112\n",
"0.0041208596504158646\n",
"0.003872232710471879\n",
"0.003646370274946076\n",
"0.0034406925665739774\n",
"0.003252911535067904\n",
"0.003081003730827092\n",
"0.0029231853657245566\n",
"0.0027778886027165495\n",
"0.0026437389509543774\n",
"0.002519533975680055\n",
"0.0024042235610552964\n",
"0.002296891856018224\n",
"0.002196740901195815\n",
"0.002103075838896906\n",
"0.0020152915670275107\n",
"0.0019328607029456268\n",
"0.0018553227544089168\n",
"0.001782274431118891\n",
"0.0017133610573070168\n",
"0.0016482690561274715\n",
"0.0015867194714918043\n",
"0.0015284624761753296\n",
"0.0014732727944593016\n",
"0.0014209459480452047\n",
"0.0013712952201179185\n",
"0.001324149226606064\n",
"0.0012793499847434386\n",
"0.0012367513771092131\n",
"0.0011962179212254842\n",
"0.0011576237692405567\n",
"0.001120851876855733\n",
"0.0010857932949968063\n",
"0.0010523465494305384\n",
"0.001020417084160119\n",
"0.0009899167520883712\n",
"0.0009607633424089101\n",
"0.0009328801386421226\n",
"0.0009061955037479308\n",
"0.0008806424908571753\n",
"0.000856158479084268\n",
"0.0008326848342841142\n",
"0.0008101665949125092\n",
"0.000788552183035085\n",
"0.0007677931401675053\n",
"0.0007478438876611371\n",
"0.0007286615106072425\n",
"0.0007102055644578886\n",
"0.0006924379028039329\n",
"0.0006753225248765649\n",
"0.0006588254409225502\n",
"0.000642914553711682\n",
"0.0006275595539855239\n",
"0.000612731828080193\n",
"0.0005984043756386281\n",
"0.0005845517355467234\n",
"0.0005711499183786994\n",
"0.0005581763436834919\n",
"0.0005456097808138605\n",
"0.0005334302919983713\n",
"0.0005216191767176692\n",
"0.0005101589167735288\n",
"0.0004990331212120225\n",
"0.0004882264711523199\n",
"0.00047772466406392766\n",
"0.00046751435769309957\n",
"0.0004575831136903741\n",
"0.00044791934120325116\n",
"0.0004385122407848385\n",
"0.0004293517489004639\n",
"0.0004204284835234162\n",
"0.00041173369106006774\n",
"0.0004032591951146358\n",
"0.0003949973473198476\n",
"0.00038694098056102355\n",
"0.00037908336481762284\n",
"0.00037141816578634135\n",
"0.0003639394064953727\n",
"0.00035664143180103025\n",
"0.0003495188759925494\n",
"0.00034256663333187554\n",
"0.00033577983148291857\n",
"0.0003291538078078471\n",
"0.00032268408822133665\n",
"0.000316366368663612\n",
"0.0003101964988172501\n",
"0.0003041704680177837\n",
"0.0002982843930814383\n",
"0.000292534507940978\n",
"0.00028691715479323143\n",
"0.0002814287767278767\n",
"0.0002760659114651176\n",
"0.000270825186255097\n",
"0.000265703313633292\n",
"0.0002606970879552861\n",
"0.00025580338255681845\n",
"0.00025101914752978516\n",
"0.00024634140783730274\n",
"0.0002417672618386657\n",
"0.00023729388011872743\n",
"0.0002329185044387394\n",
"0.00022863844695875102\n",
"0.0002244510894975349\n",
"0.00022035388291419267\n",
"0.0002163443465119652\n",
"0.00021242006746824416\n",
"0.000208578700262807\n",
"0.000204817966149351\n",
"0.00020113565245827303\n",
"0.00019752961203156616\n",
"0.0001939977624916267\n",
"0.0001905380855302674\n",
"0.00018714862607027705\n",
"0.00018382749150203104\n",
"0.00018057285068118212\n",
"0.00017738293307356656\n",
"0.00017425602765497317\n",
"0.00017119048191438502\n",
"0.00016818470067314628\n",
"0.00016523714497562736\n",
"0.00016234633080802752\n",
"0.00015951082789844584\n",
"0.0001567292583348756\n",
"0.0001540002952917785\n",
"0.00015132266158923713\n",
"0.0001486951282998472\n",
"0.00014611651331408737\n",
"0.0001435856798353008\n",
"0.00014110153492197242\n",
"0.00013866302796849972\n",
"0.00013626914917508337\n",
"0.00013391892803160665\n",
"0.00013161143180129287\n",
"0.00012934576395484676\n",
"0.0001271210626927477\n",
"0.00012493649938116747\n",
"0.00012279127712133686\n",
"0.00012068462918946032\n",
"0.000118615817648271\n",
"0.00011658413186843575\n",
"0.00011458888713344884\n",
"0.00011262942328538195\n",
"0.00011070510334643124\n",
"0.00010881531226547558\n",
"0.00010695945561289832\n",
"0.00010513695836045223\n",
"0.00010334726374328085\n",
"0.00010158983205243999\n",
"9.986413961660112e-05\n"
]
}
],
"source": [
"from ppca import PPCA\n",
"ppca = PPCA()\n",
"ppca.fit(data=df.values.astype(float), d=128, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.10027779, 0.15507462, 0.19134567, 0.21945611, 0.24460353,\n",
" 0.26845611, 0.29102592, 0.31095034, 0.32996193, 0.34643802,\n",
" 0.3620572 , 0.37633324, 0.38997795, 0.40261482, 0.41483703,\n",
" 0.42643099, 0.4368385 , 0.44636099, 0.45568047, 0.46451587,\n",
" 0.47289856, 0.480961 , 0.48863717, 0.49602041, 0.50310459,\n",
" 0.51003156, 0.51656425, 0.52289624, 0.52908073, 0.53498392,\n",
" 0.54078847, 0.54642035, 0.55186093, 0.55712029, 0.56207776,\n",
" 0.56691434, 0.57167964, 0.57630872, 0.58075463, 0.58516137,\n",
" 0.58940848, 0.59361722, 0.597649 , 0.60165605, 0.60551434,\n",
" 0.60932604, 0.61301356, 0.61662949, 0.62021721, 0.62374785,\n",
" 0.62722111, 0.63066072, 0.63403963, 0.63737077, 0.64063838,\n",
" 0.64385491, 0.64703044, 0.65019842, 0.65329135, 0.65630261,\n",
" 0.65926582, 0.66220377, 0.66510903, 0.66796339, 0.67077644,\n",
" 0.67351453, 0.67620537, 0.67883389, 0.68142576, 0.68400129,\n",
" 0.6865571 , 0.68904796, 0.69152211, 0.69396168, 0.69638304,\n",
" 0.69873471, 0.70108001, 0.70338456, 0.70567939, 0.70792041,\n",
" 0.71015812, 0.71239188, 0.71457543, 0.71672394, 0.71886385,\n",
" 0.72098915, 0.72306577, 0.7251115 , 0.72713712, 0.72914898,\n",
" 0.73113489, 0.73309689, 0.73505694, 0.73698951, 0.73889228,\n",
" 0.74079208, 0.74265585, 0.74450515, 0.74634531, 0.7481664 ,\n",
" 0.74995547, 0.75172597, 0.75347283, 0.75519582, 0.75690472,\n",
" 0.75859582, 0.76027799, 0.7619423 , 0.7635889 , 0.76521979,\n",
" 0.76682305, 0.76841543, 0.77000291, 0.77157517, 0.77313363,\n",
" 0.77467194, 0.77619465, 0.77770459, 0.77920782, 0.78069985,\n",
" 0.78217445, 0.78363201, 0.78507767, 0.78651602, 0.78793485,\n",
" 0.78934381, 0.79074488, 0.79214315])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ppca.var_exp"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Int64Index([ 1, 2, 3, 4, 5, 6, 7, 8,\n",
" 9, 10,\n",
" ...\n",
" 131241, 131243, 131248, 131250, 131252, 131254, 131256, 131258,\n",
" 131260, 131262],\n",
" dtype='int64', length=27278)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.index"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import torch\n",
"transformed = ppca.transform()\n",
"films_dict = dict([(k, torch.tensor(transformed[i]).float()) for k, i in zip(df.index, range(transformed.shape[0]))])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(films_dict, open('../../../../data/embeddings/ml20_pca128.pkl', 'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}