awarebayes/RecNN

View on GitHub
examples/99.To be released, but working/4. SearchNet/1. DDPG_SN.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## Deep TopK Search with Critic Adjustment"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from abc import ABC\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch.nn.functional as F\n",
    "import torch_optimizer as optim\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from IPython.display import clear_output\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 25\n",
    "n_epochs   = 100\n",
    "plot_every = 30\n",
    "step       = 0\n",
    "# ---\n",
    "\n",
    "tqdm.pandas()\n",
    "\n",
    "from jupyterthemes import jtplot\n",
    "jtplot.style(theme='grade3')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
    "dirs = recnn.data.env.DataPath(\n",
    "    base=\"../../data/\",\n",
    "    embeddings=\"embeddings/ml20_pca128.pkl\",\n",
    "    ratings=\"ml-20m/ratings.csv\",\n",
    "    cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
    "    use_cache=True\n",
    ")\n",
    "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Actor(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-1):\n",
    "        super(Actor, self).__init__()\n",
    "\n",
    "        self.drop_layer = nn.Dropout(p=0.5)\n",
    "\n",
    "        self.linear1 = nn.Linear(input_dim, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.linear3 = nn.Linear(hidden_size, action_dim)\n",
    "\n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "\n",
    "    def forward(self, state):\n",
    "        # state = self.state_rep(state)\n",
    "        x = F.relu(self.linear1(state))\n",
    "        x = self.drop_layer(x)\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.drop_layer(x)\n",
    "        # x = torch.tanh(self.linear3(x)) # in case embeds are -1 1 normalized\n",
    "        x = self.linear3(x) # in case embeds are standard scaled / wiped using PCA whitening\n",
    "        # return state, x\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Critic(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-5):\n",
    "        super(Critic, self).__init__()\n",
    "\n",
    "        self.drop_layer = nn.Dropout(p=0.5)\n",
    "\n",
    "        self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.linear3 = nn.Linear(hidden_size, 1)\n",
    "\n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        x = torch.cat([state, action], 1)\n",
    "        x = F.relu(self.linear1(x))\n",
    "        x = self.drop_layer(x)\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.drop_layer(x)\n",
    "        x = self.linear3(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class SearchK(nn.Module):\n",
    "    def __init__(self, input_dim, action_dim, hidden_size, topK, init_w=3e-1):\n",
    "        super(SearchK, self).__init__()\n",
    "\n",
    "        self.drop_layer = nn.Dropout(p=0.5)\n",
    "        self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.linear3 = nn.Linear(hidden_size, action_dim*topK)\n",
    "\n",
    "        self.linear3.weight.data.uniform_(-init_w, init_w)\n",
    "        self.linear3.bias.data.uniform_(-init_w, init_w)\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        x = torch.cat([state, action], 1)\n",
    "        x = F.relu(self.linear1(x))\n",
    "        x = self.drop_layer(x)\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = self.drop_layer(x)\n",
    "        x = self.linear3(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def soft_update(net, target_net, soft_tau=1e-2):\n",
    "    for target_param, param in zip(target_net.parameters(), net.parameters()):\n",
    "            target_param.data.copy_(\n",
    "                target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n",
    "            )\n",
    "\n",
    "def run_tests():\n",
    "    test_batch = next(iter(env.test_dataloader))\n",
    "    losses = ddpg_sn_update(test_batch, params, learn=False, step=step)\n",
    "\n",
    "    gen_actions = debug['next_action']\n",
    "    true_actions = env.base.embeddings.detach().cpu().numpy()\n",
    "\n",
    "    f = plotter.kde_reconstruction_error(ad, gen_actions, true_actions, cuda)\n",
    "    writer.add_figure('rec_error',f, losses['step'])\n",
    "    return losses"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def ddpg_sn_update(batch, params, learn=True, step=-1):\n",
    "\n",
    "    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n",
    "\n",
    "    # --------------------------------------------------------#\n",
    "    # Value Learning\n",
    "\n",
    "    with torch.no_grad():\n",
    "        next_action = target_policy_net(next_state)\n",
    "        target_value   = target_value_net(next_state, next_action.detach())\n",
    "        expected_value = reward + (1.0 - done) * params['gamma'] * target_value\n",
    "        expected_value = torch.clamp(expected_value,\n",
    "                                     params['min_value'], params['max_value'])\n",
    "\n",
    "    value = value_net(state, action)\n",
    "\n",
    "    value_loss = torch.pow(value - expected_value.detach(), 2).mean()\n",
    "\n",
    "    if learn:\n",
    "        value_optimizer.zero_grad()\n",
    "        value_loss.backward()\n",
    "        value_optimizer.step()\n",
    "    else:\n",
    "        debug['next_action'] = next_action\n",
    "        writer.add_figure('next_action',\n",
    "                    recnn.utils.pairwise_distances_fig(next_action[:50]), step)\n",
    "        writer.add_histogram('value', value, step)\n",
    "        writer.add_histogram('target_value', target_value, step)\n",
    "        writer.add_histogram('expected_value', expected_value, step)\n",
    "\n",
    "    # --------------------------------------------------------#\n",
    "    # Policy learning\n",
    "\n",
    "    gen_action = policy_net(state)\n",
    "    policy_loss = -value_net(state, gen_action)\n",
    "\n",
    "    if not learn:\n",
    "        debug['gen_action'] = gen_action\n",
    "        writer.add_histogram('policy_loss', policy_loss, step)\n",
    "        writer.add_figure('next_action',\n",
    "                    recnn.utils.pairwise_distances_fig(gen_action[:50]), step)\n",
    "\n",
    "    policy_loss = policy_loss.mean()\n",
    "\n",
    "    if learn and step % params['policy_step']== 0:\n",
    "        policy_optimizer.zero_grad()\n",
    "        policy_loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(policy_net.parameters(), -1, 1)\n",
    "        policy_optimizer.step()\n",
    "\n",
    "        soft_update(value_net, target_value_net, soft_tau=params['soft_tau'])\n",
    "        soft_update(policy_net, target_policy_net, soft_tau=params['soft_tau'])\n",
    "\n",
    "    # dont forget search loss here !\n",
    "    losses = {'value': value_loss.item(), 'policy': policy_loss.item(), 'step': step}\n",
    "    recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')\n",
    "    return losses"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# === ddpg settings ===\n",
    "\n",
    "params = {\n",
    "    'gamma'      : 0.99,\n",
    "    'min_value'  : -10,\n",
    "    'max_value'  : 10,\n",
    "    'policy_step': 10,\n",
    "    'soft_tau'   : 0.001,\n",
    "\n",
    "    'policy_lr'  : 1e-5,\n",
    "    'value_lr'   : 1e-5,\n",
    "    'search_lr'  : 1e-5,\n",
    "    'actor_weight_init': 54e-2,\n",
    "    'search_weight_init': 54e-2,\n",
    "    'critic_weight_init': 6e-1,\n",
    "}\n",
    "\n",
    "# === end ==="
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "value_net  = Critic(1290, 128, 256, params['critic_weight_init']).to(cuda)\n",
    "policy_net = Actor(1290, 128, 256, params['actor_weight_init']).to(cuda)\n",
    "search_net = SearchK(1290, 128, 2048, topK=10, init_w=params['search_weight_init']).to(cuda)\n",
    "\n",
    "target_value_net = Critic(1290, 128, 256).to(cuda)\n",
    "target_policy_net = Actor(1290, 128, 256).to(cuda)\n",
    "target_search_net = SearchK(1290, 128, 2048, topK=10).to(cuda)\n",
    "\n",
    "ad = recnn.nn.models.AnomalyDetector().to(cuda)\n",
    "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
    "ad.eval()\n",
    "\n",
    "target_policy_net.eval()\n",
    "target_value_net.eval()\n",
    "\n",
    "soft_update(value_net, target_value_net, soft_tau=1.0)\n",
    "soft_update(policy_net, target_policy_net, soft_tau=1.0)\n",
    "soft_update(search_net, target_search_net, soft_tau=1.0)\n",
    "\n",
    "value_criterion = nn.MSELoss()\n",
    "search_criterion = nn.MSELoss()\n",
    "\n",
    "# from good to bad: Ranger Radam Adam RMSprop\n",
    "value_optimizer = optim.Ranger(value_net.parameters(),\n",
    "                              lr=params['value_lr'], weight_decay=1e-2)\n",
    "policy_optimizer = optim.Ranger(policy_net.parameters(),\n",
    "                               lr=params['policy_lr'], weight_decay=1e-5)\n",
    "search_optimizer = optim.Ranger(search_net.parameters(),\n",
    "                                weight_decay=1e-5,\n",
    "                                lr=params['search_lr'])\n",
    "\n",
    "loss = {\n",
    "    'test': {'value': [], 'policy': [], 'search': [], 'step': []},\n",
    "    'train': {'value': [], 'policy': [], 'search': [], 'step': []}\n",
    "    }\n",
    "\n",
    "debug = {}\n",
    "\n",
    "writer = SummaryWriter(log_dir='../../runs')\n",
    "plotter = recnn.utils.Plotter(loss, [['value', 'policy', 'search']],)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for epoch in range(n_epochs):\n",
    "    for batch in tqdm(env.train_dataloader):\n",
    "        loss = ddpg_sn_update(batch, params, step=step)\n",
    "        plotter.log_losses(loss)\n",
    "        step += 1\n",
    "        if step % plot_every == 0:\n",
    "            clear_output(True)\n",
    "            print('step', step)\n",
    "            test_loss = run_tests()\n",
    "            plotter.log_losses(test_loss, test=True)\n",
    "            plotter.plot_loss()\n",
    "        if step > 1000:\n",
    "            assert False"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "torch.save(value_net.state_dict(), \"../../models/ddpg_value.pt\")\n",
    "torch.save(policy_net.state_dict(), \"../../models/ddpg_policy.pt\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Reconstruction error"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "gen_actions = debug['next_action']\n",
    "true_actions = env.base.embeddings.numpy()\n",
    "\n",
    "\n",
    "ad = recnn.nn.AnomalyDetector().to(cuda)\n",
    "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n",
    "ad.eval()\n",
    "\n",
    "plotter.plot_kde_reconstruction_error(ad, gen_actions, true_actions, cuda)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}