thundergolfer/text-classify-with-cnn

View on GitHub
experiment.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use this file to try out the CNN Text Classifying Network on sentences outside the dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "import datetime\n",
    "import manage_data\n",
    "from text_network import TextNetwork\n",
    "from tensorflow.contrib import learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Eval Parameters\n",
    "tf.flags.DEFINE_integer(\"batch_size\", 64, \"Batch Size (default: 64)\")\n",
    "tf.flags.DEFINE_string(\"checkpoint_dir\", \"\", \"Checkpoint directory from training run\")\n",
    "tf.flags.DEFINE_boolean(\"eval_train\", False, \"Evaluate on all training data\")\n",
    "\n",
    "# Misc Parameters\n",
    "tf.flags.DEFINE_boolean(\"allow_soft_placement\", True, \"Allow device soft device placement\")\n",
    "tf.flags.DEFINE_boolean(\"log_device_placement\", False, \"Log placement of ops on devices\")\n",
    "\n",
    "\n",
    "FLAGS = tf.flags.FLAGS\n",
    "FLAGS._parse__flags()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO: Refactor this to predict only a single case\n",
    "\n",
    "\n",
    "test_sentence = \"This is a test sentence!\" # Our x data\n",
    "x_raw = test_sentence \n",
    "\n",
    "# Map data into vocabulary \n",
    "vocab_path = os.path.join(FLAGS.checkpoint_dir, \"..\", \"vocab\")\n",
    "vocab_processor = learn.processing.VocabularyProcessor.restore(vocab_path)\n",
    "\n",
    "print(\"\\nEvaluating your sentence...\\n\")\n",
    "\n",
    "checkpoint_file = tf.train.lastest_checkpoint(FLAGS.checkpoint_dir)\n",
    "graph = tf.Graph()\n",
    "\n",
    "with graph.as_default():\n",
    "    session_conf = tf.ConfigProto(\n",
    "      allow_soft_placement=FLAGS.allow_soft_placement,\n",
    "      log_device_placement=FLAGS.log_device_placement)\n",
    "    sess = tf.Session(config=session_conf)\n",
    "    with sess.as_default():\n",
    "        # Load the saved meta graph and restore variables\n",
    "        saver = tf.train.import_meta_graph(\"{}.meta\".format(checkpoint_file))\n",
    "        saver.restore(sess, checkpoint_file)\n",
    "\n",
    "        # Get the placeholders from the graph by name\n",
    "        input_x = graph.get_operation_by_name(\"input_x\").outputs[0]\n",
    "        # input_y = graph.get_operation_by_name(\"input_y\").outputs[0]\n",
    "        dropout_keep_prob = graph.get_operation_by_name(\"dropout_keep_prob\").outputs[0]\n",
    "\n",
    "        # Tensors we want to evaluate\n",
    "        predictions = graph.get_operation_by_name(\"output/predictions\").outputs[0]\n",
    "\n",
    "        # Generate batches for one epoch\n",
    "        batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)\n",
    "\n",
    "        # Collect the predictions here\n",
    "        all_predictions = []\n",
    "\n",
    "        for x_test_batch in batches:\n",
    "            batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})\n",
    "            all_predictions = np.concatenate([all_predictions, batch_predictions])\n",
    "\n",
    "def run_network_on_sentence( sent ):\n",
    "    raise NotImplementedError \n",
    "    \n",
    "\n",
    "def classification_report( sent ):\n",
    "    raise NotImplementedError "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}