tensorflow/tensorflow

View on GitHub
tensorflow/python/keras/engine/training_utils_v1.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training-related utilities."""

import abc
import atexit
import collections
import functools
import multiprocessing.pool
import threading
import time

import numpy as np

from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_conversion
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import data as data_types
from tensorflow.python.util import nest


def is_composite_or_composite_value(tensor):
  """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
  # TODO(b/125094323): This should be isinstance(CompositeTensor) or
  # isinstance(CompositeTensorValue) once we support that.
  return isinstance(
      tensor,
      (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
       ragged_tensor_value.RaggedTensorValue))


class Aggregator(object, metaclass=abc.ABCMeta):
  """Abstract base class used to aggregate batch-level outputs of a loop.

  Attributes:
    use_steps: Whether the loop is using `step` or `batch_size`.
    num_samples: Total number of samples: `batch_size * num_batches`.
    steps: Total number of steps.
    batch_size: Batch size. It is used for validation checks between inputs and
      outputs.
    results: What to return at the end of the aggregation loop.
  """

  def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
    self.use_steps = use_steps
    self.num_samples = num_samples
    self.steps = steps
    self.batch_size = batch_size
    self.results = []

  @abc.abstractmethod
  def create(self, batch_outs):
    """Creates the initial results from the first batch outputs.

    Args:
      batch_outs: A list of batch-level outputs.
    """
    raise NotImplementedError('Must be implemented in subclasses.')

  @abc.abstractmethod
  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
    """Aggregates batch-level results into total results.

    Args:
      batch_outs: A list of batch-level outputs.
      batch_start: The start index of this batch. Always `None` if `use_steps`
        is `True`.
      batch_end: The end index of this batch. Always `None` if `use_steps` is
        `True`.
    """
    raise NotImplementedError('Must be implemented in subclasses.')

  @abc.abstractmethod
  def finalize(self):
    """Prepares the total results to be returned."""
    raise NotImplementedError('Must be implemented in subclasses.')


class MetricsAggregator(Aggregator):
  """Aggregator that calculates loss and metrics info.

  Attributes:
    use_steps: Whether the loop is using `step` or `batch_size`.
    num_samples: Total number of samples: `batch_size*num_batches`.
    steps: Total number of steps, ie number of times to iterate over a dataset
      to cover all samples.
  """

  def __init__(self, use_steps, num_samples=None, steps=None):
    super(MetricsAggregator, self).__init__(
        use_steps=use_steps,
        num_samples=num_samples,
        steps=steps,
        batch_size=None)

  def create(self, batch_outs):
    self.results = [0.] * len(batch_outs)

  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
    # Loss.
    if self.use_steps:
      self.results[0] += batch_outs[0]
    else:
      self.results[0] += batch_outs[0] * (batch_end - batch_start)
    # Metrics (always stateful, just grab current values.)
    self.results[1:] = batch_outs[1:]

  def finalize(self):
    if not self.results:
      raise ValueError('Empty training data.')
    self.results[0] /= (self.num_samples or self.steps)


def _append_sparse_tensor_value(target, to_append):
  """Append sparse tensor value objects."""
  # Make sure the sparse tensors are of the same size (except for the 0th dim).
  if len(target.dense_shape) != len(to_append.dense_shape):
    raise RuntimeError(
        'Unable to concatenate %s and %s. The inner dense shapes do not '
        'have the same number of dimensions (%s vs %s)' %
        (target, to_append, target.dense_shape, to_append.dense_shape))

  if target.dense_shape[1:] != to_append.dense_shape[1:]:
    raise RuntimeError(
        'Unable to concatenate %s and %s. The inner dense shapes do not '
        'match inner dimensions (%s vs %s)' %
        (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))

  # Add the to_append indices to target, updating the 0th value, and keeping
  # track of the maximum so we know the final dense_shape of this tensor.
  base_dim0_value = target.dense_shape[0]
  max_dim0_value = target.dense_shape[0]
  new_indices = target.indices
  for index in to_append.indices:
    # Here, we iterate through the sparse indices of the tensor to append. For
    # each index, we update its zeroth value (the batch index) by adding the
    # number of batch items in the tensor we are appending to (so an index
    # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
    # size 3 would become [3, 0, 1].)
    index[0] += base_dim0_value
    max_dim0_value = max(max_dim0_value, index[0])
    new_indices = np.append(new_indices, [index], axis=0)

  # Extend the values array to contain all of the appended values. These will
  # be in the same order as the indices added above.
  new_values = np.concatenate((target.values, to_append.values), axis=0)

  # Create a new dense shape by replacing the value for the 0th dimension
  # with the new max dim0 value.
  new_dense_shape = list(target.dense_shape)
  new_dense_shape[0] = max_dim0_value + 1
  new_dense_shape = tuple(new_dense_shape)

  return sparse_tensor.SparseTensorValue(
      indices=new_indices, values=new_values, dense_shape=new_dense_shape)


def _append_ragged_tensor_value(target, to_append):
  """Append ragged tensor value objects."""
  # Make sure the ragged tensors are of the same size (save for the 0th dim).
  if len(target.shape) != len(to_append.shape):
    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))

  if target.shape[1:] != to_append.shape[1:]:
    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))

  adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
  new_row_splits = np.append(target.row_splits, adjusted_row_splits)
  if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
    new_values = _append_ragged_tensor_value(target.values, to_append.values)
  else:
    new_values = np.concatenate((target.values, to_append.values), axis=0)

  return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)


def _append_composite_tensor(target, to_append):
  """Helper function to append composite tensors to each other in the 0 axis.

  In order to support batching within a fit/evaluate/predict call, we need
  to be able to aggregate within a CompositeTensor. Unfortunately, the CT
  API currently does not make this easy - especially in V1 mode, where we're
  working with CompositeTensor Value objects that have no connection with the
  CompositeTensors that created them.

  Args:
    target: CompositeTensor or CompositeTensor value object that will be
      appended to.
    to_append: CompositeTensor or CompositeTensor value object to append to.
      'target'.

  Returns:
    A CompositeTensor or CompositeTensor value object.

  Raises:
    RuntimeError: if concatenation is not possible.
  """
  if type(target) is not type(to_append):
    raise RuntimeError('Unable to concatenate %s and %s' %
                       (type(target), type(to_append)))

  # Perform type-specific concatenation.
  # TODO(b/125094323): This should be replaced by a simple call to
  # target.append() that should work on all of the below classes.

  # If we're seeing a CompositeTensor here, we know it's because we're in
  # Eager mode (or else we'd have evaluated the CT to a CT Value object
  # already). Therefore, it's safe to call concat() on it without evaluating
  # the result any further. If not - that is, if we're seeing a
  # SparseTensorValue or a RaggedTensorValue - we need to hand-update it
  # since we're outside of the graph anyways.
  if isinstance(target, sparse_tensor.SparseTensor):
    # We need to invoke the sparse version of concatenate here - tf.concat
    # won't work.
    return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
  elif isinstance(target, ragged_tensor.RaggedTensor):
    return array_ops.concat([target, to_append], axis=0)
  elif isinstance(target, sparse_tensor.SparseTensorValue):
    return _append_sparse_tensor_value(target, to_append)
  elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
    return _append_ragged_tensor_value(target, to_append)
  else:
    raise RuntimeError('Attempted to concatenate unsupported object %s.' %
                       type(target))


