tensorflow/models

View on GitHub
research/slim/nets/nasnet/pnasnet.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition for the PNASNet classification networks.

Paper: https://arxiv.org/abs/1712.00559
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import tensorflow.compat.v1 as tf
import tf_slim as slim
from tensorflow.contrib import training as contrib_training

from nets.nasnet import nasnet
from nets.nasnet import nasnet_utils

arg_scope = slim.arg_scope


def large_imagenet_config():
  """Large ImageNet configuration based on PNASNet-5."""
  return contrib_training.HParams(
      stem_multiplier=3.0,
      dense_dropout_keep_prob=0.5,
      num_cells=12,
      filter_scaling_rate=2.0,
      num_conv_filters=216,
      drop_path_keep_prob=0.6,
      use_aux_head=1,
      num_reduction_layers=2,
      data_format='NHWC',
      skip_reduction_layer_input=1,
      total_training_steps=250000,
      use_bounded_activation=False,
  )


def mobile_imagenet_config():
  """Mobile ImageNet configuration based on PNASNet-5."""
  return contrib_training.HParams(
      stem_multiplier=1.0,
      dense_dropout_keep_prob=0.5,
      num_cells=9,
      filter_scaling_rate=2.0,
      num_conv_filters=54,
      drop_path_keep_prob=1.0,
      use_aux_head=1,
      num_reduction_layers=2,
      data_format='NHWC',
      skip_reduction_layer_input=1,
      total_training_steps=250000,
      use_bounded_activation=False,
  )


def pnasnet_large_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997,
                            batch_norm_epsilon=0.001):
  """Default arg scope for the PNASNet Large ImageNet model."""
  return nasnet.nasnet_large_arg_scope(
      weight_decay, batch_norm_decay, batch_norm_epsilon)


def pnasnet_mobile_arg_scope(weight_decay=4e-5,
                             batch_norm_decay=0.9997,
                             batch_norm_epsilon=0.001):
  """Default arg scope for the PNASNet Mobile ImageNet model."""
  return nasnet.nasnet_mobile_arg_scope(weight_decay, batch_norm_decay,
                                        batch_norm_epsilon)


def _build_pnasnet_base(images,
                        normal_cell,
                        num_classes,
                        hparams,
                        is_training,
                        final_endpoint=None):
  """Constructs a PNASNet image model."""

  end_points = {}

  def add_and_check_endpoint(endpoint_name, net):
    end_points[endpoint_name] = net
    return final_endpoint and (endpoint_name == final_endpoint)

  # Find where to place the reduction cells or stride normal cells
  reduction_indices = nasnet_utils.calc_reduction_layers(
      hparams.num_cells, hparams.num_reduction_layers)

  # pylint: disable=protected-access
  stem = lambda: nasnet._imagenet_stem(images, hparams, normal_cell)
  # pylint: enable=protected-access
  net, cell_outputs = stem()
  if add_and_check_endpoint('Stem', net):
    return net, end_points

  # Setup for building in the auxiliary head.
  aux_head_cell_idxes = []
  if len(reduction_indices) >= 2:
    aux_head_cell_idxes.append(reduction_indices[1] - 1)

  # Run the cells
  filter_scaling = 1.0
  # true_cell_num accounts for the stem cells
  true_cell_num = 2
  activation_fn = tf.nn.relu6 if hparams.use_bounded_activation else tf.nn.relu
  for cell_num in range(hparams.num_cells):
    is_reduction = cell_num in reduction_indices
    stride = 2 if is_reduction else 1
    if is_reduction: filter_scaling *= hparams.filter_scaling_rate
    if hparams.skip_reduction_layer_input or not is_reduction:
      prev_layer = cell_outputs[-2]
    net = normal_cell(
        net,
        scope='cell_{}'.format(cell_num),
        filter_scaling=filter_scaling,
        stride=stride,
        prev_layer=prev_layer,
        cell_num=true_cell_num)
    if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
      return net, end_points
    true_cell_num += 1
    cell_outputs.append(net)

    if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
        num_classes and is_training):
      aux_net = activation_fn(net)
      # pylint: disable=protected-access
      nasnet._build_aux_head(aux_net, end_points, num_classes, hparams,
                             scope='aux_{}'.format(cell_num))
      # pylint: enable=protected-access

  # Final softmax layer
  with tf.variable_scope('final_layer'):
    net = activation_fn(net)
    net = nasnet_utils.global_avg_pool(net)
    if add_and_check_endpoint('global_pool', net) or not num_classes:
      return net, end_points
    net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
    logits = slim.fully_connected(net, num_classes)

    if add_and_check_endpoint('Logits', logits):
      return net, end_points

    predictions = tf.nn.softmax(logits, name='predictions')
    if add_and_check_endpoint('Predictions', predictions):
      return net, end_points
  return logits, end_points


