awarebayes/RecNN

View on GitHub
examples/99.To be released, but working/2. BCQ/2. BCQ Pyro.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BCQ with Pyro Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch_optimizer as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from pyro.infer import SVI, JitTrace_ELBO\n",
    "import pyro.optim as poptim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.auto import tqdm\n",
    "import pickle\n",
    "import gc\n",
    "import json\n",
    "import h5py\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 1\n",
    "# --- \n",
    "\n",
    "tqdm.pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done!\n"
     ]
    }
   ],
   "source": [
    "# https://drive.google.com/open?id=1kTyu05ZmtP2MA33J5hWdX8OyUYEDW4iI\n",
    "movie_embeddings_key_dict = pickle.load(open('../../data/infos_pca128.pytorch', 'rb'))\n",
    "movies_embeddings_tensor, \\\n",
    "key_to_id, id_to_key = recnn.data.make_items_tensor(movie_embeddings_key_dict)\n",
    "# download ml20m dataset yourself\n",
    "ratings = pd.read_csv('../../data/ml-20m/ratings.csv')\n",
    "user_dict, users = recnn.data.prepare_dataset(ratings, key_to_id, frame_size)\n",
    "del ratings\n",
    "gc.collect()\n",
    "clear_output(True)\n",
    "clear_output(True)\n",
    "print('Done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jupyterthemes import jtplot\n",
    "jtplot.style(theme='monokai')\n",
    "\n",
    "def run_tests():\n",
    "    test_batch = next(iter(test_dataloader))\n",
    "    losses = bcq_update(test_batch, params, learn=False, step=step)\n",
    "    return losses\n",
    "\n",
    "def soft_update(net, target_net, soft_tau=1e-2):\n",
    "    for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
    "            target_param.data.copy_(\n",
    "                target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# P.S. This is not a usual Actor.\n",
    "# It is a peturbative network that takes an input from the Generator\n",
    "# And adjusts (perturbates) it to look like normal action\n",
    "# P.S. Yep, this is also a reference, check out his soundcloud:\n",
    "# soundcloud.com/perturbator\n",
    "class Perturbator(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-1):\n",
    "        super(Perturbator, self).__init__()\n",
    "        \n",
    "        self.drop_layer = nn.Dropout(p=0.5)\n",
    "        \n",
    "        self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.linear3 = nn.Linear(hidden_size, action_dim)\n",
    "        \n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "        \n",
    "    def forward(self, state, action):\n",
    "        a = torch.cat([state, action], 1)\n",
    "        a = F.relu(self.linear1(a))\n",
    "        a = self.drop_layer(a)\n",
    "        a = F.relu(self.linear2(a))\n",
    "        a = self.drop_layer(a)\n",
    "        a = self.linear3(a) \n",
    "        return a + action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, latent_dim):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.e1 = nn.Linear(input_dim + action_dim, 750)\n",
    "        self.e2 = nn.Linear(750, 750)\n",
    "\n",
    "        self.mean = nn.Linear(750, latent_dim)\n",
    "        self.log_std = nn.Linear(750, latent_dim)\n",
    "    \n",
    "    def forward(self, state, action):\n",
    "        z = F.relu(self.e1(torch.cat([state, action], 1)))\n",
    "        z = F.relu(self.e2(z))\n",
    "        mean = self.mean(z)\n",
    "        log_std = self.log_std(z).clamp(-4, 4)\n",
    "        std = torch.exp(log_std)\n",
    "        return mean, std\n",
    "    \n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, latent_dim):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.d1 = nn.Linear(input_dim + latent_dim, 750)\n",
    "        self.d2 = nn.Linear(750, 750)\n",
    "        self.d3 = nn.Linear(750, action_dim)\n",
    "        \n",
    "    def forward(self, state, z):\n",
    "        z = torch.cat([state, z], 1)\n",
    "        z =  F.relu(self.d1(z))\n",
    "        z =  F.relu(self.d2(z))\n",
    "        z =  F.relu(self.d3(z))\n",
    "        return z\n",
    "\n",
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, latent_dim):\n",
    "        super(Generator, self).__init__()\n",
    "        self.encoder = Encoder(input_dim, action_dim, latent_dim)\n",
    "        self.decoder = Decoder(input_dim, action_dim, latent_dim)\n",
    "        self.latent_dim = latent_dim\n",
    "        self.normal = torch.distributions.Normal(0, 1)\n",
    "    \n",
    "    def model(self, state, action):\n",
    "        pyro.module(\"decoder\", self.decoder)\n",
    "        batch_size = state.size(0)\n",
    "        with pyro.plate(\"data\", batch_size):\n",
    "            # setup hyperparameters for prior p(z)\n",
    "            z_loc = torch.zeros(torch.Size((batch_size, self.latent_dim))).to(cuda)\n",
    "            z_scale = torch.ones(torch.Size((batch_size, self.latent_dim))).to(cuda)\n",
    "            # sample from prior (value will be sampled by guide when computing the ELBO)\n",
    "            z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1)).to(cuda)\n",
    "            # decode the latent code z\n",
    "            decoded = self.decoder.forward(state, z)\n",
    "            \n",
    "            # score against actual images\n",
    "            d_loc, d_scale = decoded.mean(0), decoded.std(0) + 1e-8 # small epsilon\n",
    "            obs = pyro.sample(\"obs\", dist.Normal(d_loc, d_scale).to_event(1), obs=action)\n",
    "\n",
    "            \n",
    "    def guide(self, state, action):\n",
    "        pyro.module(\"encoder\", self.encoder)\n",
    "        batch_size = state.size(0)\n",
    "        with pyro.plate(\"data\", batch_size):\n",
    "            # use the encoder to get the parameters used to define q(z|x)\n",
    "            z_loc, z_scale = self.encoder.forward(state, action)\n",
    "            # sample the latent code z\n",
    "            pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n",
    "            \n",
    "    def decode(self, state, z=None):\n",
    "        if z is None:\n",
    "            z = self.normal.sample([state.size(0), self.latent_dim])\n",
    "            z = z.clamp(-0.9, 0.9).to(cuda)\n",
    "            \n",
    "        return self.decoder(state, z)\n",
    "    \n",
    "    def forward(self, state, action):\n",
    "        # encode image x\n",
    "        z_loc, z_scale = self.encoder(state, action)\n",
    "        # sample in latent space\n",
    "        z = dist.Normal(z_loc, z_scale).sample()\n",
    "        # decode the image (note we don't sample in image space)\n",
    "        action = self.decoder(z, state)\n",
    "        return action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Trace Shapes:             \n",
      "       Param Sites:             \n",
      "decoder$$$d1.weight 750 1802    \n",
      "  decoder$$$d1.bias      750    \n",
      "decoder$$$d2.weight 750  750    \n",
      "  decoder$$$d2.bias      750    \n",
      "decoder$$$d3.weight 128  750    \n",
      "  decoder$$$d3.bias      128    \n",
      "      Sample Sites:             \n",
      "          data dist        |    \n",
      "              value 100    |    \n",
      "           log_prob        |    \n",
      "        latent dist 100    | 512\n",
      "              value 100    | 512\n",
      "           log_prob 100    |    \n",
      "           obs dist 100    | 128\n",
      "              value 100    | 128\n",
      "           log_prob 100    |    \n"
     ]
    }
   ],
   "source": [
    "generator_net = Generator(1290, 128, 512).to(cuda)\n",
    "trace = pyro.poutine.trace(generator_net.model).get_trace(torch.zeros(100, 1290).to(cuda),\n",
    "                                                         torch.zeros(100, 128).to(cuda))\n",
    "trace.compute_log_prob()  # optional, but allows printing of log_prob shapes\n",
    "print(trace.format_shapes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bcq_update(batch, params, learn=True, step=-1):\n",
    "    \n",
    "    batch = [i.to(cuda) for i in batch]\n",
    "    state, action, reward, next_state, done = batch\n",
    "    reward     = reward.unsqueeze(1)\n",
    "    done       = done.unsqueeze(1)\n",
    "    batch_size = done.size(0)\n",
    "    \n",
    "    # --------------------------------------------------------#\n",
    "    # Variational Auto-Encoder Learning\n",
    "    recon = generator_net(state, action)\n",
    "    if not learn:\n",
    "        generator_loss = generator_optimizer.evaluate_loss(state, action)\n",
    "        debugger.log_object('recon test', recon)\n",
    "        writer.add_figure('reconstructed',\n",
    "                          recnn.plot.pairwise_distances_fig(recon[:50]), step)\n",
    "        \n",
    "    if learn:\n",
    "        generator_loss = generator_optimizer.step(state, action)\n",
    "        debugger.log_object('recon', recon)\n",
    "        \n",
    "    # --------------------------------------------------------#\n",
    "    # Value Learning\n",
    "    with torch.no_grad():\n",
    "        # p.s. repeat_interleave was added in torch 1.1\n",
    "        # if an error pops up, run 'conda update pytorch'\n",
    "        state_rep = torch.repeat_interleave(next_state, params['n_generator_samples'], 0)\n",
    "        sampled_action = generator_net.decode(state_rep)\n",
    "        perturbed_action = target_perturbator_net(state_rep, sampled_action)\n",
    "        target_Q1 = target_value_net1(state_rep, perturbed_action)\n",
    "        target_Q2 = target_value_net1(state_rep, perturbed_action)\n",
    "        target_value = 0.75 * torch.min(target_Q1, target_Q2) # value soft update\n",
    "        target_value+= 0.25 * torch.max(target_Q1, target_Q2) #\n",
    "        target_value = target_value.view(batch_size, -1).max(1)[0].view(-1, 1)\n",
    "        \n",
    "        expected_value = reward + (1.0 - done) * params['gamma'] * target_value\n",
    "\n",
    "    value = value_net1(state, action)\n",
    "    value_loss = torch.pow(value - expected_value.detach(), 2).mean()\n",
    "    debugger.log_error('value', value, test=(not learn))\n",
    "    debugger.log_error('target_value ', target_value, test=(not learn))\n",
    "    \n",
    "    if learn:\n",
    "        value_optimizer1.zero_grad()\n",
    "        value_optimizer2.zero_grad()\n",
    "        value_loss.backward()\n",
    "        value_optimizer1.step()\n",
    "        value_optimizer2.step()\n",
    "    else:\n",
    "        writer.add_histogram('value', value, step)\n",
    "        writer.add_histogram('target value', target_value, step)\n",
    "        writer.add_histogram('expected value', expected_value, step)\n",
    "        writer.close()\n",
    "    \n",
    "    # --------------------------------------------------------#\n",
    "    # Perturbator learning\n",
    "    sampled_actions = generator_net.decode(state)\n",
    "    perturbed_actions= perturbator_net(state, sampled_actions)\n",
    "    perturbator_loss = -value_net1(state, perturbed_actions)\n",
    "    if not learn:\n",
    "        writer.add_histogram('perturbator_loss', perturbator_loss, step)\n",
    "    perturbator_loss = perturbator_loss.mean()\n",
    "    \n",
    "    debugger.log_object('sampled_actions', sampled_actions, test=(not learn))\n",
    "    debugger.log_object('perturbed_actions', perturbed_actions, test=(not learn))\n",
    "    \n",
    "    if learn:\n",
    "        if step % params['perturbator_step']:\n",
    "            perturbator_optimizer.zero_grad()\n",
    "            perturbator_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(perturbator_net.parameters(), -1, 1)\n",
    "            perturbator_optimizer.step()\n",
    "        \n",
    "        soft_update(value_net1, target_value_net1, soft_tau=params['soft_tau'])\n",
    "        soft_update(value_net2, target_value_net2, soft_tau=params['soft_tau'])\n",
    "        soft_update(perturbator_net, target_perturbator_net, soft_tau=params['soft_tau'])\n",
    "    else:\n",
    "        writer.add_figure('sampled_actions',\n",
    "            recnn.plot.pairwise_distances_fig(sampled_actions[:50]), step)\n",
    "        writer.add_figure('perturbed_actions',\n",
    "            recnn.plot.pairwise_distances_fig(perturbed_actions[:50]), step)\n",
    "        \n",
    "    # --------------------------------------------------------#\n",
    "\n",
    "    losses = {'value': value_loss.item(),\n",
    "              'perturbator': perturbator_loss.item(),\n",
    "              'generator': generator_loss,\n",
    "              'step': step}\n",
    "    \n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === BCQ settings ===\n",
    "params = {\n",
    "    # algorithm parameters\n",
    "    'gamma'              : 0.99,\n",
    "    'soft_tau'           : 0.001,\n",
    "    'n_generator_samples': 1,\n",
    "    'perturbator_step'   : 10,\n",
    "    \n",
    "    # learning rates\n",
    "    'perturbator_lr' : 1e-5,\n",
    "    'value_lr'       : 1e-5,\n",
    "    'generator_lr'   : 1e-5,\n",
    "}\n",
    "# === end ==="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_net = Generator(1290, 128, 512).to(cuda)\n",
    "value_net1  = recnn.models.Critic(1290, 128, 256, init_w=8e-1).to(cuda)\n",
    "value_net2  = recnn.models.Critic(1290, 128, 256, init_w=8e-1).to(cuda)\n",
    "perturbator_net = Perturbator(1290, 128, 256).to(cuda)\n",
    "\n",
    "target_value_net1 = recnn.models.Critic(1290, 128, 256).to(cuda)\n",
    "target_value_net2 = recnn.models.Critic(1290, 128, 256).to(cuda)\n",
    "target_perturbator_net = Perturbator(1290, 128, 256).to(cuda)\n",
    "\n",
    "target_perturbator_net.eval()\n",
    "target_value_net1.eval()\n",
    "target_value_net2.eval()\n",
    "\n",
    "soft_update(value_net1, target_value_net1, soft_tau=1.0)\n",
    "soft_update(value_net2, target_value_net2, soft_tau=1.0)\n",
    "soft_update(perturbator_net, target_perturbator_net, soft_tau=1.0)\n",
    "\n",
    "\n",
    "# optim.Adam can be replaced with RAdam\n",
    "value_optimizer1 = optim.RAdam(value_net1.parameters(),\n",
    "                              lr=params['value_lr'])\n",
    "value_optimizer2 = optim.RAdam(value_net2.parameters(),\n",
    "                              lr=params['perturbator_lr'])\n",
    "perturbator_optimizer = optim.RAdam(perturbator_net.parameters(),\n",
    "                              lr=params['value_lr'], weight_decay=1e-1)\n",
    "generator_optimizer = SVI(generator_net.model, generator_net.guide,\n",
    "                          poptim.Adam({\"lr\": params['generator_lr']}), loss=JitTrace_ELBO())\n",
    "# I would advice you not to weight decay generator\n",
    "\n",
    "layout = {\n",
    "    'train': {'value': [], 'perturbator': [], 'generator': [], 'step': []},\n",
    "    'test': {'value': [], 'perturbator': [], 'generator': [], 'step': []},\n",
    "    }\n",
    "\n",
    "writer = SummaryWriter(log_dir='../../runs')\n",
    "debugger = recnn.Debugger(layout, run_tests, writer)\n",
    "plotter = recnn.Plotter(debugger, [['generator'], ['value', 'perturbator']],)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pyro.enable_validation(True) \n",
    "# torch.autograd.set_detect_anomaly(True)\n",
    "\n",
    "step = 1\n",
    "n_epochs = 100\n",
    "batch_size = 25\n",
    "\n",
    "epoch_bar = tqdm(total=n_epochs)\n",
    "\n",
    "test_losses = [[], [], [], []]\n",
    "train_users = users[:-5000]\n",
    "test_users = users[-5000:]\n",
    "\n",
    "def prepare_batch_wrapper(x):\n",
    "    batch = recnn.data.prepare_batch_static_size(x, movies_embeddings_tensor,\n",
    "                                                 frame_size=frame_size)\n",
    "    return batch\n",
    "\n",
    "train_user_dataset = recnn.data.UserDataset(train_users, user_dict)\n",
    "test_user_dataset = recnn.data.UserDataset(test_users, user_dict)\n",
    "train_dataloader = DataLoader(train_user_dataset, batch_size=batch_size,\n",
    "                        shuffle=False, num_workers=1,collate_fn=prepare_batch_wrapper)\n",
    "test_dataloader = DataLoader(test_user_dataset, batch_size=batch_size,\n",
    "                        shuffle=False, num_workers=1,collate_fn=prepare_batch_wrapper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# --- config ---\n",
    "plot_every = 10\n",
    "# --- end ---\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    epoch_bar.update(1)\n",
    "    for batch in tqdm(train_dataloader):\n",
    "        loss = bcq_update(batch, params, step=step)\n",
    "        debugger.log_losses(loss)\n",
    "        step += 1\n",
    "        debugger.log_step(step)\n",
    "        if step % plot_every == 0:\n",
    "            clear_output(True)\n",
    "            print('step', step)\n",
    "            debugger.test()\n",
    "            plotter.plot_loss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}