examples/2. REINFORCE TopK Off Policy Correction/0. Inner workings of REINFORCE inside recnn (optional).ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# REINFORCE\n",
"\n",
"The following code contains an implementation of the REINFORCE algorithm, **without Off Policy Correction, LSTM state encoder, and Noise Contrastive Estimation**. Look for these in other notebooks.\n",
"\n",
"Also, I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations.\n",
"\n",
"**I use actor-critic for reward assigning.** In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.\n",
"\n",
"### **Note on this tutorials:**\n",
"**They mostly contain low level implementations explaining what is going on inside the library.**\n",
"\n",
"**Most of the stuff explained here is already available out of the box for your usage.**\n",
"\n",
"If you do not care about the detailed implementation with code, go to the [Library Basics]/algorithms how to/reinforce, there is a 20 liner version"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import torch.nn.functional as F\n",
"from torch.distributions import Categorical\n",
"import torch_optimizer as optim\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"from time import gmtime, strftime\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"cuda = torch.device('cuda')\n",
"\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 10\n",
"n_epochs = 100\n",
"plot_every = 30\n",
"step = 0\n",
"num_items = 5000 # n items to recommend. Can be adjusted for your vram \n",
"# --- \n",
"tqdm.pandas()\n",
"\n",
"\n",
"from jupyterthemes import jtplot\n",
"jtplot.style(theme='grade3')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## I will drop low freq items because it doesnt fit into my videocard vram"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Dict, Callable\n",
"\n",
"# Plain args. Shouldn't be mutated\n",
"class DataFuncKwargs:\n",
" def __init__(self, **kwargs):\n",
" self.kwargs = kwargs\n",
" \n",
" def get(self, name: str):\n",
" if name not in self.kwargs:\n",
" example = \"\"\"\n",
" # example on how to use kwargs:\n",
" def prepare_dataset(args, args_mut):\n",
" args.set_kwarg('{}', your_value) # set kwargs for your functions here!\n",
" pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]\n",
" recnn.data.build_data_pipeline(pipeline, args, args_mut)\n",
" \"\"\"\n",
" raise AttributeError(\"No kwarg with name {} found!\\n{}\".format(name, example.format(err_desc)))\n",
" return self.kwargs[name]\n",
" \n",
" def set(self, name: str, value):\n",
" self.kwargs[name] = value\n",
"\n",
"# Used for returning, arguments are mutable\n",
"class DataFuncArgsMut:\n",
" def __init__(self, df, base, users: List[int], user_dict: Dict[int, Dict[str, np.ndarray]]):\n",
" self.base = base\n",
" self.users = users\n",
" self.user_dict = user_dict\n",
" self.df = df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):\n",
"\n",
" \"\"\"\n",
" Basic prepare dataset function. Automatically makes index linear, in ml20 movie indices look like:\n",
" [1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you.\n",
" \"\"\"\n",
"\n",
" # get args\n",
" frame_size = kwargs.get('frame_size')\n",
" key_to_id = args_mut.base.key_to_id\n",
" df = args_mut.df\n",
" \n",
" # rating range mapped from [0, 5] to [-5, 5]\n",
" df['rating'] = try_progress_apply(df['rating'], lambda i: 2 * (i - 2.5))\n",
" # id's tend to be inconsistent and sparse so they are remapped here\n",
" df['movieId'] = try_progress_apply(df['movieId'], lambda i: key_to_id.get(i))\n",
"\n",
" users = df[['userId', 'movieId']].groupby(['userId']).size()\n",
" users = users[users > frame_size].sort_values(ascending=False).index\n",
"\n",
" if pd.get_type() == \"modin\": df = df._to_pandas() # pandas groupby is sync and doesnt affect performance \n",
" ratings = df.sort_values(by='timestamp').set_index('userId').drop('timestamp', axis=1).groupby('userId')\n",
"\n",
" # Groupby user\n",
" user_dict = {}\n",
"\n",
" def app(x):\n",
" userid = x.index[0]\n",
" user_dict[int(userid)] = {}\n",
" user_dict[int(userid)]['items'] = x['movieId'].values\n",
" user_dict[int(userid)]['ratings'] = x['rating'].values\n",
"\n",
" try_progress_apply(ratings, app)\n",
"\n",
" args_mut.user_dict = user_dict\n",
" args_mut.users = users\n",
"\n",
" return args_mut, kwargs"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def truncate_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):\n",
" \"\"\"\n",
" Truncate #items to reduct_items_to provided in the kwargs\n",
" \"\"\"\n",
"\n",
" # here are adjusted n items to keep\n",
" num_items = kwargs.get('reduce_items_to')\n",
" df = args_mut.df\n",
" \n",
" to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index\n",
" to_keep = df['movieId'].value_counts().sort_values()[-num_items:].index\n",
" to_remove_indices = df[df['movieId'].isin(to_remove)].index\n",
" num_removed = len(to_remove)\n",
"\n",
" df.drop(to_remove_indices, inplace=True)\n",
"\n",
" for i in list(args_mut.base.movie_embeddings_key_dict.keys()):\n",
" if i not in to_keep:\n",
" del args_mut.base.movie_embeddings_key_dict[i]\n",
"\n",
" args_mut.base.embeddings, args_mut.base.key_to_id, \\\n",
" args_mut.base.id_to_key = recnn.data.make_items_tensor(args_mut.base.movie_embeddings_key_dict)\n",
" args_mut.df = df\n",
"\n",
" print('action space is reduced to {} - {} = {}'.format(num_items + num_removed, num_removed,\n",
" num_items))\n",
"\n",
" return args_mut, kwargs"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def batch_contstate_discaction(batch, item_embeddings_tensor, frame_size, num_items, *args, **kwargs):\n",
" \n",
" \"\"\"\n",
" Embed Batch: continuous state discrete action\n",
" \"\"\"\n",
" \n",
" from recnn.data.utils import get_irsu\n",
" \n",
" items_t, ratings_t, sizes_t, users_t = get_irsu(batch)\n",
" items_emb = item_embeddings_tensor[items_t.long()]\n",
" b_size = ratings_t.size(0)\n",
"\n",
" items = items_emb[:, :-1, :].view(b_size, -1)\n",
" next_items = items_emb[:, 1:, :].view(b_size, -1)\n",
" ratings = ratings_t[:, :-1]\n",
" next_ratings = ratings_t[:, 1:]\n",
"\n",
" state = torch.cat([items, ratings], 1)\n",
" next_state = torch.cat([next_items, next_ratings], 1)\n",
" action = items_t[:, -1]\n",
" reward = ratings_t[:, -1]\n",
"\n",
" done = torch.zeros(b_size)\n",
" done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1\n",
" \n",
" one_hot_action = torch.zeros(action.size(0), num_items)\n",
" one_hot_action.scatter_(1, action.view(-1,1), 1)\n",
"\n",
" batch = {'state': state, 'action': one_hot_action, 'reward': reward, 'next_state': next_state, 'done': done,\n",
" 'meta': {'users': users_t, 'sizes': sizes_t}}\n",
" return batch\n",
"\n",
"def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):\n",
" return batch_contstate_discaction(batch, item_embeddings_tensor, frame_size=frame_size, num_items=num_items)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[<function truncate_dataset at 0x7f4675108c10>, <function prepare_dataset at 0x7f4675108940>]\n 0%| | 0/18946308 [00:00<?, ?it/s]action space is reduced to 26744 - 21744 = 5000\nexecuted!\n100%|██████████| 18946308/18946308 [00:12<00:00, 1503667.62it/s]\n100%|██████████| 18946308/18946308 [00:14<00:00, 1315333.93it/s]\n100%|██████████| 138493/138493 [00:07<00:00, 19649.79it/s]\nexecuted!\n"
}
],
"source": [
"def build_data_pipeline(chain, kwargs: DataFuncKwargs, args_mut: DataFuncArgsMut):\n",
" \"\"\"\n",
" :param chain: array of callable\n",
" :param **kwargs: any kwargs you like\n",
" \"\"\"\n",
" print(chain)\n",
" for call in chain:\n",
" # note: returned kwargs are not utilized to guarantee immutability\n",
" args_mut, _ = call(args_mut, kwargs)\n",
" return kwargs, args_mut\n",
"\n",
"def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):\n",
" return batch_contstate_discaction(batch, item_embeddings_tensor,\n",
" frame_size=frame_size, num_items=num_items)\n",
"\n",
" \n",
"def prepare_dataset(args_mut, kwargs):\n",
" kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here!\n",
" pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]\n",
" build_data_pipeline(pipeline, kwargs, args_mut)\n",
" \n",
"\n",
"# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
"env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n",
" '../../data/ml-20m/ratings.csv', frame_size, batch_size,\n",
" embed_batch=embed_batch, prepare_dataset=prepare_dataset,\n",
" num_workers=0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class DiscreteActor(nn.Module):\n",
" def __init__(self, hidden_size, num_inputs, num_actions):\n",
" super(DiscreteActor, self).__init__()\n",
"\n",
" self.linear1 = nn.Linear(num_inputs, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, num_actions)\n",
" \n",
" self.saved_log_probs = []\n",
" self.rewards = []\n",
"\n",
" def forward(self, inputs):\n",
" x = inputs\n",
" x = F.relu(self.linear1(x))\n",
" action_scores = self.linear2(x)\n",
" return F.softmax(action_scores)\n",
" \n",
" \n",
" def select_action(self, state):\n",
" probs = self.forward(state)\n",
" m = Categorical(probs)\n",
" action = m.sample()\n",
" self.saved_log_probs.append(m.log_prob(action))\n",
" return action, probs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Because I do not have a dynamic environment, I also will include a critic. If you have a real non static environment, you can do w/o citic."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class ChooseREINFORCE():\n",
" \n",
" def __init__(self, method=None):\n",
" if method is None:\n",
" method = ChooseREINFORCE.reinforce\n",
" self.method = method\n",
" \n",
" @staticmethod\n",
" def basic_reinforce(policy, returns, *args, **kwargs):\n",
" policy_loss = []\n",
" for log_prob, R in zip(policy.saved_log_probs, returns):\n",
" policy_loss.append(-log_prob * R)\n",
" policy_loss = torch.cat(policy_loss).sum()\n",
" return policy_loss\n",
" \n",
" @staticmethod\n",
" def reinforce_with_correction():\n",
" raise NotImplemented\n",
"\n",
" def __call__(self, policy, optimizer, learn=True):\n",
" R = 0\n",
" \n",
" returns = []\n",
" for r in policy.rewards[::-1]:\n",
" R = r + 0.99 * R\n",
" returns.insert(0, R)\n",
" \n",
" returns = torch.tensor(returns)\n",
" returns = (returns - returns.mean()) / (returns.std() + 0.0001)\n",
"\n",
" policy_loss = self.method(policy, returns)\n",
" \n",
" if learn:\n",
" optimizer.zero_grad()\n",
" policy_loss.backward()\n",
" optimizer.step()\n",
" \n",
" del policy.rewards[:]\n",
" del policy.saved_log_probs[:]\n",
"\n",
" return policy_loss"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# === reinforce settings ===\n",
"\n",
"params = {\n",
" 'reinforce': ChooseREINFORCE(ChooseREINFORCE.basic_reinforce),\n",
" 'gamma' : 0.99,\n",
" 'min_value' : -10,\n",
" 'max_value' : 10,\n",
" 'policy_step': 10,\n",
" 'soft_tau' : 0.001,\n",
" \n",
" 'policy_lr' : 1e-5,\n",
" 'value_lr' : 1e-5,\n",
" 'actor_weight_init': 54e-2,\n",
" 'critic_weight_init': 6e-1,\n",
"}\n",
"\n",
"# === end ==="
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"nets = {\n",
" 'value_net': recnn.nn.Critic(1290, num_items, 2048, params['critic_weight_init']).to(cuda),\n",
" 'target_value_net': recnn.nn.Critic(1290, num_items, 2048, params['actor_weight_init']).to(cuda).eval(),\n",
" \n",
" 'policy_net': DiscreteActor(2048, 1290, num_items).to(cuda),\n",
" 'target_policy_net': DiscreteActor(2048, 1290, num_items).to(cuda).eval(),\n",
"}\n",
"\n",
"\n",
"# from good to bad: Ranger Radam Adam RMSprop\n",
"optimizer = {\n",
" 'value_optimizer': optim.Ranger(nets['value_net'].parameters(),\n",
" lr=params['value_lr'], weight_decay=1e-2),\n",
"\n",
" 'policy_optimizer': optim.Ranger(nets['policy_net'].parameters(),\n",
" lr=params['policy_lr'], weight_decay=1e-5)\n",
"}\n",
"\n",
"\n",
"loss = {\n",
" 'test': {'value': [], 'policy': [], 'step': []},\n",
" 'train': {'value': [], 'policy': [], 'step': []}\n",
" }\n",
"\n",
"debug = {}\n",
"\n",
"reinforce.writer = SummaryWriter(log_dir='../../runs/Reinforce{}/'.format(strftime(\"%H_%M\", gmtime())))\n",
"plotter = recnn.utils.Plotter(loss, [['value', 'policy']],)\n",
"device = cuda"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def reinforce_update(batch, params, nets, optimizer,\n",
" device=torch.device('cpu'),\n",
" debug=None, writer=recnn.utils.DummyWriter(),\n",
" learn=False, step=-1):\n",
" \n",
" state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n",
" \n",
" predicted_action, predicted_probs = nets['policy_net'].select_action(state)\n",
" reward = nets['value_net'](state, predicted_probs).detach()\n",
" nets['policy_net'].rewards.append(reward.mean())\n",
" \n",
" value_loss = recnn.nn.value_update(batch, params, nets, optimizer,\n",
" writer=writer,\n",
" device=device,\n",
" debug=debug, learn=learn, step=step)\n",
" \n",
" \n",
" \n",
" if step % params['policy_step'] == 0 and step > 0:\n",
" \n",
" policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)\n",
" \n",
" del nets['policy_net'].rewards[:]\n",
" del nets['policy_net'].saved_log_probs[:]\n",
" \n",
" print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())\n",
" \n",
" recnn.utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])\n",
" recnn.utils.soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])\n",
"\n",
" losses = {'value': value_loss.item(),\n",
" 'policy': policy_loss.item(),\n",
" 'step': step}\n",
"\n",
" recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')\n",
"\n",
" return losses"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": false
},
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "name 'env' is not defined",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-1a1791fbfb7b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m loss = reinforce_update(batch, params, nets, optimizer,\n\u001b[1;32m 5\u001b[0m \u001b[0mwriter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwriter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'env' is not defined"
]
}
],
"source": [
"step = 0\n",
"for epoch in range(n_epochs):\n",
" for batch in tqdm(env.train_dataloader):\n",
" loss = reinforce_update(batch, params, nets, optimizer,\n",
" writer=writer,\n",
" device=device,\n",
" debug=debug, learn=True, step=step)\n",
" if loss:\n",
" plotter.log_losses(loss)\n",
" step += 1\n",
" if step % plot_every == 0:\n",
" # clear_output(True)\n",
" print('step', step)\n",
" #test_loss = run_tests()\n",
" #plotter.log_losses(test_loss, test=True)\n",
" #plotter.plot_loss()\n",
" #if step > 1000:\n",
" # pass\n",
" # assert False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5-final"
}
},
"nbformat": 4,
"nbformat_minor": 2
}