examples/1. Vanilla RL/5. LSTM State Encoder.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Doesn't work yet!\n",
"## PM if I get lazy and dont implement it for too long!\n",
"## LSTM state encoder\n",
"\n",
"P.S. This snippet uses the library version of the learning function, you can see the visualization in the tensorboard\n",
"\n",
"\n",
"**And aslo this lstm state encoder is experemental. Recnn is an RL library, you are expected to do dataloading by yourself**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Note on this tutorials:**\n",
"**They mostly contain low level implementations explaining what is going on inside the library.**\n",
"\n",
"**Most of the stuff explained here is already available out of the box for your usage.**"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import torch.nn.functional as F\n",
"import torch_optimizer as optim\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"cuda = torch.device('cuda')\n",
"\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 25\n",
"n_epochs = 100\n",
"plot_every = 30\n",
"step = 0\n",
"# --- \n",
"\n",
"tqdm.pandas()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d917602d37c9498f9cd653522889a4e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "239dfc1ec119486a8b4c553e360fb061",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "587859c1219b4d22a702a6dcaefa9fbf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"state_encoder = nn.LSTM(129, 256, batch_first=True).to(cuda)\n",
"\n",
"# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
"dirs = recnn.data.env.DataPath(\n",
" base=\"../../../data/\",\n",
" embeddings=\"embeddings/ml20_pca128.pkl\",\n",
" ratings=\"ml-20m/ratings.csv\",\n",
" cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
" use_cache=True\n",
")\n",
"env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def run_tests():\n",
" batch = next(env.test_batch())\n",
" loss = ddpg_update(batch, params, nets, optimizer,\n",
" cuda, debug, writer, step=step, learn=False)\n",
" return losses"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# === ddpg settings ===\n",
"\n",
"params = {\n",
" 'gamma' : 0.99,\n",
" 'min_value' : -10,\n",
" 'max_value' : 10,\n",
" 'policy_step': 10,\n",
" 'soft_tau' : 0.001,\n",
" \n",
" 'policy_lr' : 1e-5,\n",
" 'value_lr' : 1e-5,\n",
" 'actor_weight_init': 54e-2,\n",
" 'critic_weight_init': 6e-1,\n",
"}\n",
"\n",
"# === end ==="
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"\n",
"value_net = recnn.nn.Critic(256, 128, 256, params['critic_weight_init']).to(cuda)\n",
"policy_net = recnn.nn.Actor(256, 128, 256, params['actor_weight_init']).to(cuda)\n",
"\n",
"target_value_net = recnn.nn.Critic(256, 128, 256).to(cuda)\n",
"target_policy_net = recnn.nn.Actor(256, 128, 256).to(cuda)\n",
"\n",
"target_policy_net.eval()\n",
"target_value_net.eval()\n",
"\n",
"\n",
"recnn.utils.soft_update(value_net, target_value_net, soft_tau=1.0)\n",
"recnn.utils.soft_update(policy_net, target_policy_net, soft_tau=1.0)\n",
"\n",
"pm = list(policy_net.parameters()) + list(state_encoder.parameters())\n",
"value_optimizer = optim.RAdam(value_net.parameters(),\n",
" lr=params['value_lr'], weight_decay=1e-2)\n",
"policy_optimizer = optim.RAdam(pm, lr=params['policy_lr'] , weight_decay=1e-2)\n",
"\n",
"nets = {\n",
" 'value_net': value_net,\n",
" 'target_value_net': target_value_net,\n",
" 'policy_net': policy_net,\n",
" 'target_policy_net': target_policy_net,\n",
"}\n",
"\n",
"optimizer = {\n",
" 'policy_optimizer': policy_optimizer,\n",
" 'value_optimizer': value_optimizer\n",
"}\n",
"\n",
"debug = {}\n",
"writer = SummaryWriter(log_dir='../../runs')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6b4a29b0a45144738fbd7647972156cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fd1469b920064b06b41e2faffafcd4ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=5263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-ee66f16313bb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m loss = recnn.nn.update.ddpg_update(batch, params, nets, optimizer,\n\u001b[1;32m 4\u001b[0m cuda, debug, writer, step=step)\n\u001b[1;32m 5\u001b[0m \u001b[0;31m#if step % 10 == 0:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/tqdm/_tqdm_notebook.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 224\u001b[0m \u001b[0;31m# return super(tqdm...) will not catch exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1003\u001b[0m \"\"\"), fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m 1004\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1005\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1006\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1007\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Documents/RecNN/recnn/data/env.py\u001b[0m in \u001b[0;36mtrain_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mreward\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mratings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0mnext_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 548\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_packed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 564\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 565\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mGRU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRNNBase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_tensor\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0munsorted_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 543\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 544\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, input, hx, batch_sizes, max_batch_size, sorted_indices)\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 525\u001b[0m result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,\n\u001b[0;32m--> 526\u001b[0;31m self.dropout, self.training, self.bidirectional, self.batch_first)\n\u001b[0m\u001b[1;32m 527\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 528\u001b[0m result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"step = 0\n",
"for batch in tqdm(env.train_batch()):\n",
" loss = recnn.nn.update.ddpg_update(batch, params, nets, optimizer,\n",
" cuda, debug, writer, step=step)\n",
" #if step % 10 == 0:\n",
" # run_tests()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}