tensorflow/models

View on GitHub
research/autoaugment/wrn.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Builds the Wide-ResNet Model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import custom_ops as ops
import numpy as np
import tensorflow as tf



def residual_block(
    x, in_filter, out_filter, stride, activate_before_residual=False):
  """Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.

  Args:
    x: Tensor that is the output of the previous layer in the model.
    in_filter: Number of filters `x` has.
    out_filter: Number of filters that the output of this layer will have.
    stride: Integer that specified what stride should be applied to `x`.
    activate_before_residual: Boolean on whether a BN->ReLU should be applied
      to x before the convolution is applied.

  Returns:
    A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
    and then adding that Tensor to `x`.
  """

  if activate_before_residual:  # Pass up RELU and BN activation for resnet
    with tf.variable_scope('shared_activation'):
      x = ops.batch_norm(x, scope='init_bn')
      x = tf.nn.relu(x)
      orig_x = x
  else:
    orig_x = x

  block_x = x
  if not activate_before_residual:
    with tf.variable_scope('residual_only_activation'):
      block_x = ops.batch_norm(block_x, scope='init_bn')
      block_x = tf.nn.relu(block_x)

  with tf.variable_scope('sub1'):
    block_x = ops.conv2d(
        block_x, out_filter, 3, stride=stride, scope='conv1')

  with tf.variable_scope('sub2'):
    block_x = ops.batch_norm(block_x, scope='bn2')
    block_x = tf.nn.relu(block_x)
    block_x = ops.conv2d(
        block_x, out_filter, 3, stride=1, scope='conv2')

  with tf.variable_scope(
      'sub_add'):  # If number of filters do not agree then zero pad them
    if in_filter != out_filter:
      orig_x = ops.avg_pool(orig_x, stride, stride)
      orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
  x = orig_x + block_x
  return x


def _res_add(in_filter, out_filter, stride, x, orig_x):
  """Adds `x` with `orig_x`, both of which are layers in the model.

  Args:
    in_filter: Number of filters in `orig_x`.
    out_filter: Number of filters in `x`.
    stride: Integer specifying the stide that should be applied `orig_x`.
    x: Tensor that is the output of the previous layer.
    orig_x: Tensor that is the output of an earlier layer in the network.

  Returns:
    A Tensor that is the result of `x` and `orig_x` being added after
    zero padding and striding are applied to `orig_x` to get the shapes
    to match.
  """
  if in_filter != out_filter:
    orig_x = ops.avg_pool(orig_x, stride, stride)
    orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
  x = x + orig_x
  orig_x = x
  return x, orig_x


def build_wrn_model(images, num_classes, wrn_size):
  """Builds the WRN model.

  Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    wrn_size: Parameter that scales the number of filters in the Wide ResNet
      model.

  Returns:
    The logits of the Wide ResNet model.
  """
  kernel_size = wrn_size
  filter_size = 3
  num_blocks_per_resnet = 4
  filters = [
      min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
  ]
  strides = [1, 2, 2]  # stride for each resblock

  # Run the first conv
  with tf.variable_scope('init'):
    x = images
    output_filters = filters[0]
    x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')

  first_x = x  # Res from the beginning
  orig_x = x  # Res from previous block

  for block_num in range(1, 4):
    with tf.variable_scope('unit_{}_0'.format(block_num)):
      activate_before_residual = True if block_num == 1 else False
      x = residual_block(
          x,
          filters[block_num - 1],
          filters[block_num],
          strides[block_num - 1],
          activate_before_residual=activate_before_residual)
    for i in range(1, num_blocks_per_resnet):
      with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
        x = residual_block(
            x,
            filters[block_num],
            filters[block_num],
            1,
            activate_before_residual=False)
    x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
                         strides[block_num - 1], x, orig_x)
  final_stride_val = np.prod(strides)
  x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
  with tf.variable_scope('unit_last'):
    x = ops.batch_norm(x, scope='final_bn')
    x = tf.nn.relu(x)
    x = ops.global_avg_pool(x)
    logits = ops.fc(x, num_classes)
  return logits