class ConcatAggregator(Aggregator):
  """Combine tensor-likes which cannot be merged on the fly.

  This class expects to aggregate a single tensor-like rather than a nested
  structure of tensor-likes.
  """

  def __init__(self, batch_size):
    self.composite = None
    super(ConcatAggregator, self).__init__(
        use_steps=True, num_samples=None, steps=None, batch_size=batch_size)

  def create(self, batch_element):
    self.composite = is_composite_or_composite_value(batch_element)

  def aggregate(self, batch_element, batch_start=None, batch_end=None):

    # TODO(psv): Add num_samples check here to detect when output batch
    # #samples is < batch size and != input batch #samples.
    if self.batch_size and self.batch_size < batch_element.shape[0]:
      raise ValueError(
          'Mismatch between expected batch size and model output batch size. '
          'Output shape = {}, expected output shape = shape {}'.format(
              batch_element.shape,
              (self.batch_size,) + batch_element.shape[1:]))
    self.results.append(batch_element)

  def finalize(self):
    # Special case of single batch inference which skips a copy.
    if len(self.results) == 1:
      self.results = self.results[0]

    elif self.composite:
      # TODO(taylorrobie): efficiently concatenate.
      results = self.results[0]
      for r in self.results[1:]:
        results = _append_composite_tensor(results, r)
      self.results = results

    else:
      self.results = np.concatenate(self.results, axis=0)


_COPY_THREADS = 4
_COPY_POOL = None


def get_copy_pool():
  """Shared threadpool for copying arrays.

  Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
  creating a pool per SliceAggregator.

  Returns:
    The global copy threadpool.
  """
  global _COPY_POOL
  if _COPY_POOL is None:
    _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
    atexit.register(_COPY_POOL.close)
  return _COPY_POOL


class SliceAggregator(Aggregator):
  """Combine arrays where the final size is known.

  This class expects to aggregate a single tensor-like rather than a nested
  structure of tensor-likes.

  NumPy copies are an operation that threads handle quite well because all of
  the heavy lifting is in c and does not need the GIL. Moreover, we can perform
  lock-free writes to the same buffer in multiple threads because the nature of
  result aggregation guarantees that either the indices are disjoint or the
  aggregator will throw an exception in finalize. Moreover, because aggregation
  is performed on the slowest varying dimension, assignments for a given batch
  will write to contiguous blocks of memory, further minimizing contention.

  There is, however, some scheduling and context switching overhead which will
  offset the gains from pipelining the slice assignment. Below a given threshold
  it is faster to simply assign in the main thread rather than enqueue the
  assignment in a side thread. The exact threshold will vary from system to
  system, but the time is not very sensitive to the exact transition so a value
  of 2 ** 14 was chosen which should be reasonable on most systems.
  """

  _BINARY_SIZE_THRESHOLD = 2 ** 14
  _MAX_COPY_SECONDS = 300

  def __init__(self, num_samples, batch_size):
    self._async_copies = []
    self._pool = get_copy_pool()
    self._errors = []
    super(SliceAggregator, self).__init__(
        use_steps=False,
        num_samples=num_samples,
        steps=None,
        batch_size=batch_size)

  def create(self, batch_element):
    # This step does not need to be pipelined because NumPy empty array
    # initialization is effectively instantaneous.
    shape = (self.num_samples,) + batch_element.shape[1:]
    dtype = batch_element.dtype

    self.results = np.empty(shape=shape, dtype=dtype)

  def aggregate(self, batch_element, batch_start, batch_end):
    # Fail early.
    if self._errors:
      raise self._errors[0]

    # In the special case of single batch inference, no copy is needed.
    if batch_end - batch_start == self.num_samples:
      if self.num_samples != batch_element.shape[0]:
        raise ValueError(
            'Mismatch between expected batch size and model output batch size. '
            'Output shape = {}, expected output shape = shape {}'.format(
                batch_element.shape, self.results.shape))

      self.results = batch_element
      return

    # This is an approximate threshold, so we don't need to consider the number
    # of bytes per element.
    num_elements = np.prod(batch_element.shape)
    if num_elements < self._BINARY_SIZE_THRESHOLD:
      self.results[batch_start:batch_end] = batch_element
    else:
      is_finished = threading.Event()
      self._pool.apply_async(
          self._slice_assign,
          args=(batch_element, batch_start, batch_end, is_finished))
      self._async_copies.append(is_finished)

  def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
    """Legacy utility method to slice input arrays."""
    try:
      self.results[batch_start:batch_end] = batch_element

    except Exception as e:  # pylint: disable=broad-except
      # `_slice_assign` should only be called in threads and exceptions raised
      # in threads do not carry over to the main thread. So instead we perform a
      # a broad catch in the thread and then store the exception to be re-raised
      # in the main thread.
      self._errors.append(e)

    finally:
      is_finished.set()

  def finalize(self):
    start_time = time.time()
    for is_finished in self._async_copies:
      timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
      if not is_finished.wait(timeout):
        raise ValueError('Timed out waiting for copy to complete.')

    if self._errors:
      raise self._errors[0]


class OutputsAggregator(Aggregator):
  """Aggregator that concatenates outputs."""

  _structure = None

  def create(self, batch_outs):
    # SparseTensorValue is a named tuple which nest will flatten, so we need
    # to guard it to properly handle the structure.
    self._structure = nest.get_traverse_shallow_structure(
        lambda x: not is_composite_or_composite_value(x), batch_outs)
    batch_outs = nest.flatten_up_to(self._structure, batch_outs)

    for batch_element in batch_outs:
      if is_composite_or_composite_value(batch_element):
        # If the output is not a ndarray, it will be either a composite tensor
        # or a composite tensor's Value object. In either case, we can't
        # allocate an array to hold the object - we'll handle it later.
        self.results.append(ConcatAggregator(self.batch_size))
      elif isinstance(batch_element, np.ndarray):
        self.results.append(
            (ConcatAggregator(self.batch_size) if self.use_steps else
             SliceAggregator(self.num_samples, self.batch_size)))
      else:
        # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
        # Fail fast rather than trying to concatenate it.
        raise RuntimeError('Attempted to aggregate unsupported object {}.'
                           .format(batch_element))

      self.results[-1].create(batch_element)

  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
    for batch_element, result in zip(batch_outs, self.results):
      result.aggregate(batch_element, batch_start, batch_end)

  def finalize(self):
    for result in self.results:
      result.finalize()
    self.results = [i.results for i in self.results]
    self.results = nest.pack_sequence_as(self._structure, self.results)


def get_progbar(model, count_mode, include_metrics=True):
  """Get Progbar."""
  if include_metrics:
    stateful_metric_names = getattr(model, 'metrics_names', None)
    if stateful_metric_names:
      stateful_metric_names = stateful_metric_names[1:]  # Exclude `loss`
  else:
    stateful_metric_names = None
  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)


def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
  """Determine the number of samples provided for training and evaluation.

  The number of samples is not defined when running with `steps`,
  in which case the number of samples is set to `None`.

  Args:
      ins: List of tensors to be fed to the Keras function.
      batch_size: Integer batch size or `None` if not defined.
      steps: Total number of steps (batches of samples) before declaring
        `_predict_loop` finished. Ignored with the default value of `None`.
      steps_name: The public API's parameter name for `steps`.

  Raises:
      ValueError: when `steps` is `None` and the attribute `ins.shape`
      does not exist. Also raises ValueError when `steps` is not `None`
      and `batch_size` is not `None` because they are mutually
      exclusive.

  Returns:
      When steps is `None`, returns the number of samples to be
      processed based on the size of the first dimension of the
      first input numpy array. When steps is not `None` and
      `batch_size` is `None`, returns `None`.
  """
  if steps is not None and batch_size is not None:
    raise ValueError('If ' + steps_name +
                     ' is set, the `batch_size` must be None.')
  if check_steps_argument(ins, steps, steps_name):
    return None

  if hasattr(ins[0], 'shape'):
    return int(ins[0].shape[0])
  return None  # Edge case where ins == [static_learning_phase]


def standardize_single_array(x, expected_shape=None):
  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
  if x is None:
    return None

  if is_composite_or_composite_value(x):
    return x

  if isinstance(x, int):
    raise ValueError(
        'Expected an array data type but received an integer: {}'.format(x))

  if (x.shape is not None and len(x.shape) == 1 and
      (expected_shape is None or len(expected_shape) != 1)):
    if tensor_util.is_tf_type(x):
      x = array_ops.expand_dims(x, axis=1)
    else:
      x = np.expand_dims(x, 1)
  return x