def build_pnasnet_large(images,
                        num_classes,
                        is_training=True,
                        final_endpoint=None,
                        config=None):
  """Build PNASNet Large model for the ImageNet Dataset."""
  hparams = copy.deepcopy(config) if config else large_imagenet_config()
  # pylint: disable=protected-access
  nasnet._update_hparams(hparams, is_training)
  # pylint: enable=protected-access

  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
    tf.logging.info(
        'A GPU is available on the machine, consider using NCHW '
        'data format for increased speed on GPU.')

  if hparams.data_format == 'NCHW':
    images = tf.transpose(a=images, perm=[0, 3, 1, 2])

  # Calculate the total number of cells in the network.
  # There is no distinction between reduction and normal cells in PNAS so the
  # total number of cells is equal to the number normal cells plus the number
  # of stem cells (two by default).
  total_num_cells = hparams.num_cells + 2

  normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
                                  hparams.drop_path_keep_prob, total_num_cells,
                                  hparams.total_training_steps,
                                  hparams.use_bounded_activation)
  with arg_scope(
      [slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
      is_training=is_training):
    with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d,
                    slim.batch_norm, slim.separable_conv2d,
                    nasnet_utils.factorized_reduction,
                    nasnet_utils.global_avg_pool,
                    nasnet_utils.get_channel_index,
                    nasnet_utils.get_channel_dim],
                   data_format=hparams.data_format):
      return _build_pnasnet_base(
          images,
          normal_cell=normal_cell,
          num_classes=num_classes,
          hparams=hparams,
          is_training=is_training,
          final_endpoint=final_endpoint)
build_pnasnet_large.default_image_size = 331


def build_pnasnet_mobile(images,
                         num_classes,
                         is_training=True,
                         final_endpoint=None,
                         config=None):
  """Build PNASNet Mobile model for the ImageNet Dataset."""
  hparams = copy.deepcopy(config) if config else mobile_imagenet_config()
  # pylint: disable=protected-access
  nasnet._update_hparams(hparams, is_training)
  # pylint: enable=protected-access

  if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
    tf.logging.info(
        'A GPU is available on the machine, consider using NCHW '
        'data format for increased speed on GPU.')

  if hparams.data_format == 'NCHW':
    images = tf.transpose(a=images, perm=[0, 3, 1, 2])

  # Calculate the total number of cells in the network.
  # There is no distinction between reduction and normal cells in PNAS so the
  # total number of cells is equal to the number normal cells plus the number
  # of stem cells (two by default).
  total_num_cells = hparams.num_cells + 2

  normal_cell = PNasNetNormalCell(hparams.num_conv_filters,
                                  hparams.drop_path_keep_prob, total_num_cells,
                                  hparams.total_training_steps,
                                  hparams.use_bounded_activation)
  with arg_scope(
      [slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
      is_training=is_training):
    with arg_scope(
        [
            slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm,
            slim.separable_conv2d, nasnet_utils.factorized_reduction,
            nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index,
            nasnet_utils.get_channel_dim
        ],
        data_format=hparams.data_format):
      return _build_pnasnet_base(
          images,
          normal_cell=normal_cell,
          num_classes=num_classes,
          hparams=hparams,
          is_training=is_training,
          final_endpoint=final_endpoint)


build_pnasnet_mobile.default_image_size = 224


class PNasNetNormalCell(nasnet_utils.NasNetABaseCell):
  """PNASNet Normal Cell."""

  def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
               total_training_steps, use_bounded_activation=False):
    # Configuration for the PNASNet-5 model.
    operations = [
        'separable_5x5_2', 'max_pool_3x3', 'separable_7x7_2', 'max_pool_3x3',
        'separable_5x5_2', 'separable_3x3_2', 'separable_3x3_2', 'max_pool_3x3',
        'separable_3x3_2', 'none'
    ]
    used_hiddenstates = [1, 1, 0, 0, 0, 0, 0]
    hiddenstate_indices = [1, 1, 0, 0, 0, 0, 4, 0, 1, 0]

    super(PNasNetNormalCell, self).__init__(
        num_conv_filters, operations, used_hiddenstates, hiddenstate_indices,
        drop_path_keep_prob, total_num_cells, total_training_steps,
        use_bounded_activation)