research/slim/nets/mobilenet_v1_train.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build and train mobilenet_v1 with options for quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
from tensorflow.contrib import quantize as contrib_quantize
from datasets import dataset_factory
from nets import mobilenet_v1
from preprocessing import preprocessing_factory
flags = tf.app.flags
flags.DEFINE_string('master', '', 'Session master')
flags.DEFINE_integer('task', 0, 'Task')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps')
flags.DEFINE_integer('batch_size', 64, 'Batch size')
flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish')
flags.DEFINE_integer('number_of_steps', None,
'Number of training steps to perform before stopping')
flags.DEFINE_integer('image_size', 224, 'Input image resolution')
flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet')
flags.DEFINE_bool('quantize', False, 'Quantize training')
flags.DEFINE_string('fine_tune_checkpoint', '',
'Checkpoint from which to start finetuning.')
flags.DEFINE_string('checkpoint_dir', '',
'Directory for writing training checkpoints and logs')
flags.DEFINE_string('dataset_dir', '', 'Location of dataset')
flags.DEFINE_integer('log_every_n_steps', 100, 'Number of steps per log')
flags.DEFINE_integer('save_summaries_secs', 100,
'How often to save summaries, secs')
flags.DEFINE_integer('save_interval_secs', 100,
'How often to save checkpoints, secs')
FLAGS = flags.FLAGS
_LEARNING_RATE_DECAY_FACTOR = 0.94
def get_learning_rate():
if FLAGS.fine_tune_checkpoint:
# If we are fine tuning a checkpoint we need to start at a lower learning
# rate since we are farther along on training.
return 1e-4
else:
return 0.045
def get_quant_delay():
if FLAGS.fine_tune_checkpoint:
# We can start quantizing immediately if we are finetuning.
return 0
else:
# We need to wait for the model to train a bit before we quantize if we are
# training from scratch.
return 250000
def imagenet_input(is_training):
"""Data reader for imagenet.
Reads in imagenet data and performs pre-processing on the images.
Args:
is_training: bool specifying if train or validation dataset is needed.
Returns:
A batch of images and labels.
"""
if is_training:
dataset = dataset_factory.get_dataset('imagenet', 'train',
FLAGS.dataset_dir)
else:
dataset = dataset_factory.get_dataset('imagenet', 'validation',
FLAGS.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
'mobilenet_v1', is_training=is_training)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch([image, label],
batch_size=FLAGS.batch_size,
num_threads=4,
capacity=5 * FLAGS.batch_size)
labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
return images, labels
def build_model():
"""Builds graph for model to train with rewrites for quantization.
Returns:
g: Graph with fake quantization ops and batch norm folding suitable for
training quantized weights.
train_tensor: Train op for execution during training.
"""
g = tf.Graph()
with g.as_default(), tf.device(
tf.train.replica_device_setter(FLAGS.ps_tasks)):
inputs, labels = imagenet_input(is_training=True)
with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):
logits, _ = mobilenet_v1.mobilenet_v1(
inputs,
is_training=True,
depth_multiplier=FLAGS.depth_multiplier,
num_classes=FLAGS.num_classes)
tf.losses.softmax_cross_entropy(labels, logits)
# Call rewriter to produce graph with fake quant ops and folded batch norms
# quant_delay delays start of quantization till quant_delay steps, allowing
# for better model accuracy.
if FLAGS.quantize:
contrib_quantize.create_training_graph(quant_delay=get_quant_delay())
total_loss = tf.losses.get_total_loss(name='total_loss')
# Configure the learning rate using an exponential decay.
num_epochs_per_decay = 2.5
imagenet_size = 1271167
decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay)
learning_rate = tf.train.exponential_decay(
get_learning_rate(),
tf.train.get_or_create_global_step(),
decay_steps,
_LEARNING_RATE_DECAY_FACTOR,
staircase=True)
opt = tf.train.GradientDescentOptimizer(learning_rate)
train_tensor = slim.learning.create_train_op(
total_loss,
optimizer=opt)
slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training')
return g, train_tensor
def get_checkpoint_init_fn():
"""Returns the checkpoint init_fn if the checkpoint is provided."""
if FLAGS.fine_tune_checkpoint:
variables_to_restore = slim.get_variables_to_restore()
global_step_reset = tf.assign(
tf.train.get_or_create_global_step(), 0)
# When restoring from a floating point model, the min/max values for
# quantized weights and activations are not present.
# We instruct slim to ignore variables that are missing during restoration
# by setting ignore_missing_vars=True
slim_init_fn = slim.assign_from_checkpoint_fn(
FLAGS.fine_tune_checkpoint,
variables_to_restore,
ignore_missing_vars=True)
def init_fn(sess):
slim_init_fn(sess)
# If we are restoring from a floating point model, we need to initialize
# the global step to zero for the exponential decay to result in
# reasonable learning rates.
sess.run(global_step_reset)
return init_fn
else:
return None
def train_model():
"""Trains mobilenet_v1."""
g, train_tensor = build_model()
with g.as_default():
slim.learning.train(
train_tensor,
FLAGS.checkpoint_dir,
is_chief=(FLAGS.task == 0),
master=FLAGS.master,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
number_of_steps=FLAGS.number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
init_fn=get_checkpoint_init_fn(),
global_step=tf.train.get_global_step())
def main(unused_arg):
train_model()
if __name__ == '__main__':
tf.app.run(main)