def get_composite_shape(tensor):
  """Returns the shape of the passed composite tensor."""
  if isinstance(tensor, sparse_tensor.SparseTensorValue):
    # SparseTensorValues use a 'dense_shape' attribute
    return tensor.dense_shape
  else:
    return tensor.shape


def standardize_input_data(data,
                           names,
                           shapes=None,
                           check_batch_axis=True,
                           exception_prefix=''):
  """Normalizes inputs and targets provided by users.

  Users may pass data as a list of arrays, dictionary of arrays,
  or as a single array. We normalize this to an ordered list of
  arrays (same order as `names`), while checking that the provided
  arrays have shapes that match the network's expectations.

  Args:
      data: User-provided input data (polymorphic).
      names: List of expected array names.
      shapes: Optional list of expected array shapes.
      check_batch_axis: Boolean; whether to check that the batch axis of the
        arrays matches the expected value found in `shapes`.
      exception_prefix: String prefix used for exception formatting.

  Returns:
      List of standardized input arrays (one array per model input).

  Raises:
      ValueError: in case of improperly formatted user-provided data.
  """
  try:
    data_len = len(data)
  except TypeError:
    # For instance if data is `None` or a symbolic Tensor.
    data_len = None

  if not names:
    if data_len and not isinstance(data, dict):
      raise ValueError(
          'Error when checking model ' + exception_prefix + ': '
          'expected no data, but got:', data)
    return []
  if data is None:
    return [None for _ in range(len(names))]

  if isinstance(data, dict):
    try:
      data = [
          data[x].values
          if data[x].__class__.__name__ == 'DataFrame' else data[x]
          for x in names
      ]
    except KeyError as e:
      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
                       'for each key in: ' + str(names))
  elif isinstance(data, (list, tuple)):
    if isinstance(data[0], (list, tuple)):
      data = [np.asarray(d) for d in data]
    elif len(names) == 1 and isinstance(data[0], (float, int)):
      data = [np.asarray(data)]
    else:
      data = [
          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
      ]
  else:
    data = data.values if data.__class__.__name__ == 'DataFrame' else data
    data = [data]

  if shapes is not None:
    data = [
        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
    ]
  else:
    data = [standardize_single_array(x) for x in data]

  if len(data) != len(names):
    if data and hasattr(data[0], 'shape'):
      raise ValueError('Error when checking model ' + exception_prefix +
                       ': the list of Numpy arrays that you are passing to '
                       'your model is not the size the model expected. '
                       'Expected to see ' + str(len(names)) + ' array(s), ' +
                       'for inputs ' + str(names) + ' but instead got the '
                       'following list of ' + str(len(data)) + ' arrays: ' +
                       str(data)[:200] + '...')
    elif len(names) > 1:
      raise ValueError('Error when checking model ' + exception_prefix +
                       ': you are passing a list as input to your model, '
                       'but the model expects a list of ' + str(len(names)) +
                       ' Numpy arrays instead. The list you passed was: ' +
                       str(data)[:200])
    elif len(data) == 1 and not hasattr(data[0], 'shape'):
      raise TypeError('Error when checking model ' + exception_prefix +
                      ': data should be a Numpy array, or list/dict of '
                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
    elif len(names) == 1:
      data = [np.asarray(data)]

  # Check shapes compatibility.
  if shapes:
    for i in range(len(names)):
      if shapes[i] is not None:
        if tensor_util.is_tf_type(data[i]):
          tensorshape = data[i].shape
          if not tensorshape:
            continue
          data_shape = tuple(tensorshape.as_list())
        elif is_composite_or_composite_value(data[i]):
          tensorshape = get_composite_shape(data[i])
          data_shape = tuple(tensorshape.as_list())
        else:
          data_shape = data[i].shape

        shape = shapes[i]
        if len(data_shape) != len(shape):
          raise ValueError('Error when checking ' + exception_prefix +
                           ': expected ' + names[i] + ' to have ' +
                           str(len(shape)) + ' dimensions, but got array '
                           'with shape ' + str(data_shape))
        if not check_batch_axis:
          data_shape = data_shape[1:]
          shape = shape[1:]
        for dim, ref_dim in zip(data_shape, shape):
          if ref_dim != dim and ref_dim is not None and dim is not None:
            raise ValueError('Error when checking ' + exception_prefix +
                             ': expected ' + names[i] + ' to have shape ' +
                             str(shape) + ' but got array with shape ' +
                             str(data_shape))
  return data


def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
  """Maps `sample_weight` or `class_weight` to model outputs.

  Args:
      x_weight: User-provided `sample_weight` or `class_weight` argument.
      output_names: List of output names (strings) in the model.
      weight_type: A string used purely for exception printing.

  Returns:
      A list of `sample_weight` or `class_weight` where there are exactly
          one element per model output.

  Raises:
      ValueError: In case of invalid user-provided argument.
  """
  if x_weight is None or (isinstance(x_weight, (list, tuple)) and
                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
    return [None for _ in output_names]
  if len(output_names) == 1:
    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
      return x_weight
    if isinstance(x_weight, dict) and output_names[0] in x_weight:
      return [x_weight[output_names[0]]]
    else:
      return [x_weight]
  if isinstance(x_weight, (list, tuple)):
    if len(x_weight) != len(output_names):
      raise ValueError('Provided `' + weight_type + '` was a list of ' +
                       str(len(x_weight)) + ' elements, but the model has ' +
                       str(len(output_names)) + ' outputs. '
                       'You should provide one `' + weight_type + '`'
                       'array per model output.')
    return x_weight
  if isinstance(x_weight, collections.abc.Mapping):
    generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
    x_weights = []
    for name in output_names:
      x_weights.append(x_weight.get(name))
    return x_weights
  else:
    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
                    'should be either a list or a dict. '
                    'Provided `' + weight_type + '` type not understood: ' +
                    str(x_weight))


def standardize_class_weights(class_weight, output_names):
  return standardize_sample_or_class_weights(class_weight, output_names,
                                             'class_weight')


def standardize_sample_weights(sample_weight, output_names):
  return standardize_sample_or_class_weights(sample_weight, output_names,
                                             'sample_weight')


def check_array_lengths(inputs, targets, weights=None):
  """Does user input validation for numpy arrays.

  Args:
      inputs: list of Numpy arrays of inputs.
      targets: list of Numpy arrays of targets.
      weights: list of Numpy arrays of sample weights.

  Raises:
      ValueError: in case of incorrectly formatted data.
  """

  def is_tensor_or_composite_tensor(x):
    return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x)

  def set_of_lengths(x):
    # Returns a set with the variation between
    # different shapes, with None => 0
    if x is None:
      return {}
    else:
      return set([
          y.shape[0]
          for y in x
          if y is not None and not is_tensor_or_composite_tensor(y)
      ])

  set_x = set_of_lengths(inputs)
  set_y = set_of_lengths(targets)
  set_w = set_of_lengths(weights)
  if len(set_x) > 1:
    raise ValueError('All input arrays (x) should have '
                     'the same number of samples. Got array shapes: ' +
                     str([x.shape for x in inputs]))
  if len(set_y) > 1:
    raise ValueError('All target arrays (y) should have '
                     'the same number of samples. Got array shapes: ' +
                     str([y.shape for y in targets]))
  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
    raise ValueError('Input arrays should have '
                     'the same number of samples as target arrays. '
                     'Found ' + str(list(set_x)[0]) + ' input samples '
                     'and ' + str(list(set_y)[0]) + ' target samples.')
  if len(set_w) > 1:
    raise ValueError('All sample_weight arrays should have '
                     'the same number of samples. Got array shapes: ' +
                     str([w.shape for w in weights]))
  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
    raise ValueError('Sample_weight arrays should have '
                     'the same number of samples as target arrays. Got ' +
                     str(list(set_y)[0]) + ' input samples and ' +
                     str(list(set_w)[0]) + ' target samples.')


