tensorflow/models

View on GitHub
research/slim/nets/dcgan.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from math import log

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow.compat.v1 as tf
import tf_slim as slim


def _validate_image_inputs(inputs):
  inputs.get_shape().assert_has_rank(4)
  inputs.get_shape()[1:3].assert_is_fully_defined()
  if inputs.get_shape()[1] != inputs.get_shape()[2]:
    raise ValueError('Input tensor does not have equal width and height: ',
                     inputs.get_shape()[1:3])
  width = inputs.get_shape().as_list()[1]
  if log(width, 2) != int(log(width, 2)):
    raise ValueError('Input tensor `width` is not a power of 2: ', width)


# TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
# setups need the gradient of gradient FusedBatchNormGrad.
def discriminator(inputs,
                  depth=64,
                  is_training=True,
                  reuse=None,
                  scope='Discriminator',
                  fused_batch_norm=False):
  """Discriminator network for DCGAN.

  Construct discriminator network from inputs to the final endpoint.

  Args:
    inputs: A tensor of size [batch_size, height, width, channels]. Must be
      floating point.
    depth: Number of channels in first convolution layer.
    is_training: Whether the network is for training or not.
    reuse: Whether or not the network variables should be reused. `scope`
      must be given to be reused.
    scope: Optional variable_scope.
    fused_batch_norm: If `True`, use a faster, fused implementation of
      batch norm.

  Returns:
    logits: The pre-softmax activations, a tensor of size [batch_size, 1]
    end_points: a dictionary from components of the network to their activation.

  Raises:
    ValueError: If the input image shape is not 4-dimensional, if the spatial
      dimensions aren't defined at graph construction time, if the spatial
      dimensions aren't square, or if the spatial dimensions aren't a power of
      two.
  """

  normalizer_fn = slim.batch_norm
  normalizer_fn_args = {
      'is_training': is_training,
      'zero_debias_moving_mean': True,
      'fused': fused_batch_norm,
  }

  _validate_image_inputs(inputs)
  inp_shape = inputs.get_shape().as_list()[1]

  end_points = {}
  with tf.variable_scope(
      scope, values=[inputs], reuse=reuse) as scope:
    with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
      with slim.arg_scope([slim.conv2d],
                          stride=2,
                          kernel_size=4,
                          activation_fn=tf.nn.leaky_relu):
        net = inputs
        for i in xrange(int(log(inp_shape, 2))):
          scope = 'conv%i' % (i + 1)
          current_depth = depth * 2**i
          normalizer_fn_ = None if i == 0 else normalizer_fn
          net = slim.conv2d(
              net, current_depth, normalizer_fn=normalizer_fn_, scope=scope)
          end_points[scope] = net

        logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID',
                             normalizer_fn=None, activation_fn=None)
        logits = tf.reshape(logits, [-1, 1])
        end_points['logits'] = logits

        return logits, end_points


# TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
# setups need the gradient of gradient FusedBatchNormGrad.
def generator(inputs,
              depth=64,
              final_size=32,
              num_outputs=3,
              is_training=True,
              reuse=None,
              scope='Generator',
              fused_batch_norm=False):
  """Generator network for DCGAN.

  Construct generator network from inputs to the final endpoint.

  Args:
    inputs: A tensor with any size N. [batch_size, N]
    depth: Number of channels in last deconvolution layer.
    final_size: The shape of the final output.
    num_outputs: Number of output features. For images, this is the number of
      channels.
    is_training: whether is training or not.
    reuse: Whether or not the network has its variables should be reused. scope
      must be given to be reused.
    scope: Optional variable_scope.
    fused_batch_norm: If `True`, use a faster, fused implementation of
      batch norm.

  Returns:
    logits: the pre-softmax activations, a tensor of size
      [batch_size, 32, 32, channels]
    end_points: a dictionary from components of the network to their activation.

  Raises:
    ValueError: If `inputs` is not 2-dimensional.
    ValueError: If `final_size` isn't a power of 2 or is less than 8.
  """
  normalizer_fn = slim.batch_norm
  normalizer_fn_args = {
      'is_training': is_training,
      'zero_debias_moving_mean': True,
      'fused': fused_batch_norm,
  }

  inputs.get_shape().assert_has_rank(2)
  if log(final_size, 2) != int(log(final_size, 2)):
    raise ValueError('`final_size` (%i) must be a power of 2.' % final_size)
  if final_size < 8:
    raise ValueError('`final_size` (%i) must be greater than 8.' % final_size)

  end_points = {}
  num_layers = int(log(final_size, 2)) - 1
  with tf.variable_scope(
      scope, values=[inputs], reuse=reuse) as scope:
    with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
      with slim.arg_scope([slim.conv2d_transpose],
                          normalizer_fn=normalizer_fn,
                          stride=2,
                          kernel_size=4):
        net = tf.expand_dims(tf.expand_dims(inputs, 1), 1)

        # First upscaling is different because it takes the input vector.
        current_depth = depth * 2 ** (num_layers - 1)
        scope = 'deconv1'
        net = slim.conv2d_transpose(
            net, current_depth, stride=1, padding='VALID', scope=scope)
        end_points[scope] = net

        for i in xrange(2, num_layers):
          scope = 'deconv%i' % (i)
          current_depth = depth * 2 ** (num_layers - i)
          net = slim.conv2d_transpose(net, current_depth, scope=scope)
          end_points[scope] = net

        # Last layer has different normalizer and activation.
        scope = 'deconv%i' % (num_layers)
        net = slim.conv2d_transpose(
            net, depth, normalizer_fn=None, activation_fn=None, scope=scope)
        end_points[scope] = net

        # Convert to proper channels.
        scope = 'logits'
        logits = slim.conv2d(
            net,
            num_outputs,
            normalizer_fn=None,
            activation_fn=None,
            kernel_size=1,
            stride=1,
            padding='VALID',
            scope=scope)
        end_points[scope] = logits

        logits.get_shape().assert_has_rank(4)
        logits.get_shape().assert_is_compatible_with(
            [None, final_size, final_size, num_outputs])

        return logits, end_points