ixxi-dante/nw2vec

View on GitHub
projects/correctness/gae-reproduction-citeseer-nw2vec-multitask_arch-shallow_adj-ov=0.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train nw2vec with the original VGAE architecture for Cora embeddings + a kernel for the scalar product decoding + an intermediate layer in decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# Train on CPU (hide GPU) due to memory constraints\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"\"\n",
    "\n",
    "import contextlib\n",
    "\n",
    "import numpy as np\n",
    "import keras\n",
    "from keras_tqdm import TQDMNotebookCallback as TQDMCallback\n",
    "\n",
    "from nw2vec import layers\n",
    "from nw2vec import ae\n",
    "from nw2vec import utils\n",
    "from nw2vec import batching\n",
    "import settings\n",
    "\n",
    "from gae.input_data import load_data\n",
    "from gae.preprocessing import sparse_to_tuple, mask_test_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@contextlib.contextmanager\n",
    "def gae_directory():\n",
    "    working_directory = os.path.abspath(os.curdir)\n",
    "    try:\n",
    "        # Move to the GAE directory\n",
    "        os.chdir('../../gae')\n",
    "        yield\n",
    "    finally:\n",
    "        # Move back\n",
    "        os.chdir(working_directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "with gae_directory():\n",
    "    adj, features = load_data('citeseer')\n",
    "    features = features.toarray()\n",
    "\n",
    "#adj_train = mask_test_edges(adj)[0]\n",
    "adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)\n",
    "assert adj_train.diagonal().sum() == 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_p_builder(dims, feature_codec='SigmoidBernoulli', adj_kernel=None, use_bias=False,\n",
    "                    embedding_slices=None, with_l1=True):\n",
    "\n",
    "    # Validate arguments and set default values\n",
    "    assert feature_codec in ['SigmoidBernoulli', 'OrthogonalGaussian', 'SoftmaxMultinomial']\n",
    "    if embedding_slices is None:\n",
    "        adj_embedding_slice, v_embedding_slice = [slice(None), slice(None)]\n",
    "    else:\n",
    "        adj_embedding_slice, v_embedding_slice = embedding_slices\n",
    "\n",
    "    # Extract the dimensions we use.\n",
    "    dim_data, dim_l1, _, _ = dims\n",
    "\n",
    "    def p_builder(p_input):\n",
    "        # Get slices of the embeddings for each prediction\n",
    "        p_input_adj = layers.InnerSlice(adj_embedding_slice)(p_input)\n",
    "        p_input_v = layers.InnerSlice(v_embedding_slice)(p_input)\n",
    "        \n",
    "        p_penultimate_adj = p_input_adj\n",
    "        \n",
    "        if with_l1:\n",
    "            p_penultimate_v = keras.layers.Dense(\n",
    "                dim_l1, use_bias=use_bias, activation='relu',\n",
    "                kernel_regularizer='l2', bias_regularizer='l2',\n",
    "                name='p_layer1_v'\n",
    "            )(p_input_v)\n",
    "        else:\n",
    "            p_penultimate_v = p_input_v\n",
    "\n",
    "        # Prepare kwargs for the Bilinear adj decoder, then build it.\n",
    "        adj_kwargs = {}\n",
    "        if adj_kernel is not None:\n",
    "            adj_kwargs['fixed_kernel'] = adj_kernel\n",
    "        else:\n",
    "            adj_kwargs['kernel_regularizer'] = 'l2'\n",
    "        p_adj = layers.Bilinear(0, use_bias=use_bias, name='p_adj',\n",
    "                                bias_regularizer='l2',\n",
    "                                **adj_kwargs)([p_penultimate_adj, p_penultimate_adj])\n",
    "\n",
    "        # Finally build the feature decoder according to the requested codec.\n",
    "        if feature_codec in ['SigmoidBernoulli', 'SoftmaxMultinomial']:\n",
    "            p_v = keras.layers.Dense(dim_data, use_bias=use_bias,\n",
    "                                     kernel_regularizer='l2', bias_regularizer='l2',\n",
    "                                     name='p_v')(p_penultimate_v)\n",
    "        else:\n",
    "            assert feature_codec == 'OrthogonalGaussian'\n",
    "            p_v_μ_flat = keras.layers.Dense(dim_data, use_bias=use_bias,\n",
    "                                            kernel_regularizer='l2', bias_regularizer='l2',\n",
    "                                            name='p_v_mu_flat')(p_penultimate_v)\n",
    "            p_v_logS_flat = keras.layers.Dense(dim_data, use_bias=use_bias,\n",
    "                                               kernel_regularizer='l2', bias_regularizer='l2',\n",
    "                                               name='p_v_logS_flat')(p_penultimate_v)\n",
    "            p_v = keras.layers.Concatenate(name='p_v_mulogS_flat')([p_v_μ_flat, p_v_logS_flat])\n",
    "\n",
    "        return ([p_adj, p_v], ('SigmoidBernoulliScaledAdjacency', feature_codec))\n",
    "\n",
    "    return p_builder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ξ_samples = 1\n",
    "n_nodes = adj_train.shape[0]\n",
    "dim_data, dim_l1, dim_ξ_adj, dim_ξ_v = features.shape[1], 32, 16, 16\n",
    "overlap = 0\n",
    "dims = (dim_data, dim_l1, dim_ξ_adj, dim_ξ_v)\n",
    "loss_weights = {\n",
    "    'q_mulogS_flat': 1e-2 * 1.0 / (dim_ξ_adj - overlap + dim_ξ_v),\n",
    "    'p_adj': 1.0 / (n_nodes * np.log(2)),\n",
    "    'p_v': 1.0 / (features.shape[1] * np.log(2)),\n",
    "}\n",
    "\n",
    "# Actual VAE\n",
    "q_model, q_codecs = ae.build_q(dims, overlap=overlap,\n",
    "                               fullbatcher=batching.fullbatches, minibatcher=batching.pq_batches)\n",
    "p_builder = build_p_builder(dims,\n",
    "                               feature_codec='SigmoidBernoulli',\n",
    "                               adj_kernel=np.eye(16),\n",
    "                               embedding_slices=[slice(dim_ξ_adj),\n",
    "                                                 slice(dim_ξ_adj - overlap, dim_ξ_adj - overlap + dim_ξ_v)],\n",
    "                               with_l1=True)\n",
    "vae, vae_codecs = ae.build_vae(\n",
    "    (q_model, q_codecs), p_builder,\n",
    "    n_ξ_samples,\n",
    "    loss_weights=loss_weights\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_func(batch_adj, required_nodes, final_nodes):\n",
    "    return [\n",
    "        np.zeros(1), # ignored\n",
    "        utils.expand_dims_tile(utils.expand_dims_tile(batch_adj + np.eye(batch_adj.shape[0]), 0, n_ξ_samples), 0, 1),\n",
    "        utils.expand_dims_tile(features[final_nodes], 1, n_ξ_samples),\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sl/.conda/envs/base36/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "195cefbddee44765a3f5165e48f56119",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Training', max=200), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 200\n",
    "\n",
    "history = vae.fit_fullbatches(batcher_kws={'adj': adj_train, 'features': features, 'target_func': target_func},\n",
    "                              epochs=n_epochs,\n",
    "                              verbose=0, callbacks=[TQDMCallback(show_inner=False)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = {'history': history.history}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1440x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axes = plt.subplots(1, len(history['history']), figsize=(len(history['history']) * 5, 4))\n",
    "for i, (title, values) in enumerate(history['history'].items()):\n",
    "    axes[i].plot(np.array(values))\n",
    "    axes[i].set_title(title)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import average_precision_score\n",
    "from keras import backend as K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_roc_score(edges_pos, edges_neg, adj_pred):\n",
    "    # Predict on test set of edges\n",
    "    preds = []\n",
    "    pos = []\n",
    "    for e in edges_pos:\n",
    "        preds.append(adj_pred[e[0], e[1]])\n",
    "        pos.append(adj[e[0], e[1]])\n",
    "\n",
    "    preds_neg = []\n",
    "    neg = []\n",
    "    for e in edges_neg:\n",
    "        preds_neg.append(adj_pred[e[0], e[1]])\n",
    "        neg.append(adj[e[0], e[1]])\n",
    "\n",
    "    preds_all = np.hstack([preds, preds_neg])\n",
    "    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])\n",
    "    roc_score = roc_auc_score(labels_all, preds_all)\n",
    "    ap_score = average_precision_score(labels_all, preds_all)\n",
    "\n",
    "    return roc_score, ap_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_preds, adj_preds, v_preds = zip(*[vae.predict_fullbatch(adj=adj_train, features=features) for _ in range(10)])\n",
    "\n",
    "q_preds = np.array(q_preds)\n",
    "adj_preds = np.array(adj_preds)\n",
    "\n",
    "q_pred = q_preds.mean(0)\n",
    "adj_pred = scipy.special.expit(adj_preds[:, 0, :, :, :]).mean(1).mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9262214708368555, 0.9375222651716126)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_roc_score(test_edges, test_edges_false, adj_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANQAAADuCAYAAABBPynTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFrdJREFUeJzt3Xm0nVV5x/HvL2FyaYVg0CIiQqVWLBrUImu5Wi0q4FBAnIjLChZFW622DktwQhEU66qo1apREUQLKA6lLa4YBKuWQYMiICwM4BSCAyZxAgK59+kfe5/kzeHee/bJ2Xff3MPvw3pXznnPO+xzuc9997uH51VEYGZ1LJjrApiNEweUWUUOKLOKHFBmFTmgzCpyQJlV5IAyq8gBZVaRA8qsou3mugB273LoX983fr12omjbK6/esDwiDpvlIlXlgLKmbls7wRXLH1K07fa737R4lotTnQPKGgsmYnKuCzFrHFDWVACTjO+AbDdKWHOThf8NIukMSb+UdO00n0vSByXdKOlqSY/tfHaMpFV5OabWd3NAWVNBcHdMFi0FzgRmarR4OrBvXo4HPgIgaVfgJOAJwIHASZIWjfC1NnFAWVMBTBBFy8BjRXwDWDvDJkcAn47kcmAXSbsDhwIrImJtRKwDVjBzYBbzPZQ11/Aeag/gZ533q/O66daPzAFlTQUwUT5LfLGklZ33yyJi2RCn0zRFmG79yBxQ1twQjea3RcTjRzjVamDPzvuHAGvy+if3rf/6COfZxPdQ1lQU3j+V3EMVuAB4cW7tOwj4TUTcCiwHDpG0KDdGHJLXjcxXKGsqAu6udAsl6RzSlWaxpNWklrvt03nio8CFwDOAG4HbgZfkz9ZKeifwnXyokyNipsaNYg4oa0xMTHkLM7yIWDrg8wBeOc1nZwBnVClIhwPKmgpgcnwHSjigrL1aV6htkQPKmkoduw4osyoCuDvGt3HZAWVNBWJijHtrHFDW3GS4ymdWhe+hzKoSE76HMqsjzdh1QJlVESHuioVzXYxZ44Cy5iZ9D2VWR2qUcJXPrBI3SphV40YJs8om3LFrVkcg7o7x/bUb329m26Rxb5QY329m26RATETZUkLSYZJuyNlhT5ji89MlXZWXH0pa3/lsovPZBTW+n69Q1lytRglJC4EPA08jZTL6jqQLIuK63jYR8c+d7f8ROKBziDsiYkmVwmS+QllTETARC4qWAgcCN0bEzRFxF3AuKVvsdJYC51T4GtNyQFlTqVFiYdFCTnTZWY7vO1xxBlhJewF7Axd3Vu+Uj3u5pCNrfD9X+ay5IRolBiW6HCYD7NHA+RHRfXziQyNijaR9gIslXRMRN5UWbiq+QllTgZiMsqXAdJlhp3I0fdW9iFiT/72ZlDn2gHvuNhwHlDU3wYKipcB3gH0l7S1pB1LQ3KO1TtIjgEXAZZ11iyTtmF8vBp4IXNe/77Bc5bOmUl6+On/HI2KjpFeR0igvBM6IiB9IOhlYGRG94FoKnJsTX/Y8EviYpEnSheW0buvg1nJAWWP1MscCRMSFpJTL3XVv63v/9in2uxTYv1pBMgeUNZXSiHmCoVkVEapW5dsWOaCsOc+HMqskzYfy9A2zSjxj16ya1GzuK5RZFb2xfOPKAWXNOaeEWSVp+oarfGbV+B7KrJI02txVPrMq/ARDs6p8hTKryiMlzCpxK59ZZa7ymVXSyykxrsb3T4VtkwLYGAuKlhIFmWOPlfSrTobYl3Y+O0bSqrwcU+P7+QplzdWq8pVkjs3Oi4hX9e27K3AS8HhSnF+Z9103Spl8hbK2ClOIFVYLh80c23UosCIi1uYgWgEctlXfqcMBZU31JhiWLNTLHPscSVdLOl9SL49fcdbZYbjKZ80N0ShRI3PsfwHnRMQGSa8AzgIOLtx3aL5CWVO9CYatMsdGxK8jYkN++3HgcaX7bg0HlDUViI2TC4qWAgMzx0ravfP2cOD6/Ho5cEjOILsIOCSvG4mrfNZcraFHhZljXy3pcGAjsBY4Nu+7VtI7SUEJcHJErB21TNoyO63Z7Nr5EQ+Kg5YtLdr2q0/+wJUD7qG2Ob5CWVNO0mJWmQPKrJJATJQ1OMxLDihrzvOhzCqJcJXPrKpwQJnVMt7zoRxQ1pyvUGaVRMDEpAPKrBq38plVErjKZ1aRGyXMqhrn8dgOKGvOVT6zSlIrn8fymVXjKp9ZReNc5Rvfa69tkwIRUbaUKMgc+1pJ1+U0Yl+TtFfns4lORtkL+vfdGr5CWXO1anyFmWO/Bzw+Im6X9PfAvwAvyJ/dERFLKhUH8BXKWguISRUtBQZmjo2ISyLi9vz2clK6sFnjgLLmhqjy1coc23Mc8JXO+53ycS+XdGSN7+YqnzU3RCtfjcyxaUPpRaQHAzyps/qhEbFG0j7AxZKuiYibiks3BQeUNVV5LF9R9ldJTwXeDDypk0WWiFiT/71Z0teBA4CRAspVPmsrgFDZMlhJ5tgDgI8Bh0fELzvrF0naMb9eDDwR6H8MztB8hbLmanXsFmaOfS9wP+DzkgB+GhGHA48EPiZpknRhOW2K50oNzQFljRW34BWJiAuBC/vWva3z+qnT7HcpsH+1gmQOKGvPQ4/MKgkPPZq3JD1Z0uq5Lscg+cHK35rrcgBICkkPn9WTROEyD411QM0lSa/KnYYbJJ051+XZtqhwmX9c5atM0nYRsZHUH3IK6eHI92l87m3b5FwXYPbMmyuUpAMkfVfS7ySdJ+lcSacMeYwTJN2Uj3GdpGfn9TtKWitp/862D5R0h6Td8vtn5VHJ6yVdKunRnW1/LOmNkq4G/pB/sb8YEV8Gfr0V3/W9kr4laef8/u8kXS9pnaTlfSOmQ9IrJa0CVnXWvULSqrzPh5XbjAcdb9bV7Yfa5syLgMqddl8GzgZ2BT4PPGcrDnUT8JfAzsA7gM9I2j33np8LvKiz7VLgooj4laTHAmcALwceQOoovKDXMdjZ/pnALlt7lZC0QNLHgUcDh0TEb/IYszcBRwG7Ad8Ezunb9UjgCcB+nXXPAv4CeAzwfNKVksLjzaqIsmU+mhcBBRwEbA+8PyLujojz2fwox2IR8fmIWBMRkxFxHukv+oH547OAF0rq/Uz+lhTAAC8DPhYRV0TEREScBWzI5er5YET8LCLuGP7rAen7nUP6g/E3nRHSLwfeHRHX50B9F7Ck76ry7ohY23fu0yJifUT8FLgEWDLE8WaXGyXm3IOBW2LL55f+ZNiDSHpxp9q2HvhzYDFARFwB/AF4kqQ/Ax7O5mEsewGv6+2X990zl6unO+p5azycNPXgHXkqQs9ewAc6511LumPvjqqe6tw/77y+nTRaoPR4s2uMq3zzpVHiVmAPSeoE1UMZYiBj/gv8ceApwGURMSHpKrZsTjqLVO37OXB+RNyZ1/8MODUiTp3hFKP+Tb2eNFnuK5IOjogb+s792UrnLjnerNI8vfqUmC9XqMtIT/F+taTtJB3F5qpaqfuSfvF+BSDpJaQrVNfZwLNJQfXpzvqPA6+Q9AQl95X0TEl/NN3Jcjl3Io0xWyhpJ0kz/gGLiHNI9zcXSfqTvPqjwImSHpWPu7Ok5xV+56nUPt5wQjBZuMxD8yKgchXoKOBYYB1pCvMXhzzGdcC/koLzF6RxXP/Xt81q4LukwPtmZ/1K0n3Uh/L5b8xlmclbgDuAE0gBekdeN6icZwEnk+bnPCwivgS8BzhX0m+Ba4GnDzrODMeverytK0ThMg8p5mlzSu4sXR0RA39JhzzuGcCa2se1ZMe99ozdT3hN0bY/+Yc3XDlgguE2Z77cQzUh6WGkK+EBc1uSMTc//4YXGanKpwEpnFqQ9CZJv59i+crgvbc4zjtJ1Z/3RsSPZqe0Nu4duwOvUJL2JN2g/zFp0MiyiPiApHeQphX/ELgTOF73TOE0ayLi2M7bd1U43luBt456HBtsnFv5Sqp8G4HXRcR3c6vWlZJWkPotboqI/QAknUjqR2kSUDaP3ZsDKiJuJfUDERG/k3Q9KZjuD6zvbLqaNPxlC0qpn44H0I47PG77xQ9MHywg9QBNpL9YsV3+KYfyivSvJvKlPyAWdo47CbEgrWdBbKoiaCJvvl3aN3rnAbQxH0Oka23eT5N5ffTW55P0XkfeJ5dr0zkWpvOlYwZMaMtKdF6n/LVQp9xsPq4m8/dbkE/RXxHPf9I1ofRjmUz7xYJ8vIWd77Ggcw6l424qY+9Yofwz2fxz636myc3l7f3yT1muTft1fkYEd/30ltsiYrdptq56hZJ0GPABUvfEJyLitL7PdyTVsB5HGlf5goj4cf7sRFJqsQng1RGxfNTyDNUokW/aDwCuIDUj75cHhK7M62b8UcWGuzbcdcvqa7eqpPUsBm5zGWa1DDMPY6p0f6SyzLHHAesi4uGSjiZ1GbxA0n6kpC6PIo14uUjSn0bExChlKg4oSfcDvgD8U0T8Njdb70zqw3gnKcC+0L9fRCwDluVjrJzrZlCXYY7LULePaVPmWABJvcyx3YA6Anh7fn0+8KE88v4I4Nw8MPpHkm7Mx7tslAIVtfJJ2p4ULJ+NiF6H6leBfUl/jc4kRXqVhOs25so7dmtkjt20TR4M/BvSjIFhs84WKWnlE/BJ4PqIeF/no92AXgqnXYFVEfGDUQtk40/lEwxrZI6dbpvirLPDKKnyPZE0leGaPJgU0nizpaQpAXcCl5KmBQyybGsKWZnLkMxdGepV+Uoyx/a2WZ3HUu5MGmFflHV2WCWtfN9i6mi+cIp1g441579ILsPclkFRtZVvU+ZY4BZSI8ML+7a5ADiGdG/0XODiiAil50H9h6T3kRol9gW+PWqBPPTI2qvUyleYOfaTwNm50WEtKejI232O1ICxEXjlqC184ICyuVCxH6ogc+ydwJTTU/L8tpnmuA2tyfSN2RzzJ+kMSb+UdG1n3a6SViglKVkhaVFeL0kfzOW4WilXRG+fY/L2qyQdM2QZ9pR0iVLikx9Iek3rcuT5Vt+W9P1chnfk9XtLuiIf7zyl/By9xDTn5TJckfsYe8c6Ma+/QdKhw/wsisoaZct8NOsB1el8ezopicjS3KlWy5nAYX3rTgC+FhH7Al/L78ll2DcvxwMfyWXcFTiJNNLjQOCk3i9/od7wrEeS8ky8Mn/HluXYABwcEY8hNRYdJukgUkfm6bkM60gdndDp8AROz9vR1+F5GPDv+f9hHZFa+UqW+ajFFWrgYxtHERHfINWNu44gTWcn/3tkZ/2nI7kc2EXS7qSMQCtyopN1wAruGaQzleHWiPhufv070nT2PVqWIx/r9/nt9nkJ4GBSh+ZUZeiV7XzgKf0dnnnUfa/Ds54xnmDYIqBmpQNtgAflMYi9sYgPHFCWamXsG57VtBySFuaujV+SgvEmYH0nrVn3eE07PLfggBrJrHSgbaVZ7eTrH57Vuhw5xdkSUp/KgaRnIE13vKYdnl2+hxrNrHSgDfCLXIUi/9t7ct10ZRm5jNMMz2peDoCIWA98nXQ/t4s2J4fpHm/TuVp0eN5btAiogY9tnAW9zjzyv//ZWf/i3Mp2EPCbXBVbDhyi9JjIRcAheV2RGYZnNSuHpN0k7ZJf3wd4Kule7hJSh+ZUZeiVbVOHZ15/dG4F3JtKHZ5bGOMq36z3Q03X+Vbr+JLOAZ5MGki5mtRKdhrwOUnHAT9lcz/EhcAzSDfatwMvyWVcqzQFvpeN9uSI6G/omMl0w7NalmN34KzcIrcA+FxE/Lek60gZjk4BvkcKfGjc4blJzN8WvBLzNuuRzU87PXjPeNjLXlu07Q0nv9ZZj8xm0pv4PK4cUNaeA8qsknncJF7CAWXtjXGjhAPKmvMVyqwmB5RZJfO407aEA8qaG+cq37x4PpSNmQZDj6ab3Nm3zRJJl+UJmVdLekHnszMl/UjpEbJXSVrSv/9UHFDWXKMJhtNN7uy6HXhxRPQmU76/Nx4ye0NELMnLVVPsfw8OKGur9Oo0erVwusmdm4sS8cOIWJVfryHNBpg2J3sJB5Q1pSEWBmeOncl0kzunLpd0ILADWz4I/dRcFTxd6aEDA7lRwtorv/rMmDlW0kWk55b1e/Mwxclz1c4GjomIXmXzRODnpCBbBryR9OzjGTmgrLlarXwR8dRpzyH9QtLuEXFr3+TO/u3uD/wP8Jac36N37Fvzyw2SPgW8vqRMrvJZe23uoaab3LlJnvD6JVLCnM/3fdabaS3S/VfRY5gcUNZWuzRipwFPk7SK9Pyo0wAkPV7SJ/I2zwf+Cjh2iubxz0q6BriG9CytU0pO6iqftdegYzcifg08ZYr1K4GX5tefAT4zzf4Hb815HVDW3DiPlHBAWXsOKLN6fIUyqyXwBEOzWpykxaw2B5RZPRrjXJAOKGvLM3bN6vI9lFlF45zb3AFl7fkKZVaJM8eaVeaAMqvDHbtmlWlyfCPKAWVtjXk/lGfsWnMtZuyWJLrM2010Zute0Fm/t6Qr8v7n5enyAzmgrL02OSVKEl0C3NFJZnl4Z/17gNPz/uuA40pO6oCy5hRly4gGJrqctnwpMcvBwPnD7u+AsrYCiChbRlOa6HKnnETzckm9oHkAsD4iNub3q4E9Sk7qRglrboj7o8WSVnbeL4uIZZuOUyfR5UMjYo2kfYCLc6aj306xXVGEO6CsqSH7oWbMHFsj0WXOaU5E3Czp68ABwBeAXSRtl69SDwHWlBTYVT5rq7S6N3qVryTR5aJeznJJi4EnAtdFRACXAM+daf+pOKCsuUaNEiWJLh8JrJT0fVIAnRYR1+XP3gi8VtKNpHuqT5ac1FU+a2/bSXR5KbD/NPvfDBw47HkdUNacx/KZ1RLAxPhGlAPKmvMVyqwmZz0yq8dXKLNaxnz6hgPKmhIgN0qY1ePMsWa1uMpnVlOVcXrbLAeUNedWPrOafIUyqyTcymdW1/jGkwPK2nOzuVlNDiizSsb8KfCeAm9NiUBRtox0noLMsZL+upM19ipJd/ZSiUk6U9KPOp8tKTmvA8ram5wsW0YzMHNsRFzSyxpLSmx5O/DVziZv6GSVvarkpA4oa6tX5StZRjNs5tjnAl+JiNtHOakDypobosq3OGd17S3HD3Ga0syxPUcD5/StO1XS1ZJO76UbG8SNEtZe+f3RjIkuK2WOJSfC3B9Y3ll9IvBzYAdgGSmt2MmDjuWAssbqDY6tkTk2ez7wpYi4u3PsW/PLDZI+Bby+pEyu8llbvaxHJctoBmaO7VhKX3UvB2HvSRxHAteWnNQBZc21aDanLHMskh4G7An8b9/+n80PDrgGWAycUnJSV/msvQYjJUoyx+b3P2aKR9VExMFbc14HlLUVgB9abVaLZ+ya1eWAMqskgInxHR3rgLLGAsIBZVaPq3xmlbiVz6wyX6HMKnJAmVUSARMTc12KWeOAsvZ8hTKryAFlVku4lc+smoBwx65ZRR56ZFZJRI0UYdssB5S1N8aNEp4Cb83F5GTRMgpJz5P0A0mTkmbKnHSYpBsk3SjphM76vSVdkTPPnidph5LzOqCssTzBsGQZzbXAUcA3pttA0kLgw8DTgf2ApZL2yx+/Bzg9Z55dBxxXclIHlLXVGxxbsoxymojrI+KGAZsdCNwYETdHxF3AucAROdPRwcD5ebuSzLOA76GssQCifOjRYkkrO++XRcSyisXZA/hZ5/1q4AnAA4D1EbGxs/4eiVym4oCytmKoCYZbnTk2ImbKw7fpEFOVcIb1AzmgrLmoNFJipsyxhVaTcvL1PARYA9wG7CJpu3yV6q0fyAFlTf2Odcsvmvzc4sLNb5vVwsB3gH0l7Q3cQnpgwAsjIiRdQnoix7kMzjy7iWKM+wTs3kvSs4F/A3YD1gNXRcShkh4MfCIinpG3ewbwfmAhcEZEnJrX70MKpl2B7wEviogNA8/rgDKrx83mZhU5oMwqckCZVeSAMqvIAWVWkQPKrCIHlFlFDiiziv4fl8Kjr4l0Pd0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADxCAYAAAB703NLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEfhJREFUeJzt3X2QXXV9x/H3hwCiGBQMUh7FsbFthqngRGiHVnGoEpwR7Ix2iO0UW1rqVHRa+yBaqxRbS22V2pFag0VBi0hx0MhQA1oV2vGBMFoKWDTDYwwlhgSNRYXsfvrHOZucPXvv3nN3755zd/fzmjmz9zzc3/kt5Lu/x3N+sk1E7LNf1xmIGDcJioiaBEVETYIioiZBEVGToIioSVDEoiXpCknbJd3Z57wk/YOkLZLukPTCJukmKGIx+yiwbpbzZwKry+184INNEk1QxKJl+xZg5yyXnA1c5cJXgWdKOnJQuvuPKoMRTZzx0oP96M6JRtfefsdP7gJ+XDm0wfaGIW53NPBQZX9reezh2b6UoIhWPbpzgq9vOq7RtSuO/M6Pba+dx+3U49jAeU0JimiVgUkm27rdVuDYyv4xwLZBX0qbIlplzJOeaLSNwEbgN8teqF8Avm971qoTpKSIDoyqpJD0CeA0YJWkrcA7gQMAbP8TcCPwCmAL8DjwW03STVBEq4yZGNHjCrbXDzhv4A3DppugiNZNDm7rdipBEa0yMJGgiJguJUVEhYEnx/wR6ARFtMo41aeIaQwT4x0TCYpoVzGiPd4SFNEyMdFzStL4SFBEq4qGdoIiYq9inCJBETHNZEqKiH1SUkTUGDEx5k8sJCiidak+RVQY8YRXdJ2NWSUoolXF4F2qTxHTpKEdUWGLCaekiJhmMiVFxD5FQ3u8/9mNd+5iyUlDO6KHiYxTROyTEe2IHibT+xSxTzEhMEERsZcRT2aaR8Q+NmM/eDfeuRtTki6S9PHy83GSfiip558/SfdL+pV2c9gzH3vz3HFOmGy4dSUlxTzZfhB4etf5WCzM+JcUCYpFQtL+tvd0nY9RGPeG9njnbg7K6spbJd0taZekj0g6aMB3DpV0g6Tvld+5QdIxlfPPlfRlSbsl3Qysqpw7XpIlDfwDI+lnJd0n6Zxy/yhJnyrve5+kN1WuvUjSdZI+LukHwOvKY9dKuqrMy12S1la+0ze9cWHEpJttXVlyQVH6deAM4HnA84G3D7h+P+AjwHOA44AfAR+onL8auJ0iGN4FnDtshso1nG8C3mj7Gkn7AZ8F/oticcLTgT+QdEbla2cD1wHPBP6lPHYWcE15bONUPhum17niFTf7N9q6slSD4gO2H7K9E/grYNDiHo/a/pTtx23vLr/zEiga0sCLgD+3/ZNymdrPDpmfX6b4B3yu7RvKYy8CDrd9se0nbN8LXA6cU/neV2x/2vak7R+Vx/7D9o22J4CPAS8YIr0xULwMrcnWlaXapqguE/sAcNRsF0t6GnApxULlh5aHV5Y9SkcBu2z/Xy3NY2nu9cCXbX+xcuw5wFGSHqscWwHc2uf3mPK/lc+PAweVVbcm6XXOjP+I9njnbu6q/2CPY/CKmH8E/Axwiu1DgBeXx0Wx5vKhkg6upTmM1wPHSbq0cuwh4D7bz6xsK22/onLNMK8ibpLeWBj3kmKpBsUbJB0j6TDgbcAnB1y/kqId8Vj5nXdOnbD9ALAZ+AtJB0r6JeCVQ+ZnN0Up9GJJl5THvg78QNJbJD1V0gpJJ0h60ZBpTxl1egvCFpPer9HWhKR1ku6RtEXShT3OHyfpi5K+IekOSQP/SCzVoLiaolF7b7n95YDr/x54KrAD+Crwudr51wKnADspAuaqYTNk+zHgZcCZkt5VtgleCZwI3Ffe+8PAM4ZNu0x/pOktlKKhvaLRNkhZvb0MOBNYA6yXtKZ22duBa22fRNG++sdB6S7VNsVttv+66cW2t1EsPVv1ocr5eykay73sR/F2+Z4LP9s+vvJ5J/saxlP37dkJYPuiQcds3w/76hnDpteNkT6jfTKwpfz/g6RrKHrs7q5cY+CQ8vMzaLC4/FINijadANxfLk8bAxQN7cbthVWSNlf2N9jeUNk/mumdEVspSvSqi4CbJL0ROBgYOOVm2QSFpLdRtC/qbrV95hzTfDPwp8Ab55O35WaIEe0dttfOcr5XdNX/OK0HPmr7vZJ+EfiYpBNs9107ZskFRbW6Ujv+buDdI77X+4D3jTLNpW5qRHtEtjK9p/EYZlaPzqPo5MD2V8rZDauA7f0SXaoN7Rhjk+zXaGvgNmB1OQ3nQIqG9MbaNQ9SjO4j6eeAg4DvzZZoqyXFqsNW+PhjD+h7/tt3PK3F3MRC2c2uHbYP73XOhicnR/O32PYeSRcAmygGKq+wfZeki4HNtjdSjEFdLukPKapWrxvU/ptXUEhaB7y/zNCHbV8y2/XHH3sAX9/UfyD4jKNOnE92Ykx83tc90O9cUX0aXQXF9o3AjbVj76h8vhs4dZg05xwUlT7il1HU7W6TtLHMRERf4/4u2fmE7N4+YttPUMzcPHs02YqlaqpLdpynjs+n+tSkjxhJ5wPnAxx39JLr7Iqhjbb6tBDmk7smfcTY3mB7re21hz9rvN/iEO1Yys9oN+kjjpim6H0a7z+O8wmKvX3EwHcp+ohfO5JcxZI14sG7BTHnoOjXRzzbd759x9Nm7XbdtO2bfc+lu3bpWNLrU/TqI46YzZATAjuR7qBo3bj3PiUoolW22JOgiJgu1aeIirQpInpIUAxhrt21g74b42NJj1NEzNWSHqeIGJYNe0b0kNFCSVBE61J9iqhImyKiBycoIqZLQzuiwk6bYmQGjUNk2vliISbS+xQxXdoUERWZ+xRR56JdMc4SFNG69D5FVDgN7YiZUn1qSaadLx7pfYqosBMUETOkSzaiJm2KiAojJtP7FDHdmBcUWQgyWlY2tJtsTUhaJ+keSVskXdjnml+TdLekuyRdPSjNlBTRvhEVFU2WmJO0GngrcKrtXZKePSjdZREU85l23uT7MZwRdsnuXWIOQNLUEnPVdRd/F7jM9q7i3u67fvaU+a6Oej+wG5gA9theO5/0YukzMDnZOChWSdpc2d9ge0Nlv8kSc88HkPSfFEtGXGT7c7PddBQlxUtt7xhBOrEcGGheUuwY8Ie2yRJz+wOrgdMoVtu6VdIJth/rl2ga2tE6u9nWQJMl5rYCn7H9pO37gHsogqSv+QaFgZsk3V6ugjqDpPMlbZa0+Ul+Ms/bxZLghttge5eYk3QgxRJzG2vXfBp4KYCkVRTVqXtnS3S+1adTbW8rW/Q3S/of27dULyjrgBsADtFh495FHQuueXfrIP2WmJN0MbDZ9sby3Msl3U3R9v0T24/Olu58l/faVv7cLul6it6AW2b/Vix7I/zT2GuJOdvvqHw28OZya2TO1SdJB0taOfUZeDlw51zTi2XC4Ek12royn5LiCOB6SVPpXD2oq2tc5fU5bVuis2TLAZMXjDAvsVyMectyWYxox5hJUERUDDd414kERbQuDxlF1HXYs9REgiJap5QUi19enzNCzadwdCZBES1TGtoRM6SkiKiZ7DoDs0tQRLsyThExU3qfIurGPCjyOGpETUqKecrrc4aX6lNElck0j4gZUlJETJfqU0RdgiKiJkERsY+c6lPETOl9Wt7y+pyZUlJE1CUoIirSpojoIUERMZ3G/CGjzJKNqElJEe1L9Slms+xen7MIGtoDq0+SrpC0XdKdlWOHSbpZ0nfKn4cubDZjSRnd8l4Lokmb4qPAutqxC4Ev2F4NfKHcj2hmsQdFuYbdztrhs4Ery89XAq8acb5iiRJF71OTrStz7X06wvbDAOXPZ/e7MKujxjTeNylw0NaEpHWS7pG0RVLfGoukV0uypNnW5QZa6JK1vcH2WttrD+ApC327WAxGVH2StAK4DDgTWAOsl7Smx3UrgTcBX2uSvbkGxSOSjixveCSwfY7pxHI0ujbFycAW2/fafgK4hqJqX/cu4D3Aj5skOteg2AicW34+F/jMHNOJZWiI6tOqqap3uZ1fS+po4KHK/tby2L57SScBx9q+oWn+Bo5TSPoEcFqZwa3AO4FLgGslnQc8CLym6Q2juSU77bx5z9IO27O1AXo9mLE3dUn7AZcCr2t8RxoEhe31fU6dPsyNIoCioT26nqWtwLGV/WOAbZX9lcAJwJfKpa1/Ctgo6Szbm/slmhHtaN/oxiBuA1ZLei7wXeAc4LV7b2N/H1g1tS/pS8AfzxYQkAmB0YFRdcna3gNcAGwCvgVca/suSRdLOmuu+UtJEe0b4Wi17RuBG2vH3tHn2tOapJmgiHZlzbuI6cT4z5JNUCxii3XaeYIioi5BEVGToIioWARP3iUoon0Jiojpxv0VNwmKaF2qTxFVGbyLroz1qq0Jioh9MqId0YMmxzsqEhTRrrQpImZK9SmiLkERMV1Kioi6BEWMo85enzPat3ksiARFtCrjFBG9eLyjIkERrUtJEVGVwbuImdLQjqhJUMSiNJ/X56w4cpaTZuwb2nNdHfUiSd+V9M1ye8XCZjOWklEu77UQ5ro6KsCltk8stxt7nI/obcxXR22yPsUtko5f+KzEcrAYBu/m8yr+CyTdUVav+i4un9VRYxobTTbbujLXoPgg8DzgROBh4L39LszqqDHDYq8+9WL7kanPki4HGi+yF7Ekq09TywWXfhW4s9+1EdMYmHSzrSNzXR31NEknUvyK9wO/1+Rmu9m14/O+7oHKoVXAjiHz3IbkaxY9xiHq+XrOrAmMeUkx19VR/3kuN7N9eHVf0uYBS8J2IvkazrD5GmX1SdI64P3ACuDDti+pnX8z8DvAHuB7wG/bfmBGQhVZCDJaN6reJ0krgMuAM4E1wHpJa2qXfQNYa/vngeuA9wxKN0ER7Wra89SsNDkZ2GL7XttPANcAZ0+7nf1F24+Xu1+lWGt7Vl3PfdrQ8f37Sb6G0zhfxeBd4/rTKknVNa832K7e62jgocr+VuCUWdI7D/i3QTftNChqv+DYSL6GM3S+ms+S3TGgraJe2el5ofQbwFrgJYNu2nVJEcvQECXFIFuBYyv7xwDbZtxP+hXgz4CX2B44rSJtimjXaNsUtwGrJT1X0oHAOcDG6gWSTgI+BJxle3uTRDsJCknrJN0jaYukC7vIQy+S7pf03+V0+M2Dv7Ggeek1Zf8wSTdL+k75s++cs5bzNcSjBKOb+2R7D3ABsAn4FnCt7bskXSzprPKyvwWeDvxrmbeNfZLb9zu65Qc+ym60bwMvoyj+bgPW27671Yz0IOl+iu67zgfIJL0Y+CFwle0TymPvAXbavqT8Y3Ko7beMQb4uAn5o++8Gff+QlUf75JN+v9G9vnDr22/vYlymi5JiYDdaFFP2gZ21w2cDV5afrwRe1Wqm6JuvIRIoHkdtsnWli6Do1Y12dAf56MXATZJul3R+15np4QjbDwOUP5/dcX6qGj1KABSPozbZOtJFUDTuRuvAqbZfSDFC+oayqhCDNX6UABj7qeNdBEWjbrQu2N5W/twOXE9R1Rsnj0zNUC5/NupNWWi2H7E9YXsSuJwB/900Odlo60oXQTGwG60Lkg6WtHLqM/Byxm9K/Ebg3PLzucBnOszLXkM9SmCKwbsmW0daH7yzvUfSVDfaCuAK23e1nY8ejgCulwTFf5erbX+uq8z0mbJ/CXCtpPOAB4HXjEm+Gj9KIDzKwbsF0cmIdvn2j7F6A4jte4EXdJ2PKX2m7AOc3mpGakbyKEGCIqImQRFRMdWmGGMJimhdlz1LTSQoomXdDsw1kaCIdi2CFywnKKJ94117SlBE+zJOEVGXoIiosGFivOtPCYpoX0qKiJoERUTF1AuWx1iCIlpmcNoUEfuYNLQjZkibIqImQRFRlQmBEdMZyNTxiJqUFBFVmeYRMZ3BGaeIqMmIdkRN2hQRFXZ6nyJmSEkRUWU8MdF1JmaVoIh2Zep4RA9j3iWb1VGjVQY86UZbE4MWFZX0FEmfLM9/TdLxg9JMUES7XD5k1GQboFxU9DKKlafWAOslralddh6wy/ZPA5cCfzMo3QRFtM4TE422BposKlpdPPM64HSVi5D0kzZFtGo3uzZ93tetanj5QbX1zDfY3lDZ77Wo6Cm1NPZeUy4Y9H3gWUDfZaETFNEq2+tGmFyTRUWHXng01adYzJosKrr3Gkn7A89gwDrgCYpYzJosKlpdPPPVwL/bsw+pp/oUi1a/RUUlXQxstr2RYj2+j0naQlFCnDMoXQ0ImohlJ9WniJoERURNgiKiJkERUZOgiKhJUETUJCgiav4fdNYTdKDAoWYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for layer in vae.layers:\n",
    "    if hasattr(layer, 'kernel'):\n",
    "        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 4))#, sharey=True)\n",
    "        im1 = ax1.imshow(K.eval(layer.kernel).T)\n",
    "        ax1.set_title('{} kernel'.format(layer.name))\n",
    "        plt.colorbar(im1, ax=ax1)\n",
    "        if hasattr(layer, 'bias') and layer.bias is not None:\n",
    "            im2 = ax2.imshow(K.eval(K.expand_dims(layer.bias, -1)))\n",
    "            ax2.set_title('{} bias'.format(layer.name))\n",
    "            plt.colorbar(im2, ax=ax2)\n",
    "        else:\n",
    "            ax2.set_visible(False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}