def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
  """Does validation on the compatibility of targets and loss functions.

  This helps prevent users from using loss functions incorrectly. This check
  is purely for UX purposes.

  Args:
      targets: list of Numpy arrays of targets.
      loss_fns: list of loss functions.
      output_shapes: list of shapes of model outputs.

  Raises:
      ValueError: if a loss function or target array
          is incompatible with an output.
  """
  key_loss_fns = {
      losses.mean_squared_error, losses.binary_crossentropy,
      losses.categorical_crossentropy
  }
  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
                      losses.CategoricalCrossentropy)
  for y, loss, shape in zip(targets, loss_fns, output_shapes):
    if y is None or loss is None or tensor_util.is_tf_type(y):
      continue
    if losses.is_categorical_crossentropy(loss):
      if y.shape[-1] == 1:
        raise ValueError('You are passing a target array of shape ' +
                         str(y.shape) +
                         ' while using as loss `categorical_crossentropy`. '
                         '`categorical_crossentropy` expects '
                         'targets to be binary matrices (1s and 0s) '
                         'of shape (samples, classes). '
                         'If your targets are integer classes, '
                         'you can convert them to the expected format via:\n'
                         '```\n'
                         'from keras.utils import to_categorical\n'
                         'y_binary = to_categorical(y_int)\n'
                         '```\n'
                         '\n'
                         'Alternatively, you can use the loss function '
                         '`sparse_categorical_crossentropy` instead, '
                         'which does expect integer targets.')

    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
                                               (loss.fn in key_loss_fns))):
      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
        if out_dim is not None and target_dim != out_dim:
          loss_name = loss.name
          if loss_name is None:
            loss_type = loss.fn if is_loss_wrapper else type(loss)
            loss_name = loss_type.__name__
          raise ValueError('A target array with shape ' + str(y.shape) +
                           ' was passed for an output of shape ' + str(shape) +
                           ' while using as loss `' + loss_name + '`. '
                           'This loss expects targets to have the same shape '
                           'as the output.')


