examples/99.To be released, but working/2. BCQ/2. BCQ Pyro.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BCQ with Pyro Inference"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import torch_optimizer as optim"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pyro\n",
"import pyro.distributions as dist\n",
"from pyro.infer import SVI, JitTrace_ELBO\n",
"import pyro.optim as poptim"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"import pickle\n",
"import gc\n",
"import json\n",
"import h5py\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"cuda = torch.device('cuda')\n",
"\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 1\n",
"# --- \n",
"\n",
"tqdm.pandas()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done!\n"
]
}
],
"source": [
"# https://drive.google.com/open?id=1kTyu05ZmtP2MA33J5hWdX8OyUYEDW4iI\n",
"movie_embeddings_key_dict = pickle.load(open('../../data/infos_pca128.pytorch', 'rb'))\n",
"movies_embeddings_tensor, \\\n",
"key_to_id, id_to_key = recnn.data.make_items_tensor(movie_embeddings_key_dict)\n",
"# download ml20m dataset yourself\n",
"ratings = pd.read_csv('../../data/ml-20m/ratings.csv')\n",
"user_dict, users = recnn.data.prepare_dataset(ratings, key_to_id, frame_size)\n",
"del ratings\n",
"gc.collect()\n",
"clear_output(True)\n",
"clear_output(True)\n",
"print('Done!')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from jupyterthemes import jtplot\n",
"jtplot.style(theme='monokai')\n",
"\n",
"def run_tests():\n",
" test_batch = next(iter(test_dataloader))\n",
" losses = bcq_update(test_batch, params, learn=False, step=step)\n",
" return losses\n",
"\n",
"def soft_update(net, target_net, soft_tau=1e-2):\n",
" for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
" target_param.data.copy_(\n",
" target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# P.S. This is not a usual Actor.\n",
"# It is a peturbative network that takes an input from the Generator\n",
"# And adjusts (perturbates) it to look like normal action\n",
"# P.S. Yep, this is also a reference, check out his soundcloud:\n",
"# soundcloud.com/perturbator\n",
"class Perturbator(nn.Module):\n",
" def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-1):\n",
" super(Perturbator, self).__init__()\n",
" \n",
" self.drop_layer = nn.Dropout(p=0.5)\n",
" \n",
" self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear3 = nn.Linear(hidden_size, action_dim)\n",
" \n",
" self.linear3.weight.data.uniform_(-init_w, init_w)\n",
" self.linear3.bias.data.uniform_(-init_w, init_w)\n",
" \n",
" def forward(self, state, action):\n",
" a = torch.cat([state, action], 1)\n",
" a = F.relu(self.linear1(a))\n",
" a = self.drop_layer(a)\n",
" a = F.relu(self.linear2(a))\n",
" a = self.drop_layer(a)\n",
" a = self.linear3(a) \n",
" return a + action"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
" def __init__(self, input_dim, action_dim, latent_dim):\n",
" super(Encoder, self).__init__()\n",
" self.e1 = nn.Linear(input_dim + action_dim, 750)\n",
" self.e2 = nn.Linear(750, 750)\n",
"\n",
" self.mean = nn.Linear(750, latent_dim)\n",
" self.log_std = nn.Linear(750, latent_dim)\n",
" \n",
" def forward(self, state, action):\n",
" z = F.relu(self.e1(torch.cat([state, action], 1)))\n",
" z = F.relu(self.e2(z))\n",
" mean = self.mean(z)\n",
" log_std = self.log_std(z).clamp(-4, 4)\n",
" std = torch.exp(log_std)\n",
" return mean, std\n",
" \n",
"class Decoder(nn.Module):\n",
" def __init__(self, input_dim, action_dim, latent_dim):\n",
" super(Decoder, self).__init__()\n",
" self.d1 = nn.Linear(input_dim + latent_dim, 750)\n",
" self.d2 = nn.Linear(750, 750)\n",
" self.d3 = nn.Linear(750, action_dim)\n",
" \n",
" def forward(self, state, z):\n",
" z = torch.cat([state, z], 1)\n",
" z = F.relu(self.d1(z))\n",
" z = F.relu(self.d2(z))\n",
" z = F.relu(self.d3(z))\n",
" return z\n",
"\n",
"\n",
"class Generator(nn.Module):\n",
" def __init__(self, input_dim, action_dim, latent_dim):\n",
" super(Generator, self).__init__()\n",
" self.encoder = Encoder(input_dim, action_dim, latent_dim)\n",
" self.decoder = Decoder(input_dim, action_dim, latent_dim)\n",
" self.latent_dim = latent_dim\n",
" self.normal = torch.distributions.Normal(0, 1)\n",
" \n",
" def model(self, state, action):\n",
" pyro.module(\"decoder\", self.decoder)\n",
" batch_size = state.size(0)\n",
" with pyro.plate(\"data\", batch_size):\n",
" # setup hyperparameters for prior p(z)\n",
" z_loc = torch.zeros(torch.Size((batch_size, self.latent_dim))).to(cuda)\n",
" z_scale = torch.ones(torch.Size((batch_size, self.latent_dim))).to(cuda)\n",
" # sample from prior (value will be sampled by guide when computing the ELBO)\n",
" z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1)).to(cuda)\n",
" # decode the latent code z\n",
" decoded = self.decoder.forward(state, z)\n",
" \n",
" # score against actual images\n",
" d_loc, d_scale = decoded.mean(0), decoded.std(0) + 1e-8 # small epsilon\n",
" obs = pyro.sample(\"obs\", dist.Normal(d_loc, d_scale).to_event(1), obs=action)\n",
"\n",
" \n",
" def guide(self, state, action):\n",
" pyro.module(\"encoder\", self.encoder)\n",
" batch_size = state.size(0)\n",
" with pyro.plate(\"data\", batch_size):\n",
" # use the encoder to get the parameters used to define q(z|x)\n",
" z_loc, z_scale = self.encoder.forward(state, action)\n",
" # sample the latent code z\n",
" pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n",
" \n",
" def decode(self, state, z=None):\n",
" if z is None:\n",
" z = self.normal.sample([state.size(0), self.latent_dim])\n",
" z = z.clamp(-0.9, 0.9).to(cuda)\n",
" \n",
" return self.decoder(state, z)\n",
" \n",
" def forward(self, state, action):\n",
" # encode image x\n",
" z_loc, z_scale = self.encoder(state, action)\n",
" # sample in latent space\n",
" z = dist.Normal(z_loc, z_scale).sample()\n",
" # decode the image (note we don't sample in image space)\n",
" action = self.decoder(z, state)\n",
" return action"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Trace Shapes: \n",
" Param Sites: \n",
"decoder$$$d1.weight 750 1802 \n",
" decoder$$$d1.bias 750 \n",
"decoder$$$d2.weight 750 750 \n",
" decoder$$$d2.bias 750 \n",
"decoder$$$d3.weight 128 750 \n",
" decoder$$$d3.bias 128 \n",
" Sample Sites: \n",
" data dist | \n",
" value 100 | \n",
" log_prob | \n",
" latent dist 100 | 512\n",
" value 100 | 512\n",
" log_prob 100 | \n",
" obs dist 100 | 128\n",
" value 100 | 128\n",
" log_prob 100 | \n"
]
}
],
"source": [
"generator_net = Generator(1290, 128, 512).to(cuda)\n",
"trace = pyro.poutine.trace(generator_net.model).get_trace(torch.zeros(100, 1290).to(cuda),\n",
" torch.zeros(100, 128).to(cuda))\n",
"trace.compute_log_prob() # optional, but allows printing of log_prob shapes\n",
"print(trace.format_shapes())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def bcq_update(batch, params, learn=True, step=-1):\n",
" \n",
" batch = [i.to(cuda) for i in batch]\n",
" state, action, reward, next_state, done = batch\n",
" reward = reward.unsqueeze(1)\n",
" done = done.unsqueeze(1)\n",
" batch_size = done.size(0)\n",
" \n",
" # --------------------------------------------------------#\n",
" # Variational Auto-Encoder Learning\n",
" recon = generator_net(state, action)\n",
" if not learn:\n",
" generator_loss = generator_optimizer.evaluate_loss(state, action)\n",
" debugger.log_object('recon test', recon)\n",
" writer.add_figure('reconstructed',\n",
" recnn.plot.pairwise_distances_fig(recon[:50]), step)\n",
" \n",
" if learn:\n",
" generator_loss = generator_optimizer.step(state, action)\n",
" debugger.log_object('recon', recon)\n",
" \n",
" # --------------------------------------------------------#\n",
" # Value Learning\n",
" with torch.no_grad():\n",
" # p.s. repeat_interleave was added in torch 1.1\n",
" # if an error pops up, run 'conda update pytorch'\n",
" state_rep = torch.repeat_interleave(next_state, params['n_generator_samples'], 0)\n",
" sampled_action = generator_net.decode(state_rep)\n",
" perturbed_action = target_perturbator_net(state_rep, sampled_action)\n",
" target_Q1 = target_value_net1(state_rep, perturbed_action)\n",
" target_Q2 = target_value_net1(state_rep, perturbed_action)\n",
" target_value = 0.75 * torch.min(target_Q1, target_Q2) # value soft update\n",
" target_value+= 0.25 * torch.max(target_Q1, target_Q2) #\n",
" target_value = target_value.view(batch_size, -1).max(1)[0].view(-1, 1)\n",
" \n",
" expected_value = reward + (1.0 - done) * params['gamma'] * target_value\n",
"\n",
" value = value_net1(state, action)\n",
" value_loss = torch.pow(value - expected_value.detach(), 2).mean()\n",
" debugger.log_error('value', value, test=(not learn))\n",
" debugger.log_error('target_value ', target_value, test=(not learn))\n",
" \n",
" if learn:\n",
" value_optimizer1.zero_grad()\n",
" value_optimizer2.zero_grad()\n",
" value_loss.backward()\n",
" value_optimizer1.step()\n",
" value_optimizer2.step()\n",
" else:\n",
" writer.add_histogram('value', value, step)\n",
" writer.add_histogram('target value', target_value, step)\n",
" writer.add_histogram('expected value', expected_value, step)\n",
" writer.close()\n",
" \n",
" # --------------------------------------------------------#\n",
" # Perturbator learning\n",
" sampled_actions = generator_net.decode(state)\n",
" perturbed_actions= perturbator_net(state, sampled_actions)\n",
" perturbator_loss = -value_net1(state, perturbed_actions)\n",
" if not learn:\n",
" writer.add_histogram('perturbator_loss', perturbator_loss, step)\n",
" perturbator_loss = perturbator_loss.mean()\n",
" \n",
" debugger.log_object('sampled_actions', sampled_actions, test=(not learn))\n",
" debugger.log_object('perturbed_actions', perturbed_actions, test=(not learn))\n",
" \n",
" if learn:\n",
" if step % params['perturbator_step']:\n",
" perturbator_optimizer.zero_grad()\n",
" perturbator_loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(perturbator_net.parameters(), -1, 1)\n",
" perturbator_optimizer.step()\n",
" \n",
" soft_update(value_net1, target_value_net1, soft_tau=params['soft_tau'])\n",
" soft_update(value_net2, target_value_net2, soft_tau=params['soft_tau'])\n",
" soft_update(perturbator_net, target_perturbator_net, soft_tau=params['soft_tau'])\n",
" else:\n",
" writer.add_figure('sampled_actions',\n",
" recnn.plot.pairwise_distances_fig(sampled_actions[:50]), step)\n",
" writer.add_figure('perturbed_actions',\n",
" recnn.plot.pairwise_distances_fig(perturbed_actions[:50]), step)\n",
" \n",
" # --------------------------------------------------------#\n",
"\n",
" losses = {'value': value_loss.item(),\n",
" 'perturbator': perturbator_loss.item(),\n",
" 'generator': generator_loss,\n",
" 'step': step}\n",
" \n",
" return losses"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# === BCQ settings ===\n",
"params = {\n",
" # algorithm parameters\n",
" 'gamma' : 0.99,\n",
" 'soft_tau' : 0.001,\n",
" 'n_generator_samples': 1,\n",
" 'perturbator_step' : 10,\n",
" \n",
" # learning rates\n",
" 'perturbator_lr' : 1e-5,\n",
" 'value_lr' : 1e-5,\n",
" 'generator_lr' : 1e-5,\n",
"}\n",
"# === end ==="
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"generator_net = Generator(1290, 128, 512).to(cuda)\n",
"value_net1 = recnn.models.Critic(1290, 128, 256, init_w=8e-1).to(cuda)\n",
"value_net2 = recnn.models.Critic(1290, 128, 256, init_w=8e-1).to(cuda)\n",
"perturbator_net = Perturbator(1290, 128, 256).to(cuda)\n",
"\n",
"target_value_net1 = recnn.models.Critic(1290, 128, 256).to(cuda)\n",
"target_value_net2 = recnn.models.Critic(1290, 128, 256).to(cuda)\n",
"target_perturbator_net = Perturbator(1290, 128, 256).to(cuda)\n",
"\n",
"target_perturbator_net.eval()\n",
"target_value_net1.eval()\n",
"target_value_net2.eval()\n",
"\n",
"soft_update(value_net1, target_value_net1, soft_tau=1.0)\n",
"soft_update(value_net2, target_value_net2, soft_tau=1.0)\n",
"soft_update(perturbator_net, target_perturbator_net, soft_tau=1.0)\n",
"\n",
"\n",
"# optim.Adam can be replaced with RAdam\n",
"value_optimizer1 = optim.RAdam(value_net1.parameters(),\n",
" lr=params['value_lr'])\n",
"value_optimizer2 = optim.RAdam(value_net2.parameters(),\n",
" lr=params['perturbator_lr'])\n",
"perturbator_optimizer = optim.RAdam(perturbator_net.parameters(),\n",
" lr=params['value_lr'], weight_decay=1e-1)\n",
"generator_optimizer = SVI(generator_net.model, generator_net.guide,\n",
" poptim.Adam({\"lr\": params['generator_lr']}), loss=JitTrace_ELBO())\n",
"# I would advice you not to weight decay generator\n",
"\n",
"layout = {\n",
" 'train': {'value': [], 'perturbator': [], 'generator': [], 'step': []},\n",
" 'test': {'value': [], 'perturbator': [], 'generator': [], 'step': []},\n",
" }\n",
"\n",
"writer = SummaryWriter(log_dir='../../runs')\n",
"debugger = recnn.Debugger(layout, run_tests, writer)\n",
"plotter = recnn.Plotter(debugger, [['generator'], ['value', 'perturbator']],)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# pyro.enable_validation(True) \n",
"# torch.autograd.set_detect_anomaly(True)\n",
"\n",
"step = 1\n",
"n_epochs = 100\n",
"batch_size = 25\n",
"\n",
"epoch_bar = tqdm(total=n_epochs)\n",
"\n",
"test_losses = [[], [], [], []]\n",
"train_users = users[:-5000]\n",
"test_users = users[-5000:]\n",
"\n",
"def prepare_batch_wrapper(x):\n",
" batch = recnn.data.prepare_batch_static_size(x, movies_embeddings_tensor,\n",
" frame_size=frame_size)\n",
" return batch\n",
"\n",
"train_user_dataset = recnn.data.UserDataset(train_users, user_dict)\n",
"test_user_dataset = recnn.data.UserDataset(test_users, user_dict)\n",
"train_dataloader = DataLoader(train_user_dataset, batch_size=batch_size,\n",
" shuffle=False, num_workers=1,collate_fn=prepare_batch_wrapper)\n",
"test_dataloader = DataLoader(test_user_dataset, batch_size=batch_size,\n",
" shuffle=False, num_workers=1,collate_fn=prepare_batch_wrapper)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# --- config ---\n",
"plot_every = 10\n",
"# --- end ---\n",
"\n",
"for epoch in range(n_epochs):\n",
" epoch_bar.update(1)\n",
" for batch in tqdm(train_dataloader):\n",
" loss = bcq_update(batch, params, step=step)\n",
" debugger.log_losses(loss)\n",
" step += 1\n",
" debugger.log_step(step)\n",
" if step % plot_every == 0:\n",
" clear_output(True)\n",
" print('step', step)\n",
" debugger.test()\n",
" plotter.plot_loss()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}