text_network.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class TextNetwork(object):\n",
" \"\"\"\n",
" Convolutional Neural Network (CNN) for text classification\n",
" Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.\n",
" \"\"\"\n",
" def __init__(\n",
" self, sequence_length, num_classes, vocab_size,\n",
" embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):\n",
"\n",
" # Placeholders for input, output and dropout\n",
" self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name=\"input_x\")\n",
" self.input_y = tf.placeholder(tf.float32, [None, num_classes], name=\"input_y\")\n",
" self.dropout_keep_prob = tf.placeholder(tf.float32, name=\"dropout_keep_prob\")\n",
"\n",
" # Keeping track of l2 regularization loss (optional)\n",
" l2_loss = tf.constant(0.0)\n",
"\n",
" # Embedding layer\n",
" with tf.device('/cpu:0'), tf.name_scope(\"embedding\"):\n",
" W = tf.Variable(\n",
" tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),\n",
" name=\"W\")\n",
" self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)\n",
" self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)\n",
"\n",
" # Create a convolution + maxpool layer for each filter size\n",
" pooled_outputs = []\n",
" for i, filter_size in enumerate(filter_sizes):\n",
" with tf.name_scope(\"conv-maxpool-%s\" % filter_size):\n",
" # Convolution Layer\n",
" filter_shape = [filter_size, embedding_size, 1, num_filters]\n",
" W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name=\"W\")\n",
" b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name=\"b\")\n",
" conv = tf.nn.conv2d(\n",
" self.embedded_chars_expanded,\n",
" W,\n",
" strides=[1, 1, 1, 1],\n",
" padding=\"VALID\",\n",
" name=\"conv\")\n",
" # Apply nonlinearity\n",
" h = tf.nn.relu(tf.nn.bias_add(conv, b), name=\"relu\")\n",
" # Maxpooling over the outputs\n",
" pooled = tf.nn.max_pool(\n",
" h,\n",
" ksize=[1, sequence_length - filter_size + 1, 1, 1],\n",
" strides=[1, 1, 1, 1],\n",
" padding='VALID',\n",
" name=\"pool\")\n",
" pooled_outputs.append(pooled)\n",
"\n",
" # Combine all the pooled features\n",
" num_filters_total = num_filters * len(filter_sizes)\n",
" self.h_pool = tf.concat(3, pooled_outputs)\n",
" self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])\n",
"\n",
" # Add dropout\n",
" with tf.name_scope(\"dropout\"):\n",
" self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)\n",
"\n",
" # Final (unnormalized) scores and predictions\n",
" with tf.name_scope(\"output\"):\n",
" W = tf.get_variable(\n",
" \"W\",\n",
" shape=[num_filters_total, num_classes],\n",
" initializer=tf.contrib.layers.xavier_initializer())\n",
" b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name=\"b\")\n",
" l2_loss += tf.nn.l2_loss(W)\n",
" l2_loss += tf.nn.l2_loss(b)\n",
" self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name=\"scores\")\n",
" self.predictions = tf.argmax(self.scores, 1, name=\"predictions\")\n",
"\n",
" # CalculateMean cross-entropy loss\n",
" with tf.name_scope(\"loss\"):\n",
" losses = tf.nn.softmax_cross_entropy_with_logits(self.scores, self.input_y)\n",
" self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss\n",
"\n",
" # Accuracy\n",
" with tf.name_scope(\"accuracy\"):\n",
" correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))\n",
" self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, \"float\"), name=\"accuracy\")"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [Root]",
"language": "python",
"name": "Python [Root]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}