def collect_per_output_metric_info(metrics,
                                   output_names,
                                   output_shapes,
                                   loss_fns,
                                   from_serialized=False,
                                   is_weighted=False):
  """Maps metric names and functions to model outputs.

  Args:
      metrics: a list or a list of lists or a dict of metric functions.
      output_names: a list of the names (strings) of model outputs.
      output_shapes: a list of the shapes (strings) of model outputs.
      loss_fns: a list of the loss functions corresponding to the model outputs.
      from_serialized: whether the model the metrics are being sourced from is
        being initialized from a serialized format.
      is_weighted: Boolean indicating whether the given metrics are weighted.

  Returns:
      A list (one entry per model output) of dicts.
      For instance, if the model has 2 outputs, and for the first output
      we want to compute "binary_accuracy" and "binary_crossentropy",
      and just "binary_accuracy" for the second output,
      the list would look like: `[{
          'acc': binary_accuracy(),
          'ce': binary_crossentropy(),
        }, {
          'acc': binary_accuracy(),
        }]`

  Raises:
      TypeError: if an incorrect type is passed for the `metrics` argument.
  """
  if not metrics:
    return [{} for _ in output_names]

  if isinstance(metrics, list):
    any_sub_list = any(isinstance(m, list) for m in metrics)
    if any_sub_list:
      if len(metrics) != len(output_names):
        raise ValueError('When passing a list of lists as `metrics`, '
                         'it should have one entry per model output. '
                         'The model has ' + str(len(output_names)) +
                         ' outputs, but you passed metrics=' + str(metrics))
      # User has provided a list of len = len(outputs).
      nested_metrics = [generic_utils.to_list(m) for m in metrics]
    else:
      # If it is a single list we then apply all metrics to all outputs.
      if len(output_names) > 1:
        nested_metrics = []
        for _ in output_names:
          nested_metrics.append(
              [metrics_module.clone_metric(m) for m in metrics])
      else:
        nested_metrics = [metrics]
  elif isinstance(metrics, collections.abc.Mapping):
    generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
    nested_metrics = []
    for name in output_names:
      output_metrics = generic_utils.to_list(metrics.get(name, []))
      nested_metrics.append(output_metrics)
  else:
    raise TypeError('Type of `metrics` argument not understood. '
                    'Expected a list or dictionary, found: ' + str(metrics))

  per_output_metrics = []
  for i, metrics in enumerate(nested_metrics):
    metrics_dict = collections.OrderedDict()
    for metric in metrics:
      metric_name = get_metric_name(metric, is_weighted)
      metric_fn = get_metric_function(
          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
      metric_fn._from_serialized = from_serialized  # pylint: disable=protected-access

      # If the metric function is not stateful, we create a stateful version.
      if not isinstance(metric_fn, metrics_module.Metric):
        metric_fn = metrics_module.MeanMetricWrapper(
            metric_fn, name=metric_name)
        # If the metric is being revived from something stateless, such as a
        # string (e.g. "accuracy"), we may need to later reapply transformations
        # such as renaming.
        metric_fn._from_serialized = False  # pylint: disable=protected-access
      metrics_dict[metric_name] = metric_fn
    per_output_metrics.append(metrics_dict)

  return per_output_metrics


def batch_shuffle(index_array, batch_size):
  """Shuffles an array in a batch-wise fashion.

  Useful for shuffling HDF5 arrays
  (where one cannot access arbitrary indices).

  Args:
      index_array: array of indices to be shuffled.
      batch_size: integer.

  Returns:
      The `index_array` array, shuffled in a batch-wise fashion.
  """
  batch_count = int(len(index_array) / batch_size)
  # to reshape we need to be cleanly divisible by batch size
  # we stash extra items and reappend them after shuffling
  last_batch = index_array[batch_count * batch_size:]
  index_array = index_array[:batch_count * batch_size]
  index_array = index_array.reshape((batch_count, batch_size))
  np.random.shuffle(index_array)
  index_array = index_array.flatten()
  return np.append(index_array, last_batch)


def standardize_weights(y,
                        sample_weight=None,
                        class_weight=None,
                        sample_weight_mode=None):
  """Performs sample weight validation and standardization.

  Everything gets normalized to a single sample-wise (or timestep-wise)
  weight array. If both `sample_weight` and `class_weight` are provided,
  the weights are multiplied.

  Args:
      y: Numpy array or Tensor of model targets to be weighted.
      sample_weight: User-provided `sample_weight` argument.
      class_weight: User-provided `class_weight` argument.
      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
        that we expect 2D weight data that will be applied to the last 2
        dimensions of the targets (i.e. we are weighting timesteps, not
        samples).

  Returns:
      A numpy array of target weights, one entry per sample to weight.

  Raises:
      ValueError: In case of invalid user-provided arguments.
  """
  # Iterator may return sample_weight as 1-tuple
  if isinstance(sample_weight, tuple):
    sample_weight = sample_weight[0]
  if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
    if sample_weight_mode != 'temporal':
      raise ValueError('"sample_weight_mode '
                       'should be None or "temporal". '
                       'Found: ' + str(sample_weight_mode))
    if len(y.shape) < 3:
      raise ValueError('Found a sample_weight array for '
                       'an input with shape ' + str(y.shape) + '. '
                       'Timestep-wise sample weighting (use of '
                       'sample_weight_mode="temporal") is restricted to '
                       'outputs that are at least 3D, i.e. that have '
                       'a time dimension.')
    if sample_weight is not None and len(sample_weight.shape) != 2:
      raise ValueError('Found a sample_weight array with shape ' +
                       str(sample_weight.shape) + '. '
                       'In order to use timestep-wise sample weighting, '
                       'you should pass a 2D sample_weight array.')
  else:
    if sample_weight is not None and len(sample_weight.shape) != 1:
      raise ValueError(
          'Found a sample_weight array with shape {}. In order to '
          'use timestep-wise sample weights, you should specify '
          'sample_weight_mode="temporal" in compile(); founssd "{}" '
          'instead. If you just mean to use sample-wise weights, '
          'make sure your sample_weight array is 1D.'.format(
              sample_weight.shape, sample_weight_mode))

  if sample_weight is not None:
    if len(sample_weight.shape) > len(y.shape):
      raise ValueError('Found a sample_weight with shape' +
                       str(sample_weight.shape) + '.'
                       'Expected sample_weight with rank '
                       'less than or equal to ' + str(len(y.shape)))

    if (not tensor_util.is_tf_type(sample_weight) and
        y.shape[:sample_weight.ndim] != sample_weight.shape):
      raise ValueError('Found a sample_weight array with shape ' +
                       str(sample_weight.shape) + ' for an input with shape ' +
                       str(y.shape) + '. '
                       'sample_weight cannot be broadcast.')

  # Class weights applied per-sample.
  class_sample_weight = None
  if isinstance(class_weight, dict):
    if len(y.shape) > 2:
      raise ValueError('`class_weight` not supported for '
                       '3+ dimensional targets.')

    if tensor_util.is_tf_type(y):
      # Few classes are expected, so densifying is reasonable.
      keys = np.array(sorted(class_weight.keys()))
      values = np.array([class_weight[i] for i in keys])
      weight_vector = np.zeros(np.max(keys) + 1)
      weight_vector[:] = np.nan
      weight_vector[keys] = values

      y_classes = smart_cond.smart_cond(
          len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1,
          lambda: backend.argmax(y, axis=1),
          lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
      class_sample_weight = array_ops.gather(weight_vector, y_classes)
      gen_array_ops.check_numerics(
          class_sample_weight,
          'Invalid classes or class weights detected. NaN values indicate that '
          'an appropriate class weight could not be determined.')
      class_sample_weight = math_ops.cast(class_sample_weight, backend.floatx())
      if sample_weight is not None:
        sample_weight = math_ops.cast(
            tensor_conversion.convert_to_tensor_v2_with_dispatch(sample_weight),
            backend.floatx(),
        )
    else:
      y_classes = y
      if len(y.shape) == 2:
        if y.shape[1] > 1:
          y_classes = np.argmax(y, axis=1)
        elif y.shape[1] == 1:
          y_classes = np.reshape(y, y.shape[0])

      class_sample_weight = np.asarray(
          [class_weight[cls] for cls in y_classes if cls in class_weight])

      if len(class_sample_weight) != len(y_classes):
        # subtract the sets to pick all missing classes
        existing_classes = set(y_classes)
        existing_class_weight = set(class_weight.keys())
        raise ValueError(
            '`class_weight` must contain all classes in the data.'
            ' The classes %s exist in the data but not in '
            '`class_weight`.' % (existing_classes - existing_class_weight))

  if class_sample_weight is not None and sample_weight is not None:
    # Multiply weights if both are provided.
    return class_sample_weight * sample_weight
  if sample_weight is not None:
    return sample_weight
  if class_sample_weight is not None:
    return class_sample_weight
  return None


def has_symbolic_tensors(ls):
  if context.executing_eagerly():
    return False
  return has_tensors(ls)


def has_tensors(ls):
  """Returns true if `ls` contains tensors."""
  # Note: at some point in time ragged tensors didn't count as tensors, so this
  # returned false for ragged tensors. Making this return true fails some tests
  # which would then require a steps_per_epoch argument.
  if isinstance(ls, (list, tuple)):
    return any(
        tensor_util.is_tf_type(v) and
        not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
  if isinstance(ls, dict):
    return any(
        tensor_util.is_tf_type(v) and
        not isinstance(v, ragged_tensor.RaggedTensor)
        for _, v in ls.items())
  return tensor_util.is_tf_type(ls) and not isinstance(
      ls, ragged_tensor.RaggedTensor)


def get_metric_name(metric, weighted=False):
  """Returns the name corresponding to the given metric input.

  Args:
    metric: Metric function name or reference.
    weighted: Boolean indicating if the given metric is weighted.

  Returns:
      The metric name.
  """
  if tf2.enabled():
    # We keep the string that the user has set in compile as the metric name.
    if isinstance(metric, str):
      return metric

    metric = metrics_module.get(metric)
    return metric.name if hasattr(metric, 'name') else metric.__name__
  else:
    metric_name_prefix = 'weighted_' if weighted else ''
    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
      if metric in ('accuracy', 'acc'):
        suffix = 'acc'
      elif metric in ('crossentropy', 'ce'):
        suffix = 'ce'
    else:
      metric_fn = metrics_module.get(metric)
      # Get metric name as string
      if hasattr(metric_fn, 'name'):
        suffix = metric_fn.name
      else:
        suffix = metric_fn.__name__
    metric_name = metric_name_prefix + suffix
    return metric_name


def get_metric_function(metric, output_shape=None, loss_fn=None):
  """Returns the metric function corresponding to the given metric input.

  Args:
      metric: Metric function name or reference.
      output_shape: The shape of the output that this metric will be calculated
        for.
      loss_fn: The loss function used.

  Returns:
      The metric function.
  """
  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
    return metrics_module.get(metric)

  is_sparse_categorical_crossentropy = (
      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
      (isinstance(loss_fn, losses.LossFunctionWrapper) and
       loss_fn.fn == losses.sparse_categorical_crossentropy))

  is_binary_crossentropy = (
      isinstance(loss_fn, losses.BinaryCrossentropy) or
      (isinstance(loss_fn, losses.LossFunctionWrapper) and
       loss_fn.fn == losses.binary_crossentropy))

  if metric in ['accuracy', 'acc']:
    if output_shape[-1] == 1 or is_binary_crossentropy:
      return metrics_module.binary_accuracy
    elif is_sparse_categorical_crossentropy:
      return metrics_module.sparse_categorical_accuracy
    # If the output_shape[-1] is not 1, then we know output is `categorical`.
    # We assume it is sparse categorical only if loss is explicitly given
    # as sparse categorical crossentropy loss.
    return metrics_module.categorical_accuracy
  else:
    if output_shape[-1] == 1 or is_binary_crossentropy:
      return metrics_module.binary_crossentropy
    elif is_sparse_categorical_crossentropy:
      return metrics_module.sparse_categorical_crossentropy
    return metrics_module.categorical_crossentropy


def call_metric_function(metric_fn,
                         y_true,
                         y_pred=None,
                         weights=None,
                         mask=None):
  """Invokes metric function and returns the metric result tensor."""
  if mask is not None:
    mask = math_ops.cast(mask, y_pred.dtype)
    if weights is None:
      # Use mask as sample weight.
      weights = mask
    else:
      # Update dimensions of weights to match with mask.
      weights = math_ops.cast(weights, dtype=y_pred.dtype)
      mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
          mask, sample_weight=weights)
      weights *= mask

  if y_pred is not None:
    return metric_fn(y_true, y_pred, sample_weight=weights)
  # `Mean` metric only takes a single value.
  return metric_fn(y_true, sample_weight=weights)


def get_loss_function(loss):
  """Returns the loss corresponding to the loss input in `compile` API."""
  if loss is None or isinstance(loss, losses.Loss):
    return loss

  if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
    # It is not safe to assume that the loss takes no constructor arguments.
    raise ValueError(
        'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
        'before passing them to Model.compile.'.format(loss))

  # Deserialize loss configuration, if needed.
  if isinstance(loss, collections.abc.Mapping):
    loss = losses.get(loss)

  # Custom callable class.
  if callable(loss) and not hasattr(loss, '__name__'):
    return loss

  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
  # in `LossFunctionWrapper` class.
  loss_fn = losses.get(loss)

  # For losses which are given as strings/functions in the compile API,
  # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
  # (both in distribution strategy context and otherwise).
  return losses.LossFunctionWrapper(
      loss_fn,
      name=loss_fn.__name__,
      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)


def validate_dataset_input(x, y, sample_weight, validation_split=None):
  """Validates user input arguments when a dataset iterator is passed.

  Args:
    x: Input data. A `tf.data` dataset or iterator.
    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
      Expected to be `None` when `x` is a dataset iterator.
    sample_weight: An optional sample-weight array passed by the user to weight
      the importance of each sample in `x`. Expected to be `None` when `x` is a
      dataset iterator
    validation_split: Float between 0 and 1. Fraction of the training data to be
      used as validation data. Expected to be `None` when `x` is a dataset
      iterator.

  Raises:
    ValueError: if argument `y` or `sample_weight` or `validation_split` are
        provided by user.
  """
  if y is not None:
    raise ValueError('You passed a dataset or dataset iterator (%s) as '
                     'input `x` to your model. In that case, you should '
                     'not specify a target (`y`) argument, since the dataset '
                     'or dataset iterator generates both input data and '
                     'target data. '
                     'Received: %s' % (x, y))
  if sample_weight is not None:
    raise ValueError('`sample_weight` argument is not supported when input '
                     '`x` is a dataset or a dataset iterator. Instead, you'
                     'can provide sample_weight as the third element  of your'
                     'dataset, i.e. (inputs, targets, sample_weight). '
                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
  if validation_split is not None and validation_split != 0.0:
    raise ValueError(
        '`validation_split` argument is not supported when '
        'input `x` is a dataset or a dataset iterator. '
        'Received: x=%s, validation_split=%f' % (x, validation_split))


def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
  """Helper function to validate either inputs or targets."""
  if isinstance(inp, (list, tuple)):
    if not all(isinstance(v, np.ndarray) or
               tensor_util.is_tf_type(v) for v in inp):
      raise ValueError(
          'Please provide as model inputs either a single array or a list of '
          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
  elif isinstance(inp, dict):
    if not allow_dict:
      raise ValueError(
          'You cannot pass a dictionary as model {}.'.format(field_name))
  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp):
    raise ValueError(
        'Please provide as model inputs either a single array or a list of '
        'arrays. You passed: {}={}'.format(field_name, orig_inp))


def check_generator_arguments(y=None, sample_weight=None,
                              validation_split=None):
  """Validates arguments passed when using a generator."""
  if y is not None:
    raise ValueError('`y` argument is not supported when data is'
                     'a generator or Sequence instance. Instead pass targets'
                     ' as the second element of the generator.')
  if sample_weight is not None:
    raise ValueError('`sample_weight` argument is not supported when data is'
                     'a generator or Sequence instance. Instead pass sample'
                     ' weights as the third element of the generator.')
  if validation_split:
    raise ValueError('If your data is in the form of a Python generator, '
                     'you cannot use `validation_split`.')


def check_steps_argument(input_data, steps, steps_name):
  """Validates `steps` argument based on input data's type.

  The cases when `steps` value must be provided are when
    1. input data passed is an iterator.
    2. model was built on top of symbolic tensors, input data is not
       required and is `None`.
    3. input data passed is a symbolic tensor.

  Args:
      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
        tf.data.Dataset iterator or `None`.
      steps: Integer or `None`. Total number of steps (batches of samples) to
        execute.
      steps_name: The public API's parameter name for `steps`.

  Returns:
    boolean, True if `steps` argument is required, else False.

  Raises:
      ValueError: if `steps` argument is required for given input data type
        but not provided.
  """
  is_x_iterator = isinstance(
      input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase))
  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
      (isinstance(input_data, list) and not input_data)):
    if steps is None:
      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
      raise ValueError('When using {input_type} as input to a model, you should'
                       ' specify the `{steps_name}` argument.'.format(
                           input_type=input_type_str, steps_name=steps_name))
    return True

  if isinstance(input_data, (data_types.DatasetV1, data_types.DatasetV2)):
    return True

  if steps is not None:
    list_types = (np.ndarray, list, tuple)
    if (isinstance(input_data, list_types) or
        (isinstance(input_data, dict) and
         any(isinstance(v, list_types) for v in input_data.values()))):
      logging.warning('When passing input data as arrays, do not specify '
                      '`steps_per_epoch`/`steps` argument. '
                      'Please use `batch_size` instead.')
  return False


def cast_single_tensor(x, dtype=None):
  if isinstance(x, np.ndarray):
    x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x)
  dtype = dtype or backend.floatx()
  if x.dtype.is_floating:
    return math_ops.cast(x, dtype=dtype)
  return x


