thundergolfer/text-classify-with-cnn

View on GitHub
train.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "import datetime\n",
    "import manage_data\n",
    "from text_network import TextNetwork\n",
    "from tensorflow.contrib import learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Parameters:\n",
      "ALLOW_SOFT_PLACEMENT=True\n",
      "BATCH_SIZE=64\n",
      "CHECKPOINT_EVERY=100\n",
      "DROPOUT_KEEP_PROB=0.5\n",
      "EMBEDDING_DIM=128\n",
      "EVALUATE_EVERY=100\n",
      "FILTER_SIZES=3,4,5\n",
      "L2_REG_LAMBDA=0.0\n",
      "LOG_DEVICE_PLACEMENT=False\n",
      "NUM_EPOCHS=200\n",
      "NUM_FILTERS=128\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# Model Hyperparameters\n",
    "tf.flags.DEFINE_integer(\"embedding_dim\", 128, \"Dimensionality of character embedding (default: 128)\")\n",
    "tf.flags.DEFINE_string(\"filter_sizes\", \"3,4,5\", \"Comma-separated filter sizes (default: '3,4,5')\")\n",
    "tf.flags.DEFINE_integer(\"num_filters\", 128, \"Number of filters per filter size (default: 128)\")\n",
    "tf.flags.DEFINE_float(\"dropout_keep_prob\", 0.5, \"Dropout keep probability (default: 0.5)\")\n",
    "tf.flags.DEFINE_float(\"l2_reg_lambda\", 0.0, \"L2 regularizaion lambda (default: 0.0)\")\n",
    "\n",
    "# Training parameters\n",
    "tf.flags.DEFINE_integer(\"batch_size\", 64, \"Batch Size (default: 64)\")\n",
    "tf.flags.DEFINE_integer(\"num_epochs\", 200, \"Number of training epochs (default: 200)\")\n",
    "tf.flags.DEFINE_integer(\"evaluate_every\", 100, \"Evaluate model on dev set after this many steps (default: 100)\")\n",
    "tf.flags.DEFINE_integer(\"checkpoint_every\", 100, \"Save model after this many steps (default: 100)\")\n",
    "# Misc Parameters\n",
    "tf.flags.DEFINE_boolean(\"allow_soft_placement\", True, \"Allow device soft device placement\")\n",
    "tf.flags.DEFINE_boolean(\"log_device_placement\", False, \"Log placement of ops on devices\")\n",
    "\n",
    "FLAGS = tf.flags.FLAGS\n",
    "FLAGS._parse_flags()\n",
    "print(\"\\nParameters:\")\n",
    "for attr, value in sorted(FLAGS.__flags.items()):\n",
    "    print(\"{}={}\".format(attr.upper(), value))\n",
    "print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Preparation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "Vocabulary Size: 18758\n",
      "Train/Dev split: 9662/1000\n"
     ]
    }
   ],
   "source": [
    "# Load data\n",
    "print(\"Loading data...\")\n",
    "x_text, y = manage_data.load_data_and_labels()\n",
    "\n",
    "# Build vocabulary\n",
    "max_document_length = max([len(x.split(\" \")) for x in x_text])\n",
    "vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)\n",
    "x = np.array(list(vocab_processor.fit_transform(x_text)))\n",
    "\n",
    "# Randomly shuffle data\n",
    "np.random.seed(10)\n",
    "shuffle_indices = np.random.permutation(np.arange(len(y)))\n",
    "x_shuffled = x[shuffle_indices]\n",
    "y_shuffled = y[shuffle_indices]\n",
    "\n",
    "# Split train/test set\n",
    "# TODO: This is very crude, should use cross-validation\n",
    "x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]\n",
    "y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]\n",
    "print(\"Vocabulary Size: {:d}\".format(len(vocab_processor.vocabulary_)))\n",
    "print(\"Train/Dev split: {:d}/{:d}\".format(len(y_train), len(y_dev)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training The Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing to /home/jono/Github/text-classify-with-cnn/runs/1476486719\n",
      "\n",
      "2016-10-15T10:12:00.872912: step 1, loss 2.17834, acc 0.515625\n",
      "2016-10-15T10:12:01.104044: step 2, loss 2.37684, acc 0.4375\n",
      "2016-10-15T10:12:01.345502: step 3, loss 1.61573, acc 0.53125\n",
      "2016-10-15T10:12:01.568742: step 4, loss 2.46489, acc 0.5625\n",
      "2016-10-15T10:12:01.798536: step 5, loss 1.50424, acc 0.609375\n",
      "2016-10-15T10:12:02.015717: step 6, loss 1.67856, acc 0.484375\n",
      "2016-10-15T10:12:02.235344: step 7, loss 1.87402, acc 0.546875\n",
      "2016-10-15T10:12:02.485685: step 8, loss 1.75293, acc 0.609375\n",
      "2016-10-15T10:12:02.733335: step 9, loss 1.60018, acc 0.578125\n",
      "2016-10-15T10:12:02.952329: step 10, loss 1.65958, acc 0.53125\n",
      "2016-10-15T10:12:03.176495: step 11, loss 1.48052, acc 0.546875\n",
      "2016-10-15T10:12:03.406426: step 12, loss 2.32531, acc 0.46875\n",
      "2016-10-15T10:12:03.640859: step 13, loss 2.31976, acc 0.484375\n",
      "2016-10-15T10:12:03.869680: step 14, loss 2.01, acc 0.46875\n",
      "2016-10-15T10:12:04.095002: step 15, loss 1.75927, acc 0.5\n",
      "2016-10-15T10:12:04.340691: step 16, loss 1.70119, acc 0.484375\n",
      "2016-10-15T10:12:04.563084: step 17, loss 1.6076, acc 0.53125\n",
      "2016-10-15T10:12:04.785239: step 18, loss 1.20017, acc 0.640625\n",
      "2016-10-15T10:12:05.012354: step 19, loss 1.50494, acc 0.46875\n",
      "2016-10-15T10:12:05.242894: step 20, loss 1.65939, acc 0.578125\n",
      "2016-10-15T10:12:05.499912: step 21, loss 1.40056, acc 0.5625\n",
      "2016-10-15T10:12:05.738980: step 22, loss 1.36007, acc 0.59375\n",
      "2016-10-15T10:12:05.967967: step 23, loss 1.82848, acc 0.53125\n",
      "2016-10-15T10:12:06.210751: step 24, loss 2.1214, acc 0.484375\n",
      "2016-10-15T10:12:06.445041: step 25, loss 1.45436, acc 0.546875\n",
      "2016-10-15T10:12:06.673484: step 26, loss 1.73738, acc 0.5\n",
      "2016-10-15T10:12:06.910386: step 27, loss 1.68768, acc 0.5625\n",
      "2016-10-15T10:12:07.147864: step 28, loss 1.30136, acc 0.5625\n",
      "2016-10-15T10:12:07.367672: step 29, loss 1.35643, acc 0.5625\n",
      "2016-10-15T10:12:07.596474: step 30, loss 1.83106, acc 0.4375\n",
      "2016-10-15T10:12:07.823037: step 31, loss 1.86499, acc 0.578125\n",
      "2016-10-15T10:12:08.045291: step 32, loss 1.87879, acc 0.46875\n",
      "2016-10-15T10:12:08.267881: step 33, loss 1.88167, acc 0.484375\n",
      "2016-10-15T10:12:08.485892: step 34, loss 1.20186, acc 0.625\n",
      "2016-10-15T10:12:08.708856: step 35, loss 1.98814, acc 0.4375\n",
      "2016-10-15T10:12:08.940056: step 36, loss 1.56428, acc 0.59375\n",
      "2016-10-15T10:12:09.166673: step 37, loss 2.02957, acc 0.46875\n",
      "2016-10-15T10:12:09.384138: step 38, loss 1.92991, acc 0.46875\n",
      "2016-10-15T10:12:09.602507: step 39, loss 1.38488, acc 0.640625\n",
      "2016-10-15T10:12:09.822334: step 40, loss 1.83013, acc 0.546875\n",
      "2016-10-15T10:12:10.041669: step 41, loss 1.51196, acc 0.625\n",
      "2016-10-15T10:12:10.262227: step 42, loss 1.04271, acc 0.59375\n",
      "2016-10-15T10:12:10.481477: step 43, loss 1.49973, acc 0.515625\n",
      "2016-10-15T10:12:10.701666: step 44, loss 1.49, acc 0.609375\n",
      "2016-10-15T10:12:10.923002: step 45, loss 1.44082, acc 0.53125\n",
      "2016-10-15T10:12:11.141560: step 46, loss 1.43061, acc 0.484375\n",
      "2016-10-15T10:12:11.361128: step 47, loss 1.81411, acc 0.515625\n",
      "2016-10-15T10:12:11.581722: step 48, loss 1.23713, acc 0.546875\n",
      "2016-10-15T10:12:11.796604: step 49, loss 1.5056, acc 0.5625\n",
      "2016-10-15T10:12:12.019603: step 50, loss 1.23, acc 0.625\n",
      "2016-10-15T10:12:12.239029: step 51, loss 1.80079, acc 0.46875\n",
      "2016-10-15T10:12:12.466978: step 52, loss 2.44691, acc 0.421875\n",
      "2016-10-15T10:12:12.691878: step 53, loss 2.14964, acc 0.46875\n",
      "2016-10-15T10:12:12.911293: step 54, loss 1.3923, acc 0.546875\n",
      "2016-10-15T10:12:13.133364: step 55, loss 1.65153, acc 0.515625\n",
      "2016-10-15T10:12:13.351238: step 56, loss 1.46936, acc 0.515625\n",
      "2016-10-15T10:12:13.572986: step 57, loss 1.29942, acc 0.578125\n",
      "2016-10-15T10:12:13.786631: step 58, loss 1.8689, acc 0.46875\n",
      "2016-10-15T10:12:14.002450: step 59, loss 1.28186, acc 0.5625\n",
      "2016-10-15T10:12:14.228103: step 60, loss 1.4272, acc 0.515625\n",
      "2016-10-15T10:12:14.449732: step 61, loss 1.38579, acc 0.546875\n",
      "2016-10-15T10:12:14.672107: step 62, loss 1.54824, acc 0.484375\n",
      "2016-10-15T10:12:14.889157: step 63, loss 1.71746, acc 0.53125\n",
      "2016-10-15T10:12:15.108954: step 64, loss 1.89061, acc 0.40625\n",
      "2016-10-15T10:12:15.329058: step 65, loss 1.14454, acc 0.625\n",
      "2016-10-15T10:12:15.549783: step 66, loss 1.72255, acc 0.53125\n",
      "2016-10-15T10:12:15.769959: step 67, loss 1.08323, acc 0.59375\n",
      "2016-10-15T10:12:15.988682: step 68, loss 1.28509, acc 0.65625\n",
      "2016-10-15T10:12:16.208728: step 69, loss 1.56131, acc 0.515625\n",
      "2016-10-15T10:12:16.430192: step 70, loss 1.59217, acc 0.515625\n",
      "2016-10-15T10:12:16.656004: step 71, loss 1.62195, acc 0.515625\n",
      "2016-10-15T10:12:16.872326: step 72, loss 1.50321, acc 0.546875\n",
      "2016-10-15T10:12:17.091269: step 73, loss 1.29522, acc 0.546875\n",
      "2016-10-15T10:12:17.309066: step 74, loss 1.1761, acc 0.625\n",
      "2016-10-15T10:12:17.527164: step 75, loss 1.27479, acc 0.515625\n",
      "2016-10-15T10:12:17.749515: step 76, loss 1.91551, acc 0.46875\n",
      "2016-10-15T10:12:17.970980: step 77, loss 1.26546, acc 0.5\n",
      "2016-10-15T10:12:18.189730: step 78, loss 1.12266, acc 0.65625\n",
      "2016-10-15T10:12:18.407322: step 79, loss 1.9547, acc 0.484375\n",
      "2016-10-15T10:12:18.627851: step 80, loss 1.00029, acc 0.546875\n",
      "2016-10-15T10:12:18.847346: step 81, loss 1.688, acc 0.546875\n",
      "2016-10-15T10:12:19.065755: step 82, loss 1.54275, acc 0.453125\n",
      "2016-10-15T10:12:19.285684: step 83, loss 1.77438, acc 0.515625\n",
      "2016-10-15T10:12:19.506882: step 84, loss 1.24367, acc 0.53125\n",
      "2016-10-15T10:12:19.732368: step 85, loss 1.42807, acc 0.5625\n",
      "2016-10-15T10:12:19.950892: step 86, loss 1.46412, acc 0.59375\n",
      "2016-10-15T10:12:20.172392: step 87, loss 1.45359, acc 0.5625\n",
      "2016-10-15T10:12:20.388848: step 88, loss 1.70674, acc 0.4375\n",
      "2016-10-15T10:12:20.610542: step 89, loss 1.4345, acc 0.53125\n",
      "2016-10-15T10:12:20.833926: step 90, loss 1.31255, acc 0.578125\n",
      "2016-10-15T10:12:21.053368: step 91, loss 1.07133, acc 0.640625\n",
      "2016-10-15T10:12:21.285574: step 92, loss 1.20513, acc 0.515625\n",
      "2016-10-15T10:12:21.501478: step 93, loss 1.60184, acc 0.515625\n",
      "2016-10-15T10:12:21.722405: step 94, loss 1.26125, acc 0.5625\n",
      "2016-10-15T10:12:21.937582: step 95, loss 1.24756, acc 0.546875\n",
      "2016-10-15T10:12:22.157988: step 96, loss 1.07844, acc 0.625\n",
      "2016-10-15T10:12:22.375592: step 97, loss 1.58785, acc 0.5\n",
      "2016-10-15T10:12:22.608746: step 98, loss 1.69374, acc 0.40625\n",
      "2016-10-15T10:12:22.831462: step 99, loss 1.42811, acc 0.484375\n",
      "2016-10-15T10:12:23.053271: step 100, loss 1.05428, acc 0.671875\n",
      "\n",
      "Evaluation:\n",
      "2016-10-15T10:12:23.856138: step 100, loss 0.719137, acc 0.594\n",
      "\n",
      "Saved model checkpoint to /home/jono/Github/text-classify-with-cnn/runs/1476486719/checkpoints/model-100\n",
      "\n",
      "2016-10-15T10:12:24.657749: step 101, loss 1.60184, acc 0.515625\n",
      "2016-10-15T10:12:24.877484: step 102, loss 1.29558, acc 0.484375\n",
      "2016-10-15T10:12:25.092704: step 103, loss 1.22662, acc 0.53125\n",
      "2016-10-15T10:12:25.309380: step 104, loss 1.99945, acc 0.484375\n",
      "2016-10-15T10:12:25.528338: step 105, loss 1.24838, acc 0.5\n",
      "2016-10-15T10:12:25.742480: step 106, loss 1.23168, acc 0.546875\n",
      "2016-10-15T10:12:25.970456: step 107, loss 1.16021, acc 0.53125\n",
      "2016-10-15T10:12:26.185724: step 108, loss 1.0626, acc 0.578125\n",
      "2016-10-15T10:12:26.402798: step 109, loss 1.36584, acc 0.53125\n",
      "2016-10-15T10:12:26.619056: step 110, loss 1.31638, acc 0.5\n",
      "2016-10-15T10:12:26.834185: step 111, loss 1.40253, acc 0.515625\n",
      "2016-10-15T10:12:27.056232: step 112, loss 1.24609, acc 0.484375\n",
      "2016-10-15T10:12:27.271940: step 113, loss 0.95558, acc 0.6875\n",
      "2016-10-15T10:12:27.490441: step 114, loss 1.20972, acc 0.5\n",
      "2016-10-15T10:12:27.716505: step 115, loss 0.850529, acc 0.671875\n",
      "2016-10-15T10:12:27.933472: step 116, loss 1.21025, acc 0.546875\n",
      "2016-10-15T10:12:28.152836: step 117, loss 1.57255, acc 0.390625\n",
      "2016-10-15T10:12:28.374061: step 118, loss 1.14735, acc 0.515625\n",
      "2016-10-15T10:12:28.589133: step 119, loss 1.11571, acc 0.671875\n",
      "2016-10-15T10:12:28.807469: step 120, loss 1.24635, acc 0.546875\n",
      "2016-10-15T10:12:29.026163: step 121, loss 1.72968, acc 0.515625\n",
      "2016-10-15T10:12:29.243407: step 122, loss 1.31092, acc 0.5\n",
      "2016-10-15T10:12:29.460504: step 123, loss 1.85621, acc 0.4375\n",
      "2016-10-15T10:12:29.695413: step 124, loss 1.06487, acc 0.578125\n",
      "2016-10-15T10:12:29.911933: step 125, loss 1.7949, acc 0.46875\n",
      "2016-10-15T10:12:30.131825: step 126, loss 1.11409, acc 0.53125\n",
      "2016-10-15T10:12:30.346780: step 127, loss 1.01377, acc 0.640625\n",
      "2016-10-15T10:12:30.568089: step 128, loss 1.11762, acc 0.609375\n",
      "2016-10-15T10:12:30.785175: step 129, loss 1.37035, acc 0.4375\n",
      "2016-10-15T10:12:30.999227: step 130, loss 1.02695, acc 0.625\n",
      "2016-10-15T10:12:31.215286: step 131, loss 1.51713, acc 0.578125\n",
      "2016-10-15T10:12:31.448318: step 132, loss 1.59183, acc 0.4375\n",
      "2016-10-15T10:12:31.666681: step 133, loss 1.29736, acc 0.578125\n",
      "2016-10-15T10:12:31.889705: step 134, loss 1.10579, acc 0.65625\n",
      "2016-10-15T10:12:32.104027: step 135, loss 1.36328, acc 0.453125\n",
      "2016-10-15T10:12:32.325099: step 136, loss 0.856365, acc 0.6875\n",
      "2016-10-15T10:12:32.544049: step 137, loss 1.03201, acc 0.5625\n",
      "2016-10-15T10:12:32.761233: step 138, loss 1.73107, acc 0.421875\n",
      "2016-10-15T10:12:32.984635: step 139, loss 1.60015, acc 0.53125\n",
      "2016-10-15T10:12:33.209024: step 140, loss 1.04512, acc 0.609375\n",
      "2016-10-15T10:12:33.426916: step 141, loss 1.34894, acc 0.5\n",
      "2016-10-15T10:12:33.646354: step 142, loss 0.949152, acc 0.625\n",
      "2016-10-15T10:12:33.861454: step 143, loss 1.18836, acc 0.53125\n",
      "2016-10-15T10:12:34.082681: step 144, loss 1.00057, acc 0.5625\n",
      "2016-10-15T10:12:34.300121: step 145, loss 1.17293, acc 0.609375\n",
      "2016-10-15T10:12:34.516054: step 146, loss 1.43594, acc 0.515625\n",
      "2016-10-15T10:12:34.732748: step 147, loss 1.33237, acc 0.484375\n",
      "2016-10-15T10:12:34.948271: step 148, loss 1.29256, acc 0.46875\n",
      "2016-10-15T10:12:35.165617: step 149, loss 1.20981, acc 0.546875\n",
      "2016-10-15T10:12:35.386875: step 150, loss 1.11616, acc 0.5625\n",
      "2016-10-15T10:12:35.604839: step 151, loss 1.01066, acc 0.532258\n",
      "2016-10-15T10:12:35.823732: step 152, loss 1.10394, acc 0.5625\n",
      "2016-10-15T10:12:36.042787: step 153, loss 0.749592, acc 0.703125\n",
      "2016-10-15T10:12:36.259079: step 154, loss 1.08546, acc 0.59375\n",
      "2016-10-15T10:12:36.479301: step 155, loss 0.848611, acc 0.671875\n",
      "2016-10-15T10:12:36.699000: step 156, loss 1.18111, acc 0.53125\n",
      "2016-10-15T10:12:36.915655: step 157, loss 0.903777, acc 0.703125\n",
      "2016-10-15T10:12:37.133633: step 158, loss 0.770446, acc 0.75\n",
      "2016-10-15T10:12:37.350593: step 159, loss 1.01564, acc 0.65625\n",
      "2016-10-15T10:12:37.574548: step 160, loss 1.13375, acc 0.578125\n",
      "2016-10-15T10:12:37.791948: step 161, loss 1.03907, acc 0.515625\n",
      "2016-10-15T10:12:38.012848: step 162, loss 0.972708, acc 0.609375\n",
      "2016-10-15T10:12:38.228790: step 163, loss 0.897121, acc 0.671875\n",
      "2016-10-15T10:12:38.446337: step 164, loss 0.875331, acc 0.6875\n",
      "2016-10-15T10:12:38.665272: step 165, loss 1.39326, acc 0.46875\n",
      "2016-10-15T10:12:38.883191: step 166, loss 0.984082, acc 0.5625\n"
     ]
    }
   ],
   "source": [
    "with tf.Graph().as_default():\n",
    "    session_conf = tf.ConfigProto(\n",
    "      allow_soft_placement=FLAGS.allow_soft_placement,\n",
    "      log_device_placement=FLAGS.log_device_placement)\n",
    "    sess = tf.Session(config=session_conf)\n",
    "    with sess.as_default():\n",
    "        cnn = TextNetwork(\n",
    "            sequence_length=x_train.shape[1],\n",
    "            num_classes=y_train.shape[1],\n",
    "            vocab_size=len(vocab_processor.vocabulary_),\n",
    "            embedding_size=FLAGS.embedding_dim,\n",
    "            filter_sizes=list(map(int, FLAGS.filter_sizes.split(\",\"))),\n",
    "            num_filters=FLAGS.num_filters,\n",
    "            l2_reg_lambda=FLAGS.l2_reg_lambda)\n",
    "\n",
    "        # Define Training procedure\n",
    "        global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n",
    "        optimizer = tf.train.AdamOptimizer(1e-3)\n",
    "        grads_and_vars = optimizer.compute_gradients(cnn.loss)\n",
    "        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)\n",
    "\n",
    "        # Keep track of gradient values and sparsity (optional)\n",
    "        grad_summaries = []\n",
    "        for g, v in grads_and_vars:\n",
    "            if g is not None:\n",
    "                grad_hist_summary = tf.histogram_summary(\"{}/grad/hist\".format(v.name), g)\n",
    "                sparsity_summary = tf.scalar_summary(\"{}/grad/sparsity\".format(v.name), tf.nn.zero_fraction(g))\n",
    "                grad_summaries.append(grad_hist_summary)\n",
    "                grad_summaries.append(sparsity_summary)\n",
    "        grad_summaries_merged = tf.merge_summary(grad_summaries)\n",
    "\n",
    "        # Output directory for models and summaries\n",
    "        timestamp = str(int(time.time()))\n",
    "        out_dir = os.path.abspath(os.path.join(os.path.curdir, \"runs\", timestamp))\n",
    "        print(\"Writing to {}\\n\".format(out_dir))\n",
    "\n",
    "        # Summaries for loss and accuracy\n",
    "        loss_summary = tf.scalar_summary(\"loss\", cnn.loss)\n",
    "        acc_summary = tf.scalar_summary(\"accuracy\", cnn.accuracy)\n",
    "\n",
    "        # Train Summaries\n",
    "        train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged])\n",
    "        train_summary_dir = os.path.join(out_dir, \"summaries\", \"train\")\n",
    "        train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph)\n",
    "\n",
    "        # Dev summaries\n",
    "        dev_summary_op = tf.merge_summary([loss_summary, acc_summary])\n",
    "        dev_summary_dir = os.path.join(out_dir, \"summaries\", \"dev\")\n",
    "        dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph)\n",
    "\n",
    "        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it\n",
    "        checkpoint_dir = os.path.abspath(os.path.join(out_dir, \"checkpoints\"))\n",
    "        checkpoint_prefix = os.path.join(checkpoint_dir, \"model\")\n",
    "        if not os.path.exists(checkpoint_dir):\n",
    "            os.makedirs(checkpoint_dir)\n",
    "        saver = tf.train.Saver(tf.all_variables())\n",
    "\n",
    "        # Write vocabulary\n",
    "        vocab_processor.save(os.path.join(out_dir, \"vocab\"))\n",
    "\n",
    "        # Initialize all variables\n",
    "        sess.run(tf.initialize_all_variables())\n",
    "\n",
    "        def train_step(x_batch, y_batch):\n",
    "            \"\"\"\n",
    "            A single training step\n",
    "            \"\"\"\n",
    "            feed_dict = {\n",
    "              cnn.input_x: x_batch,\n",
    "              cnn.input_y: y_batch,\n",
    "              cnn.dropout_keep_prob: FLAGS.dropout_keep_prob\n",
    "            }\n",
    "            _, step, summaries, loss, accuracy = sess.run(\n",
    "                [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],\n",
    "                feed_dict)\n",
    "            time_str = datetime.datetime.now().isoformat()\n",
    "            print(\"{}: step {}, loss {:g}, acc {:g}\".format(time_str, step, loss, accuracy))\n",
    "            train_summary_writer.add_summary(summaries, step)\n",
    "\n",
    "        def dev_step(x_batch, y_batch, writer=None):\n",
    "            \"\"\"\n",
    "            Evaluates model on a dev set\n",
    "            \"\"\"\n",
    "            feed_dict = {\n",
    "              cnn.input_x: x_batch,\n",
    "              cnn.input_y: y_batch,\n",
    "              cnn.dropout_keep_prob: 1.0\n",
    "            }\n",
    "            step, summaries, loss, accuracy = sess.run(\n",
    "                [global_step, dev_summary_op, cnn.loss, cnn.accuracy],\n",
    "                feed_dict)\n",
    "            time_str = datetime.datetime.now().isoformat()\n",
    "            print(\"{}: step {}, loss {:g}, acc {:g}\".format(time_str, step, loss, accuracy))\n",
    "            if writer:\n",
    "                writer.add_summary(summaries, step)\n",
    "\n",
    "        # Generate batches\n",
    "        batches = manage_data.batch_iter(\n",
    "            list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)\n",
    "        # Training loop. For each batch...\n",
    "        for batch in batches:\n",
    "            x_batch, y_batch = zip(*batch)\n",
    "            train_step(x_batch, y_batch)\n",
    "            current_step = tf.train.global_step(sess, global_step)\n",
    "            if current_step % FLAGS.evaluate_every == 0:\n",
    "                print(\"\\nEvaluation:\")\n",
    "                dev_step(x_dev, y_dev, writer=dev_summary_writer)\n",
    "                print(\"\")\n",
    "            if current_step % FLAGS.checkpoint_every == 0:\n",
    "                path = saver.save(sess, checkpoint_prefix, global_step=current_step)\n",
    "                print(\"Saved model checkpoint to {}\\n\".format(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}