train.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import os\n",
"import time\n",
"import datetime\n",
"import manage_data\n",
"from text_network import TextNetwork\n",
"from tensorflow.contrib import learn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set Parameters"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Parameters:\n",
"ALLOW_SOFT_PLACEMENT=True\n",
"BATCH_SIZE=64\n",
"CHECKPOINT_EVERY=100\n",
"DROPOUT_KEEP_PROB=0.5\n",
"EMBEDDING_DIM=128\n",
"EVALUATE_EVERY=100\n",
"FILTER_SIZES=3,4,5\n",
"L2_REG_LAMBDA=0.0\n",
"LOG_DEVICE_PLACEMENT=False\n",
"NUM_EPOCHS=200\n",
"NUM_FILTERS=128\n",
"\n"
]
}
],
"source": [
"\n",
"\n",
"# Model Hyperparameters\n",
"tf.flags.DEFINE_integer(\"embedding_dim\", 128, \"Dimensionality of character embedding (default: 128)\")\n",
"tf.flags.DEFINE_string(\"filter_sizes\", \"3,4,5\", \"Comma-separated filter sizes (default: '3,4,5')\")\n",
"tf.flags.DEFINE_integer(\"num_filters\", 128, \"Number of filters per filter size (default: 128)\")\n",
"tf.flags.DEFINE_float(\"dropout_keep_prob\", 0.5, \"Dropout keep probability (default: 0.5)\")\n",
"tf.flags.DEFINE_float(\"l2_reg_lambda\", 0.0, \"L2 regularizaion lambda (default: 0.0)\")\n",
"\n",
"# Training parameters\n",
"tf.flags.DEFINE_integer(\"batch_size\", 64, \"Batch Size (default: 64)\")\n",
"tf.flags.DEFINE_integer(\"num_epochs\", 200, \"Number of training epochs (default: 200)\")\n",
"tf.flags.DEFINE_integer(\"evaluate_every\", 100, \"Evaluate model on dev set after this many steps (default: 100)\")\n",
"tf.flags.DEFINE_integer(\"checkpoint_every\", 100, \"Save model after this many steps (default: 100)\")\n",
"# Misc Parameters\n",
"tf.flags.DEFINE_boolean(\"allow_soft_placement\", True, \"Allow device soft device placement\")\n",
"tf.flags.DEFINE_boolean(\"log_device_placement\", False, \"Log placement of ops on devices\")\n",
"\n",
"FLAGS = tf.flags.FLAGS\n",
"FLAGS._parse_flags()\n",
"print(\"\\nParameters:\")\n",
"for attr, value in sorted(FLAGS.__flags.items()):\n",
" print(\"{}={}\".format(attr.upper(), value))\n",
"print(\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Preparation "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data...\n",
"Vocabulary Size: 18758\n",
"Train/Dev split: 9662/1000\n"
]
}
],
"source": [
"# Load data\n",
"print(\"Loading data...\")\n",
"x_text, y = manage_data.load_data_and_labels()\n",
"\n",
"# Build vocabulary\n",
"max_document_length = max([len(x.split(\" \")) for x in x_text])\n",
"vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)\n",
"x = np.array(list(vocab_processor.fit_transform(x_text)))\n",
"\n",
"# Randomly shuffle data\n",
"np.random.seed(10)\n",
"shuffle_indices = np.random.permutation(np.arange(len(y)))\n",
"x_shuffled = x[shuffle_indices]\n",
"y_shuffled = y[shuffle_indices]\n",
"\n",
"# Split train/test set\n",
"# TODO: This is very crude, should use cross-validation\n",
"x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]\n",
"y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]\n",
"print(\"Vocabulary Size: {:d}\".format(len(vocab_processor.vocabulary_)))\n",
"print(\"Train/Dev split: {:d}/{:d}\".format(len(y_train), len(y_dev)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training The Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing to /home/jono/Github/text-classify-with-cnn/runs/1476486719\n",
"\n",
"2016-10-15T10:12:00.872912: step 1, loss 2.17834, acc 0.515625\n",
"2016-10-15T10:12:01.104044: step 2, loss 2.37684, acc 0.4375\n",
"2016-10-15T10:12:01.345502: step 3, loss 1.61573, acc 0.53125\n",
"2016-10-15T10:12:01.568742: step 4, loss 2.46489, acc 0.5625\n",
"2016-10-15T10:12:01.798536: step 5, loss 1.50424, acc 0.609375\n",
"2016-10-15T10:12:02.015717: step 6, loss 1.67856, acc 0.484375\n",
"2016-10-15T10:12:02.235344: step 7, loss 1.87402, acc 0.546875\n",
"2016-10-15T10:12:02.485685: step 8, loss 1.75293, acc 0.609375\n",
"2016-10-15T10:12:02.733335: step 9, loss 1.60018, acc 0.578125\n",
"2016-10-15T10:12:02.952329: step 10, loss 1.65958, acc 0.53125\n",
"2016-10-15T10:12:03.176495: step 11, loss 1.48052, acc 0.546875\n",
"2016-10-15T10:12:03.406426: step 12, loss 2.32531, acc 0.46875\n",
"2016-10-15T10:12:03.640859: step 13, loss 2.31976, acc 0.484375\n",
"2016-10-15T10:12:03.869680: step 14, loss 2.01, acc 0.46875\n",
"2016-10-15T10:12:04.095002: step 15, loss 1.75927, acc 0.5\n",
"2016-10-15T10:12:04.340691: step 16, loss 1.70119, acc 0.484375\n",
"2016-10-15T10:12:04.563084: step 17, loss 1.6076, acc 0.53125\n",
"2016-10-15T10:12:04.785239: step 18, loss 1.20017, acc 0.640625\n",
"2016-10-15T10:12:05.012354: step 19, loss 1.50494, acc 0.46875\n",
"2016-10-15T10:12:05.242894: step 20, loss 1.65939, acc 0.578125\n",
"2016-10-15T10:12:05.499912: step 21, loss 1.40056, acc 0.5625\n",
"2016-10-15T10:12:05.738980: step 22, loss 1.36007, acc 0.59375\n",
"2016-10-15T10:12:05.967967: step 23, loss 1.82848, acc 0.53125\n",
"2016-10-15T10:12:06.210751: step 24, loss 2.1214, acc 0.484375\n",
"2016-10-15T10:12:06.445041: step 25, loss 1.45436, acc 0.546875\n",
"2016-10-15T10:12:06.673484: step 26, loss 1.73738, acc 0.5\n",
"2016-10-15T10:12:06.910386: step 27, loss 1.68768, acc 0.5625\n",
"2016-10-15T10:12:07.147864: step 28, loss 1.30136, acc 0.5625\n",
"2016-10-15T10:12:07.367672: step 29, loss 1.35643, acc 0.5625\n",
"2016-10-15T10:12:07.596474: step 30, loss 1.83106, acc 0.4375\n",
"2016-10-15T10:12:07.823037: step 31, loss 1.86499, acc 0.578125\n",
"2016-10-15T10:12:08.045291: step 32, loss 1.87879, acc 0.46875\n",
"2016-10-15T10:12:08.267881: step 33, loss 1.88167, acc 0.484375\n",
"2016-10-15T10:12:08.485892: step 34, loss 1.20186, acc 0.625\n",
"2016-10-15T10:12:08.708856: step 35, loss 1.98814, acc 0.4375\n",
"2016-10-15T10:12:08.940056: step 36, loss 1.56428, acc 0.59375\n",
"2016-10-15T10:12:09.166673: step 37, loss 2.02957, acc 0.46875\n",
"2016-10-15T10:12:09.384138: step 38, loss 1.92991, acc 0.46875\n",
"2016-10-15T10:12:09.602507: step 39, loss 1.38488, acc 0.640625\n",
"2016-10-15T10:12:09.822334: step 40, loss 1.83013, acc 0.546875\n",
"2016-10-15T10:12:10.041669: step 41, loss 1.51196, acc 0.625\n",
"2016-10-15T10:12:10.262227: step 42, loss 1.04271, acc 0.59375\n",
"2016-10-15T10:12:10.481477: step 43, loss 1.49973, acc 0.515625\n",
"2016-10-15T10:12:10.701666: step 44, loss 1.49, acc 0.609375\n",
"2016-10-15T10:12:10.923002: step 45, loss 1.44082, acc 0.53125\n",
"2016-10-15T10:12:11.141560: step 46, loss 1.43061, acc 0.484375\n",
"2016-10-15T10:12:11.361128: step 47, loss 1.81411, acc 0.515625\n",
"2016-10-15T10:12:11.581722: step 48, loss 1.23713, acc 0.546875\n",
"2016-10-15T10:12:11.796604: step 49, loss 1.5056, acc 0.5625\n",
"2016-10-15T10:12:12.019603: step 50, loss 1.23, acc 0.625\n",
"2016-10-15T10:12:12.239029: step 51, loss 1.80079, acc 0.46875\n",
"2016-10-15T10:12:12.466978: step 52, loss 2.44691, acc 0.421875\n",
"2016-10-15T10:12:12.691878: step 53, loss 2.14964, acc 0.46875\n",
"2016-10-15T10:12:12.911293: step 54, loss 1.3923, acc 0.546875\n",
"2016-10-15T10:12:13.133364: step 55, loss 1.65153, acc 0.515625\n",
"2016-10-15T10:12:13.351238: step 56, loss 1.46936, acc 0.515625\n",
"2016-10-15T10:12:13.572986: step 57, loss 1.29942, acc 0.578125\n",
"2016-10-15T10:12:13.786631: step 58, loss 1.8689, acc 0.46875\n",
"2016-10-15T10:12:14.002450: step 59, loss 1.28186, acc 0.5625\n",
"2016-10-15T10:12:14.228103: step 60, loss 1.4272, acc 0.515625\n",
"2016-10-15T10:12:14.449732: step 61, loss 1.38579, acc 0.546875\n",
"2016-10-15T10:12:14.672107: step 62, loss 1.54824, acc 0.484375\n",
"2016-10-15T10:12:14.889157: step 63, loss 1.71746, acc 0.53125\n",
"2016-10-15T10:12:15.108954: step 64, loss 1.89061, acc 0.40625\n",
"2016-10-15T10:12:15.329058: step 65, loss 1.14454, acc 0.625\n",
"2016-10-15T10:12:15.549783: step 66, loss 1.72255, acc 0.53125\n",
"2016-10-15T10:12:15.769959: step 67, loss 1.08323, acc 0.59375\n",
"2016-10-15T10:12:15.988682: step 68, loss 1.28509, acc 0.65625\n",
"2016-10-15T10:12:16.208728: step 69, loss 1.56131, acc 0.515625\n",
"2016-10-15T10:12:16.430192: step 70, loss 1.59217, acc 0.515625\n",
"2016-10-15T10:12:16.656004: step 71, loss 1.62195, acc 0.515625\n",
"2016-10-15T10:12:16.872326: step 72, loss 1.50321, acc 0.546875\n",
"2016-10-15T10:12:17.091269: step 73, loss 1.29522, acc 0.546875\n",
"2016-10-15T10:12:17.309066: step 74, loss 1.1761, acc 0.625\n",
"2016-10-15T10:12:17.527164: step 75, loss 1.27479, acc 0.515625\n",
"2016-10-15T10:12:17.749515: step 76, loss 1.91551, acc 0.46875\n",
"2016-10-15T10:12:17.970980: step 77, loss 1.26546, acc 0.5\n",
"2016-10-15T10:12:18.189730: step 78, loss 1.12266, acc 0.65625\n",
"2016-10-15T10:12:18.407322: step 79, loss 1.9547, acc 0.484375\n",
"2016-10-15T10:12:18.627851: step 80, loss 1.00029, acc 0.546875\n",
"2016-10-15T10:12:18.847346: step 81, loss 1.688, acc 0.546875\n",
"2016-10-15T10:12:19.065755: step 82, loss 1.54275, acc 0.453125\n",
"2016-10-15T10:12:19.285684: step 83, loss 1.77438, acc 0.515625\n",
"2016-10-15T10:12:19.506882: step 84, loss 1.24367, acc 0.53125\n",
"2016-10-15T10:12:19.732368: step 85, loss 1.42807, acc 0.5625\n",
"2016-10-15T10:12:19.950892: step 86, loss 1.46412, acc 0.59375\n",
"2016-10-15T10:12:20.172392: step 87, loss 1.45359, acc 0.5625\n",
"2016-10-15T10:12:20.388848: step 88, loss 1.70674, acc 0.4375\n",
"2016-10-15T10:12:20.610542: step 89, loss 1.4345, acc 0.53125\n",
"2016-10-15T10:12:20.833926: step 90, loss 1.31255, acc 0.578125\n",
"2016-10-15T10:12:21.053368: step 91, loss 1.07133, acc 0.640625\n",
"2016-10-15T10:12:21.285574: step 92, loss 1.20513, acc 0.515625\n",
"2016-10-15T10:12:21.501478: step 93, loss 1.60184, acc 0.515625\n",
"2016-10-15T10:12:21.722405: step 94, loss 1.26125, acc 0.5625\n",
"2016-10-15T10:12:21.937582: step 95, loss 1.24756, acc 0.546875\n",
"2016-10-15T10:12:22.157988: step 96, loss 1.07844, acc 0.625\n",
"2016-10-15T10:12:22.375592: step 97, loss 1.58785, acc 0.5\n",
"2016-10-15T10:12:22.608746: step 98, loss 1.69374, acc 0.40625\n",
"2016-10-15T10:12:22.831462: step 99, loss 1.42811, acc 0.484375\n",
"2016-10-15T10:12:23.053271: step 100, loss 1.05428, acc 0.671875\n",
"\n",
"Evaluation:\n",
"2016-10-15T10:12:23.856138: step 100, loss 0.719137, acc 0.594\n",
"\n",
"Saved model checkpoint to /home/jono/Github/text-classify-with-cnn/runs/1476486719/checkpoints/model-100\n",
"\n",
"2016-10-15T10:12:24.657749: step 101, loss 1.60184, acc 0.515625\n",
"2016-10-15T10:12:24.877484: step 102, loss 1.29558, acc 0.484375\n",
"2016-10-15T10:12:25.092704: step 103, loss 1.22662, acc 0.53125\n",
"2016-10-15T10:12:25.309380: step 104, loss 1.99945, acc 0.484375\n",
"2016-10-15T10:12:25.528338: step 105, loss 1.24838, acc 0.5\n",
"2016-10-15T10:12:25.742480: step 106, loss 1.23168, acc 0.546875\n",
"2016-10-15T10:12:25.970456: step 107, loss 1.16021, acc 0.53125\n",
"2016-10-15T10:12:26.185724: step 108, loss 1.0626, acc 0.578125\n",
"2016-10-15T10:12:26.402798: step 109, loss 1.36584, acc 0.53125\n",
"2016-10-15T10:12:26.619056: step 110, loss 1.31638, acc 0.5\n",
"2016-10-15T10:12:26.834185: step 111, loss 1.40253, acc 0.515625\n",
"2016-10-15T10:12:27.056232: step 112, loss 1.24609, acc 0.484375\n",
"2016-10-15T10:12:27.271940: step 113, loss 0.95558, acc 0.6875\n",
"2016-10-15T10:12:27.490441: step 114, loss 1.20972, acc 0.5\n",
"2016-10-15T10:12:27.716505: step 115, loss 0.850529, acc 0.671875\n",
"2016-10-15T10:12:27.933472: step 116, loss 1.21025, acc 0.546875\n",
"2016-10-15T10:12:28.152836: step 117, loss 1.57255, acc 0.390625\n",
"2016-10-15T10:12:28.374061: step 118, loss 1.14735, acc 0.515625\n",
"2016-10-15T10:12:28.589133: step 119, loss 1.11571, acc 0.671875\n",
"2016-10-15T10:12:28.807469: step 120, loss 1.24635, acc 0.546875\n",
"2016-10-15T10:12:29.026163: step 121, loss 1.72968, acc 0.515625\n",
"2016-10-15T10:12:29.243407: step 122, loss 1.31092, acc 0.5\n",
"2016-10-15T10:12:29.460504: step 123, loss 1.85621, acc 0.4375\n",
"2016-10-15T10:12:29.695413: step 124, loss 1.06487, acc 0.578125\n",
"2016-10-15T10:12:29.911933: step 125, loss 1.7949, acc 0.46875\n",
"2016-10-15T10:12:30.131825: step 126, loss 1.11409, acc 0.53125\n",
"2016-10-15T10:12:30.346780: step 127, loss 1.01377, acc 0.640625\n",
"2016-10-15T10:12:30.568089: step 128, loss 1.11762, acc 0.609375\n",
"2016-10-15T10:12:30.785175: step 129, loss 1.37035, acc 0.4375\n",
"2016-10-15T10:12:30.999227: step 130, loss 1.02695, acc 0.625\n",
"2016-10-15T10:12:31.215286: step 131, loss 1.51713, acc 0.578125\n",
"2016-10-15T10:12:31.448318: step 132, loss 1.59183, acc 0.4375\n",
"2016-10-15T10:12:31.666681: step 133, loss 1.29736, acc 0.578125\n",
"2016-10-15T10:12:31.889705: step 134, loss 1.10579, acc 0.65625\n",
"2016-10-15T10:12:32.104027: step 135, loss 1.36328, acc 0.453125\n",
"2016-10-15T10:12:32.325099: step 136, loss 0.856365, acc 0.6875\n",
"2016-10-15T10:12:32.544049: step 137, loss 1.03201, acc 0.5625\n",
"2016-10-15T10:12:32.761233: step 138, loss 1.73107, acc 0.421875\n",
"2016-10-15T10:12:32.984635: step 139, loss 1.60015, acc 0.53125\n",
"2016-10-15T10:12:33.209024: step 140, loss 1.04512, acc 0.609375\n",
"2016-10-15T10:12:33.426916: step 141, loss 1.34894, acc 0.5\n",
"2016-10-15T10:12:33.646354: step 142, loss 0.949152, acc 0.625\n",
"2016-10-15T10:12:33.861454: step 143, loss 1.18836, acc 0.53125\n",
"2016-10-15T10:12:34.082681: step 144, loss 1.00057, acc 0.5625\n",
"2016-10-15T10:12:34.300121: step 145, loss 1.17293, acc 0.609375\n",
"2016-10-15T10:12:34.516054: step 146, loss 1.43594, acc 0.515625\n",
"2016-10-15T10:12:34.732748: step 147, loss 1.33237, acc 0.484375\n",
"2016-10-15T10:12:34.948271: step 148, loss 1.29256, acc 0.46875\n",
"2016-10-15T10:12:35.165617: step 149, loss 1.20981, acc 0.546875\n",
"2016-10-15T10:12:35.386875: step 150, loss 1.11616, acc 0.5625\n",
"2016-10-15T10:12:35.604839: step 151, loss 1.01066, acc 0.532258\n",
"2016-10-15T10:12:35.823732: step 152, loss 1.10394, acc 0.5625\n",
"2016-10-15T10:12:36.042787: step 153, loss 0.749592, acc 0.703125\n",
"2016-10-15T10:12:36.259079: step 154, loss 1.08546, acc 0.59375\n",
"2016-10-15T10:12:36.479301: step 155, loss 0.848611, acc 0.671875\n",
"2016-10-15T10:12:36.699000: step 156, loss 1.18111, acc 0.53125\n",
"2016-10-15T10:12:36.915655: step 157, loss 0.903777, acc 0.703125\n",
"2016-10-15T10:12:37.133633: step 158, loss 0.770446, acc 0.75\n",
"2016-10-15T10:12:37.350593: step 159, loss 1.01564, acc 0.65625\n",
"2016-10-15T10:12:37.574548: step 160, loss 1.13375, acc 0.578125\n",
"2016-10-15T10:12:37.791948: step 161, loss 1.03907, acc 0.515625\n",
"2016-10-15T10:12:38.012848: step 162, loss 0.972708, acc 0.609375\n",
"2016-10-15T10:12:38.228790: step 163, loss 0.897121, acc 0.671875\n",
"2016-10-15T10:12:38.446337: step 164, loss 0.875331, acc 0.6875\n",
"2016-10-15T10:12:38.665272: step 165, loss 1.39326, acc 0.46875\n",
"2016-10-15T10:12:38.883191: step 166, loss 0.984082, acc 0.5625\n"
]
}
],
"source": [
"with tf.Graph().as_default():\n",
" session_conf = tf.ConfigProto(\n",
" allow_soft_placement=FLAGS.allow_soft_placement,\n",
" log_device_placement=FLAGS.log_device_placement)\n",
" sess = tf.Session(config=session_conf)\n",
" with sess.as_default():\n",
" cnn = TextNetwork(\n",
" sequence_length=x_train.shape[1],\n",
" num_classes=y_train.shape[1],\n",
" vocab_size=len(vocab_processor.vocabulary_),\n",
" embedding_size=FLAGS.embedding_dim,\n",
" filter_sizes=list(map(int, FLAGS.filter_sizes.split(\",\"))),\n",
" num_filters=FLAGS.num_filters,\n",
" l2_reg_lambda=FLAGS.l2_reg_lambda)\n",
"\n",
" # Define Training procedure\n",
" global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n",
" optimizer = tf.train.AdamOptimizer(1e-3)\n",
" grads_and_vars = optimizer.compute_gradients(cnn.loss)\n",
" train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)\n",
"\n",
" # Keep track of gradient values and sparsity (optional)\n",
" grad_summaries = []\n",
" for g, v in grads_and_vars:\n",
" if g is not None:\n",
" grad_hist_summary = tf.histogram_summary(\"{}/grad/hist\".format(v.name), g)\n",
" sparsity_summary = tf.scalar_summary(\"{}/grad/sparsity\".format(v.name), tf.nn.zero_fraction(g))\n",
" grad_summaries.append(grad_hist_summary)\n",
" grad_summaries.append(sparsity_summary)\n",
" grad_summaries_merged = tf.merge_summary(grad_summaries)\n",
"\n",
" # Output directory for models and summaries\n",
" timestamp = str(int(time.time()))\n",
" out_dir = os.path.abspath(os.path.join(os.path.curdir, \"runs\", timestamp))\n",
" print(\"Writing to {}\\n\".format(out_dir))\n",
"\n",
" # Summaries for loss and accuracy\n",
" loss_summary = tf.scalar_summary(\"loss\", cnn.loss)\n",
" acc_summary = tf.scalar_summary(\"accuracy\", cnn.accuracy)\n",
"\n",
" # Train Summaries\n",
" train_summary_op = tf.merge_summary([loss_summary, acc_summary, grad_summaries_merged])\n",
" train_summary_dir = os.path.join(out_dir, \"summaries\", \"train\")\n",
" train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph)\n",
"\n",
" # Dev summaries\n",
" dev_summary_op = tf.merge_summary([loss_summary, acc_summary])\n",
" dev_summary_dir = os.path.join(out_dir, \"summaries\", \"dev\")\n",
" dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph)\n",
"\n",
" # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it\n",
" checkpoint_dir = os.path.abspath(os.path.join(out_dir, \"checkpoints\"))\n",
" checkpoint_prefix = os.path.join(checkpoint_dir, \"model\")\n",
" if not os.path.exists(checkpoint_dir):\n",
" os.makedirs(checkpoint_dir)\n",
" saver = tf.train.Saver(tf.all_variables())\n",
"\n",
" # Write vocabulary\n",
" vocab_processor.save(os.path.join(out_dir, \"vocab\"))\n",
"\n",
" # Initialize all variables\n",
" sess.run(tf.initialize_all_variables())\n",
"\n",
" def train_step(x_batch, y_batch):\n",
" \"\"\"\n",
" A single training step\n",
" \"\"\"\n",
" feed_dict = {\n",
" cnn.input_x: x_batch,\n",
" cnn.input_y: y_batch,\n",
" cnn.dropout_keep_prob: FLAGS.dropout_keep_prob\n",
" }\n",
" _, step, summaries, loss, accuracy = sess.run(\n",
" [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],\n",
" feed_dict)\n",
" time_str = datetime.datetime.now().isoformat()\n",
" print(\"{}: step {}, loss {:g}, acc {:g}\".format(time_str, step, loss, accuracy))\n",
" train_summary_writer.add_summary(summaries, step)\n",
"\n",
" def dev_step(x_batch, y_batch, writer=None):\n",
" \"\"\"\n",
" Evaluates model on a dev set\n",
" \"\"\"\n",
" feed_dict = {\n",
" cnn.input_x: x_batch,\n",
" cnn.input_y: y_batch,\n",
" cnn.dropout_keep_prob: 1.0\n",
" }\n",
" step, summaries, loss, accuracy = sess.run(\n",
" [global_step, dev_summary_op, cnn.loss, cnn.accuracy],\n",
" feed_dict)\n",
" time_str = datetime.datetime.now().isoformat()\n",
" print(\"{}: step {}, loss {:g}, acc {:g}\".format(time_str, step, loss, accuracy))\n",
" if writer:\n",
" writer.add_summary(summaries, step)\n",
"\n",
" # Generate batches\n",
" batches = manage_data.batch_iter(\n",
" list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)\n",
" # Training loop. For each batch...\n",
" for batch in batches:\n",
" x_batch, y_batch = zip(*batch)\n",
" train_step(x_batch, y_batch)\n",
" current_step = tf.train.global_step(sess, global_step)\n",
" if current_step % FLAGS.evaluate_every == 0:\n",
" print(\"\\nEvaluation:\")\n",
" dev_step(x_dev, y_dev, writer=dev_summary_writer)\n",
" print(\"\")\n",
" if current_step % FLAGS.checkpoint_every == 0:\n",
" path = saver.save(sess, checkpoint_prefix, global_step=current_step)\n",
" print(\"Saved model checkpoint to {}\\n\".format(path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [Root]",
"language": "python",
"name": "Python [Root]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}