def cast_if_floating_dtype_and_mismatch(targets, outputs):
  """Returns target data tensors using correct datatype.

  Checks that each target and output pair are the same datatype. If not, casts
  the target to the output's datatype.

  Args:
    targets: tensor or list of targets.
    outputs: tensor or list of outputs.

  Returns:
    Targets in appropriate datatype.
  """
  if tensor_util.is_tf_type(targets):
    # There is one target, so output[0] should be the only output.
    return cast_single_tensor(targets, dtype=outputs[0].dtype)
  new_targets = []
  for target, out in zip(targets, outputs):
    if isinstance(target, np.ndarray):
      target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target)
    if target.dtype != out.dtype:
      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
    else:
      new_targets.append(target)
  return new_targets


def cast_if_floating_dtype(x, dtype=None):
  """Casts the given data tensors to the default floating point type.

  Casts only if the input is already a floating point type.
  Args:
    x: tensor or list/tuple of tensors.
    dtype: The dtype to which Tensors should be cast.

  Returns:
    Converted input.
  """
  return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
                            x)


def cast_to_model_input_dtypes(x, model):
  """Casts the given data tensors to the dtypes of the model inputs.

  Args:
    x: tensor or list/tuple of tensors.
    model: The model.

  Returns:
    Converted input. Each tensor is casted to the corresponding input in
    `model.inputs`.
  """
  input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
  return nest.map_structure(math_ops.cast, x, input_dtypes)


def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
  """Prepares sample weight modes for the model.

  Args:
    training_endpoints: List of model _TrainingEndpoints.
    sample_weight_mode: sample weight mode user input passed from compile API.

  Raises:
    ValueError: In case of invalid `sample_weight_mode` input.
  """

  if isinstance(sample_weight_mode, collections.abc.Mapping):
    generic_utils.check_for_unexpected_keys(
        'sample_weight_mode', sample_weight_mode,
        [e.output_name for e in training_endpoints])

    for end_point in training_endpoints:
      if not end_point.should_skip_target_weights():
        if end_point.output_name not in sample_weight_mode:
          raise ValueError('Output ' + end_point.output_name +
                           'missing from `_sample_weight_modes` dictionary')
        else:
          end_point.sample_weight_mode = sample_weight_mode.get(
              end_point.output_name)
  elif isinstance(sample_weight_mode, (list, tuple)):
    if len(sample_weight_mode) != len(training_endpoints):
      raise ValueError('When passing a list as sample_weight_mode, '
                       'it should have one entry per model output. '
                       'The model has ' + str(len(training_endpoints)) +
                       ' outputs, but you passed ' +
                       str(len(sample_weight_mode)) + '_sample_weight_modes.')
    for mode, endpoint in zip(sample_weight_mode, training_endpoints):
      if not endpoint.should_skip_target_weights():
        endpoint.sample_weight_mode = mode
  else:
    for endpoint in training_endpoints:
      if not endpoint.should_skip_target_weights():
        endpoint.sample_weight_mode = sample_weight_mode


