awarebayes/RecNN

View on GitHub
examples/2. REINFORCE TopK Off Policy Correction/0. Inner workings of REINFORCE inside recnn (optional).ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE\n",
    "\n",
    "The following code contains an implementation of the REINFORCE algorithm, **without Off Policy Correction, LSTM state encoder, and Noise Contrastive Estimation**. Look for these in other notebooks.\n",
    "\n",
    "Also, I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations.\n",
    "\n",
    "**I use actor-critic for reward assigning.** In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.\n",
    "\n",
    "### **Note on this tutorials:**\n",
    "**They mostly contain low level implementations explaining what is going on inside the library.**\n",
    "\n",
    "**Most of the stuff explained here is already available out of the box for your usage.**\n",
    "\n",
    "If you do not care about the detailed implementation with code, go to the [Library Basics]/algorithms how to/reinforce, there is a 20 liner version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch.nn.functional as F\n",
    "from torch.distributions import Categorical\n",
    "import torch_optimizer as optim\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.auto import tqdm\n",
    "from time import gmtime, strftime\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 10\n",
    "n_epochs   = 100\n",
    "plot_every = 30\n",
    "step       = 0\n",
    "num_items    = 5000 # n items to recommend. Can be adjusted for your vram \n",
    "# --- \n",
    "tqdm.pandas()\n",
    "\n",
    "\n",
    "from jupyterthemes import jtplot\n",
    "jtplot.style(theme='grade3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## I will drop low freq items because it doesnt fit into my videocard vram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Dict, Callable\n",
    "\n",
    "# Plain args. Shouldn't be mutated\n",
    "class DataFuncKwargs:\n",
    "    def __init__(self, **kwargs):\n",
    "        self.kwargs = kwargs\n",
    "        \n",
    "    def get(self, name: str):\n",
    "        if name not in self.kwargs:\n",
    "            example = \"\"\"\n",
    "                # example on how to use kwargs:\n",
    "                def prepare_dataset(args, args_mut):\n",
    "                    args.set_kwarg('{}', your_value) # set kwargs for your functions here!\n",
    "                    pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]\n",
    "                    recnn.data.build_data_pipeline(pipeline, args, args_mut)\n",
    "            \"\"\"\n",
    "            raise AttributeError(\"No kwarg with name {} found!\\n{}\".format(name, example.format(err_desc)))\n",
    "        return self.kwargs[name]\n",
    "    \n",
    "    def set(self, name: str, value):\n",
    "        self.kwargs[name] = value\n",
    "\n",
    "# Used for returning, arguments are mutable\n",
    "class DataFuncArgsMut:\n",
    "    def __init__(self, df, base, users: List[int], user_dict: Dict[int, Dict[str, np.ndarray]]):\n",
    "        self.base = base\n",
    "        self.users = users\n",
    "        self.user_dict = user_dict\n",
    "        self.df = df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):\n",
    "\n",
    "    \"\"\"\n",
    "        Basic prepare dataset function. Automatically makes index linear, in ml20 movie indices look like:\n",
    "        [1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you.\n",
    "    \"\"\"\n",
    "\n",
    "    # get args\n",
    "    frame_size = kwargs.get('frame_size')\n",
    "    key_to_id = args_mut.base.key_to_id\n",
    "    df = args_mut.df\n",
    "    \n",
    "    # rating range mapped from [0, 5] to [-5, 5]\n",
    "    df['rating'] = try_progress_apply(df['rating'], lambda i: 2 * (i - 2.5))\n",
    "    # id's tend to be inconsistent and sparse so they are remapped here\n",
    "    df['movieId'] = try_progress_apply(df['movieId'], lambda i: key_to_id.get(i))\n",
    "\n",
    "    users = df[['userId', 'movieId']].groupby(['userId']).size()\n",
    "    users = users[users > frame_size].sort_values(ascending=False).index\n",
    "\n",
    "    if pd.get_type() == \"modin\": df = df._to_pandas() # pandas groupby is sync and doesnt affect performance \n",
    "    ratings = df.sort_values(by='timestamp').set_index('userId').drop('timestamp', axis=1).groupby('userId')\n",
    "\n",
    "    # Groupby user\n",
    "    user_dict = {}\n",
    "\n",
    "    def app(x):\n",
    "        userid = x.index[0]\n",
    "        user_dict[int(userid)] = {}\n",
    "        user_dict[int(userid)]['items'] = x['movieId'].values\n",
    "        user_dict[int(userid)]['ratings'] = x['rating'].values\n",
    "\n",
    "    try_progress_apply(ratings, app)\n",
    "\n",
    "    args_mut.user_dict = user_dict\n",
    "    args_mut.users = users\n",
    "\n",
    "    return args_mut, kwargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncate_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):\n",
    "    \"\"\"\n",
    "        Truncate #items to reduct_items_to provided in the kwargs\n",
    "    \"\"\"\n",
    "\n",
    "    # here are adjusted n items to keep\n",
    "    num_items = kwargs.get('reduce_items_to')\n",
    "    df = args_mut.df\n",
    "    \n",
    "    to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index\n",
    "    to_keep = df['movieId'].value_counts().sort_values()[-num_items:].index\n",
    "    to_remove_indices = df[df['movieId'].isin(to_remove)].index\n",
    "    num_removed = len(to_remove)\n",
    "\n",
    "    df.drop(to_remove_indices, inplace=True)\n",
    "\n",
    "    for i in list(args_mut.base.movie_embeddings_key_dict.keys()):\n",
    "        if i not in to_keep:\n",
    "            del args_mut.base.movie_embeddings_key_dict[i]\n",
    "\n",
    "    args_mut.base.embeddings, args_mut.base.key_to_id, \\\n",
    "    args_mut.base.id_to_key = recnn.data.make_items_tensor(args_mut.base.movie_embeddings_key_dict)\n",
    "    args_mut.df = df\n",
    "\n",
    "    print('action space is reduced to {} - {} = {}'.format(num_items + num_removed, num_removed,\n",
    "                                                           num_items))\n",
    "\n",
    "    return args_mut, kwargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_contstate_discaction(batch, item_embeddings_tensor, frame_size, num_items, *args, **kwargs):\n",
    "    \n",
    "    \"\"\"\n",
    "    Embed Batch: continuous state discrete action\n",
    "    \"\"\"\n",
    "    \n",
    "    from recnn.data.utils import get_irsu\n",
    "    \n",
    "    items_t, ratings_t, sizes_t, users_t = get_irsu(batch)\n",
    "    items_emb = item_embeddings_tensor[items_t.long()]\n",
    "    b_size = ratings_t.size(0)\n",
    "\n",
    "    items = items_emb[:, :-1, :].view(b_size, -1)\n",
    "    next_items = items_emb[:, 1:, :].view(b_size, -1)\n",
    "    ratings = ratings_t[:, :-1]\n",
    "    next_ratings = ratings_t[:, 1:]\n",
    "\n",
    "    state = torch.cat([items, ratings], 1)\n",
    "    next_state = torch.cat([next_items, next_ratings], 1)\n",
    "    action = items_t[:, -1]\n",
    "    reward = ratings_t[:, -1]\n",
    "\n",
    "    done = torch.zeros(b_size)\n",
    "    done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1\n",
    "    \n",
    "    one_hot_action = torch.zeros(action.size(0), num_items)\n",
    "    one_hot_action.scatter_(1, action.view(-1,1), 1)\n",
    "\n",
    "    batch = {'state': state, 'action': one_hot_action, 'reward': reward, 'next_state': next_state, 'done': done,\n",
    "             'meta': {'users': users_t, 'sizes': sizes_t}}\n",
    "    return batch\n",
    "\n",
    "def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):\n",
    "    return batch_contstate_discaction(batch, item_embeddings_tensor, frame_size=frame_size, num_items=num_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "[<function truncate_dataset at 0x7f4675108c10>, <function prepare_dataset at 0x7f4675108940>]\n  0%|          | 0/18946308 [00:00<?, ?it/s]action space is reduced to 26744 - 21744 = 5000\nexecuted!\n100%|██████████| 18946308/18946308 [00:12<00:00, 1503667.62it/s]\n100%|██████████| 18946308/18946308 [00:14<00:00, 1315333.93it/s]\n100%|██████████| 138493/138493 [00:07<00:00, 19649.79it/s]\nexecuted!\n"
    }
   ],
   "source": [
    "def build_data_pipeline(chain, kwargs: DataFuncKwargs, args_mut: DataFuncArgsMut):\n",
    "    \"\"\"\n",
    "        :param chain: array of callable\n",
    "        :param **kwargs: any kwargs you like\n",
    "    \"\"\"\n",
    "    print(chain)\n",
    "    for call in chain:\n",
    "        # note: returned kwargs are not utilized to guarantee immutability\n",
    "        args_mut, _ = call(args_mut, kwargs)\n",
    "    return kwargs, args_mut\n",
    "\n",
    "def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):\n",
    "    return batch_contstate_discaction(batch, item_embeddings_tensor,\n",
    "                                                 frame_size=frame_size, num_items=num_items)\n",
    "\n",
    "    \n",
    "def prepare_dataset(args_mut, kwargs):\n",
    "    kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here!\n",
    "    pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]\n",
    "    build_data_pipeline(pipeline, kwargs, args_mut)\n",
    "    \n",
    "\n",
    "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
    "env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n",
    "                              '../../data/ml-20m/ratings.csv', frame_size, batch_size,\n",
    "                              embed_batch=embed_batch, prepare_dataset=prepare_dataset,\n",
    "                              num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DiscreteActor(nn.Module):\n",
    "    def __init__(self, hidden_size, num_inputs, num_actions):\n",
    "        super(DiscreteActor, self).__init__()\n",
    "\n",
    "        self.linear1 = nn.Linear(num_inputs, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, num_actions)\n",
    "        \n",
    "        self.saved_log_probs = []\n",
    "        self.rewards = []\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        x = inputs\n",
    "        x = F.relu(self.linear1(x))\n",
    "        action_scores = self.linear2(x)\n",
    "        return F.softmax(action_scores)\n",
    "    \n",
    "    \n",
    "    def select_action(self, state):\n",
    "        probs = self.forward(state)\n",
    "        m = Categorical(probs)\n",
    "        action = m.sample()\n",
    "        self.saved_log_probs.append(m.log_prob(action))\n",
    "        return action, probs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Because I do not have a dynamic environment, I also will include a critic. If you have a real non static environment, you can do w/o citic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChooseREINFORCE():\n",
    "    \n",
    "    def __init__(self, method=None):\n",
    "        if method is None:\n",
    "            method = ChooseREINFORCE.reinforce\n",
    "        self.method = method\n",
    "    \n",
    "    @staticmethod\n",
    "    def basic_reinforce(policy, returns, *args, **kwargs):\n",
    "        policy_loss = []\n",
    "        for log_prob, R in zip(policy.saved_log_probs, returns):\n",
    "            policy_loss.append(-log_prob * R)\n",
    "        policy_loss = torch.cat(policy_loss).sum()\n",
    "        return policy_loss\n",
    "    \n",
    "    @staticmethod\n",
    "    def reinforce_with_correction():\n",
    "        raise NotImplemented\n",
    "\n",
    "    def __call__(self, policy, optimizer, learn=True):\n",
    "        R = 0\n",
    "        \n",
    "        returns = []\n",
    "        for r in policy.rewards[::-1]:\n",
    "            R = r + 0.99 * R\n",
    "            returns.insert(0, R)\n",
    "            \n",
    "        returns = torch.tensor(returns)\n",
    "        returns = (returns - returns.mean()) / (returns.std() + 0.0001)\n",
    "\n",
    "        policy_loss = self.method(policy, returns)\n",
    "        \n",
    "        if learn:\n",
    "            optimizer.zero_grad()\n",
    "            policy_loss.backward()\n",
    "            optimizer.step()\n",
    "        \n",
    "        del policy.rewards[:]\n",
    "        del policy.saved_log_probs[:]\n",
    "\n",
    "        return policy_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === reinforce settings ===\n",
    "\n",
    "params = {\n",
    "    'reinforce': ChooseREINFORCE(ChooseREINFORCE.basic_reinforce),\n",
    "    'gamma'      : 0.99,\n",
    "    'min_value'  : -10,\n",
    "    'max_value'  : 10,\n",
    "    'policy_step': 10,\n",
    "    'soft_tau'   : 0.001,\n",
    "    \n",
    "    'policy_lr'  : 1e-5,\n",
    "    'value_lr'   : 1e-5,\n",
    "    'actor_weight_init': 54e-2,\n",
    "    'critic_weight_init': 6e-1,\n",
    "}\n",
    "\n",
    "# === end ==="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "nets = {\n",
    "    'value_net': recnn.nn.Critic(1290, num_items, 2048, params['critic_weight_init']).to(cuda),\n",
    "    'target_value_net': recnn.nn.Critic(1290, num_items, 2048, params['actor_weight_init']).to(cuda).eval(),\n",
    "    \n",
    "    'policy_net':  DiscreteActor(2048, 1290, num_items).to(cuda),\n",
    "    'target_policy_net': DiscreteActor(2048, 1290, num_items).to(cuda).eval(),\n",
    "}\n",
    "\n",
    "\n",
    "# from good to bad: Ranger Radam Adam RMSprop\n",
    "optimizer = {\n",
    "    'value_optimizer': optim.Ranger(nets['value_net'].parameters(),\n",
    "                                          lr=params['value_lr'], weight_decay=1e-2),\n",
    "\n",
    "    'policy_optimizer': optim.Ranger(nets['policy_net'].parameters(),\n",
    "                                           lr=params['policy_lr'], weight_decay=1e-5)\n",
    "}\n",
    "\n",
    "\n",
    "loss = {\n",
    "    'test': {'value': [], 'policy': [], 'step': []},\n",
    "    'train': {'value': [], 'policy': [], 'step': []}\n",
    "    }\n",
    "\n",
    "debug = {}\n",
    "\n",
    "reinforce.writer = SummaryWriter(log_dir='../../runs/Reinforce{}/'.format(strftime(\"%H_%M\", gmtime())))\n",
    "plotter = recnn.utils.Plotter(loss, [['value', 'policy']],)\n",
    "device = cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reinforce_update(batch, params, nets, optimizer,\n",
    "                     device=torch.device('cpu'),\n",
    "                     debug=None, writer=recnn.utils.DummyWriter(),\n",
    "                     learn=False, step=-1):\n",
    "    \n",
    "    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n",
    "    \n",
    "    predicted_action, predicted_probs = nets['policy_net'].select_action(state)\n",
    "    reward = nets['value_net'](state, predicted_probs).detach()\n",
    "    nets['policy_net'].rewards.append(reward.mean())\n",
    "    \n",
    "    value_loss = recnn.nn.value_update(batch, params, nets, optimizer,\n",
    "                     writer=writer,\n",
    "                     device=device,\n",
    "                     debug=debug, learn=learn, step=step)\n",
    "    \n",
    "    \n",
    "    \n",
    "    if step % params['policy_step'] == 0 and step > 0:\n",
    "        \n",
    "        policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)\n",
    "        \n",
    "        del nets['policy_net'].rewards[:]\n",
    "        del nets['policy_net'].saved_log_probs[:]\n",
    "        \n",
    "        print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())\n",
    "    \n",
    "        recnn.utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])\n",
    "        recnn.utils.soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])\n",
    "\n",
    "        losses = {'value': value_loss.item(),\n",
    "                  'policy': policy_loss.item(),\n",
    "                  'step': step}\n",
    "\n",
    "        recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')\n",
    "\n",
    "        return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "output_type": "error",
     "ename": "NameError",
     "evalue": "name 'env' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-1a1791fbfb7b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m         loss = reinforce_update(batch, params, nets, optimizer,\n\u001b[1;32m      5\u001b[0m                      \u001b[0mwriter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwriter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'env' is not defined"
     ]
    }
   ],
   "source": [
    "step = 0\n",
    "for epoch in range(n_epochs):\n",
    "    for batch in tqdm(env.train_dataloader):\n",
    "        loss = reinforce_update(batch, params, nets, optimizer,\n",
    "                     writer=writer,\n",
    "                     device=device,\n",
    "                     debug=debug, learn=True, step=step)\n",
    "        if loss:\n",
    "            plotter.log_losses(loss)\n",
    "        step += 1\n",
    "        if step % plot_every == 0:\n",
    "            # clear_output(True)\n",
    "            print('step', step)\n",
    "            #test_loss = run_tests()\n",
    "            #plotter.log_losses(test_loss, test=True)\n",
    "            #plotter.plot_loss()\n",
    "        #if step > 1000:\n",
    "        #    pass\n",
    "        #    assert False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}