examples/0. Embeddings Generation/1. (proof of concept) DQN.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN\n",
"P.S. This is not my code.\n",
"I mainly copied it from higgsfield's RL-Adventure\n",
"\n",
"This is a Proof of Concept for embeddings generation. Although, it is possible to generate new embeddings using transfer learning, I would advice you to use it only for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import torch_optimizer as optim"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"device = torch.device('cuda')\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 10\n",
"embed_dim = 128\n",
"# --- \n",
"\n",
"tqdm.pandas()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done!\n"
]
}
],
"source": [
"# https://drive.google.com/open?id=1kTyu05ZmtP2MA33J5hWdX8OyUYEDW4iI\n",
"# download ml20m dataset yourself\n",
"ratings = pd.read_csv('../../data/ml-20m/ratings.csv')\n",
"keys = list(sorted(ratings['movieId'].unique()))\n",
"key_to_id = dict(zip(keys, range(len(keys))))\n",
"user_dict, users = recnn.data.prepare_dataset(ratings, key_to_id, frame_size)\n",
"\n",
"del ratings\n",
"gc.collect()\n",
"clear_output(True)\n",
"clear_output(True)\n",
"print('Done!')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class DuelDQN(nn.Module):\n",
" def __init__(self, input_dim, action_dim):\n",
" super(DuelDQN, self).__init__()\n",
" self.feature = nn.Sequential(nn.Linear(input_dim, 128), nn.ReLU())\n",
" self.advantage = nn.Sequential(nn.Linear(128, 128), nn.ReLU(),\n",
" nn.Linear(128, action_dim))\n",
" self.value = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1))\n",
" \n",
" def forward(self, x):\n",
" x = self.feature(x)\n",
" advantage = self.advantage(x)\n",
" value = self.value(x)\n",
" return value + advantage - advantage.mean()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def dqn_update(step, batch, params, learn=True):\n",
" batch = [i.to(device) for i in batch]\n",
" items, next_items, ratings, next_ratings, action, reward, done = batch\n",
" b_size = items.size(0)\n",
" state = torch.cat([embeddings(items).view(b_size, -1), ratings], 1)\n",
" next_state = torch.cat([embeddings(next_items).view(b_size, -1), next_ratings], 1)\n",
"\n",
" q_values = dqn(state)\n",
" with torch.no_grad():\n",
" next_q_values = target_dqn(next_state)\n",
" q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)\n",
" next_q_value = next_q_values.max(1)[0]\n",
" expected_q_value = reward + params['gamma'] * next_q_value * (1 - done)\n",
" \n",
" loss = (q_value - expected_q_value).pow(2).mean()\n",
" \n",
" if learn:\n",
" writer.add_scalar('value/train', loss, step)\n",
" embeddings_optimizer.zero_grad()\n",
" value_optimizer.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(dqn.parameters(), -1, 1)\n",
" embeddings_optimizer.step()\n",
" value_optimizer.step()\n",
" else:\n",
" writer.add_histogram('q_values', q_values, step)\n",
" writer.add_scalar('value/test', loss, step)\n",
" \n",
" return loss.item()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def run_tests():\n",
" test_batch = next(iter(test_dataloader))\n",
" losses = dqn_update(step, test_batch, params, learn=False)\n",
" return losses\n",
"\n",
"def soft_update(net, target_net, soft_tau=1e-2):\n",
" for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
" target_param.data.copy_(\n",
" target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# === DQN settings ===\n",
"\n",
"params = {\n",
" 'gamma' : 0.99,\n",
" 'value_lr' : 1e-5,\n",
" 'embeddings_lr': 1e-5,\n",
"}\n",
"\n",
"# === end ==="
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"dqn = DuelDQN((embed_dim + 1) * frame_size, len(keys)).to(device)\n",
"target_dqn = DuelDQN((embed_dim + 1) * frame_size, len(keys)).to(device)\n",
"embeddings = nn.Embedding(len(keys), embed_dim).to(device)\n",
"embeddings.load_state_dict(torch.load('../../models/embeddings/dqn.pt'))\n",
"target_dqn.load_state_dict(dqn.state_dict())\n",
"target_dqn.eval()\n",
"\n",
"value_optimizer = optim.RAdam(dqn.parameters(),\n",
" lr=params['value_lr'])\n",
"embeddings_optimizer = optim.RAdam(embeddings.parameters(),\n",
" lr=params['embeddings_lr'])\n",
"writer = SummaryWriter(log_dir='../../runs')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "741d4f1a60274ec48703b381861d84ac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"n_epochs = 100\n",
"batch_size = 25\n",
"\n",
"epoch_bar = tqdm(total=n_epochs)\n",
"train_users = users[:-5000]\n",
"test_users = users[-5000:]\n",
"\n",
"def prepare_batch_wrapper(x):\n",
" batch = recnn.data.prepare_batch_static_size(x, frame_size=frame_size)\n",
" return batch\n",
"\n",
"train_user_dataset = recnn.data.UserDataset(train_users, user_dict)\n",
"test_user_dataset = recnn.data.UserDataset(test_users, user_dict)\n",
"train_dataloader = DataLoader(train_user_dataset, batch_size=batch_size,\n",
" shuffle=True, num_workers=4,collate_fn=prepare_batch_wrapper)\n",
"test_dataloader = DataLoader(test_user_dataset, batch_size=batch_size,\n",
" shuffle=True, num_workers=4,collate_fn=prepare_batch_wrapper)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 510\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-609ee27765a3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'step'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mmem_usage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_memory_allocated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0mtest_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_tests\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0mtest_step\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mtest_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-11-0cd1c2b49881>\u001b[0m in \u001b[0;36mrun_tests\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_tests\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mtest_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdqn_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-5-fbde6b67f8e7>\u001b[0m in \u001b[0;36mdqn_update\u001b[0;34m(step, batch, params, learn)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mvalue_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_histogram\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'q_values'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'value/test'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py\u001b[0m in \u001b[0;36madd_histogram\u001b[0;34m(self, tag, values, global_step, bins, walltime, max_bins)\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0mbins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault_bins\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 375\u001b[0m self._get_file_writer().add_summary(\n\u001b[0;32m--> 376\u001b[0;31m histogram(tag, values, bins, max_bins=max_bins), global_step, walltime)\n\u001b[0m\u001b[1;32m 377\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/summary.py\u001b[0m in \u001b[0;36mhistogram\u001b[0;34m(name, values, bins, max_bins)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_clean_tag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mhist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_histogram\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_bins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mSummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSummary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhisto\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/summary.py\u001b[0m in \u001b[0;36mmake_histogram\u001b[0;34m(values, bins, max_bins)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'The input has no element.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0mcounts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlimits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistogram\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0mnum_bins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcounts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmax_bins\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnum_bins\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mmax_bins\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/numpy/lib/histograms.py\u001b[0m in \u001b[0;36mhistogram\u001b[0;34m(a, bins, range, normed, weights, density)\u001b[0m\n\u001b[1;32m 863\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 864\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_range\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBLOCK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 865\u001b[0;31m \u001b[0msa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mBLOCK\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 866\u001b[0m \u001b[0mcum_n\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0m_search_sorted_inclusive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbin_edges\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 867\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36msort\u001b[0;34m(a, axis, kind, order)\u001b[0m\n\u001b[1;32m 932\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 933\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"K\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 934\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkind\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkind\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 935\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 936\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"\n",
"# --- config ---\n",
"plot_every = 30\n",
"# --- end ---\n",
"\n",
"step = 1\n",
"\n",
"train_loss = []\n",
"test_loss = []\n",
"test_step = []\n",
"mem_usage = []\n",
"\n",
"torch.cuda.reset_max_memory_allocated()\n",
"for epoch in range(n_epochs):\n",
" epoch_bar.update(1)\n",
" for batch in tqdm(train_dataloader):\n",
" loss = dqn_update(step, batch, params)\n",
" train_loss.append(loss)\n",
" step += 1\n",
" if step % 30:\n",
" torch.cuda.empty_cache()\n",
" soft_update(dqn, target_dqn)\n",
" if step % plot_every == 0:\n",
" clear_output(True)\n",
" print('step', step)\n",
" mem_usage.append(torch.cuda.max_memory_allocated())\n",
" test_ = run_tests()\n",
" test_step.append(step)\n",
" test_loss.append(test_)\n",
" plt.plot(train_loss)\n",
" plt.plot(test_step, test_loss)\n",
" plt.show()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[-0.1295 1.16935 -0.48958 1.15725 -1.12578 -1.12871 1.20861\n",
" 0.7092 -1.75433 0.78336 -0.06654 0.01701 0.10032 -1.66869\n",
" -1.66936 1.31066 -1.89035 0.85733 0.29467 -2.57335 1.09865\n",
" -0.37118 0.28674 -0.03325 -0.04116 0.31573 0.17564 -0.85095\n",
" -0.72411 1.78047 -0.08764 -1.82273 -1.7149 0.42091 -1.9658\n",
" -0.37696 -0.99364 -0.02163 0.59692 0.43212 0.84377 -1.03082\n",
" -1.10816 1.07965 -0.51705 0.18891 0.54048 -0.62935 0.76171\n",
" -1.03486 -1.33935 -0.72902 1.07903 -0.77892 -0.06467 0.37879\n",
" 0.3149 0.83153 -0.73904 -0.77743 0.61622 0.25018 -1.14094\n",
" -0.60113 -0.77198 0.43308 -0.02435 1.53947 1.10815 0.55461\n",
" -0.59295 0.30839 1.29305 -1.54144 -0.4722 0.42438 -1.42892\n",
" 0.30123 -1.38345 0.89604 -0.46374 1.01234 0.79323 -1.8172\n",
" 1.76816 1.31062 -0.77374 0.85284 0.22765 0.91387 0.40847\n",
" 0.27434 0.98777 1.06231 0.05701 1.17557 0.20852 -0.32834\n",
" -0.11413 -0.94112 1.02541 0.38978 0.86436 -0.7692 -0.03691\n",
" 0.327 2.11549 -0.13522 0.45396 -0.1976 0.3953 -1.70671\n",
" -1.22622 -0.18697 -0.14552 -0.66944 -1.26578 1.10903 -1.47451\n",
" -0.58925 1.12494 -0.70216 -0.97481 0.29219 1.33927 1.24211\n",
" 0.18459 0.22939]]]\n"
]
}
],
"source": [
"np.set_printoptions(precision=5)\n",
"np.set_printoptions(suppress=True)\n",
"print(embeddings(torch.tensor([[686]]).to(device)).detach().cpu().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"torch.save(embeddings.state_dict(), \"../../models/embeddings/dqn.pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f94280c1710>]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEDCAYAAADOc0QpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFQhJREFUeJzt3X2QXXV5wPHvk2yChBejECPyIuAgaq1R2AKKBZEOFWtLrTh9ob5QmEwchtJOnYL9w5naOq2j06pjJZNBUaeibZFUx9EUx2pxWmC6KSAISDEgxGCzQUFkN7l77336xz277K6b7A2555zcs9/PTGbvOeeXm4fLeZ788pzfPScyE0lSsyyrOwBJ0uBZ3CWpgSzuktRAFndJaiCLuyQ1kMVdkhqo1uIeEZ+OiJ0RcU8fY18cEd+MiO9GxLcj4rgqYpSkYVT3zP0zwJv6HPsR4HOZ+SrgA8DflBWUJA27Wot7Zt4C/GT2voh4SURsiYitEfGdiHhZcegVwDeL198CLqowVEkaKnXP3BeyCbgyM08H3gt8sth/F/C24vVbgSMi4qga4pOkg95I3QHMFhGHA68D/iUipncfUvx8L/CJiHg3cAvwI6BddYySNAwOquJO718ST2Tmq+cfyMwdwO/AzF8Cb8vMJyuOT5KGwkHVlsnMnwEPRcTbAaJnXfH66IiYjvd9wKdrClOSDnp1L4X8AnArcGpEbI+Iy4BLgMsi4i7gezxz4fQNwPcj4gFgLfDBGkKWpKEQ/dzyNyJWA9cBrwQS+KPMvHXW8UuAq4vNnwPvycy7Bh+uJKkf/fbcPwZsycyLI2IlsGre8YeAczPzpxFxIb0VL2cOME5J0n5YdOYeEUfSW4Z4cvYxzY+I5wH3ZOax+xp39NFH54knnrgfoUqStm7duisz1yw2rp+Z+8nAOHB9cXFzK3BVZj69l/GXAV9f6EBErAfWA5xwwgmMjY318cdLkqZFxA/7GdfPBdUR4DTg2sx8DfA0cM1e/tDz6BX3qxc6npmbMnM0M0fXrFn0Lx5J0rPUT3HfDmzPzNuL7RvpFfs5IuJV9C66XpSZjw8uREnS/lq0uGfmj4FHI+LUYtf5wL2zx0TECcBNwDsy84GBRylJ2i/9rpa5Evh8sVJmG3BpRGwAyMyNwPuBo4BPFrcNaGfmaAnxSpL60Fdxz8w7gfnFeuOs45cDlw8wLknSATiobj8gSRoMi7skNdDBdldILSG3/uBxbv3BrrrDkCo3euLzOeel5S4Ht7irNn+75X7uevQJnrl1v7Q0bDj3JRZ3NdfuVoc3/dIL2fiO0+sORWoce+6qTavTZeWIp6BUBjNLtWm1Le5SWcws1WaPxV0qjZml2rTaHVYu9xSUymBmqTatTpdDnLlLpTCzVBt77lJ5zCzVot3p0k1sy0glMbNUiz3tLoAzd6kkZpZq0bK4S6Uys1SLVsfiLpXJzFItZmbu9tylUphZqoU9d6lcZpZqMT1zd527VA4zS7Ww5y6Vy8xSLZ6ZuS+vORKpmSzuqoVLIaVymVmqRavTAVwtI5XFzFItnLlL5eorsyJidUTcGBH3R8R9EfHaeccjIj4eEQ9GxHcj4rRywlVTuBRSKle/z1D9GLAlMy+OiJXAqnnHLwROKX6dCVxb/JQW5JeYpHItmlkRcSRwDvApgMxsZeYT84ZdBHwue24DVkfEMQOPVo0xvRTSde5SOfrJrJOBceD6iLgjIq6LiMPmjTkWeHTW9vZi3xwRsT4ixiJibHx8/FkHreFnz10qVz+ZNQKcBlybma8BngaumTcmFvh9+Qs7Mjdl5mhmjq5Zs2a/g1Vz2HOXytVPZm0Htmfm7cX2jfSK/fwxx8/aPg7YceDhqansuUvlWjSzMvPHwKMRcWqx63zg3nnDvgK8s1g1cxbwZGY+NthQ1SStdpdlASMWd6kU/a6WuRL4fLFSZhtwaURsAMjMjcDXgDcDDwITwKUlxKoGaXV8fqpUpr6Ke2beCYzO271x1vEErhhgXGq4VrtrS0YqkdmlWuxpd1npTcOk0ljcVYtWu+sad6lEZpdqYc9dKpfZpVq02h177lKJzC7VotXucsgKTz+pLGaXatHquFpGKpPZpVq02vbcpTKZXaqFxV0ql9mlWuzxS0xSqcwu1cKZu1Qus0u12GNxl0pldqkWrY7fUJXKZHapFt44TCqX2aVa2HOXymV2qRbeW0Yql9mlynW6SaebrFzuLX+lsljcVbmWD8eWSmd2qXIWd6l8Zpcqt6fTASzuUpnMLlVueuZ+iEshpdKYXaqcbRmpfGaXKtfqFDN3i7tUmpF+BkXEw8BTQAdoZ+bovOPPBf4ROKF4z49k5vWDDVVN4cxdKl9fxb1wXmbu2suxK4B7M/M3I2IN8P2I+Hxmtg48RDWNxV0q36CyK4EjIiKAw4GfAO0BvbcaZs90cfeCqlSafrMrgZsjYmtErF/g+CeAlwM7gLuBqzKzO39QRKyPiLGIGBsfH3/WQWu4OXOXytdvdp2dmacBFwJXRMQ5847/OnAn8CLg1cAnIuLI+W+SmZsyczQzR9esWXMgcWuI7bG4S6XrK7syc0fxcyewGThj3pBLgZuy50HgIeBlgwxUzeFqGal8i2ZXRBwWEUdMvwYuAO6ZN+wR4PxizFrgVGDbYENVU8y0ZbxxmFSaflbLrAU2966VMgLckJlbImIDQGZuBP4K+ExE3A0EcPU+VtZoibPnLpVv0eKemduAdQvs3zjr9Q56M3ppUa2295aRymZ2qXLTPXeLu1Qes0uVa7nOXSqd2aXKTRf3Fcuj5kik5rK4q3J7iuenFhfpJZXA4q7Ktdpd7+UulcwMU+Va7a4XU6WSmWGqnMVdKp8Zpsq1Ol1vPSCVzAxT5fZMOXOXymaGqXKtjsVdKpsZpsq12l2/wCSVzAxT5bygKpXPDFPlel9i8na/Upks7qqcbRmpfGaYKtdqd1wKKZXMDFPlXC0jlc8MU+Vsy0jlM8NUOVfLSOUzw1Q5i7tUPjNMlbPnLpXPDFOlut1kqpP23KWSmWGqlA/HlqphhqlSe4rnp7rOXSrXSD+DIuJh4CmgA7Qzc3SBMW8APgqsAHZl5rmDC1NNMf1wbGfuUrn6Ku6F8zJz10IHImI18EngTZn5SES8YCDRqXGm2zLO3KVyDSrD/gC4KTMfAcjMnQN6XzWMM3epGv1mWAI3R8TWiFi/wPGXAs+LiG8XY9650JtExPqIGIuIsfHx8Wcbs4bYTHFf7l0hpTL125Y5OzN3FO2Wb0TE/Zl5y7z3OR04HzgUuDUibsvMB2a/SWZuAjYBjI6O5oGHr2HjzF2qRl8Zlpk7ip87gc3AGfOGbAe2ZObTRV/+FmDdIANVM7Q6HcDiLpVt0QyLiMMi4ojp18AFwD3zhn0Z+NWIGImIVcCZwH2DDlbDb89MW8biLpWpn7bMWmBzREyPvyEzt0TEBoDM3JiZ90XEFuC7QBe4LjPn/wUg2ZaRKrJocc/MbSzQYsnMjfO2Pwx8eHChqYlafolJqoQZpkp5+wGpGmaYKtWy5y5VwgxTpey5S9Uww1Qp2zJSNcwwVcqZu1QNM0yVcp27VA0zTJWyuEvVMMNUqVa7y4rlwbJlUXcoUqNZ3FWpVrvLISPeEVIqm8VdlWp1Ol5MlSpglqlSrXbXfrtUAbNMlWq1u87cpQqYZapUq2Nxl6pglqlStmWkaphlqtQe2zJSJcwyVcqeu1QNs0yVanW6PqhDqoBZpkrZc5eqYZapUrZlpGqYZaqUSyGlaphlqtSeKdsyUhXMMlXKmbtUjb6yLCIejoi7I+LOiBjbx7hfiYhORFw8uBDVJPbcpWqM7MfY8zJz194ORsRy4EPAvx1wVGosi7tUjUFm2ZXAl4CdA3xPNUhm9ta523OXStdvliVwc0RsjYj18w9GxLHAW4GN+3qTiFgfEWMRMTY+Pr7/0WqotTq9R+wdssKHdUhl67ctc3Zm7oiIFwDfiIj7M/OWWcc/ClydmZ2IvT8+LTM3AZsARkdH89kGrXJdct1t3LvjZwN/3+n/4X5DVSpfX8U9M3cUP3dGxGbgDGB2cR8FvlgU9qOBN0dEOzP/dcDxqmTdbvKfDz7OuuNXs+645w78/ZcvCy785WMG/r6S5lq0uEfEYcCyzHyqeH0B8IHZYzLzpFnjPwN81cI+nHa3OwBc+MoXsuHcl9QcjaRnq5+Z+1pgczErHwFuyMwtEbEBIDP32WfXcJls9Yr7qpX2xaVhtmhxz8xtwLoF9i9Y1DPz3QceluoyURT353jRUxpqXtnSHJNTztylJrC4a47ptsyhztyloWZx1xzTbZlDnblLQ83irjl2z7Rl9ufOFJIONhZ3zTFhW0ZqBIu75photQEvqErDzuKuOabbMi6FlIabxV1zTPglJqkRLO6aY9KZu9QIFnfNMdnqcMjIMpYv2/vdPSUd/CzummOi1bElIzWAxV1zTE51XAYpNYDFXXNMtjp+O1VqAIu75picsrhLTWBx1xwTrTarVnjrAWnYWdw1x+RU15m71AAWd80x2Wp7QVVqAIu75nAppNQMFnfNsXuqw3Ms7tLQs7hrjolWh1W2ZaShZ3HXjMxkcsq2jNQEFnfN2NPukoltGakBLO6aMXO7X9sy0tDr69sqEfEw8BTQAdqZOTrv+CXA1cXmz4H3ZOZdA4xTFZi+3a/r3KXhtz9fRTwvM3ft5dhDwLmZ+dOIuBDYBJx5wNGpUpPFI/YO9eHY0tAbSBZn5n/N2rwNOG4Q76tqTba6gG0ZqQn67bkncHNEbI2I9YuMvQz4+kIHImJ9RIxFxNj4+Pj+xKkKTMzM3C3u0rDrd+Z+dmbuiIgXAN+IiPsz85b5gyLiPHrF/fULvUlmbqLXsmF0dDSfZcwqyYQ9d6kx+pq5Z+aO4udOYDNwxvwxEfEq4Drgosx8fJBBqhq7i9Uy3ltGGn6LFveIOCwijph+DVwA3DNvzAnATcA7MvOBMgJV+WaWQjpzl4ZeP22ZtcDmiJgef0NmbomIDQCZuRF4P3AU8Mli3C8sl9TBb2YppDN3aegtWtwzcxuwboH9G2e9vhy4fLChqWqTLXvuUlP4DVXNmLDnLjWGxV0zJqc6rFy+jJHlnhbSsDOLNWOy1bYlIzWExV0zJqc6tmSkhrC4a4aP2JOaw+KuGbunOrZlpIawuGvGRMu2jNQUFnfNmGg5c5eawuKuGbu9oCo1hsVdM7ygKjWHxV0zJqc6PoVJagiLu2ZMekFVagyLuwDITCZabdsyUkNY3AVAq9Olm94RUmoKi7uAWbf7tS0jNYLFXcCsB3U4c5caweIuwEfsSU1jcRdgW0ZqGou7ANsyUtNY3AXYlpGaxuIu4Jm2zHNsy0iNYHEXAJNTbQBWefsBqRH6Ku4R8XBE3B0Rd0bE2ALHIyI+HhEPRsR3I+K0wYeqMk22uoBtGakp9meadl5m7trLsQuBU4pfZwLXFj81JCZavZm7bRmpGQbVlrkI+Fz23AasjohjBvTeqsCkF1SlRum3uCdwc0RsjYj1Cxw/Fnh01vb2Yt8cEbE+IsYiYmx8fHz/o1VpJqc6jCwLViz3MozUBP1m8tmZeRq99ssVEXHOvOOxwO/JX9iRuSkzRzNzdM2aNfsZqsrkI/akZumruGfmjuLnTmAzcMa8IduB42dtHwfsGESAqsbuKZ/CJDXJosU9Ig6LiCOmXwMXAPfMG/YV4J3FqpmzgCcz87GBR6vSTPigDqlR+lktsxbYHBHT42/IzC0RsQEgMzcCXwPeDDwITACXlhOuytJry7jGXWqKRbM5M7cB6xbYv3HW6wSuGGxoqtLuqQ6HrvBiqtQUZrMAikfsOXOXmsLiLgAmp7qulpEaxOIuACZbbS+oSg1icRfQu6DqUkipOSzuAnrfUPW+MlJzWNwF9O4t48xdag6Lu5jqdGl305671CAWd808Ys/VMlJzWNw163a/rnOXmsLiLianpmfung5SU5jNmnkK06ErnLlLTTF02fwfD4zz11+9t+4wGmV325671DRDV9wPP2SEU9YeXncYjXPmSUfx6uNX1x2GpAEZuuJ++oufx+kvPr3uMCTpoGbPXZIayOIuSQ1kcZekBrK4S1IDWdwlqYEs7pLUQBZ3SWogi7skNVBkZj1/cMQ48MNn+duPBnYNMJxh5efgZzDNz2HpfAYvzsw1iw2qrbgfiIgYy8zRuuOom5+Dn8E0Pwc/g/lsy0hSA1ncJamBhrW4b6o7gIOEn4OfwTQ/Bz+DOYay5y5J2rdhnblLkvbB4i5JDTR0xT0i3hQR34+IByPimrrjqUJEHB8R34qI+yLiexFxVbH/+RHxjYj43+Ln8+qOtQoRsTwi7oiIrxbbJ0XE7cXn8E8RsbLuGMsUEasj4saIuL84J167FM+FiPjTIh/uiYgvRMRzltq5sC9DVdwjYjnwD8CFwCuA34+IV9QbVSXawJ9l5suBs4Ariv/ua4BvZuYpwDeL7aXgKuC+WdsfAv6++Bx+ClxWS1TV+RiwJTNfBqyj91ksqXMhIo4F/hgYzcxXAsuB32PpnQt7NVTFHTgDeDAzt2VmC/gicFHNMZUuMx/LzP8pXj9FL5mPpfff/tli2GeB364nwupExHHAbwDXFdsBvBG4sRjS6M8hIo4EzgE+BZCZrcx8giV4LtB7TOihETECrAIeYwmdC4sZtuJ+LPDorO3txb4lIyJOBF4D3A6szczHoPcXAPCC+iKrzEeBPwe6xfZRwBOZ2S62m35OnAyMA9cXranrIuIwlti5kJk/Aj4CPEKvqD8JbGVpnQv7NGzFPRbYt2TWckbE4cCXgD/JzJ/VHU/VIuItwM7M3Dp79wJDm3xOjACnAddm5muAp2l4C2YhxTWFi4CTgBcBh9Fr187X5HNhn4atuG8Hjp+1fRywo6ZYKhURK+gV9s9n5k3F7v+LiGOK48cAO+uKryJnA78VEQ/Ta8m9kd5MfnXxT3No/jmxHdiembcX2zfSK/ZL7Vz4NeChzBzPzCngJuB1LK1zYZ+Grbj/N3BKcUV8Jb0LKF+pOabSFX3lTwH3ZebfzTr0FeBdxet3AV+uOrYqZeb7MvO4zDyR3v/7f8/MS4BvARcXwxr9OWTmj4FHI+LUYtf5wL0ssXOBXjvmrIhYVeTH9OewZM6FxQzdN1Qj4s30ZmvLgU9n5gdrDql0EfF64DvA3TzTa/4Len33fwZOoHeyvz0zf1JLkBWLiDcA783Mt0TEyfRm8s8H7gD+MDP31BlfmSLi1fQuKK8EtgGX0puoLalzISL+EvhdeqvJ7gAup9djXzLnwr4MXXGXJC1u2NoykqQ+WNwlqYEs7pLUQBZ3SWogi7skNZDFXZIayOIuSQ30/7dk1tTBsaK6AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(mem_usage)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5626054144"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}