def prepare_loss_functions(loss, output_names):
  """Converts loss to a list of loss functions.

  Args:
      loss: String (name of objective function), objective function or
        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
        outputs, you can use a different loss on each output by passing a
        dictionary or a list of losses. The loss value that will be minimized by
        the model will then be the sum of all individual losses.
      output_names: List of model output names.

  Returns:
      A list of loss objective functions.

  Raises:
      ValueError: If loss is a dict with keys not in model output names,
          or if loss is a list with len not equal to model outputs.
  """
  if isinstance(loss, collections.abc.Mapping):
    generic_utils.check_for_unexpected_keys('loss', loss, output_names)
    loss_functions = []
    for name in output_names:
      if name not in loss:
        logging.warning(
            'Output {0} missing from loss dictionary. We assume '
            'this was done on purpose. The fit and evaluate APIs will not be '
            'expecting any data to be passed to {0}.'.format(name))
      loss_functions.append(get_loss_function(loss.get(name, None)))
  elif isinstance(loss, str):
    loss_functions = [get_loss_function(loss) for _ in output_names]
  elif isinstance(loss, collections.abc.Sequence):
    if len(loss) != len(output_names):
      raise ValueError('When passing a list as loss, it should have one entry '
                       'per model outputs. The model has {} outputs, but you '
                       'passed loss={}'.format(len(output_names), loss))
    loss_functions = nest.map_structure(get_loss_function, loss)
  else:
    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]

  return loss_functions


def prepare_loss_weights(training_endpoints, loss_weights=None):
  """Converts loss weights to a list of loss weights.

  The result loss weights will be populated on the training endpoint.

  Args:
      training_endpoints: List of model training endpoints.
      loss_weights: Optional list or dictionary specifying scalar coefficients
        (Python floats) to weight the loss contributions of different model
        outputs. The loss value that will be minimized by the model will then be
        the *weighted sum* of all individual losses, weighted by the
          `loss_weights` coefficients. If a list, it is expected to have a 1:1
            mapping to the model's outputs. If a dict, it is expected to map
            output names (strings) to scalar coefficients.

  Raises:
      ValueError: If loss weight is a dict with key not in model output names,
          or if loss is a list with len not equal to model outputs.
  """
  if loss_weights is None:
    for e in training_endpoints:
      e.loss_weight = 1.
  elif isinstance(loss_weights, collections.abc.Mapping):
    generic_utils.check_for_unexpected_keys(
        'loss_weights', loss_weights,
        [e.output_name for e in training_endpoints])
    for e in training_endpoints:
      e.loss_weight = loss_weights.get(e.output_name, 1.)
  elif isinstance(loss_weights, list):
    if len(loss_weights) != len(training_endpoints):
      raise ValueError('When passing a list as loss_weights, '
                       'it should have one entry per model output. '
                       'The model has ' + str(len(training_endpoints)) +
                       ' outputs, but you passed loss_weights=' +
                       str(loss_weights))
    for w, e in zip(loss_weights, training_endpoints):
      e.loss_weight = w
  else:
    raise TypeError('Could not interpret loss_weights argument: ' +
                    str(loss_weights) + ' - expected a list of dicts.')


# TODO(rohanj): This is a hack to get around not depending on feature_column and
# create a cyclical dependency. Figure out a cleaner solution
def is_feature_layer(layer):
  """Returns whether `layer` is a FeatureLayer or not."""
  return getattr(layer, '_is_feature_layer', False)


def is_eager_dataset_or_iterator(data):
  return context.executing_eagerly() and isinstance(
      data, (data_types.DatasetV1, data_types.DatasetV2,
             iterator_ops.IteratorBase))


# pylint: disable=protected-access
def get_dataset_graph_def(dataset):
  if context.executing_eagerly():
    graph_def_str = dataset._as_serialized_graph().numpy()
  else:
    graph_def_str = backend.get_value(dataset._as_serialized_graph())
  return graph_pb2.GraphDef().FromString(graph_def_str)


def verify_dataset_shuffled(x):
  """Verifies that the dataset is shuffled.

  Args:
    x: Dataset passed as an input to the model.

  Returns:
    boolean, whether the input dataset is shuffled or not.
  """
  assert isinstance(x, data_types.DatasetV2)
  graph_def = get_dataset_graph_def(x)
  for node in graph_def.node:
    if node.op.startswith('ShuffleDataset'):
      return True
  # Also check graph_def.library.function for ds.interleave or ds.flat_map
  for function in graph_def.library.function:
    for node in function.node_def:
      if node.op.startswith('ShuffleDataset'):
        return True
  logging.warning('Expected a shuffled dataset but input dataset `x` is '
                  'not shuffled. Please invoke `shuffle()` on input dataset.')
  return False


def is_dataset_or_iterator(data):
  return isinstance(data, (data_types.DatasetV1, data_types.DatasetV2,
                           iterator_ops.Iterator, iterator_ops.IteratorBase))


def get_iterator(dataset):
  """Create and initialize an iterator from a dataset."""
  if context.executing_eagerly():
    iterator = dataset_ops.make_one_shot_iterator(dataset)
  else:
    iterator = dataset_ops.make_initializable_iterator(dataset)
  initialize_iterator(iterator)
  return iterator


def initialize_iterator(iterator):
  if not context.executing_eagerly():
    init_op = iterator.initializer
    backend.get_session((init_op,)).run(init_op)


def extract_tensors_from_dataset(dataset):
  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.

  Args:
    dataset: Dataset instance.

  Returns:
    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
  """
  iterator = get_iterator(dataset)
  inputs, targets, sample_weight = unpack_iterator_input(iterator)
  return inputs, targets, sample_weight


def unpack_iterator_input(iterator):
  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.

  Args:
    iterator: Instance of a dataset iterator.

  Returns:
    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
  """
  try:
    next_element = iterator.get_next()
  except errors.OutOfRangeError:
    raise RuntimeError('Your dataset iterator ran out of data; '
                       'Make sure that your dataset can generate '
                       'required number of samples.')

  if isinstance(next_element, (list, tuple)):
    if len(next_element) not in [2, 3]:
      raise ValueError(
          'Please provide model inputs as a list or tuple of 2 or 3 '
          'elements: (input, target) or (input, target, sample_weights) '
          'Received %s' % next_element)
    if len(next_element) == 2:
      x, y = next_element
      weights = None
    else:
      x, y, weights = next_element
  else:
    x = next_element
    y = None
    weights = None
  return x, y, weights


