awarebayes/RecNN

View on GitHub
examples/1. Vanilla RL/4. SAC.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Soft Actor Critic\n",
    "inspiration: [higgsfield/RL Adventure 2](https://github.com/higgsfield/RL-Adventure-2/blob/master/7.soft%20actor-critic.ipynb)\n",
    "\n",
    "paper https://arxiv.org/abs/1801.01290\n",
    "\n",
    "### Notice is doesnt work good and I havent kept up with the development of this module for too long. I am just afraid to run it at this point due to the code here being too depreceted"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **Note on this tutorials:**\n",
    "**They mostly contain low level implementations explaining what is going on inside the library.**\n",
    "\n",
    "**Most of the stuff explained here is already available out of the box for your usage.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch.nn.functional as F\n",
    "import torch_optimizer as optim\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# comment out if you are not using themes\n",
    "from jupyterthemes import jtplot\n",
    "jtplot.style(theme='grade3')\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 25\n",
    "n_epochs   = 100\n",
    "plot_every = 30\n",
    "step       = 0\n",
    "# --- \n",
    "\n",
    "\n",
    "tqdm.pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f8d90b9e8214d8aa6986f45628d6d68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "19409b3356934d378074819563ac0484",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7725f391e0554753889ca1a6063c695a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
    "dirs = recnn.data.env.DataPath(\n",
    "    base=\"../../../data/\",\n",
    "    embeddings=\"embeddings/ml20_pca128.pkl\",\n",
    "    ratings=\"ml-20m/ratings.csv\",\n",
    "    cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
    "    use_cache=True\n",
    ")\n",
    "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)\n",
    "\n",
    "# optional: apply tanh -1 1 to actions\n",
    "with torch.no_grad():\n",
    "    env.base.embeddings = torch.tanh(env.base.embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_update(net, target_net, soft_tau=1e-2):\n",
    "    for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
    "            target_param.data.copy_(\n",
    "                target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
    "            )\n",
    "\n",
    "ad = recnn.models.AnomalyDetector().to(cuda)\n",
    "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
    "ad.eval()\n",
    "\n",
    "def run_tests():\n",
    "    test_batch = next(iter(env.test_dataloader))\n",
    "    losses = soft_q_update(step, test_batch, params, learn=False)\n",
    "    \n",
    "    gen_actions = debugger.debug_dict['mat']['next_action']\n",
    "    gen_test_actions = debugger.debug_dict['mat']['test next_action']\n",
    "    true_actions = np.stack(env.embeddings.detach().cpu().numpy())\n",
    "    f = plotter.kde_reconstruction_error(ad, gen_actions, gen_test_actions, true_actions, cuda)\n",
    "    writer.add_figure('rec error', f, losses['step'])\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StateCritic(nn.Module):\n",
    "    def __init__(self, state_dim, hidden_dim, init_w=3e-3):\n",
    "        super(StateCritic, self).__init__()\n",
    "        \n",
    "        self.linear1 = nn.Linear(state_dim, hidden_dim)\n",
    "        self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.linear3 = nn.Linear(hidden_dim, 1)\n",
    "        \n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "        \n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.linear1(state))\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.linear3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftQ(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_dim, init_w=3e-3):\n",
    "        super(SoftQ, self).__init__()\n",
    "        \n",
    "        self.linear1 = nn.Linear(input_dim + action_dim, hidden_dim)\n",
    "        self.linear2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.linear3 = nn.Linear(hidden_dim, 1)\n",
    "        \n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "        \n",
    "    def forward(self, state, action):\n",
    "        x = torch.cat([state, action], 1)\n",
    "        x = F.relu(self.linear1(x))\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.linear3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StochasticActor(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_size, params):\n",
    "        super(StochasticActor, self).__init__()\n",
    "        \n",
    "        self.log_std_min = params['log_std_min']\n",
    "        self.log_std_max = params['log_std_max']\n",
    "        self.drop_layer = nn.Dropout(p=0.5)\n",
    "        self.normal_dist = torch.distributions.Normal(0, 1)\n",
    "        \n",
    "        self.linear1 = nn.Linear(input_dim, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
    "        \n",
    "        self.mean_linear = nn.Linear(hidden_size, action_dim)\n",
    "        self.mean_linear.weight.data.uniform_(-params['mean_initw'],\n",
    "                                              params['mean_initw'])\n",
    "        self.mean_linear.bias.data.uniform_(-params['mean_initw'],\n",
    "                                            params['mean_initw'])\n",
    "        \n",
    "        self.log_std_linear = nn.Linear(hidden_size, action_dim)\n",
    "        self.log_std_linear.weight.data.uniform_(-params['std_initw'],\n",
    "                                                  params['std_initw'])\n",
    "        self.log_std_linear.bias.data.uniform_(-params['std_initw'],\n",
    "                                                params['std_initw'])\n",
    "\n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.linear1(state))\n",
    "        x = self.drop_layer(x)\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.drop_layer(x)\n",
    "        mean    = self.mean_linear(x)\n",
    "        log_std = self.log_std_linear(x)\n",
    "        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)\n",
    "        \n",
    "        return mean, log_std\n",
    "    \n",
    "    def evaluate(self, state, epsilon=1e-6):\n",
    "        mean, log_std = self.forward(state)\n",
    "        \n",
    "        std = log_std.exp()\n",
    "        z = self.normal_dist.sample().to(cuda)\n",
    "    \n",
    "        action = torch.tanh(mean + z * std)\n",
    "        log_prob = torch.distributions.Normal(mean, std).log_prob(action)\n",
    "        log_prob -= torch.log(1 - action.pow(2) + epsilon)\n",
    "        # log_prob = log_prob.sum(-1, keepdim=True)\n",
    "        \n",
    "        return action, log_prob, z, mean, log_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_q_update(step, batch, params, learn=True):\n",
    "    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n",
    "    \n",
    "    # --------------------------------------------------------#\n",
    "    # SoftQ Learning\n",
    "    expected_softq_value = soft_q_net(state, action)\n",
    "    expected_value   = value_net(state)\n",
    "    next_action, log_prob, z, mean, log_std = policy_net.evaluate(state)\n",
    "    \n",
    "    target_value = target_value_net(next_state)\n",
    "    next_q_value = reward + (1 - done) * params['gamma'] * target_value\n",
    "    q_value_loss = soft_q_criterion(expected_softq_value, next_q_value.detach())\n",
    "    if learn:\n",
    "        soft_q_optimizer.zero_grad()\n",
    "        q_value_loss.backward()\n",
    "        soft_q_optimizer.step()\n",
    "        \n",
    "    # --------------------------------------------------------#\n",
    "    # StateValue Learning\n",
    "    expected_next_softq_value = soft_q_net(state, next_action)\n",
    "    next_value = expected_next_softq_value - log_prob\n",
    "    value_loss = value_criterion(expected_value, next_value.detach())\n",
    "    \n",
    "    if learn:\n",
    "        value_optimizer.zero_grad()\n",
    "        value_loss.backward()\n",
    "        value_optimizer.step()\n",
    "        soft_update(value_net, target_value_net, params['soft_tau'])\n",
    "    # --------------------------------------------------------#\n",
    "    # Policy Learning\n",
    "    log_prob_target = expected_next_softq_value - expected_value\n",
    "    policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()\n",
    "\n",
    "    mean_loss = params['mean_lambda'] * mean.pow(2).mean()\n",
    "    std_loss  = params['std_lambda']  * log_std.pow(2).mean()\n",
    "    z_loss    = params['z_lambda']    * z.pow(2).sum(0).mean()\n",
    "\n",
    "    policy_loss += mean_loss + std_loss + z_loss\n",
    "    \n",
    "    if learn:\n",
    "        debugger.log_object('next_action', next_action)\n",
    "        policy_optimizer.zero_grad()\n",
    "        policy_loss.backward()\n",
    "        policy_optimizer.step()\n",
    "    \n",
    "    if not learn:\n",
    "        debugger.log_object('test next_action', next_action)\n",
    "        writer.add_figure('next_action',\n",
    "                          recnn.plot.pairwise_distances_fig(next_action[:50]), step)\n",
    "        writer.add_histogram('expected_softq_value', expected_softq_value, step)\n",
    "        writer.add_histogram('expected_value', expected_value, step)\n",
    "        writer.add_histogram('log_prob', log_prob, step)\n",
    "        writer.add_histogram('z', z, step)\n",
    "        writer.add_histogram('mean', mean, step)\n",
    "        writer.add_histogram('log_std', log_std, step)\n",
    "        writer.add_histogram('target_value', target_value, step)\n",
    "        writer.add_histogram('next_q_value', next_q_value, step)\n",
    "        writer.add_histogram('expected_next_softq_value', expected_next_softq_value, step)\n",
    "        writer.add_histogram('next_value', next_value, step)\n",
    "        writer.add_histogram('log_prob_target', log_prob_target, step)\n",
    "        writer.add_histogram('mean_loss', mean_loss, step)\n",
    "        writer.add_histogram('std_loss', std_loss, step)\n",
    "        writer.add_histogram('z_loss', z_loss, step)\n",
    "        writer.close()\n",
    "\n",
    "    losses = {'value': value_loss.item(),\n",
    "              'softq': q_value_loss.item(),\n",
    "              'policy': policy_loss.item(),\n",
    "              'step'  : step}\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'gamma': 0.99,\n",
    "    'soft_tau': 0.001,\n",
    "    'policy_lr': 1e-5,\n",
    "    'value_lr': 1e-5,\n",
    "    'soft_qvalue_lr': 1e-5,\n",
    "    \n",
    "    'critic_weight_init': 1e-1,\n",
    "    'soft_critic_weight_init': 2e-1,\n",
    "    'mean_lambda':1e-3,\n",
    "    'std_lambda':1e-3,\n",
    "    'z_lambda': 1e-10,\n",
    "}\n",
    "\n",
    "stochastic_actor_params = {\n",
    "    'mean_initw': 1e-1,\n",
    "    'std_initw': 6e-1,\n",
    "    'log_std_min':-2,\n",
    "    'log_std_max':2\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f49edb70828>"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "value_net  = StateCritic(1290, 256, params['critic_weight_init']).to(cuda)\n",
    "target_value_net = StateCritic(1290, 256).to(cuda)\n",
    "\n",
    "soft_q_net = SoftQ(1290, 128, 256, params['soft_critic_weight_init']).to(cuda)\n",
    "policy_net = StochasticActor(1290, 128, 256, stochastic_actor_params).to(cuda)\n",
    "\n",
    "soft_update(value_net, target_value_net, soft_tau=1.0)\n",
    "\n",
    "value_criterion  = nn.MSELoss()\n",
    "soft_q_criterion = nn.MSELoss()\n",
    "\n",
    "value_optimizer = optim.RAdam(value_net.parameters(), lr=params['value_lr'])\n",
    "soft_q_optimizer = optim.RAdam(soft_q_net.parameters(), lr=params['soft_qvalue_lr'])\n",
    "policy_optimizer = optim.RAdam(policy_net.parameters(), lr=params['policy_lr'])\n",
    "\n",
    "\n",
    "layout = {\n",
    "    'train': {'value': [], 'softq': [], 'policy': [], 'step': []},\n",
    "    'test': {'value': [], 'softq': [], 'policy': [], 'step': []}\n",
    "    }\n",
    "\n",
    "writer = SummaryWriter(log_dir='../../runs')\n",
    "debugger = recnn.Debugger(layout, run_tests, writer)\n",
    "plotter = recnn.Plotter(debugger, [['policy'], ['value', 'softq']],)\n",
    "torch.autograd.set_detect_anomaly(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 150\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x432 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-35-651adf7232a7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msoft_q_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m         \u001b[0mdebugger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_losses\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-32-c0b50579c38e>\u001b[0m in \u001b[0;36msoft_q_update\u001b[0;34m(step, batch, params, learn)\u001b[0m\n\u001b[1;32m     42\u001b[0m         \u001b[0mdebugger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_object\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'next_action'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_action\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0mpolicy_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m         \u001b[0mpolicy_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0mpolicy_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    116\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    117\u001b[0m         \"\"\"\n\u001b[0;32m--> 118\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    120\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     91\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     92\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(n_epochs):\n",
    "    for batch in tqdm(env.train_dataloader):\n",
    "        loss = soft_q_update(step, batch, params)\n",
    "        debugger.log_losses(loss)\n",
    "        step += 1\n",
    "        debugger.log_step(step)\n",
    "        if step % plot_every == 0:\n",
    "            clear_output(True)\n",
    "            print('step', step)\n",
    "            debugger.test()\n",
    "            plotter.plot_loss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_actions = debug['next_action']\n",
    "true_actions = env.embeddings.numpy()\n",
    "\n",
    "\n",
    "ad = recnn.nn.AnomalyDetector().to(cuda)\n",
    "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
    "ad.eval()\n",
    "\n",
    "plotter.plot_kde_reconstruction_error(ad, gen_actions, true_actions, cuda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}