awarebayes/RecNN

View on GitHub
examples/1. Vanilla RL/5. LSTM State Encoder.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Doesn't work yet!\n",
    "## PM if I get lazy and dont implement it for too long!\n",
    "## LSTM state encoder\n",
    "\n",
    "P.S. This snippet uses the library version of the learning function, you can see the visualization in the tensorboard\n",
    "\n",
    "\n",
    "**And aslo this lstm state encoder is experemental. Recnn is an RL library, you are expected to do dataloading by yourself**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **Note on this tutorials:**\n",
    "**They mostly contain low level implementations explaining what is going on inside the library.**\n",
    "\n",
    "**Most of the stuff explained here is already available out of the box for your usage.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch.nn.functional as F\n",
    "import torch_optimizer as optim\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 25\n",
    "n_epochs   = 100\n",
    "plot_every = 30\n",
    "step       = 0\n",
    "# --- \n",
    "\n",
    "tqdm.pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d917602d37c9498f9cd653522889a4e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "239dfc1ec119486a8b4c553e360fb061",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "587859c1219b4d22a702a6dcaefa9fbf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "state_encoder = nn.LSTM(129, 256, batch_first=True).to(cuda)\n",
    "\n",
    "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
    "dirs = recnn.data.env.DataPath(\n",
    "    base=\"../../../data/\",\n",
    "    embeddings=\"embeddings/ml20_pca128.pkl\",\n",
    "    ratings=\"ml-20m/ratings.csv\",\n",
    "    cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
    "    use_cache=True\n",
    ")\n",
    "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_tests():\n",
    "    batch = next(env.test_batch())\n",
    "    loss = ddpg_update(batch, params, nets, optimizer,\n",
    "                       cuda, debug, writer, step=step, learn=False)\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === ddpg settings ===\n",
    "\n",
    "params = {\n",
    "    'gamma'      : 0.99,\n",
    "    'min_value'  : -10,\n",
    "    'max_value'  : 10,\n",
    "    'policy_step': 10,\n",
    "    'soft_tau'   : 0.001,\n",
    "    \n",
    "    'policy_lr'  : 1e-5,\n",
    "    'value_lr'   : 1e-5,\n",
    "    'actor_weight_init': 54e-2,\n",
    "    'critic_weight_init': 6e-1,\n",
    "}\n",
    "\n",
    "# === end ==="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "value_net  = recnn.nn.Critic(256, 128, 256, params['critic_weight_init']).to(cuda)\n",
    "policy_net = recnn.nn.Actor(256, 128, 256, params['actor_weight_init']).to(cuda)\n",
    "\n",
    "target_value_net = recnn.nn.Critic(256, 128, 256).to(cuda)\n",
    "target_policy_net = recnn.nn.Actor(256, 128, 256).to(cuda)\n",
    "\n",
    "target_policy_net.eval()\n",
    "target_value_net.eval()\n",
    "\n",
    "\n",
    "recnn.utils.soft_update(value_net, target_value_net, soft_tau=1.0)\n",
    "recnn.utils.soft_update(policy_net, target_policy_net, soft_tau=1.0)\n",
    "\n",
    "pm = list(policy_net.parameters()) + list(state_encoder.parameters())\n",
    "value_optimizer = optim.RAdam(value_net.parameters(),\n",
    "                              lr=params['value_lr'], weight_decay=1e-2)\n",
    "policy_optimizer = optim.RAdam(pm, lr=params['policy_lr'] , weight_decay=1e-2)\n",
    "\n",
    "nets = {\n",
    "    'value_net': value_net,\n",
    "    'target_value_net': target_value_net,\n",
    "    'policy_net': policy_net,\n",
    "    'target_policy_net': target_policy_net,\n",
    "}\n",
    "\n",
    "optimizer = {\n",
    "    'policy_optimizer': policy_optimizer,\n",
    "    'value_optimizer':  value_optimizer\n",
    "}\n",
    "\n",
    "debug = {}\n",
    "writer = SummaryWriter(log_dir='../../runs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b4a29b0a45144738fbd7647972156cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd1469b920064b06b41e2faffafcd4ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-7-ee66f16313bb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m         loss = recnn.nn.update.ddpg_update(batch, params, nets, optimizer,\n\u001b[1;32m      4\u001b[0m                        cuda, debug, writer, step=step)\n\u001b[1;32m      5\u001b[0m         \u001b[0;31m#if step % 10 == 0:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/tqdm/_tqdm_notebook.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    221\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    224\u001b[0m                 \u001b[0;31m# return super(tqdm...) will not catch exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    225\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1003\u001b[0m                 \"\"\"), fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m   1004\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1005\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1006\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1007\u001b[0m                 \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/RecNN/recnn/data/env.py\u001b[0m in \u001b[0;36mtrain_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    220\u001b[0m                     \u001b[0mreward\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mratings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m                     \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m                     \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    223\u001b[0m                     \u001b[0mnext_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    546\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    548\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m    562\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_packed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    563\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 564\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    565\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    566\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mGRU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRNNBase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_tensor\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m    541\u001b[0m         \u001b[0munsorted_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 543\u001b[0;31m         \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    544\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    545\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, input, hx, batch_sizes, max_batch_size, sorted_indices)\u001b[0m\n\u001b[1;32m    524\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    525\u001b[0m             result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,\n\u001b[0;32m--> 526\u001b[0;31m                               self.dropout, self.training, self.bidirectional, self.batch_first)\n\u001b[0m\u001b[1;32m    527\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    528\u001b[0m             result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "step = 0\n",
    "for batch in tqdm(env.train_batch()):\n",
    "        loss = recnn.nn.update.ddpg_update(batch, params, nets, optimizer,\n",
    "                       cuda, debug, writer, step=step)\n",
    "        #if step % 10 == 0:\n",
    "        #    run_tests()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}