def infer_steps_for_dataset(model,
                            dataset,
                            steps,
                            epochs=1,
                            steps_name='steps'):
  """Infers steps_per_epoch needed to loop through a dataset.

  Args:
      model: Keras model instance.
      dataset: Input data of type tf.data.Dataset.
      steps: Number of steps to draw from the dataset (may be None if unknown).
      epochs: Number of times to iterate over the dataset.
      steps_name: The string name of the steps argument, either `steps`,
        `validation_steps`, or `steps_per_epoch`. Only used for error message
        formatting.

  Returns:
    Integer or `None`. Inferred number of steps to loop through the dataset.
    `None` is returned if 1) the size of the dataset is unknown and `steps` was
    not specified, or 2) this is multi-worker training and auto sharding is
    enabled.

  Raises:
    ValueError: In case of invalid argument values.
  """
  assert isinstance(dataset, data_types.DatasetV2)
  if (model._in_multi_worker_mode() and
      (dataset.options().experimental_distribute.auto_shard_policy !=
       options_lib.AutoShardPolicy.OFF)):
    # If the dataset would be auto-sharded, we should not infer a local
    # steps_per_epoch due to the possible inbalanced sharding between workers.
    return None

  size = backend.get_value(cardinality.cardinality(dataset))
  if size == cardinality.INFINITE and steps is None:
    raise ValueError('When passing an infinitely repeating dataset, you '
                     'must specify the `%s` argument.' % (steps_name,))
  if size >= 0:
    if steps is not None and steps * epochs > size:
      if epochs > 1:
        raise ValueError('The dataset you passed contains %s batches, but you '
                         'passed `epochs=%s` and `%s=%s`, which is a total of '
                         '%s steps. We cannot draw that many steps from this '
                         'dataset. We suggest to set `%s=%s`.' %
                         (size, epochs, steps_name, steps, steps * epochs,
                          steps_name, size // epochs))
      else:
        raise ValueError('The dataset you passed contains %s batches, but you '
                         'passed `%s=%s`. We cannot draw that many steps from '
                         'this dataset. We suggest to set `%s=%s`.' %
                         (size, steps_name, steps, steps_name, size))
  if steps is None:
    if size >= 0:
      return size
    return None
  return steps


class ModelInputs(object):
  """Encapsulates model inputs.

  Allows for transforming model inputs while keeping the same structure.
  """

  def __init__(self, inputs):
    self._inputs = inputs
    self._is_dict = isinstance(self._inputs, dict)
    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))

    self._flattened_inputs = []
    self._input_names = []

    if self._is_dict:
      for k in sorted(self._inputs.keys()):
        self._flattened_inputs.append(self._inputs[k])
        self._input_names.append(k)
    else:
      self._flattened_inputs = nest.flatten(self._inputs)
      self._input_names = [
          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
      ]

  def get_input_names(self):
    """Returns keys to name inputs by.

    In case inputs provided were a list, tuple or single entry, we make up a
    key 'input_%d'. For dictionary case, we return a sorted list of keys.
    """
    return self._input_names

  def get_symbolic_inputs(self, return_single_as_list=False):
    """Returns inputs to be set as self.inputs for a model."""
    # TODO(karmel): There is a side-effect here where what you get
    # with as_list and as_dict depends on whether you have called this
    # method first, since it modifies in place.
    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
      if isinstance(v, (list, float, int)):
        v = np.asarray(v)
        if v.ndim == 1:
          v = np.expand_dims(v, 1)

      if isinstance(v, np.ndarray):
        # We fix the placeholder shape except the batch size.
        # This is suboptimal, but it is the best we can do with the info
        # we have. The user should call `model._set_inputs(placeholders)`
        # to specify custom placeholders if the need arises.
        shape = (None,) + tuple(v.shape[1:])
        if shape == (None,):
          shape = (None, 1)
        dtype = dtypes.as_dtype(v.dtype)
        if dtype.is_floating:
          dtype = backend.floatx()
        v = backend.placeholder(shape=shape, name=k, dtype=dtype)
      elif isinstance(v, tensor_spec.TensorSpec):
        shape = (None,) + tuple(v.shape.as_list()[1:])
        if shape == (None,):
          shape = (None, 1)
        v = backend.placeholder(shape=shape, name=k, dtype=v.dtype)

      self._flattened_inputs[i] = v

    if self._is_dict:
      return dict(zip(self._input_names, self._flattened_inputs))
    if self._is_single_input and not return_single_as_list:
      return self._flattened_inputs[0]
    return self._flattened_inputs

  def as_dict(self):
    """An iterable over a dictionary version of inputs."""
    for k, v in zip(self._input_names, self._flattened_inputs):
      yield k, v

  def as_list(self):
    """Returning the inputs as a list."""
    return self._flattened_inputs


# Allow use of methods not exposed to the user.
# pylint: disable=protected-access


# pylint: enable=protected-access


def generic_output_names(outputs_list):
  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]


def should_run_validation(validation_freq, epoch):
  """Checks if validation should be run this epoch.

  Args:
    validation_freq: Integer or list. If an integer, specifies how many training
      epochs to run before a new validation run is performed. If a list,
      specifies the epochs on which to run validation.
    epoch: Integer, the number of the training epoch just completed.

  Returns:
    Bool, True if validation should be run.

  Raises:
    ValueError: if `validation_freq` is an Integer and less than 1, or if
    it is neither an Integer nor a Sequence.
  """
  # `epoch` is 0-indexed internally but 1-indexed in the public API.
  one_indexed_epoch = epoch + 1

  if isinstance(validation_freq, int):
    if validation_freq < 1:
      raise ValueError('`validation_freq` can not be less than 1.')
    return one_indexed_epoch % validation_freq == 0

  if not isinstance(validation_freq, collections.abc.Container):
    raise ValueError('`validation_freq` must be an Integer or '
                     '`collections.abc.Container` (e.g. list, tuple, etc.)')
  return one_indexed_epoch in validation_freq


def split_training_and_validation_data(x, y, sample_weights, validation_split):
  """Split input data into train/eval section based on validation_split."""
  if has_symbolic_tensors(x):
    raise ValueError('If your data is in the form of symbolic tensors, '
                     'you cannot use `validation_split`.')
  if hasattr(x[0], 'shape'):
    split_at = int(x[0].shape[0] * (1. - validation_split))
  else:
    split_at = int(len(x[0]) * (1. - validation_split))
  x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
              generic_utils.slice_arrays(x, split_at))
  y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
              generic_utils.slice_arrays(y, split_at))
  if sample_weights:
    sample_weights, val_sample_weights = (
        generic_utils.slice_arrays(sample_weights, 0, split_at),
        generic_utils.slice_arrays(sample_weights, split_at),
    )
  else:
    val_sample_weights = None
  return x, y, sample_weights, val_x, val_y, val_sample_weights


def unpack_validation_data(validation_data, raise_if_ambiguous=True):
  """Unpack validation data based input type.

  The validation data is not touched if its dataset or dataset iterator.
  For other type of input (Numpy or tensor), it will be unpacked into tuple of
  3 which is x, y and sample weights.

  Args:
    validation_data: dataset, dataset iterator, or numpy, tensor tuple.
    raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
      parsed. Otherwise simply return validation_data, None, None and defer the
      decision to the caller.

  Returns:
    tuple of 3, (x, y, sample_weights) for numpy and tensor input.
  """
  if (isinstance(validation_data, (iterator_ops.Iterator,
                                   iterator_ops.IteratorBase,
                                   data_types.DatasetV2,
                                   data_utils.Sequence))
      or not hasattr(validation_data, '__len__')):
    val_x = validation_data
    val_y = None
    val_sample_weight = None
  elif len(validation_data) == 2:
    try:
      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
      val_sample_weight = None
    except ValueError:
      val_x, val_y, val_sample_weight = validation_data, None, None
  elif len(validation_data) == 3:
    try:
      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
    except ValueError:
      val_x, val_y, val_sample_weight = validation_data, None, None
  else:
    if raise_if_ambiguous:
      raise ValueError(
          'When passing a `validation_data` argument, '
          'it must contain either 2 items (x_val, y_val), '
          'or 3 items (x_val, y_val, val_sample_weights), '
          'or alternatively it could be a dataset or a '
          'dataset or a dataset iterator. '
          'However we received `validation_data=%s`' % validation_data)
    val_x, val_y, val_sample_weight = validation_data, None, None
  return val_x, val_y, val_sample_weight


class TrainingLoop(object):
  """TrainingLoop is a wrapper class around the training logic.

  This class is trying to encapsulate the different logic of fit/eval/predict
  with regard to different data input and model condition.

  Note that TrainingLoop is stateless, which means it doesn't contain any
  internal field and can be reused with different model and inputs.
  """

  def fit(self,
          model,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_freq=1,
          **kwargs):
    """Train the model with the inputs and targets."""
    raise NotImplementedError()

  def evaluate(self,
               model,
               x=None,
               y=None,
               batch_size=None,
               verbose=1,
               sample_weight=None,
               steps=None,
               callbacks=None,
               **kwargs):
    """Returns the loss value & metrics values for the model in test mode."""
    raise NotImplementedError()

  def predict(self,
              model,
              x,
              batch_size=None,
              verbose=0,
              steps=None,
              callbacks=None,
              **kwargs):
    raise NotImplementedError()