research/slim/nets/nasnet/nasnet_utils.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A custom module for some common operations used by NASNet.
Functions exposed in this file:
- calc_reduction_layers
- get_channel_index
- get_channel_dim
- global_avg_pool
- factorized_reduction
- drop_path
Classes exposed in this file:
- NasNetABaseCell
- NasNetANormalCell
- NasNetAReductionCell
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
arg_scope = slim.arg_scope
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
INVALID = 'null'
# The cap for tf.clip_by_value, it's hinted from the activation distribution
# that the majority of activation values are in the range [-6, 6].
CLIP_BY_VALUE_CAP = 6
def calc_reduction_layers(num_cells, num_reduction_layers):
"""Figure out what layers should have reductions."""
reduction_layers = []
for pool_num in range(1, num_reduction_layers + 1):
layer_num = (float(pool_num) / (num_reduction_layers + 1)) * num_cells
layer_num = int(layer_num)
reduction_layers.append(layer_num)
return reduction_layers
@slim.add_arg_scope
def get_channel_index(data_format=INVALID):
assert data_format != INVALID
axis = 3 if data_format == 'NHWC' else 1
return axis
@slim.add_arg_scope
def get_channel_dim(shape, data_format=INVALID):
assert data_format != INVALID
assert len(shape) == 4
if data_format == 'NHWC':
return int(shape[3])
elif data_format == 'NCHW':
return int(shape[1])
else:
raise ValueError('Not a valid data_format', data_format)
@slim.add_arg_scope
def global_avg_pool(x, data_format=INVALID):
"""Average pool away the height and width spatial dimensions of x."""
assert data_format != INVALID
assert data_format in ['NHWC', 'NCHW']
assert x.shape.ndims == 4
if data_format == 'NHWC':
return tf.reduce_mean(input_tensor=x, axis=[1, 2])
else:
return tf.reduce_mean(input_tensor=x, axis=[2, 3])
@slim.add_arg_scope
def factorized_reduction(net, output_filters, stride, data_format=INVALID):
"""Reduces the shape of net without information loss due to striding."""
assert data_format != INVALID
if stride == 1:
net = slim.conv2d(net, output_filters, 1, scope='path_conv')
net = slim.batch_norm(net, scope='path_bn')
return net
if data_format == 'NHWC':
stride_spec = [1, stride, stride, 1]
else:
stride_spec = [1, 1, stride, stride]
# Skip path 1
path1 = tf.nn.avg_pool2d(
net,
ksize=[1, 1, 1, 1],
strides=stride_spec,
padding='VALID',
data_format=data_format)
path1 = slim.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
# Skip path 2
# First pad with 0's on the right and bottom, then shift the filter to
# include those 0's that were added.
if data_format == 'NHWC':
pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
path2 = tf.pad(tensor=net, paddings=pad_arr)[:, 1:, 1:, :]
concat_axis = 3
else:
pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
path2 = tf.pad(tensor=net, paddings=pad_arr)[:, :, 1:, 1:]
concat_axis = 1
path2 = tf.nn.avg_pool2d(
path2,
ksize=[1, 1, 1, 1],
strides=stride_spec,
padding='VALID',
data_format=data_format)
# If odd number of filters, add an additional one to the second path.
final_filter_size = int(output_filters / 2) + int(output_filters % 2)
path2 = slim.conv2d(path2, final_filter_size, 1, scope='path2_conv')
# Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
final_path = slim.batch_norm(final_path, scope='final_path_bn')
return final_path
@slim.add_arg_scope
def drop_path(net, keep_prob, is_training=True):
"""Drops out a whole example hiddenstate with the specified probability."""
if is_training:
batch_size = tf.shape(input=net)[0]
noise_shape = [batch_size, 1, 1, 1]
random_tensor = keep_prob
random_tensor += tf.random.uniform(noise_shape, dtype=tf.float32)
binary_tensor = tf.cast(tf.floor(random_tensor), net.dtype)
keep_prob_inv = tf.cast(1.0 / keep_prob, net.dtype)
net = net * keep_prob_inv * binary_tensor
return net
def _operation_to_filter_shape(operation):
splitted_operation = operation.split('x')
filter_shape = int(splitted_operation[0][-1])
assert filter_shape == int(
splitted_operation[1][0]), 'Rectangular filters not supported.'
return filter_shape
def _operation_to_num_layers(operation):
splitted_operation = operation.split('_')
if 'x' in splitted_operation[-1]:
return 1
return int(splitted_operation[-1])
def _operation_to_info(operation):
"""Takes in operation name and returns meta information.
An example would be 'separable_3x3_4' -> (3, 4).
Args:
operation: String that corresponds to convolution operation.
Returns:
Tuple of (filter shape, num layers).
"""
num_layers = _operation_to_num_layers(operation)
filter_shape = _operation_to_filter_shape(operation)
return num_layers, filter_shape
def _stacked_separable_conv(net, stride, operation, filter_size,
use_bounded_activation):
"""Takes in an operations and parses it to the correct sep operation."""
num_layers, kernel_size = _operation_to_info(operation)
activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
for layer_num in range(num_layers - 1):
net = activation_fn(net)
net = slim.separable_conv2d(
net,
filter_size,
kernel_size,
depth_multiplier=1,
scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
stride=stride)
net = slim.batch_norm(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
stride = 1
net = activation_fn(net)
net = slim.separable_conv2d(
net,
filter_size,
kernel_size,
depth_multiplier=1,
scope='separable_{0}x{0}_{1}'.format(kernel_size, num_layers),
stride=stride)
net = slim.batch_norm(
net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, num_layers))
return net
def _operation_to_pooling_type(operation):
"""Takes in the operation string and returns the pooling type."""
splitted_operation = operation.split('_')
return splitted_operation[0]
def _operation_to_pooling_shape(operation):
"""Takes in the operation string and returns the pooling kernel shape."""
splitted_operation = operation.split('_')
shape = splitted_operation[-1]
assert 'x' in shape
filter_height, filter_width = shape.split('x')
assert filter_height == filter_width
return int(filter_height)
def _operation_to_pooling_info(operation):
"""Parses the pooling operation string to return its type and shape."""
pooling_type = _operation_to_pooling_type(operation)
pooling_shape = _operation_to_pooling_shape(operation)
return pooling_type, pooling_shape
def _pooling(net, stride, operation, use_bounded_activation):
"""Parses operation and performs the correct pooling operation on net."""
padding = 'SAME'
pooling_type, pooling_shape = _operation_to_pooling_info(operation)
if use_bounded_activation:
net = tf.nn.relu6(net)
if pooling_type == 'avg':
net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
elif pooling_type == 'max':
net = slim.max_pool2d(net, pooling_shape, stride=stride, padding=padding)
else:
raise NotImplementedError('Unimplemented pooling type: ', pooling_type)
return net
class NasNetABaseCell(object):
"""NASNet Cell class that is used as a 'layer' in image architectures.
Args:
num_conv_filters: The number of filters for each convolution operation.
operations: List of operations that are performed in the NASNet Cell in
order.
used_hiddenstates: Binary array that signals if the hiddenstate was used
within the cell. This is used to determine what outputs of the cell
should be concatenated together.
hiddenstate_indices: Determines what hiddenstates should be combined
together with the specified operations to create the NASNet cell.
use_bounded_activation: Whether or not to use bounded activations. Bounded
activations better lend themselves to quantized inference.
"""
def __init__(self, num_conv_filters, operations, used_hiddenstates,
hiddenstate_indices, drop_path_keep_prob, total_num_cells,
total_training_steps, use_bounded_activation=False):
self._num_conv_filters = num_conv_filters
self._operations = operations
self._used_hiddenstates = used_hiddenstates
self._hiddenstate_indices = hiddenstate_indices
self._drop_path_keep_prob = drop_path_keep_prob
self._total_num_cells = total_num_cells
self._total_training_steps = total_training_steps
self._use_bounded_activation = use_bounded_activation
def _reduce_prev_layer(self, prev_layer, curr_layer):
"""Matches dimension of prev_layer to the curr_layer."""
# Set the prev layer to the current layer if it is none
if prev_layer is None:
return curr_layer
curr_num_filters = self._filter_size
prev_num_filters = get_channel_dim(prev_layer.shape)
curr_filter_shape = int(curr_layer.shape[2])
prev_filter_shape = int(prev_layer.shape[2])
activation_fn = tf.nn.relu6 if self._use_bounded_activation else tf.nn.relu
if curr_filter_shape != prev_filter_shape:
prev_layer = activation_fn(prev_layer)
prev_layer = factorized_reduction(
prev_layer, curr_num_filters, stride=2)
elif curr_num_filters != prev_num_filters:
prev_layer = activation_fn(prev_layer)
prev_layer = slim.conv2d(
prev_layer, curr_num_filters, 1, scope='prev_1x1')
prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
return prev_layer
def _cell_base(self, net, prev_layer):
"""Runs the beginning of the conv cell before the predicted ops are run."""
num_filters = self._filter_size
# Check to be sure prev layer stuff is setup correctly
prev_layer = self._reduce_prev_layer(prev_layer, net)
net = tf.nn.relu6(net) if self._use_bounded_activation else tf.nn.relu(net)
net = slim.conv2d(net, num_filters, 1, scope='1x1')
net = slim.batch_norm(net, scope='beginning_bn')
# num_or_size_splits=1
net = [net]
net.append(prev_layer)
return net
def __call__(self, net, scope=None, filter_scaling=1, stride=1,
prev_layer=None, cell_num=-1, current_step=None):
"""Runs the conv cell."""
self._cell_num = cell_num
self._filter_scaling = filter_scaling
self._filter_size = int(self._num_conv_filters * filter_scaling)
i = 0
with tf.variable_scope(scope):
net = self._cell_base(net, prev_layer)
for iteration in range(5):
with tf.variable_scope('comb_iter_{}'.format(iteration)):
left_hiddenstate_idx, right_hiddenstate_idx = (
self._hiddenstate_indices[i],
self._hiddenstate_indices[i + 1])
original_input_left = left_hiddenstate_idx < 2
original_input_right = right_hiddenstate_idx < 2
h1 = net[left_hiddenstate_idx]
h2 = net[right_hiddenstate_idx]
operation_left = self._operations[i]
operation_right = self._operations[i+1]
i += 2
# Apply conv operations
with tf.variable_scope('left'):
h1 = self._apply_conv_operation(h1, operation_left,
stride, original_input_left,
current_step)
with tf.variable_scope('right'):
h2 = self._apply_conv_operation(h2, operation_right,
stride, original_input_right,
current_step)
# Combine hidden states using 'add'.
with tf.variable_scope('combine'):
h = h1 + h2
if self._use_bounded_activation:
h = tf.nn.relu6(h)
# Add hiddenstate to the list of hiddenstates we can choose from
net.append(h)
with tf.variable_scope('cell_output'):
net = self._combine_unused_states(net)
return net
def _apply_conv_operation(self, net, operation,
stride, is_from_original_input, current_step):
"""Applies the predicted conv operation to net."""
# Dont stride if this is not one of the original hiddenstates
if stride > 1 and not is_from_original_input:
stride = 1
input_filters = get_channel_dim(net.shape)
filter_size = self._filter_size
if 'separable' in operation:
net = _stacked_separable_conv(net, stride, operation, filter_size,
self._use_bounded_activation)
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif operation in ['none']:
if self._use_bounded_activation:
net = tf.nn.relu6(net)
# Check if a stride is needed, then use a strided 1x1 here
if stride > 1 or (input_filters != filter_size):
if not self._use_bounded_activation:
net = tf.nn.relu(net)
net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
elif 'pool' in operation:
net = _pooling(net, stride, operation, self._use_bounded_activation)
if input_filters != filter_size:
net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
net = slim.batch_norm(net, scope='bn_1')
if self._use_bounded_activation:
net = tf.clip_by_value(net, -CLIP_BY_VALUE_CAP, CLIP_BY_VALUE_CAP)
else:
raise ValueError('Unimplemented operation', operation)
if operation != 'none':
net = self._apply_drop_path(net, current_step=current_step)
return net
def _combine_unused_states(self, net):
"""Concatenate the unused hidden states of the cell."""
used_hiddenstates = self._used_hiddenstates
final_height = int(net[-1].shape[2])
final_num_filters = get_channel_dim(net[-1].shape)
assert len(used_hiddenstates) == len(net)
for idx, used_h in enumerate(used_hiddenstates):
curr_height = int(net[idx].shape[2])
curr_num_filters = get_channel_dim(net[idx].shape)
# Determine if a reduction should be applied to make the number of
# filters match.
should_reduce = final_num_filters != curr_num_filters
should_reduce = (final_height != curr_height) or should_reduce
should_reduce = should_reduce and not used_h
if should_reduce:
stride = 2 if final_height != curr_height else 1
with tf.variable_scope('reduction_{}'.format(idx)):
net[idx] = factorized_reduction(
net[idx], final_num_filters, stride)
states_to_combine = (
[h for h, is_used in zip(net, used_hiddenstates) if not is_used])
# Return the concat of all the states
concat_axis = get_channel_index()
net = tf.concat(values=states_to_combine, axis=concat_axis)
return net
@slim.add_arg_scope # No public API. For internal use only.
def _apply_drop_path(self, net, current_step=None,
use_summaries=False, drop_connect_version='v3'):
"""Apply drop_path regularization.
Args:
net: the Tensor that gets drop_path regularization applied.
current_step: a float32 Tensor with the current global_step value,
to be divided by hparams.total_training_steps. Usually None, which
defaults to tf.train.get_or_create_global_step() properly casted.
use_summaries: a Python boolean. If set to False, no summaries are output.
drop_connect_version: one of 'v1', 'v2', 'v3', controlling whether
the dropout rate is scaled by current_step (v1), layer (v2), or
both (v3, the default).
Returns:
The dropped-out value of `net`.
"""
drop_path_keep_prob = self._drop_path_keep_prob
if drop_path_keep_prob < 1.0:
assert drop_connect_version in ['v1', 'v2', 'v3']
if drop_connect_version in ['v2', 'v3']:
# Scale keep prob by layer number
assert self._cell_num != -1
# The added 2 is for the reduction cells
num_cells = self._total_num_cells
layer_ratio = (self._cell_num + 1)/float(num_cells)
if use_summaries:
with tf.device('/cpu:0'):
tf.summary.scalar('layer_ratio', layer_ratio)
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
if drop_connect_version in ['v1', 'v3']:
# Decrease the keep probability over time
if current_step is None:
current_step = tf.train.get_or_create_global_step()
current_step = tf.cast(current_step, tf.float32)
drop_path_burn_in_steps = self._total_training_steps
current_ratio = current_step / drop_path_burn_in_steps
current_ratio = tf.minimum(1.0, current_ratio)
if use_summaries:
with tf.device('/cpu:0'):
tf.summary.scalar('current_ratio', current_ratio)
drop_path_keep_prob = (1 - current_ratio * (1 - drop_path_keep_prob))
if use_summaries:
with tf.device('/cpu:0'):
tf.summary.scalar('drop_path_keep_prob', drop_path_keep_prob)
net = drop_path(net, drop_path_keep_prob)
return net
class NasNetANormalCell(NasNetABaseCell):
"""NASNetA Normal Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2',
'separable_3x3_2',
'separable_5x5_2',
'separable_3x3_2',
'avg_pool_3x3',
'none',
'avg_pool_3x3',
'avg_pool_3x3',
'separable_3x3_2',
'none']
used_hiddenstates = [1, 0, 0, 0, 0, 0, 0]
hiddenstate_indices = [0, 1, 1, 1, 0, 1, 1, 1, 0, 0]
super(NasNetANormalCell, self).__init__(num_conv_filters, operations,
used_hiddenstates,
hiddenstate_indices,
drop_path_keep_prob,
total_num_cells,
total_training_steps,
use_bounded_activation)
class NasNetAReductionCell(NasNetABaseCell):
"""NASNetA Reduction Cell."""
def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
total_training_steps, use_bounded_activation=False):
operations = ['separable_5x5_2',
'separable_7x7_2',
'max_pool_3x3',
'separable_7x7_2',
'avg_pool_3x3',
'separable_5x5_2',
'none',
'avg_pool_3x3',
'separable_3x3_2',
'max_pool_3x3']
used_hiddenstates = [1, 1, 1, 0, 0, 0, 0]
hiddenstate_indices = [0, 1, 0, 1, 0, 1, 3, 2, 2, 0]
super(NasNetAReductionCell, self).__init__(num_conv_filters, operations,
used_hiddenstates,
hiddenstate_indices,
drop_path_keep_prob,
total_num_cells,
total_training_steps,
use_bounded_activation)