examples/99.To be released, but working/4. SearchNet/1. DDPG_SN.ipynb
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## Deep TopK Search with Critic Adjustment"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from abc import ABC\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import torch.nn.functional as F\n",
"import torch_optimizer as optim\n",
"\n",
"from tqdm.auto import tqdm\n",
"\n",
"from IPython.display import clear_output\n",
"%matplotlib inline\n",
"\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"cuda = torch.device('cuda')\n",
"\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 25\n",
"n_epochs = 100\n",
"plot_every = 30\n",
"step = 0\n",
"# ---\n",
"\n",
"tqdm.pandas()\n",
"\n",
"from jupyterthemes import jtplot\n",
"jtplot.style(theme='grade3')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
"dirs = recnn.data.env.DataPath(\n",
" base=\"../../data/\",\n",
" embeddings=\"embeddings/ml20_pca128.pkl\",\n",
" ratings=\"ml-20m/ratings.csv\",\n",
" cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
" use_cache=True\n",
")\n",
"env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"class Actor(nn.Module):\n",
" def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-1):\n",
" super(Actor, self).__init__()\n",
"\n",
" self.drop_layer = nn.Dropout(p=0.5)\n",
"\n",
" self.linear1 = nn.Linear(input_dim, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear3 = nn.Linear(hidden_size, action_dim)\n",
"\n",
" self.linear3.weight.data.uniform_(-init_w, init_w)\n",
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
"\n",
" def forward(self, state):\n",
" # state = self.state_rep(state)\n",
" x = F.relu(self.linear1(state))\n",
" x = self.drop_layer(x)\n",
" x = F.relu(self.linear2(x))\n",
" x = self.drop_layer(x)\n",
" # x = torch.tanh(self.linear3(x)) # in case embeds are -1 1 normalized\n",
" x = self.linear3(x) # in case embeds are standard scaled / wiped using PCA whitening\n",
" # return state, x\n",
" return x"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"class Critic(nn.Module):\n",
" def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-5):\n",
" super(Critic, self).__init__()\n",
"\n",
" self.drop_layer = nn.Dropout(p=0.5)\n",
"\n",
" self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear3 = nn.Linear(hidden_size, 1)\n",
"\n",
" self.linear3.weight.data.uniform_(-init_w, init_w)\n",
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
"\n",
" def forward(self, state, action):\n",
" x = torch.cat([state, action], 1)\n",
" x = F.relu(self.linear1(x))\n",
" x = self.drop_layer(x)\n",
" x = F.relu(self.linear2(x))\n",
" x = self.drop_layer(x)\n",
" x = self.linear3(x)\n",
" return x"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"class SearchK(nn.Module):\n",
" def __init__(self, input_dim, action_dim, hidden_size, topK, init_w=3e-1):\n",
" super(SearchK, self).__init__()\n",
"\n",
" self.drop_layer = nn.Dropout(p=0.5)\n",
" self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear3 = nn.Linear(hidden_size, action_dim*topK)\n",
"\n",
" self.linear3.weight.data.uniform_(-init_w, init_w)\n",
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
"\n",
" def forward(self, state, action):\n",
" x = torch.cat([state, action], 1)\n",
" x = F.relu(self.linear1(x))\n",
" x = self.drop_layer(x)\n",
" x = F.relu(self.linear2(x))\n",
" x = self.drop_layer(x)\n",
" x = self.linear3(x)\n",
" return x"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def soft_update(net, target_net, soft_tau=1e-2):\n",
" for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
" target_param.data.copy_(\n",
" target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
" )\n",
"\n",
"def run_tests():\n",
" test_batch = next(iter(env.test_dataloader))\n",
" losses = ddpg_sn_update(test_batch, params, learn=False, step=step)\n",
"\n",
" gen_actions = debug['next_action']\n",
" true_actions = env.base.embeddings.detach().cpu().numpy()\n",
"\n",
" f = plotter.kde_reconstruction_error(ad, gen_actions, true_actions, cuda)\n",
" writer.add_figure('rec_error',f, losses['step'])\n",
" return losses"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def ddpg_sn_update(batch, params, learn=True, step=-1):\n",
"\n",
" state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n",
"\n",
" # --------------------------------------------------------#\n",
" # Value Learning\n",
"\n",
" with torch.no_grad():\n",
" next_action = target_policy_net(next_state)\n",
" target_value = target_value_net(next_state, next_action.detach())\n",
" expected_value = reward + (1.0 - done) * params['gamma'] * target_value\n",
" expected_value = torch.clamp(expected_value,\n",
" params['min_value'], params['max_value'])\n",
"\n",
" value = value_net(state, action)\n",
"\n",
" value_loss = torch.pow(value - expected_value.detach(), 2).mean()\n",
"\n",
" if learn:\n",
" value_optimizer.zero_grad()\n",
" value_loss.backward()\n",
" value_optimizer.step()\n",
" else:\n",
" debug['next_action'] = next_action\n",
" writer.add_figure('next_action',\n",
" recnn.utils.pairwise_distances_fig(next_action[:50]), step)\n",
" writer.add_histogram('value', value, step)\n",
" writer.add_histogram('target_value', target_value, step)\n",
" writer.add_histogram('expected_value', expected_value, step)\n",
"\n",
" # --------------------------------------------------------#\n",
" # Policy learning\n",
"\n",
" gen_action = policy_net(state)\n",
" policy_loss = -value_net(state, gen_action)\n",
"\n",
" if not learn:\n",
" debug['gen_action'] = gen_action\n",
" writer.add_histogram('policy_loss', policy_loss, step)\n",
" writer.add_figure('next_action',\n",
" recnn.utils.pairwise_distances_fig(gen_action[:50]), step)\n",
"\n",
" policy_loss = policy_loss.mean()\n",
"\n",
" if learn and step % params['policy_step']== 0:\n",
" policy_optimizer.zero_grad()\n",
" policy_loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(policy_net.parameters(), -1, 1)\n",
" policy_optimizer.step()\n",
"\n",
" soft_update(value_net, target_value_net, soft_tau=params['soft_tau'])\n",
" soft_update(policy_net, target_policy_net, soft_tau=params['soft_tau'])\n",
"\n",
" # dont forget search loss here !\n",
" losses = {'value': value_loss.item(), 'policy': policy_loss.item(), 'step': step}\n",
" recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')\n",
" return losses"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# === ddpg settings ===\n",
"\n",
"params = {\n",
" 'gamma' : 0.99,\n",
" 'min_value' : -10,\n",
" 'max_value' : 10,\n",
" 'policy_step': 10,\n",
" 'soft_tau' : 0.001,\n",
"\n",
" 'policy_lr' : 1e-5,\n",
" 'value_lr' : 1e-5,\n",
" 'search_lr' : 1e-5,\n",
" 'actor_weight_init': 54e-2,\n",
" 'search_weight_init': 54e-2,\n",
" 'critic_weight_init': 6e-1,\n",
"}\n",
"\n",
"# === end ==="
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"value_net = Critic(1290, 128, 256, params['critic_weight_init']).to(cuda)\n",
"policy_net = Actor(1290, 128, 256, params['actor_weight_init']).to(cuda)\n",
"search_net = SearchK(1290, 128, 2048, topK=10, init_w=params['search_weight_init']).to(cuda)\n",
"\n",
"target_value_net = Critic(1290, 128, 256).to(cuda)\n",
"target_policy_net = Actor(1290, 128, 256).to(cuda)\n",
"target_search_net = SearchK(1290, 128, 2048, topK=10).to(cuda)\n",
"\n",
"ad = recnn.nn.models.AnomalyDetector().to(cuda)\n",
"ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
"ad.eval()\n",
"\n",
"target_policy_net.eval()\n",
"target_value_net.eval()\n",
"\n",
"soft_update(value_net, target_value_net, soft_tau=1.0)\n",
"soft_update(policy_net, target_policy_net, soft_tau=1.0)\n",
"soft_update(search_net, target_search_net, soft_tau=1.0)\n",
"\n",
"value_criterion = nn.MSELoss()\n",
"search_criterion = nn.MSELoss()\n",
"\n",
"# from good to bad: Ranger Radam Adam RMSprop\n",
"value_optimizer = optim.Ranger(value_net.parameters(),\n",
" lr=params['value_lr'], weight_decay=1e-2)\n",
"policy_optimizer = optim.Ranger(policy_net.parameters(),\n",
" lr=params['policy_lr'], weight_decay=1e-5)\n",
"search_optimizer = optim.Ranger(search_net.parameters(),\n",
" weight_decay=1e-5,\n",
" lr=params['search_lr'])\n",
"\n",
"loss = {\n",
" 'test': {'value': [], 'policy': [], 'search': [], 'step': []},\n",
" 'train': {'value': [], 'policy': [], 'search': [], 'step': []}\n",
" }\n",
"\n",
"debug = {}\n",
"\n",
"writer = SummaryWriter(log_dir='../../runs')\n",
"plotter = recnn.utils.Plotter(loss, [['value', 'policy', 'search']],)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"for epoch in range(n_epochs):\n",
" for batch in tqdm(env.train_dataloader):\n",
" loss = ddpg_sn_update(batch, params, step=step)\n",
" plotter.log_losses(loss)\n",
" step += 1\n",
" if step % plot_every == 0:\n",
" clear_output(True)\n",
" print('step', step)\n",
" test_loss = run_tests()\n",
" plotter.log_losses(test_loss, test=True)\n",
" plotter.plot_loss()\n",
" if step > 1000:\n",
" assert False"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"torch.save(value_net.state_dict(), \"../../models/ddpg_value.pt\")\n",
"torch.save(policy_net.state_dict(), \"../../models/ddpg_policy.pt\")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# Reconstruction error"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"gen_actions = debug['next_action']\n",
"true_actions = env.base.embeddings.numpy()\n",
"\n",
"\n",
"ad = recnn.nn.AnomalyDetector().to(cuda)\n",
"ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
"ad.eval()\n",
"\n",
"plotter.plot_kde_reconstruction_error(ad, gen_actions, true_actions, cuda)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}