examples/99.To be released, but working/3. Decentralized Recommendation with PySyft/1. PySyft with RecNN Environment .ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic PySyft Setup\n",
"\n",
"### RecNN prerequisites "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import syft as sy\n",
"from tqdm.auto import tqdm\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"\n",
"# == recnn ==\n",
"import sys\n",
"sys.path.append(\"../../\")\n",
"import recnn\n",
"\n",
"# you can enable cuda here\n",
"cuda = False\n",
"if cuda:\n",
" cuda = torch.device('cuda')\n",
" torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
"\n",
"# ---\n",
"frame_size = 10\n",
"batch_size = 25\n",
"n_epochs = 100\n",
"plot_every = 30\n",
"step = 0\n",
"# --- \n",
"\n",
"tqdm.pandas()\n",
"\n",
"from jupyterthemes import jtplot\n",
"jtplot.style(theme='grade3')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def overwrite_batch_tensor_embeddings(batch, item_embeddings_tensor, frame_size):\n",
" \n",
" from recnn.data.utils import get_irsu\n",
" items_t, ratings_t, sizes_t, users_t = get_irsu(batch)\n",
" \n",
" items_emb = item_embeddings_tensor[items_t.long()]\n",
" b_size = ratings_t.size(0)\n",
"\n",
" items = items_emb[:, :-1, :].view(b_size, -1)\n",
" next_items = items_emb[:, 1:, :].view(b_size, -1)\n",
" ratings = ratings_t[:, :-1]\n",
" next_ratings = ratings_t[:, 1:]\n",
" \n",
" state = torch.cat([items, ratings], 1)\n",
" next_state = torch.cat([next_items, next_ratings], 1)\n",
" action = items_emb[:, -1, :]\n",
" reward = ratings_t[:, -1]\n",
"\n",
" done = torch.zeros(b_size)\n",
" # for some reason syft dies at this line\n",
" # so no done in training. not a big deal\n",
" # done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1\n",
" \n",
" batch = {'state': share(state), 'action': share(action),\n",
" 'reward': share(reward), 'next_state': share(next_state),\n",
" 'done': share(done),\n",
" 'meta': {'users': users_t, 'sizes': sizes_t}}\n",
" return batch\n",
"\n",
"# overwrite batch generation function\n",
"recnn.data.utils.batch_tensor_embeddings = overwrite_batch_tensor_embeddings"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4158c3975d294f21af6dbd8e2e7cfff2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d12e989ce1c045a1ae8d85236248d668",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4babedc28a04872ad9e2cb65bcee89c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
"env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n",
" '../../data/ml-20m/ratings.csv', frame_size, batch_size, num_workers=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# test function\n",
"def run_tests():\n",
" batch = next(iter(env.test_dataloader))\n",
" loss = ddpg.update(batch, learn=False)\n",
" return loss\n",
"\n",
"\n",
"\n",
"value_net = recnn.nn.Critic(1290, 128, 256, 54e-2)\n",
"policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1)\n",
"\n",
"if cuda:\n",
" torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
" value_net = recnn.nn.Critic(1290, 128, 256, 54e-2).to(cuda)\n",
" policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1).to(cuda)\n",
" torch.set_default_tensor_type('torch.FloatTensor')\n",
"\n",
"ddpg = recnn.nn.DDPG(policy_net, value_net)\n",
"ddpg.writer = SummaryWriter(log_dir='../../runs')\n",
"plotter = recnn.utils.Plotter(ddpg.loss_layout, [['value', 'policy']],)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Syft"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"hook = sy.TorchHook(torch) \n",
"\n",
"alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n",
"bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n",
"james = sy.VirtualWorker(id=\"james\", hook=hook)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def share_no_cuda(m):\n",
" m = m.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
" return m\n",
"\n",
"def share_cuda(m):\n",
" m = m.fix_precision()\n",
" m = m.to(cuda).share(bob, alice, crypto_provider=james, requires_grad=True)\n",
" return m\n",
"\n",
"def share(m):\n",
" if cuda:\n",
" return share_cuda(m)\n",
" else:\n",
" return share_no_cuda(m)\n",
"\n",
"if not cuda:\n",
" ddpg.nets = dict([(k, share(v)) for k, v in ddpg.nets.items()])\n",
"share_optim = lambda o: o.fix_precision()\n",
"ddpg.optimizers = dict([(k, share_optim(v)) for k, v in ddpg.optimizers.items()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Warning: it may freeze your system here.\n",
"\n",
"On my PC it just freezes it. But google colab seem to withstand the BIG FREEZE. Just wait for 5 minutes and restart, click on stop button and restart the cell."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59d02a596bd84002ac590e0cb28007cc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=5263), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:root:Internal Python error in the inspect module.\n",
"Below is the traceback from this internal error.\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 3325, in run_code\n",
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \"<ipython-input-10-21692f94cb16>\", line 19, in <module>\n",
" learn()\n",
" File \"<ipython-input-10-21692f94cb16>\", line 7, in learn\n",
" loss = ddpg.update(batch, learn=True)\n",
" File \"../../recnn/nn/algo.py\", line 67, in update\n",
" self.device, self.debug, self.writer, step=self._step, learn=learn)\n",
" File \"../../recnn/nn/update.py\", line 60, in ddpg_update\n",
" next_action = nets['target_policy_net'](next_state)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 541, in __call__\n",
" result = self.forward(*input, **kwargs)\n",
" File \"../../recnn/nn/models.py\", line 65, in forward\n",
" action = F.relu(self.linear1(state))\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 449, in overloaded_func\n",
" response = handle_func_command(command)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py\", line 298, in handle_func_command\n",
" response = new_type.handle_func_command(new_command)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 237, in handle_func_command\n",
" return cmd(*args, **kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 207, in relu\n",
" return tensor.relu()\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 141, in method_with_grad\n",
" result = getattr(new_self, name)(*new_args, **new_kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 341, in overloaded_syft_method\n",
" response = getattr(new_self, attr)(*new_args, **new_kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 805, in relu\n",
" return securenn.relu(self)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 434, in relu\n",
" return a_sh * relu_deriv(a_sh) + u\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 399, in relu_deriv\n",
" y_sh = share_convert(y_sh)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 354, in share_convert\n",
" eta_p = private_compare(x_bit_sh, r - 1, eta_pp)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 172, in private_compare\n",
" c_beta0 = -x_bit_sh + (j * r_bit) + j + wc\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 362, in __add__\n",
" return self.add(other, **kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/overload.py\", line 28, in _hook_method_args\n",
" response = attr(self, new_self, *new_args, **new_kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 355, in add\n",
" new_shares[k] = (other[k] + v) % self.field\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 484, in overloaded_pointer_method\n",
" response = owner.send_command(location, command)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 489, in send_command\n",
" ret_val = self.send_msg(Operation(message, return_ids), location=recipient)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 258, in send_msg\n",
" bin_response = self._send_msg(bin_message, location)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 7, in _send_msg\n",
" return location._recv_msg(message)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 10, in _recv_msg\n",
" return self.recv_msg(message)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 292, in recv_msg\n",
" response = self._message_router[msg_type](contents)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 412, in execute_command\n",
" response = getattr(_self, command_name)(*args, **kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 381, in overloaded_native_method\n",
" raise route_method_exception(e, self, args, kwargs)\n",
" File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 378, in overloaded_native_method\n",
" response = method(*args, **kwargs)\n",
"KeyboardInterrupt\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 2039, in showtraceback\n",
" stb = value._render_traceback_()\n",
"AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n",
" return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
" File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 319, in wrapped\n",
" return f(*args, **kwargs)\n",
" File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 353, in _fixed_getinnerframes\n",
" records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
" File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1502, in getinnerframes\n",
" frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
" File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1460, in getframeinfo\n",
" filename = getsourcefile(frame) or getfile(frame)\n",
" File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 693, in getsourcefile\n",
" if os.path.exists(filename):\n",
" File \"/home/dev/anaconda3/lib/python3.7/genericpath.py\", line 19, in exists\n",
" os.stat(path)\n",
"KeyboardInterrupt\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m"
]
}
],
"source": [
"plot_every = 50\n",
"n_epochs = 2\n",
"\n",
"def learn():\n",
" for epoch in range(n_epochs):\n",
" for batch in tqdm(env.train_dataloader):\n",
" loss = ddpg.update(batch, learn=True)\n",
" plotter.log_losses(loss)\n",
" ddpg.step()\n",
" if ddpg._step % plot_every == 0:\n",
" clear_output(True)\n",
" print('step', ddpg._step)\n",
" test_loss = run_tests()\n",
" plotter.log_losses(test_loss, test=True)\n",
" plotter.plot_loss()\n",
" if ddpg._step > 1000:\n",
" return\n",
" \n",
"learn()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}