examples/0. Embeddings Generation/Pipelines/ML20M/4. Graph Embeddings.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch_geometric.nn import GCNConv\n",
"from torch_geometric.data import Data, DataLoader\n",
"from tqdm.auto import tqdm\n",
"import numpy as np\n",
"import json\n",
"import itertools\n",
"from sklearn import preprocessing\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../../../\")\n",
"import recnn\n",
"\n",
"device = torch.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"omdb = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )\n",
"movies = {}\n",
"for i in omdb.keys():\n",
" movies[i] = omdb[i]['omdb'].get('Actors', False).split(', ')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"movie_embeddings_key_dict = pickle.load(open('../../../../data/embeddings/ml20_pca128.pkl', 'rb'))\n",
"movies_embeddings_tensor, key_to_id, id_to_key = recnn.data.make_items_tensor(movie_embeddings_key_dict)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26db04c4ebbc437f9d132da7e5d32192",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=27278), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4459e6d52135469aaadab219f7d4e124",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=45583), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"actor_list = list(itertools.chain(*movies.values()))\n",
"le = list(set(sorted(actor_list)))\n",
"\n",
"\n",
"movies = dict([(k, list(map(le.index, i))) for k, i in tqdm(movies.items())])\n",
"actors = {}\n",
"for m in movies.keys():\n",
" for a in movies[m]:\n",
" if not actors.get(a, False):\n",
" actors[a] = []\n",
" actors[a].append(m)\n",
"actors = dict([(k, i) for k, i in tqdm(actors.items()) if len(i) > 3 ])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2019-09-11 17:26:40,299 DEBUG:Loaded backend module://ipykernel.pylab.backend_inline version unknown.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x7fe2b82074e0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAUfElEQVR4nO3df6zd9X3f8eer5leFIwwltZixZrJ5WgisFO4IUqbqOrRgyB8mUiI5QolJqVxtoCUak2JaddAmSGRKghY1pXOGG6fJcsPyQ1hAxjyHK5Q/CODUwRiXcgtWaoywOozJTTo2Z+/9cT52jy/32uf+8LnnuM+HdHS+38/5nO95nS/XvPz9nu89TlUhSfqH7ZcWO4AkafFZBpIky0CSZBlIkrAMJEnAGYsd4EQuvPDCWrVqVc/zf/azn3HuueeeukCnyDDmHsbMYO5+G8bcw5gZjs+9c+fOv62qd85qA1U1sLerrrqqZuPxxx+f1fxBMYy5hzFzlbn7bRhzD2PmquNzA8/ULP9/62kiSZJlIEmyDCRJWAaSJCwDSRKWgSQJy0CShGUgScIykCQx4F9HMV+rNj1ybHnfvR9YxCSSNNg8MpAkWQaSJMtAkoRlIEnCMpAkYRlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJoocySHJOkqeS/DjJniR/2MYvSfLDJC8m+WaSs9r42W19oj2+qmtbd7bxF5Jcf6relCRpdno5MngLeH9V/RpwBbA2yTXAZ4H7qmo1cAi4tc2/FThUVf8UuK/NI8mlwHrgPcBa4E+SLFnINyNJmpuTlkF1TLbVM9utgPcD32rjW4Gb2vK6tk57/NokaeNjVfVWVb0MTABXL8i7kCTNS0+fGSRZkmQXcBDYDvw18EZVHWlT9gMr2vIK4G8A2uOHgV/pHp/mOZKkRdTTP25TVb8ArkiyDPgu8O7pprX7zPDYTOPHSbIR2AiwfPlyxsfHe4kIwOTk5HHz77j8yLHl2Wyn36bmHgbDmBnM3W/DmHsYM8P8c8/qXzqrqjeSjAPXAMuSnNH+9n8xcKBN2w+sBPYnOQM4D3i9a/yo7ud0v8ZmYDPAyMhIjY6O9pxvfHyc7vm3dP9LZzf3vp1+m5p7GAxjZjB3vw1j7mHMDPPP3cvVRO9sRwQk+WXgN4G9wOPAh9q0DcBDbXlbW6c9/v2qqja+vl1tdAmwGnhqzsklSQumlyODi4Ct7cqfXwIerKqHkzwPjCX5DPAXwANt/gPAnyeZoHNEsB6gqvYkeRB4HjgC3NZOP0mSFtlJy6CqngV+fZrxl5jmaqCq+t/Ah2fY1j3APbOPKUk6lfwNZEmSZSBJsgwkSVgGkiQsA0kSloEkCctAkoRlIEnCMpAkYRlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJwjKQJGEZSJKwDCRJWAaSJCwDSRKWgSQJy0CSRA9lkGRlkseT7E2yJ8kn2vjdSV5Jsqvdbux6zp1JJpK8kOT6rvG1bWwiyaZT85YkSbN1Rg9zjgB3VNWPkrwD2Jlke3vsvqr6XPfkJJcC64H3AP8I+J9J/ll7+EvAbwH7gaeTbKuq5xfijUiS5u6kZVBVrwKvtuWfJtkLrDjBU9YBY1X1FvBykgng6vbYRFW9BJBkrM21DCRpkaWqep+crAKeAC4D/h1wC/Am8Aydo4dDSf4YeLKqvtae8wDwvbaJtVX1O238o8B7q+r2Ka+xEdgIsHz58qvGxsZ6zjc5OcnSpUuPre9+5fCx5ctXnNfzdvptau5hMIyZwdz9Noy5hzEzHJ97zZo1O6tqZDbP7+U0EQBJlgLfBj5ZVW8muR/4NFDt/vPAbwOZ5unF9J9PvK2JqmozsBlgZGSkRkdHe43I+Pg43fNv2fTIseV9N/e+nX6bmnsYDGNmMHe/DWPuYcwM88/dUxkkOZNOEXy9qr4DUFWvdT3+ZeDhtrofWNn19IuBA215pnFJ0iLq5WqiAA8Ae6vqC13jF3VN+yDwXFveBqxPcnaSS4DVwFPA08DqJJckOYvOh8zbFuZtSJLmo5cjg/cBHwV2J9nVxn4P+EiSK+ic6tkH/C5AVe1J8iCdD4aPALdV1S8AktwOPAYsAbZU1Z4FfC+SpDnq5WqiHzD95wCPnuA59wD3TDP+6ImeJ0laHP4GsiTJMpAkWQaSJCwDSRKWgSQJy0CShGUgScIykCRhGUiSsAwkSVgGkiQsA0kSloEkCctAkoRlIEnCMpAkYRlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJoocySLIyyeNJ9ibZk+QTbfyCJNuTvNjuz2/jSfLFJBNJnk1yZde2NrT5LybZcOreliRpNno5MjgC3FFV7wauAW5LcimwCdhRVauBHW0d4AZgdbttBO6HTnkAdwHvBa4G7jpaIJKkxXXSMqiqV6vqR235p8BeYAWwDtjapm0FbmrL64CvVseTwLIkFwHXA9ur6vWqOgRsB9Yu6LuRJM1Jqqr3yckq4AngMuAnVbWs67FDVXV+koeBe6vqB218B/ApYBQ4p6o+08b/APi7qvrclNfYSOeIguXLl181NjbWc77JyUmWLl16bH33K4ePLV++4ryet9NvU3MPg2HMDObut2HMPYyZ4fjca9as2VlVI7N5/hm9TkyyFPg28MmqejPJjFOnGasTjB8/ULUZ2AwwMjJSo6OjvUZkfHyc7vm3bHrk2PK+m3vfTr9NzT0MhjEzmLvfhjH3MGaG+efu6WqiJGfSKYKvV9V32vBr7fQP7f5gG98PrOx6+sXAgROMS5IWWS9XEwV4ANhbVV/oemgbcPSKoA3AQ13jH2tXFV0DHK6qV4HHgOuSnN8+OL6ujUmSFlkvp4neB3wU2J1kVxv7PeBe4MEktwI/AT7cHnsUuBGYAH4OfBygql5P8mng6Tbvj6rq9QV5F5KkeTlpGbQPgmf6gODaaeYXcNsM29oCbJlNQEnSqedvIEuSLANJkmUgScIykCRhGUiSsAwkSVgGkiQsA0kSloEkCctAkoRlIEnCMpAkYRlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJwjKQJGEZSJKwDCRJWAaSJHoogyRbkhxM8lzX2N1JXkmyq91u7HrsziQTSV5Icn3X+No2NpFk08K/FUnSXPVyZPAVYO004/dV1RXt9ihAkkuB9cB72nP+JMmSJEuALwE3AJcCH2lzJUkD4IyTTaiqJ5Ks6nF764CxqnoLeDnJBHB1e2yiql4CSDLW5j4/68SSpAWXqjr5pE4ZPFxVl7X1u4FbgDeBZ4A7qupQkj8Gnqyqr7V5DwDfa5tZW1W/08Y/Cry3qm6f5rU2AhsBli9fftXY2FjPb2ZycpKlS5ceW9/9yuFjy5evOK/n7fTb1NzDYBgzg7n7bRhzD2NmOD73mjVrdlbVyGyef9IjgxncD3waqHb/eeC3gUwzt5j+dNS0LVRVm4HNACMjIzU6OtpzqPHxcbrn37LpkWPL+27ufTv9NjX3MBjGzGDufhvG3MOYGeafe05lUFWvHV1O8mXg4ba6H1jZNfVi4EBbnmlckrTI5nRpaZKLulY/CBy90mgbsD7J2UkuAVYDTwFPA6uTXJLkLDofMm+be2xJ0kI66ZFBkm8Ao8CFSfYDdwGjSa6gc6pnH/C7AFW1J8mDdD4YPgLcVlW/aNu5HXgMWAJsqao9C/5uJElz0svVRB+ZZviBE8y/B7hnmvFHgUdnlU6S1Bf+BrIkyTKQJFkGkiQsA0kSloEkCctAkoRlIEnCMpAkYRlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJwjKQJGEZSJKwDCRJWAaSJCwDSRKWgSQJy0CSRA9lkGRLkoNJnusauyDJ9iQvtvvz23iSfDHJRJJnk1zZ9ZwNbf6LSTacmrcjSZqLXo4MvgKsnTK2CdhRVauBHW0d4AZgdbttBO6HTnkAdwHvBa4G7jpaIJKkxXfSMqiqJ4DXpwyvA7a25a3ATV3jX62OJ4FlSS4Crge2V9XrVXUI2M7bC0aStEhSVSeflKwCHq6qy9r6G1W1rOvxQ1V1fpKHgXur6gdtfAfwKWAUOKeqPtPG/wD4u6r63DSvtZHOUQXLly+/amxsrOc3Mzk5ydKlS4+t737l8LHly1ec1/N2+m1q7mEwjJnB3P02jLmHMTMcn3vNmjU7q2pkNs8/Y4HzZJqxOsH42werNgObAUZGRmp0dLTnFx8fH6d7/i2bHjm2vO/m3rfTb1NzD4NhzAzm7rdhzD2MmWH+ued6NdFr7fQP7f5gG98PrOyadzFw4ATjkqQBMNcy2AYcvSJoA/BQ1/jH2lVF1wCHq+pV4DHguiTntw+Or2tjkqQBcNLTREm+Qeec/4VJ9tO5Kuhe4MEktwI/AT7cpj8K3AhMAD8HPg5QVa8n+TTwdJv3R1U19UNpSdIiOWkZVNVHZnjo2mnmFnDbDNvZAmyZVTpJUl/4G8iSJMtAkmQZSJKwDCRJWAaSJBb+N5AH1qru30a+9wOLmESSBo9HBpIky0CSZBlIkrAMJElYBpIkLANJEpaBJAnLQJKEZSBJwjKQJGEZSJKwDCRJWAaSJCwDSRKWgSQJy0CShGUgSWKeZZBkX5LdSXYleaaNXZBke5IX2/35bTxJvphkIsmzSa5ciDcgSZq/hTgyWFNVV1TVSFvfBOyoqtXAjrYOcAOwut02AvcvwGtLkhbAqThNtA7Y2pa3Ajd1jX+1Op4EliW56BS8viRpllJVc39y8jJwCCjgP1fV5iRvVNWyrjmHqur8JA8D91bVD9r4DuBTVfXMlG1upHPkwPLly68aGxvrOc/k5CRLly49tr77lcPTzrt8xXk9b7MfpuYeBsOYGczdb8OYexgzw/G516xZs7PrbE1Pzpjn67+vqg4k+VVge5K/PMHcTDP2tiaqqs3AZoCRkZEaHR3tOcz4+Djd82/Z9Mj0E3f/7Njivns/0PP2T5WpuYfBMGYGc/fbMOYexsww/9zzOk1UVQfa/UHgu8DVwGtHT/+0+4Nt+n5gZdfTLwYOzOf1JUkLY85lkOTcJO84ugxcBzwHbAM2tGkbgIfa8jbgY+2qomuAw1X16pyTS5IWzHxOEy0Hvpvk6Hb+a1X99yRPAw8muRX4CfDhNv9R4EZgAvg58PF5vLYkaQHNuQyq6iXg16YZ/1/AtdOMF3DbXF9PknTqzPcD5KG3qutD5kH4MFmSFoNfRyFJsgwkSZaBJAnLQJKEZSBJwjKQJGEZSJKwDCRJWAaSJCwDSRKWgSQJy0CShGUgScJvLT2O32Aq6R8qjwwkSZaBJMkykCThZwYz6v78oJufJUg6HXlkIEmyDCRJniaaNS8/lXQ6sgzmwWKQdLroexkkWQv8J2AJ8F+q6t5+ZzgVLAZJw6yvZZBkCfAl4LeA/cDTSbZV1fP9zHGqeSWSpGHT7yODq4GJqnoJIMkYsA44rcpgJjOVxB2XH+GW9piFIWkx9LsMVgB/07W+H3hv94QkG4GNbXUyyQuz2P6FwN/OK+Ei+LddufPZRQ7Tu6Hc15i734Yx9zBmhuNz/+PZPrnfZZBpxuq4larNwOY5bTx5pqpG5vLcxTSMuYcxM5i734Yx9zBmhvnn7vfvGewHVnatXwwc6HMGSdIU/S6Dp4HVSS5JchawHtjW5wySpCn6epqoqo4kuR14jM6lpVuqas8CvsScTi8NgGHMPYyZwdz9Noy5hzEzzDN3qurksyRJpzW/m0iSZBlIkk6jMkiyNskLSSaSbFrsPDNJsi/J7iS7kjzTxi5Isj3Ji+3+/AHIuSXJwSTPdY1NmzMdX2z7/tkkVw5Y7ruTvNL2+a4kN3Y9dmfL/UKS6xcp88okjyfZm2RPkk+08YHe3yfIPej7+5wkTyX5ccv9h238kiQ/bPv7m+0iF5Kc3dYn2uOrBijzV5K83LWvr2jjs/8Zqaqhv9H5MPqvgXcBZwE/Bi5d7FwzZN0HXDhl7D8Cm9ryJuCzA5DzN4ArgedOlhO4Efgend8juQb44YDlvhv499PMvbT9rJwNXNJ+hpYsQuaLgCvb8juAv2rZBnp/nyD3oO/vAEvb8pnAD9t+fBBY38b/FPjXbfnfAH/altcD3xygzF8BPjTN/Fn/jJwuRwbHvuaiqv4PcPRrLobFOmBrW94K3LSIWQCoqieA16cMz5RzHfDV6ngSWJbkov4kPd4MuWeyDhirqreq6mVggs7PUl9V1atV9aO2/FNgL53f1h/o/X2C3DMZlP1dVTXZVs9stwLeD3yrjU/d30f/O3wLuDbJdL9Ae8qcIPNMZv0zcrqUwXRfc3GiH8rFVMD/SLKzffUGwPKqehU6f8CAX120dCc2U85h2P+3t8PlLV2n4QYudzsF8et0/uY3NPt7Sm4Y8P2dZEmSXcBBYDudo5Q3qurINNmO5W6PHwZ+pb+J3565qo7u63vavr4vydlTMzcn3denSxmc9GsuBsj7qupK4AbgtiS/sdiBFsCg7//7gX8CXAG8Cny+jQ9U7iRLgW8Dn6yqN080dZqxQco98Pu7qn5RVVfQ+RaEq4F3Tzet3Q9E7qmZk1wG3An8c+BfAhcAn2rTZ535dCmDofmai6o60O4PAt+l84P42tFDuHZ/cPESntBMOQd6/1fVa+0P0v8Dvszfn5oYmNxJzqTzP9SvV9V32vDA7+/pcg/D/j6qqt4AxumcV1+W5Ogv4nZnO5a7PX4evZ+KXHBdmde2U3VVVW8Bf8Y89vXpUgZD8TUXSc5N8o6jy8B1wHN0sm5o0zYADy1OwpOaKec24GPtCoZrgMNHT28MginnSj9IZ59DJ/f6drXIJcBq4KlFyBfgAWBvVX2h66GB3t8z5R6C/f3OJMva8i8Dv0nn847HgQ+1aVP399H/Dh8Cvl/tU9p+mSHzX3b9ZSF0PuPo3tez+xnp96fip+pG59Pzv6Jz7u/3FzvPDBnfRedqih8De47mpHP+cQfwYru/YACyfoPOIf7/pfO3jFtnyknnkPRLbd/vBkYGLPeft1zPtj8kF3XN//2W+wXghkXK/K/oHMI/C+xqtxsHfX+fIPeg7+9/AfxFy/cc8B/a+LvolNME8N+As9v4OW19oj3+rgHK/P22r58DvsbfX3E0658Rv45CknTanCaSJM2DZSBJsgwkSZaBJAnLQJKEZSBJwjKQJAH/HwEpF6JIMmIVAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import pandas as pd\n",
"pd.Series(map(len, actors.values())).hist(bins=100)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9ff94bf4a5aa4e0cb493ce7b6b2e9fde",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=5701), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"edge_index = list(set(itertools.chain(*[list(itertools.permutations(map(int, actors[i]), 2)) for i in tqdm(actors.keys())])))\n",
"edge_index = np.array(edge_index)\n",
"edge_index = torch.tensor(np.vectorize(key_to_id.get)(edge_index))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"omdb = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )\n",
"y = np.array(list(map(lambda x: float(x) if x !='N/A' else np.nan,\n",
" [i['omdb'].get('imdbRating', np.nan) for i in omdb.values()])))\n",
"\n",
"y[np.isnan(y)] = y[~np.isnan(y)].mean()\n",
"y = torch.tensor(y)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"dataset = Data(x=movies_embeddings_tensor, edge_index=edge_index.t().contiguous(), y=y).to(device)\n",
"# dataset.train_idx = torch.tensor([...], dtype=torch.long)\n",
"# dataset.test_mask = torch.tensor([...], dtype=torch.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Data(edge_index=[2, 975154], x=[27279, 128], y=[27278])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"dataloader = DataLoader(dataset, batch_size=32, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"class Net(torch.nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = GCNConv(1, 16)\n",
" self.conv2 = GCNConv(16, 1)\n",
"\n",
" def forward(self, data):\n",
" x, edge_index = data.x, data.edge_index\n",
"\n",
" x = self.conv1(x, edge_index)\n",
" x = F.relu(x)\n",
" x = F.dropout(x, training=self.training)\n",
" x = self.conv2(x, edge_index)\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8fe461a8a554e0b97070078c58026f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=200), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"ename": "TypeError",
"evalue": "getattr(): attribute name must be string",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-25-cb7ee50e0f9f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'a'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m#optimizer.zero_grad()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 346\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.7/site-packages/torch_geometric/data/data.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34mr\"\"\"Gets the data of the attribute :obj:`key`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: getattr(): attribute name must be string"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model = Net().to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
"\n",
"model.train()\n",
"for epoch in tqdm(range(200)):\n",
" for d in dataloader:\n",
" print('a')\n",
" #optimizer.zero_grad()\n",
" #out = model(data)\n",
" #loss = F.mse_loss(out[data.train_mask], data.y[data.train_mask])\n",
" #loss.backward()\n",
" #optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}