awarebayes/RecNN

View on GitHub
examples/99.To be released, but working/3. Decentralized Recommendation with PySyft/1. PySyft with RecNN Environment .ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic PySyft Setup\n",
    "\n",
    "### RecNN prerequisites "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import syft as sy\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# == recnn ==\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import recnn\n",
    "\n",
    "# you can enable cuda here\n",
    "cuda = False\n",
    "if cuda:\n",
    "    cuda = torch.device('cuda')\n",
    "    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "\n",
    "# ---\n",
    "frame_size = 10\n",
    "batch_size = 25\n",
    "n_epochs   = 100\n",
    "plot_every = 30\n",
    "step       = 0\n",
    "# --- \n",
    "\n",
    "tqdm.pandas()\n",
    "\n",
    "from jupyterthemes import jtplot\n",
    "jtplot.style(theme='grade3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def overwrite_batch_tensor_embeddings(batch, item_embeddings_tensor, frame_size):\n",
    "    \n",
    "    from recnn.data.utils import get_irsu\n",
    "    items_t, ratings_t, sizes_t, users_t = get_irsu(batch)\n",
    "    \n",
    "    items_emb = item_embeddings_tensor[items_t.long()]\n",
    "    b_size = ratings_t.size(0)\n",
    "\n",
    "    items = items_emb[:, :-1, :].view(b_size, -1)\n",
    "    next_items = items_emb[:, 1:, :].view(b_size, -1)\n",
    "    ratings = ratings_t[:, :-1]\n",
    "    next_ratings = ratings_t[:, 1:]\n",
    "    \n",
    "    state = torch.cat([items, ratings], 1)\n",
    "    next_state = torch.cat([next_items, next_ratings], 1)\n",
    "    action = items_emb[:, -1, :]\n",
    "    reward = ratings_t[:, -1]\n",
    "\n",
    "    done = torch.zeros(b_size)\n",
    "    # for some reason syft dies at this line\n",
    "    # so no done in training. not a big deal\n",
    "    # done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1\n",
    "    \n",
    "    batch = {'state': share(state), 'action': share(action),\n",
    "             'reward': share(reward), 'next_state': share(next_state),\n",
    "             'done': share(done),\n",
    "             'meta': {'users': users_t, 'sizes': sizes_t}}\n",
    "    return batch\n",
    "\n",
    "# overwrite batch generation function\n",
    "recnn.data.utils.batch_tensor_embeddings = overwrite_batch_tensor_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4158c3975d294f21af6dbd8e2e7cfff2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d12e989ce1c045a1ae8d85236248d668",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4babedc28a04872ad9e2cb65bcee89c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
    "env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n",
    "                         '../../data/ml-20m/ratings.csv', frame_size, batch_size, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test function\n",
    "def run_tests():\n",
    "    batch = next(iter(env.test_dataloader))\n",
    "    loss = ddpg.update(batch, learn=False)\n",
    "    return loss\n",
    "\n",
    "\n",
    "\n",
    "value_net  = recnn.nn.Critic(1290, 128, 256, 54e-2)\n",
    "policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1)\n",
    "\n",
    "if cuda:\n",
    "    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "    value_net  = recnn.nn.Critic(1290, 128, 256, 54e-2).to(cuda)\n",
    "    policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1).to(cuda)\n",
    "    torch.set_default_tensor_type('torch.FloatTensor')\n",
    "\n",
    "ddpg = recnn.nn.DDPG(policy_net, value_net)\n",
    "ddpg.writer = SummaryWriter(log_dir='../../runs')\n",
    "plotter = recnn.utils.Plotter(ddpg.loss_layout, [['value', 'policy']],)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Syft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch) \n",
    "\n",
    "alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n",
    "bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n",
    "james = sy.VirtualWorker(id=\"james\", hook=hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def share_no_cuda(m):\n",
    "    m = m.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
    "    return m\n",
    "\n",
    "def share_cuda(m):\n",
    "    m = m.fix_precision()\n",
    "    m = m.to(cuda).share(bob, alice, crypto_provider=james, requires_grad=True)\n",
    "    return m\n",
    "\n",
    "def share(m):\n",
    "    if cuda:\n",
    "        return share_cuda(m)\n",
    "    else:\n",
    "        return share_no_cuda(m)\n",
    "\n",
    "if not cuda:\n",
    "    ddpg.nets = dict([(k, share(v)) for k, v in ddpg.nets.items()])\n",
    "share_optim = lambda o: o.fix_precision()\n",
    "ddpg.optimizers =  dict([(k, share_optim(v)) for k, v in ddpg.optimizers.items()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Warning: it may freeze your system here.\n",
    "\n",
    "On my PC it just freezes it. But google colab seem to withstand the BIG FREEZE. Just wait for 5 minutes and restart, click on stop button and restart the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59d02a596bd84002ac590e0cb28007cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5263), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ERROR:root:Internal Python error in the inspect module.\n",
      "Below is the traceback from this internal error.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 3325, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"<ipython-input-10-21692f94cb16>\", line 19, in <module>\n",
      "    learn()\n",
      "  File \"<ipython-input-10-21692f94cb16>\", line 7, in learn\n",
      "    loss = ddpg.update(batch, learn=True)\n",
      "  File \"../../recnn/nn/algo.py\", line 67, in update\n",
      "    self.device, self.debug, self.writer, step=self._step, learn=learn)\n",
      "  File \"../../recnn/nn/update.py\", line 60, in ddpg_update\n",
      "    next_action = nets['target_policy_net'](next_state)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 541, in __call__\n",
      "    result = self.forward(*input, **kwargs)\n",
      "  File \"../../recnn/nn/models.py\", line 65, in forward\n",
      "    action = F.relu(self.linear1(state))\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 449, in overloaded_func\n",
      "    response = handle_func_command(command)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py\", line 298, in handle_func_command\n",
      "    response = new_type.handle_func_command(new_command)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 237, in handle_func_command\n",
      "    return cmd(*args, **kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 207, in relu\n",
      "    return tensor.relu()\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 141, in method_with_grad\n",
      "    result = getattr(new_self, name)(*new_args, **new_kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 341, in overloaded_syft_method\n",
      "    response = getattr(new_self, attr)(*new_args, **new_kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 805, in relu\n",
      "    return securenn.relu(self)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 434, in relu\n",
      "    return a_sh * relu_deriv(a_sh) + u\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 399, in relu_deriv\n",
      "    y_sh = share_convert(y_sh)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 354, in share_convert\n",
      "    eta_p = private_compare(x_bit_sh, r - 1, eta_pp)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 172, in private_compare\n",
      "    c_beta0 = -x_bit_sh + (j * r_bit) + j + wc\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 362, in __add__\n",
      "    return self.add(other, **kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/overload.py\", line 28, in _hook_method_args\n",
      "    response = attr(self, new_self, *new_args, **new_kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 355, in add\n",
      "    new_shares[k] = (other[k] + v) % self.field\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 484, in overloaded_pointer_method\n",
      "    response = owner.send_command(location, command)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 489, in send_command\n",
      "    ret_val = self.send_msg(Operation(message, return_ids), location=recipient)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 258, in send_msg\n",
      "    bin_response = self._send_msg(bin_message, location)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 7, in _send_msg\n",
      "    return location._recv_msg(message)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 10, in _recv_msg\n",
      "    return self.recv_msg(message)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 292, in recv_msg\n",
      "    response = self._message_router[msg_type](contents)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 412, in execute_command\n",
      "    response = getattr(_self, command_name)(*args, **kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 381, in overloaded_native_method\n",
      "    raise route_method_exception(e, self, args, kwargs)\n",
      "  File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 378, in overloaded_native_method\n",
      "    response = method(*args, **kwargs)\n",
      "KeyboardInterrupt\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 2039, in showtraceback\n",
      "    stb = value._render_traceback_()\n",
      "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n",
      "    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 319, in wrapped\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 353, in _fixed_getinnerframes\n",
      "    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1502, in getinnerframes\n",
      "    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1460, in getframeinfo\n",
      "    filename = getsourcefile(frame) or getfile(frame)\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 693, in getsourcefile\n",
      "    if os.path.exists(filename):\n",
      "  File \"/home/dev/anaconda3/lib/python3.7/genericpath.py\", line 19, in exists\n",
      "    os.stat(path)\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m"
     ]
    }
   ],
   "source": [
    "plot_every = 50\n",
    "n_epochs = 2\n",
    "\n",
    "def learn():\n",
    "    for epoch in range(n_epochs):\n",
    "        for batch in tqdm(env.train_dataloader):\n",
    "            loss = ddpg.update(batch, learn=True)\n",
    "            plotter.log_losses(loss)\n",
    "            ddpg.step()\n",
    "            if ddpg._step % plot_every == 0:\n",
    "                clear_output(True)\n",
    "                print('step', ddpg._step)\n",
    "                test_loss = run_tests()\n",
    "                plotter.log_losses(test_loss, test=True)\n",
    "                plotter.plot_loss()\n",
    "            if ddpg._step > 1000:\n",
    "                return\n",
    "            \n",
    "learn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}