tensorflow/tensorflow

View on GitHub
tensorflow/python/distribute/input_lib.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed inputs."""

import functools
import sys
import time

import six

from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.distribute_lib import InputReplicationMode
from tensorflow.python.eager import context
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond as tf_cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import while_loop
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import distribute as distribute_types
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc


_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler(
    "/tensorflow/api/distribution_strategy/"
    "distributed_dataset_initialization_time_milliseconds",
    monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26),
    "Track the time (in milliseconds) to initialize distributed datasets.",
    "strategy", "workers")

_distributed_dataset_from_function_initialization_time_milliseconds = (
    monitoring.Sampler(
        "/tensorflow/api/distribution_strategy/"
        "distributed_dataset_from_function_initialization_time_milliseconds",
        monitoring.ExponentialBuckets(
            scale=1, growth_factor=2, bucket_count=26),
        "Track the time (in milliseconds) to initialize distributed datasets "
        "from function.",
        "strategy", "workers"))


def get_iterator_spec_from_dataset(strategy, dataset):
  """Returns an iterator spec from dataset function.

  This function constructs type spec for iterator obtained from
  iter(dataset).

  Args:
    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
    dataset: A tf.data.Dataset instance. If using a function that returns a
      tf.data.Dataset instance, pass dataset_fn.structured_outputs.

  Returns:
    A type_spec for iterator for dataset instance.

  """
  # pylint: disable=protected-access
  output_element_spec = dataset.element_spec
  if isinstance(dataset._type_spec,
                (DistributedDatasetSpec,
                 DistributedDatasetsFromFunctionSpec)):
    iterator_type_spec = DistributedIteratorSpec(
        strategy.extended._input_workers_with_options(),
        output_element_spec,
        strategy.extended._container_strategy(),
        options=None,
        cardinality=dataset.cardinality,
        enable_get_next_as_optional=True)
  else:
    if strategy.extended._num_gpus_per_worker:
      logging.warning(
          f"{strategy.extended._num_gpus_per_worker} GPUs "
          "are allocated per worker. Please use DistributedDataset by "
          "calling strategy.experimental_distribute_dataset or strategy."
          "distribute_datasets_from_function to make best use of GPU "
          "resources"
      )
    iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec)
  return iterator_type_spec
  # pylint: enable=protected-access


class InputWorkers(object):
  """A 1-to-many mapping from input worker devices to compute devices."""

  # TODO(ishark): Remove option canonicalize_devices and make all the callers
  # pass canonicalized or raw device strings as relevant from strategy.
  def __init__(self,
               worker_device_pairs,
               canonicalize_devices=True):
    """Initialize an `InputWorkers` object.

    Args:
      worker_device_pairs: A sequence of pairs: `(input device, a tuple of
        compute devices fed by that input device)`.
      canonicalize_devices: Whether to canonicalize devices for workers fully or
        partially. If False, it will partially canonicalize devices by removing
        job and task.
    """
    self._worker_device_pairs = worker_device_pairs
    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
    self._canonicalize_devices = canonicalize_devices
    if canonicalize_devices:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize(d)
                for d in f)
          for _, f in self._worker_device_pairs)
    else:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize_without_job_and_task(d)
                for d in f)
          for _, f in self._worker_device_pairs)

  @property
  def num_workers(self):
    return len(self._input_worker_devices)

  @property
  def worker_devices(self):
    return self._input_worker_devices

  def compute_devices_for_worker(self, worker_index):
    return self._fed_devices[worker_index]

  def __repr__(self):
    devices = self.worker_devices
    debug_repr = ",\n".join("  %d %s: %s" %
                            (i, devices[i], self._fed_devices[i])
                            for i in range(len(devices)))
    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)

  def serialize(self):
    return (self._worker_device_pairs, self._canonicalize_devices)

  def deserialize(self, serialized):
    return InputWorkers(serialized)


def _calculate_replicas_with_values(strategy, input_workers, optional_list):
  """Calcualates the number of replicas that have values.

  Args:
    strategy: the `tf.distribute.Strategy`.
    input_workers: the `InputWorkers`.
    optional_list: a list of lists `tf.experimental.Optional`. The values from
      each compute device grouped by the input device.

  Returns:
    A scalar Tensor.
  """
  worker_has_values = []
  for worker, optionals in zip(input_workers.worker_devices, optional_list):
    with ops.device(worker):
      device_has_values = [
          math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
      ]
      worker_has_values.append(
          math_ops.reduce_sum(device_has_values, keepdims=True))
  client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
  if strategy.extended._in_multi_worker_mode():  # pylint: disable=protected-access
    global_has_values = strategy.reduce(
        reduce_util.ReduceOp.SUM, client_has_values, axis=None)
    return array_ops.reshape(global_has_values, [])
  else:
    return array_ops.reshape(client_has_values, [])


def _is_statically_shaped(element_spec):
  """Test if an iterator output is statically shaped.

  For sparse and ragged tensors this only tests the batch dimension.

  Args:
    element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
      dataset of the iterator.

  Returns:
    True if the shape is static, false otherwise.
  """

  for spec in nest.flatten(element_spec):
    if isinstance(
        spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
      # For sparse or ragged tensor, we should only check the first
      # dimension in order to get_next_as_optional. This is because
      # when these tensors get batched by dataset only the batch dimension
      # is set.
      if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
        return False
    else:
      for component in spec._flat_tensor_specs:  # pylint: disable=protected-access
        if not component.shape.is_fully_defined():
          return False
  return True


class DistributedIteratorBase(collections_abc.Iterator,
                              distribute_types.DistributedIteratorInterface):
  """Common implementation for all input iterators."""

  # pylint: disable=super-init-not-called
  def __init__(
      self,
      input_workers,
      iterators,
      strategy,
      cardinality,
      enable_get_next_as_optional,
      replica_order=None,
  ):
    assert isinstance(input_workers, InputWorkers)
    if not input_workers.worker_devices:
      raise ValueError("Should have at least one worker for input iterator.")

    self._iterators = iterators
    self._input_workers = input_workers
    self._strategy = strategy
    self._cardinality = cardinality
    self._enable_get_next_as_optional = enable_get_next_as_optional
    self._replica_order = replica_order

  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except errors.OutOfRangeError:
      raise StopIteration

  def __iter__(self):
    return self

  def get_next_as_optional(self):
    # Ideally get_next_as_optional() should be consistent with get_next(), but
    # we used to always do partial batch handling in get_next_as_optional(). We
    # are keeping this behavior for now until we understantd the impact.

    # Skip partial batch handling when the dataset is infinite or empty, as
    # there won't be any partial batches in those cases. This gives the user
    # more static shapes as it avoids the tf.cond. Note that for empty datasets,
    # we can only skip in single client mode, as the dataset can be non-empty on
    # other workers.
    if self._cardinality == cardinality_lib.INFINITE:
      return optional_ops.Optional.from_value(
          self._get_next_no_partial_batch_handling())
    if (self._cardinality == 0 and
        not self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return optional_ops.Optional.empty(self._element_spec)

    optional_list = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      with ops.device(worker):
        optional_list.append(self._iterators[i].get_next_as_optional_list())

    def _create_optional_with_dummy():
      value_list = _get_value_or_dummy(
          self._input_workers, optional_list, produce_dummy=True)

      if self._replica_order is not None:
        value_list = self._reorder_replicas(value_list)

      per_replica = _create_per_replica(value_list, self._strategy)
      return optional_ops.Optional.from_value(per_replica)

    def _create_empty_optional():
      return optional_ops.Optional.empty(self._element_spec)

    num_replicas_with_values = _calculate_replicas_with_values(
        self._strategy, self._input_workers, optional_list)

    return tf_cond.cond(
        num_replicas_with_values > 0,
        _create_optional_with_dummy,
        _create_empty_optional,
        strict=True)

  def get_next(self, name=None):
    """Returns the next input from the iterator for all replicas."""
    with distribute_lib.enter_or_assert_strategy(
        self._strategy):
      if distribute_lib.get_replica_context() is not None:
        raise ValueError("next(iterator) should be called from outside of "
                         "replica_fn. e.g. strategy.run(replica_fn, "
                         "args=(next(iterator),))")

    if not self._enable_get_next_as_optional:
      return self._get_next_no_partial_batch_handling(name)

    optional_list = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      with ops.device(worker):
        optional_list.append(self._iterators[i].get_next_as_optional_list())
    num_replicas_with_values = _calculate_replicas_with_values(
        self._strategy, self._input_workers, optional_list)

    def _value_or_dummy():
      value_list = _get_value_or_dummy(
          self._input_workers, optional_list, produce_dummy=True)

      if self._replica_order is not None:
        value_list = self._reorder_replicas(value_list)

      return _create_per_replica(value_list, self._strategy)

    def _eof():
      # Optional.get_value raises InvalidArgumentError when there's no value,
      # so we need to call GetNext to raise EOFError.
      return self._get_next_no_partial_batch_handling()

    return tf_cond.cond(
        num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)

  def _get_next_no_partial_batch_handling(self, name=None):
    replicas = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      if name is not None:
        d = tf_device.DeviceSpec.from_string(worker)
        new_name = "%s_%s_%d" % (name, d.job, d.task)
      else:
        new_name = None
      with ops.device(worker):
        # Make `replicas` a flat list of values across all replicas.
        replicas.extend(self._iterators[i].get_next_as_list(new_name))

    if self._replica_order is not None:
      replicas = self._reorder_replicas(replicas)

    return _create_per_replica(replicas, self._strategy)

  def _reorder_replicas(self, replicas):
    assert len(self._replica_order) == len(
        replicas
    ), "replica order size ({}) != replicas size ({})!".format(
        len(self._replica_order), len(replicas)
    )
    return [replicas[i] for i in self._replica_order]


class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec):
  """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction."""

  __slots__ = [
      "_input_workers", "_element_spec", "_strategy", "_cardinality",
      "_enable_get_next_as_optional", "_options", "_canonicalize_devices"
  ]

  def __init__(
      self,
      input_workers,
      element_spec,
      strategy,
      options,
      cardinality=cardinality_lib.UNKNOWN,
      enable_get_next_as_optional=None,
      replica_order=None,
  ):
    # We don't want to allow deserialization of this class because we don't
    # serialize the strategy object. Currently the only places where
    # _deserialize is called is when we save/restore using SavedModels.
    if isinstance(input_workers, tuple):
      raise NotImplementedError("DistributedIteratorSpec does not have support "
                                "for deserialization.")
    else:
      self._input_workers = input_workers
      self._element_spec = element_spec
      self._strategy = strategy
      self._cardinality = cardinality
      self._enable_get_next_as_optional = enable_get_next_as_optional
      self._options = options
      if self._strategy:
        self._canonicalize_devices = getattr(self._strategy,
                                             "_canonicalize_devices", True)
      else:
        self._canonicalize_devices = True
      self._replica_order = replica_order

  def _serialize(self):
    # We cannot serialize the strategy object so we convert it to an id that we
    # can use for comparison.
    return (self._input_workers.serialize(), self._element_spec,
            id(self._strategy), id(self._options))

  def _deserialize(self):
    raise ValueError(
        f"Deserialization is currently unsupported for {type(self)}.")

  def sanity_check_type(self, other):
    """Returns the most specific TypeSpec compatible with `self` and `other`.

    Args:
      other: A `TypeSpec`.

    Raises:
      ValueError: If there is no TypeSpec that is compatible with both `self`
        and `other`.
    """
    # pylint: disable=protected-access
    if type(self) is not type(other):
      raise ValueError("No TypeSpec is compatible with both %s and %s" %
                       (self, other))
    if self._input_workers.serialize() != other._input_workers.serialize():
      raise ValueError("_input_workers is not compatible with both %s "
                       "and %s" % (self, other))
    if self._strategy is not other._strategy:
      raise ValueError("tf.distribute strategy is not compatible with both %s "
                       "and %s" % (self, other))

  def is_subtype_of(self, other):
    """Returns True if `self` is subtype of `other`.

    Args:
      other: A `TypeSpec`.
    """
    try:
      self.sanity_check_type(other)
      nest.assert_same_structure(self._element_spec, other._element_spec)  # pylint: disable=protected-access
    except (TypeError, ValueError):
      return False

    self_elements = nest.flatten(self._element_spec)
    other_elements = nest.flatten(other._element_spec)  # pylint: disable=protected-access

    return all(
        self_element.is_subtype_of(other_element)
        for (self_element, other_element) in zip(self_elements, other_elements))

  def most_specific_common_supertype(self, others):
    """Returns the most specific supertype of `self` and `others`.

    Args:
      others: A Sequence of `TypeSpec`.

    Returns `None` if a supertype does not exist.
    """
    try:
      for other in others:
        self.sanity_check_type(other)
        nest.assert_same_structure(self._element_spec, other._element_spec)  # pylint: disable=protected-access
    except (TypeError, ValueError):
      return None

    self_elements = nest.flatten(self._element_spec)
    others_elements = [nest.flatten(other._element_spec) for other in others]  # pylint: disable=protected-access
    common_elements = [None] * len(self_elements)

    for i, self_element in enumerate(self_elements):
      common_elements[i] = self_element.most_specific_common_supertype(
          [other_elements[i] for other_elements in others_elements])
      if common_elements[i] is None:
        return None
    common_element_spec = nest.pack_sequence_as(self._element_spec,
                                                common_elements)
    return type(self)(
        self._input_workers,
        common_element_spec,
        self._strategy,
        self._options,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional)

  def _with_tensor_ranks_only(self):
    element_spec = nest.map_structure(
        lambda s: s._with_tensor_ranks_only(),  # pylint: disable=protected-access
        self._element_spec)
    return type(self)(
        self._input_workers,
        element_spec,
        self._strategy,
        self._options,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional)

  # TODO(b/206014848): Remove once names are not used.
  def _without_tensor_names(self):
    element_spec = nest.map_structure(
        lambda s: s._without_tensor_names(),  # pylint: disable=protected-access
        self._element_spec)
    return type(self)(
        self._input_workers,
        element_spec,
        self._strategy,
        self._options,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional)


class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec):
  """Type specification for `DistributedIterator`."""

  @property
  def value_type(self):
    return DistributedIterator

  @property
  def _component_specs(self):
    specs = []
    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access

    for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
      element_spec = nest.map_structure(
          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
      specs.append(
          _SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
                                           element_spec, self._options,
                                           self._canonicalize_devices))
    return specs

  def _to_components(self, value):
    return value._iterators  # pylint: disable=protected-access

  def _from_components(self, components):
    return DistributedIterator(
        input_workers=self._input_workers,
        iterators=None,
        components=components,
        element_spec=self._element_spec,
        strategy=self._strategy,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional,
        options=self._options,
        replica_order=self._replica_order,
    )

  @staticmethod
  def from_value(value):
    # pylint: disable=protected-access
    return DistributedIteratorSpec(
        value._input_workers,
        value._element_spec,
        value._strategy,
        value._options,
        cardinality=value._cardinality,
        enable_get_next_as_optional=value._enable_get_next_as_optional)


class DistributedIterator(DistributedIteratorBase,
                          composite_tensor.CompositeTensor):
  """Input Iterator for a distributed dataset."""

  def __init__(
      self,
      input_workers=None,
      iterators=None,
      strategy=None,
      components=None,
      element_spec=None,
      cardinality=cardinality_lib.UNKNOWN,
      enable_get_next_as_optional=False,
      options=None,
      replica_order=None,
  ):
    if input_workers is None:
      raise ValueError("`input_workers` should be "
                       "provided.")

    error_message = ("Either `input_workers` or "
                     "both `components` and `element_spec` need to be "
                     "provided.")
    self._options = options

    if iterators is None:
      if (components is None or element_spec is None):
        raise ValueError(error_message)
      self._element_spec = element_spec
      self._input_workers = input_workers
      self._iterators = components
      self._strategy = strategy
      self._cardinality = cardinality
      self._enable_get_next_as_optional = enable_get_next_as_optional
      self._replica_order = replica_order
    else:
      if (components is not None and element_spec is not None):
        raise ValueError(error_message)

      super(DistributedIterator, self).__init__(
          input_workers,
          iterators,
          strategy,
          cardinality,
          enable_get_next_as_optional,
          replica_order,
      )

  @property
  def element_spec(self):
    # When partial batch handling is enabled, always set the batch dimension to
    # None, otherwise we just follow element_spec of the underlying dataset
    # (whose batch dimension may also be None). This is because with partial
    # batching handling we could always produce empty batches.
    if (self._enable_get_next_as_optional and
        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return nest.map_structure(
          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
    return self._element_spec

  @property
  def _type_spec(self):
    # Note that we use actual element_spec instead of the rebatched-as-dynamic
    # one to create DistributedIteratorSpec, to be consistent with the
    # underlying iterators' specs.
    return DistributedIteratorSpec(
        self._input_workers,
        self._element_spec,
        self._strategy,
        self._options,
        self._cardinality,
        self._enable_get_next_as_optional,
        self._replica_order,
    )


class _IterableInput(collections_abc.Iterable,
                     distribute_types.DistributedDatasetInterface):
  """Base class for iterable inputs for distribution strategies."""

  # pylint: disable=super-init-not-called
  def __init__(self, input_workers):
    assert isinstance(input_workers, InputWorkers)
    self._input_workers = input_workers

  def __iter__(self):
    raise NotImplementedError("must be implemented in descendants")

  def reduce(self, initial_state, reduce_fn):
    """Execute a `reduce_fn` over all the elements of the input."""
    iterator = iter(self)
    optional_data = iterator.get_next_as_optional()

    def cond(optional_data, state):
      del state  # Unused.
      return optional_data.has_value()

    def loop_body(optional_data, state):
      """Executes `reduce_fn` in a loop till the dataset is empty."""
      state = reduce_fn(state, optional_data.get_value())
      optional_data = iterator.get_next_as_optional()
      return optional_data, state

    optional_data, final_state = while_loop.while_loop(
        cond,
        loop_body, [optional_data, initial_state],
        parallel_iterations=1,
        return_same_structure=True)
    return final_state


class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec):
  """Type specification for `DistributedDataset."""

  @property
  def value_type(self):
    return DistributedDataset

  @property
  def _component_specs(self):
    specs = []
    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access

    for i, _ in enumerate(worker_device_pairs):
      element_spec = nest.map_structure(
          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
      specs.append(dataset_ops.DatasetSpec(element_spec))
    return specs

  def _to_components(self, value):
    return value._cloned_datasets  # pylint: disable=protected-access

  def _from_components(self, components):
    return DistributedDataset(
        input_workers=self._input_workers,
        strategy=self._strategy,
        components=components,
        element_spec=self._element_spec,
        enable_get_next_as_optional=self._enable_get_next_as_optional,
        options=self._options,
        replica_order=self._replica_order,
    )

  @staticmethod
  def from_value(value):
    # pylint: disable=protected-access
    return DistributedDatasetSpec(
        value._input_workers,
        value._element_spec,
        value._strategy,
        value._options,
        enable_get_next_as_optional=value._enable_get_next_as_optional)
    # pylint: enable=protected-access


class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
  """Distributed dataset that supports prefetching to multiple devices."""

  def __init__(
      self,
      input_workers,
      strategy,
      dataset=None,
      num_replicas_in_sync=None,
      input_context=None,
      components=None,
      element_spec=None,
      enable_get_next_as_optional=None,
      build=True,
      options=None,
      replica_order=None,
  ):
    """Distribute the dataset on all workers.

    If `num_replicas_in_sync` is not None, we split each batch of the dataset
    into `num_replicas_in_sync` smaller batches, to be distributed among that
    worker's replicas, so that the batch size for a global step (across all
    workers and replicas) is as expected.

    Args:
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      dataset: `tf.data.Dataset` that will be used as the input source. Either
        dataset or components field should be passed when constructing
        DistributedDataset. Use this when contructing DistributedDataset from a
        new `tf.data.Dataset`. Use components when constructing using
        DistributedDatasetSpec.
      num_replicas_in_sync: Optional integer. If this is not None, the value is
        used to decide how to rebatch datasets into smaller batches so that the
        total batch size for each step (across all workers and replicas) adds up
        to `dataset`'s batch size.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      components: datasets when DistributedDataset is constructed from
        DistributedDatasetSpec. Either field dataset or components should be
        passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDataset and verified against element_spec from components.
      enable_get_next_as_optional: this is required when components is passed
        instead of dataset.
      build: whether to build underlying datasets when this object is created.
        This is only useful for `ParameterServerStrategy` now.
      options: `tf.distribute.InputOptions` used to control options on how this
        dataset is distributed.
      replica_order: the order of the replicas, which will be used to reorder
        the iterators to match the device order.
    """
    super(DistributedDataset, self).__init__(input_workers=input_workers)
    if input_workers is None or strategy is None:
      raise ValueError("input_workers and strategy are required arguments")
    if dataset is not None and components is not None:
      raise ValueError("Only one of dataset or components should be present")
    if dataset is None and components is None:
      raise ValueError("At least one of dataset or components should be passed")

    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options
    self._input_context = input_context
    self._num_replicas_in_sync = num_replicas_in_sync
    self._replica_order = replica_order

    if dataset is not None:
      self._original_dataset = dataset
      self._built = False
      if build:
        self.build()
    else:
      if not build:
        raise ValueError(
            "When constructing DistributedDataset with components, build "
            "should not be False. This is an internal error. Please file a "
            "bug.")
      if enable_get_next_as_optional is None:
        raise ValueError(
            "When constructing DistributedDataset with components, " +
            "enable_get_next_as_optional should also be passed")
      self._cloned_datasets = components
      self._cardinality = _cardinality(self._cloned_datasets[0])
      self._enable_get_next_as_optional = enable_get_next_as_optional

      assert element_spec is not None
      if element_spec != _create_distributed_tensor_spec(
          self._strategy, self._cloned_datasets[0].element_spec):
        raise ValueError("Mismatched element_spec from the passed components")
      self._element_spec = element_spec

      self._built = True

  def build(self, dataset_to_replace=None):
    assert not self._built
    dataset = dataset_to_replace or self._original_dataset
    self._cardinality = _cardinality(dataset)
    self._enable_get_next_as_optional = _enable_get_next_as_optional(
        self._strategy, dataset, self._cardinality)
    distribute_start_time_ns = time.time_ns()
    self._create_cloned_datasets_from_dataset(dataset, self._input_context,
                                              self._input_workers,
                                              self._strategy,
                                              self._num_replicas_in_sync)
    if context.executing_eagerly():
      # Records the time to initialize the distributed dataset.
      context.async_wait()
      distribute_duration_ms = (time.time_ns() -
                                distribute_start_time_ns) // 1_000_000
      _distributed_dataset_initialization_time_milliseconds.get_cell(
          self._strategy.__class__.__name__,
          str(self._input_workers.num_workers)).add(distribute_duration_ms)
    self._element_spec = _create_distributed_tensor_spec(
        self._strategy, self._cloned_datasets[0].element_spec)
    self._built = True

  def auto_shard(self, num_shards, shard_ix):
    assert (
        len(self._cloned_datasets) == len(self._input_workers.worker_devices)
    ), (
        f"datasets: {len(self._cloned_datasets)}, "
        f"input workers: {len(self._input_workers.worker_devices)}"
    )
    sharded_datasets = []
    for i in range(len(self._input_workers.worker_devices)):
      with ops.colocate_with(self._cloned_datasets[i]._variant_tensor):  # pylint:disable=protected-access
        sharded_datasets.append(
            input_ops.auto_shard_dataset(
                self._cloned_datasets[i], num_shards, shard_ix,
                self._num_replicas_in_sync
                ))
    return DistributedDataset(
        self._input_workers,
        self._strategy,
        components=sharded_datasets,
        element_spec=self._element_spec,
        options=self._options,
        enable_get_next_as_optional=self._enable_get_next_as_optional)

  @property
  def cardinality(self):
    if not self._built:
      raise ValueError(
          "Cannot get the cardinality of a dataset that is not built")
    return self._cardinality

  def _create_cloned_datasets_from_dataset(self, dataset, input_context,
                                           input_workers, strategy,
                                           num_replicas_in_sync):
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.

    # Additionally, we rebatch the dataset on each worker into
    # `num_replicas_in_sync` smaller batches to be distributed among that
    # worker's replicas, so that the batch size for a global step (across all
    # workers and replicas) adds up to the original dataset's batch size.
    if num_replicas_in_sync is not None and num_replicas_in_sync > 1:
      num_workers = input_context.num_input_pipelines if input_context else len(
          input_workers.worker_devices)
      rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
                                         num_replicas_in_sync)
    else:
      rebatch_fn = None
    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      if rebatch_fn is not None:
        dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
      dataset = input_ops.auto_shard_dataset(dataset,
                                             input_context.num_input_pipelines,
                                             input_context.input_pipeline_id,
                                             num_replicas_in_sync)
      self._cloned_datasets.append(dataset)
    else:
      replicated_ds = distribute.replicate(dataset,
                                           input_workers.worker_devices)
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = replicated_ds[worker]
          if rebatch_fn is not None:
            cloned_dataset = rebatch_fn(cloned_dataset, i)
          cloned_dataset = input_ops.auto_shard_dataset(
              cloned_dataset, len(input_workers.worker_devices), i,
              num_replicas_in_sync)
          self._cloned_datasets.append(cloned_dataset)

  def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
    """Returns a callable that rebatches the input dataset.

    Args:
      dataset: A `tf.data.Dataset` representing the dataset to be distributed.
      num_workers: An integer representing the number of workers to distribute
        `dataset` among.
      num_replicas_in_sync: An integer representing the number of replicas in
        sync across all workers.
    """
    if num_replicas_in_sync % num_workers:
      raise ValueError(
          "tf.distribute expects every worker to have the same number of "
          "replicas. However, encountered `num_replicas_in_sync` ({}) that "
          "cannot be divided by `num_workers` ({})".format(
              num_replicas_in_sync, num_workers))

    num_replicas_per_worker = num_replicas_in_sync // num_workers
    with ops.colocate_with(dataset._variant_tensor):  # pylint: disable=protected-access
      batch_size = distribute.compute_batch_size(dataset)

    def rebatch_fn(dataset, worker_index):
      try:

        def apply_rebatch():
          batch_sizes = distribute.batch_sizes_for_worker(
              batch_size, num_workers, num_replicas_per_worker, worker_index)
          return dataset.rebatch(batch_sizes).prefetch(num_replicas_per_worker)

        # pylint: disable=protected-access
        def apply_legacy_rebatch():
          return distribute._LegacyRebatchDataset(
              dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)

        with ops.colocate_with(dataset._variant_tensor):
          return tf_cond.cond(
              math_ops.not_equal(batch_size, -1),
              true_fn=apply_rebatch,
              false_fn=apply_legacy_rebatch)
      except errors.InvalidArgumentError as e:
        if "without encountering a batch" in str(e):
          six.reraise(
              ValueError,
              ValueError(
                  "Call the `batch` method on the input Dataset in order to be "
                  "able to split your input across {} replicas.\n Please see "
                  "the tf.distribute.Strategy guide. {}".format(
                      num_replicas_in_sync, e)),
              sys.exc_info()[2])
        else:
          raise

    return rebatch_fn

  def __iter__(self):
    if not (context.executing_eagerly() or
            ops.get_default_graph().building_function):
      raise RuntimeError("__iter__() is only supported inside of tf.function "
                         "or when eager execution is enabled.")
    if not self._built:
      raise ValueError("To use this dataset, you need to pass this dataset to "
                       "ClusterCoordinator.create_per_worker_dataset.")

    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
                                   True)

    worker_iterators = _create_iterators_per_worker(
        self._cloned_datasets,
        self._input_workers,
        options=self._options,
        canonicalize_devices=canonicalize_devices)
    iterator = DistributedIterator(
        self._input_workers,
        worker_iterators,
        self._strategy,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional,
        options=self._options,
        replica_order=self._replica_order,
    )
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access

    # When async eager is enabled, sometimes the iterator may not finish
    # initialization before passing to a multi device function, add a sync point
    # here to make sure all underlying iterators are initialized.
    if context.executing_eagerly():
      context.async_wait()

    return iterator

  @property
  def element_spec(self):
    """The type specification of an element of this dataset."""
    # When partial batch handling is enabled, always set the batch dimension to
    # None, otherwise we just follow element_spec of the underlying dataset
    # (whose batch dimension may also be None). This is because with partial
    # batching handling we could always produce empty batches.
    if (self._enable_get_next_as_optional and
        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return nest.map_structure(
          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
    return self._element_spec

  @property
  def _type_spec(self):
    return DistributedDatasetSpec(
        self._input_workers,
        self._element_spec,
        self._strategy,
        self._options,
        enable_get_next_as_optional=self._enable_get_next_as_optional)


class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec):
  """Type specification for `DistributedDatasetsFromFunction."""

  @property
  def value_type(self):
    return DistributedDatasetsFromFunction

  @property
  def _component_specs(self):
    specs = []
    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access

    for i, _ in enumerate(worker_device_pairs):
      element_spec = nest.map_structure(
          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
      specs.append(dataset_ops.DatasetSpec(element_spec))
    return specs

  def _to_components(self, value):
    return value._datasets  # pylint: disable=protected-access

  def _from_components(self, components):
    return DistributedDatasetsFromFunction(
        input_workers=self._input_workers,
        strategy=self._strategy,
        components=components,
        element_spec=self._element_spec,
        options=self._options)

  @staticmethod
  def from_value(value):
    # pylint: disable=protected-access
    return DistributedDatasetsFromFunctionSpec(
        input_workers=value._input_workers,
        element_spec=value._element_spec,
        strategy=value._strategy,
        options=value._options)


# TODO(priyag): Add other replication modes.
class DistributedDatasetsFromFunction(_IterableInput,
                                      composite_tensor.CompositeTensor):
  """Inputs created from dataset function."""

  def __init__(
      self,
      input_workers,
      strategy,
      input_contexts=None,
      dataset_fn=None,
      options=None,
      components=None,
      element_spec=None,
      build=True,
      replica_order=None,
  ):
    """Makes an iterable from datasets created by the given function.

    Args:
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      input_contexts: A list of `InputContext` instances to be passed to call(s)
        to `dataset_fn`. Length and order should match worker order in
        `worker_device_pairs`.
      dataset_fn: A function that returns a `Dataset` given an `InputContext`.
        Either dataset_fn or components should be passed to construct
        DistributedDatasetsFromFunction. Use this when constructing
        DistributedDataset using a function. Use components when constructing
        using DistributedDatasetsFromFunctionSpec.
      options: `tf.distribute.InputOptions` used to control options on how this
        dataset is distributed.
      components: datasets when DistributedDatasetsFromFunction is constructed
        from DistributedDatasetsFromFunctionSpec. Only one of dataset or
        components should be passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDatasetsFromFunctionSpec and verified against element_spec
        from components.
      build: whether to build underlying datasets when this object is created.
        This is only useful for `ParameterServerStrategy` now.
      replica_order: the order of the replicas, which will be used to reorder
        the iterators to match the device order.
    """
    super(DistributedDatasetsFromFunction, self).__init__(
        input_workers=input_workers)
    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options
    self._replica_order = replica_order
    if dataset_fn is not None and components is not None:
      raise ValueError("Only one of dataset_fn or components should be set")
    if dataset_fn is None and components is None:
      raise ValueError("At least one of dataset_fn or components should be set")

    if dataset_fn is not None:
      if input_workers.num_workers != len(input_contexts):
        raise ValueError(
            "Number of input workers (%d) is not same as number of "
            "input_contexts (%d)" %
            (input_workers.num_workers, len(input_contexts)))
      self._input_contexts = input_contexts
      self._num_replicas_in_sync = self._input_contexts[0].num_replicas_in_sync
      self._dataset_fn = dataset_fn
      self._built = False
      if build:
        self.build()
    else:
      if element_spec is None:
        raise ValueError(
            "element_spec should also be passed when passing components")
      if not build:
        raise ValueError(
            "When constructing DistributedDatasetFromFunction with components, "
            "build should not be False. This is an internal error. Please file "
            "a bug.")
      self._element_spec = element_spec
      self._datasets = components
      self._num_replicas_in_sync = None
      self._built = True
      self._cardinality = _cardinality(self._datasets[0])
      self._enable_get_next_as_optional = _enable_get_next_as_optional(
          self._strategy, self._datasets[0], self._cardinality)

  def build(self):
    assert not self._built
    distribute_start_time_ns = time.time_ns()
    self._datasets, element_spec = (
        _create_datasets_from_function_with_input_context(
            self._input_contexts, self._input_workers, self._dataset_fn))
    if context.executing_eagerly():
      # Records the time to initialize the distributed dataset.
      context.async_wait()
      distribute_duration_ms = (time.time_ns() -
                                distribute_start_time_ns) // 1_000_000
      _distributed_dataset_from_function_initialization_time_milliseconds.get_cell(
          self._strategy.__class__.__name__,
          str(self._input_workers.num_workers)).add(distribute_duration_ms)

    self._element_spec = _create_distributed_tensor_spec(
        self._strategy, element_spec)
    self._cardinality = _cardinality(self._datasets[0])
    self._enable_get_next_as_optional = _enable_get_next_as_optional(
        self._strategy, self._datasets[0], self._cardinality)
    self._built = True

  def auto_shard(self, num_shards, shard_ix):
    assert (
        len(self._datasets) == len(self._input_workers.worker_devices)
    ), (
        f"datasets: {len(self._datasets)}, "
        f"input workers: {len(self._input_workers.worker_devices)}"
    )
    sharded_datasets = []
    for i in range(len(self._input_workers.worker_devices)):
      with ops.colocate_with(self._datasets[i]._variant_tensor):  # pylint: disable=protected-access
        sharded_datasets.append(
            input_ops.auto_shard_dataset(
                self._datasets[i], num_shards, shard_ix,
                self._num_replicas_in_sync
            )
        )
    return DistributedDatasetsFromFunction(self._input_workers, self._strategy,
                                           components=sharded_datasets,
                                           element_spec=self._element_spec,
                                           options=self._options)

  @property
  def cardinality(self):
    if not self._built:
      raise ValueError(
          "Cannot get the cardinality of a dataset that is not built")
    return self._cardinality

  def __iter__(self):
    if not (ops.executing_eagerly_outside_functions() or
            ops.get_default_graph().building_function):
      raise RuntimeError("__iter__() is only supported inside of tf.function "
                         "or when eager execution is enabled.")

    if not self._built:
      raise ValueError("You need to use this dataset in "
                       "ClusterCoordinator.create_per_worker_dataset.")

    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
                                   True)

    iterators = _create_iterators_per_worker(
        self._datasets,
        self._input_workers,
        options=self._options,
        canonicalize_devices=canonicalize_devices)
    iterator = DistributedIterator(
        input_workers=self._input_workers,
        iterators=iterators,
        strategy=self._strategy,
        cardinality=self._cardinality,
        enable_get_next_as_optional=self._enable_get_next_as_optional,
        options=self._options,
        replica_order=self._replica_order,
    )
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access

    # When async eager is enabled, sometimes the iterator may not finish
    # initialization before passing to a multi device function, add a sync
    # point here to make sure all underlying iterators are initialized.
    if context.executing_eagerly():
      context.async_wait()

    return iterator

  @property
  def element_spec(self):
    """The type specification of an element of this dataset."""
    # When partial batch handling is enabled, always set the batch dimension to
    # None, otherwise we just follow element_spec of the underlying dataset
    # (whose batch dimension may also be None). This is because with partial
    # batching handling we could always produce empty batches.
    if (self._enable_get_next_as_optional and
        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return nest.map_structure(
          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
    return self._element_spec

  @property
  def _type_spec(self):
    return DistributedDatasetsFromFunctionSpec(self._input_workers,
                                               self._element_spec,
                                               self._strategy, self._options)


def _dummy_tensor_fn(value_structure):
  """A function to create dummy tensors from `value_structure`."""

  def create_dummy_tensor(spec):
    """Create a dummy tensor with possible batch dimensions set to 0."""
    if hasattr(spec, "_create_empty_value"):
      # Type spec may overwrite default dummy values behavior by declaring the
      # `_create_empty_value(self)` method. This method must return a value
      # compatible with the type spec with batch dimensions set to 0 or fail if
      # such a value does not exist. This allows a composite tensor to customize
      # dummy values creation as, in general, its dummy value is not composed
      # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is
      # never allowed to be empty). See b/183969859 for more discussions.
      # TODO(b/186079336): reconsider CompositeTensor support.
      return spec._create_empty_value()  # pylint: disable=protected-access

    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
      # Splice out the ragged dimensions.
      # pylint: disable=protected-access
      feature_shape = spec._shape[:1].concatenate(
          spec._shape[(1 + spec._ragged_rank):])
      feature_type = spec._dtype
      # pylint: enable=protected-access
    else:
      feature_shape = spec.shape
      feature_type = spec.dtype
    # Ideally we should set the batch dimension to 0, however as in
    # DistributionStrategy we don't know the batch dimension, we try to
    # guess it as much as possible. If the feature has unknown dimensions, we
    # will set them to 0. If the feature shape is already static, we guess the
    # first dimension as batch dimension and set it to 0.
    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
            if feature_shape else [])
    if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
                 feature_shape.is_fully_defined()):
      dims[0] = tensor_shape.Dimension(0)

    if isinstance(spec, sparse_tensor.SparseTensorSpec):
      return sparse_tensor.SparseTensor(
          values=array_ops.zeros(0, feature_type),
          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
          dense_shape=dims)

    # Create the dummy tensor.
    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
      # Reinsert the ragged dimensions with size 0.
      # pylint: disable=protected-access
      row_splits = array_ops.zeros(1, spec._row_splits_dtype)
      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
          dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
      # pylint: enable=protected-access
    return dummy_tensor

  return nest.map_structure(create_dummy_tensor, value_structure)


def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
  """Returns the value of the optionals or dummy values.

  Args:
    input_workers: the `InputWorkers`.
    optional_list: a list of lists `tf.experimental.Optional`. The values from
      each compute device grouped by the input device.
    produce_dummy: a bool. Whether to produce dummy tensors when the optional
      doesn't have a value.

  Returns:
    A flatten list of Tensors.

  """
  value_list = []
  for i, worker in enumerate(input_workers.worker_devices):
    with ops.device(worker):
      devices = input_workers.compute_devices_for_worker(i)
      for j, device in enumerate(devices):
        with ops.device(device):
          if produce_dummy:
            # pylint: disable=cell-var-from-loop
            value_list.append(
                tf_cond.cond(
                    optional_list[i][j].has_value(),
                    lambda: optional_list[i][j].get_value(),  # pylint: disable=unnecessary-lambda
                    lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
                    strict=True,
                ))
            # pylint: enable=cell-var-from-loop
          else:
            value_list.append(optional_list[i][j].get_value())
  return value_list


class _SingleWorkerDatasetIteratorBase(object):
  """Iterator for a single `tf.data.Dataset`."""

  def __init__(self, dataset, worker, devices, options=None):
    """Create iterator for the `dataset` to fetch data to worker's `devices` .

    A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch
    input to the devices on the given worker.

    Args:
      dataset: A `tf.data.Dataset` instance.
      worker: Worker on which ops should be created.
      devices: Distribute data from `dataset` to these devices.
      options: options.
    """
    self._dataset = dataset
    self._worker = worker
    self._devices = devices
    self._element_spec = dataset.element_spec
    self._options = options
    self._make_iterator()

  def _make_iterator(self):
    raise NotImplementedError("must be implemented in descendants")

  def _format_data_list_with_options(self, data_list):
    """Change the data in to a list type if required.

    The OwnedMultiDeviceIterator returns the list data type,
    while the PER_REPLICA iterator (when used with prefetch disabled)
    returns without the enclosed list. This is to fix the inconsistency.
    Args:
      data_list: data_list
    Returns:
      list
    """
    if (self._options and self._options.experimental_replication_mode ==
        InputReplicationMode.PER_REPLICA and
        not self._options.experimental_fetch_to_device):
      return [data_list]
    else:
      return data_list

  def get_next(self, device, name=None):
    """Get next element for the given device."""
    del name
    with ops.device(self._worker):
      if _should_use_multi_device_iterator(self._options):
        return self._iterator.get_next(device)
      else:
        return self._iterator.get_next()

  def get_next_as_list(self, name=None):
    """Get next element from the underlying iterator.

    Runs the iterator get_next() within a device scope. Since this doesn't use
    get_next_as_optional(), it is considerably faster than get_next_as_list(),
    but it raises EOFError if any of the device doesn't get any data.

    Args:
      name: not used.

    Returns:
      A list consisting of the next data from each device.
    """
    del name
    with ops.device(self._worker):
      return self._format_data_list_with_options(self._iterator.get_next())

  def get_next_as_optional_list(self):
    with ops.device(self._worker):
      return self._format_data_list_with_options(
          self._iterator.get_next_as_optional())


class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
  """Type specification for `_SingleWorkerOwnedDatasetIterator`."""

  __slots__ = [
      "_worker", "_devices", "_element_spec", "_options",
      "_canonicalize_devices"
  ]

  def __init__(self, worker, devices, element_spec, options,
               canonicalize_devices=True):
    self._worker = worker
    if canonicalize_devices:
      self._devices = tuple(device_util.canonicalize(d) for d in devices)
    else:
      self._devices = tuple(
          device_util.canonicalize_without_job_and_task(d) for d in devices)
    self._element_spec = element_spec
    # `self._options` intentionally made not `None` for proper serialization.
    self._options = (options if options is not None else
                     distribute_lib.InputOptions())
    self._canonicalize_devices = canonicalize_devices

  @property
  def value_type(self):
    return _SingleWorkerOwnedDatasetIterator

  def _serialize(self):
    return (self._worker, self._devices, self._element_spec, self._options,
            self._canonicalize_devices)

  def _get_multi_device_iterator_spec(self, specs):
    device_scope = device_util.canonicalize(self._worker, device_util.current())
    host_device = device_util.get_host_for_device(device_scope)
    # source_device while creating iterator governs the worker device in
    # iterator spec.
    worker = host_device
    specs.append(
        multi_device_iterator_ops.MultiDeviceIteratorSpec(
            self._devices, worker, element_spec=self._element_spec))

  @property
  def _component_specs(self):
    specs = []
    if _should_use_multi_device_iterator(self._options):
      self._get_multi_device_iterator_spec(specs)
    else:
      specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
    return specs

  def _to_components(self, value):
    return [value._iterator]  # pylint: disable=protected-access

  def _from_components(self, components):
    return _SingleWorkerOwnedDatasetIterator(
        dataset=None,
        worker=self._worker,
        devices=self._devices,
        components=components,
        element_spec=self._element_spec,
        options=self._options,
        canonicalize_devices=self._canonicalize_devices)

  @staticmethod
  def from_value(value):
    # pylint: disable=protected-access
    return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
                                            value._element_spec, value._options,
                                            value._canonicalize_devices)


class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
                                        composite_tensor.CompositeTensor):
  """Iterator for a DistributedDataset instance."""

  def __init__(self,
               dataset=None,
               worker=None,
               devices=None,
               components=None,
               element_spec=None,
               options=None,
               canonicalize_devices=None):
    """Create iterator for the `dataset` to fetch data to worker's `devices` .

    `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
    given worker. The lifetime of this iterator is tied to the encompassing
    python object. Once we go out of scope of the python object or return from
    a tf.function the underlying iterator resource is deleted.

    Args:
      dataset: A `tf.data.Dataset` instance.
      worker: Worker on which ops should be created.
      devices: Distribute data from `dataset` to these devices.
      components: Tensor components to construct the
        _SingleWorkerOwnedDatasetIterator from.
      element_spec: A nested structure of `TypeSpec` objects that represents the
      type specification of elements of the iterator.
      options: `tf.distribute.InputOptions` used to control options on how this
      dataset is distributed.
      canonicalize_devices: Whether to canonicalize devices for workers fully or
      partially. If False, it will partially canonicalize devices by removing
      job and task.
    """
    if worker is None or devices is None:
      raise ValueError("Both `worker` and `devices` should be provided")

    error_message = ("Either `dataset` or both `components` and `element_spec` "
                     "need to be provided.")

    self._options = options
    self._canonicalize_devices = canonicalize_devices
    if dataset is None:
      if (components is None or element_spec is None):
        raise ValueError(error_message)
      self._element_spec = element_spec
      self._worker = worker
      self._devices = devices
      self._iterator = components[0]
    else:
      if (components is not None or element_spec is not None):
        raise ValueError(error_message)
      super(_SingleWorkerOwnedDatasetIterator,
            self).__init__(dataset, worker, devices, self._options)

  def _create_owned_multi_device_iterator(self):
    # If the worker devices are already canonicalized, canonicalizing again
    # would have no impact.
    # For strategies running on remote workers such as PS Strategy, the device
    # scope will be derived from current worker, if used under init_scope().
    if not ops.inside_function():
      device_scope = device_util.canonicalize(self._worker,
                                              device_util.current())
      host_device = device_util.get_host_for_device(device_scope)
    else:
      # In general, iterators should not be created within tf.functions. For
      # exact visitation guarantee solutions for parameter server training,
      # however, we do create iterators within the tf.functions that are
      # dispatched to workers. In these cases, the traced device must match the
      # runtime device. Since tracing occurs on the chief, we do not want to use
      # the current device scope, which would be the chief, but rather use the
      # relative worker device scope explicitly.
      device_scope, host_device = self._worker, self._worker
    with ops.device(device_scope):
      if self._options is not None:
        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
            self._dataset,
            self._devices,
            source_device=host_device,
            max_buffer_size=self._options
            .experimental_per_replica_buffer_size,
            prefetch_buffer_size=self._options
            .experimental_per_replica_buffer_size)
      else:
        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
            self._dataset, self._devices, source_device=host_device)

  def _make_iterator(self):
    """Make appropriate iterator on the dataset."""
    if not self._worker:
      raise ValueError("Worker device must be specified when creating an "
                       "owned iterator.")
    if _should_use_multi_device_iterator(self._options):
      self._create_owned_multi_device_iterator()
    else:
      with ops.device(self._worker):
        self._iterator = iter(self._dataset)

  @property
  def element_spec(self):
    return self._element_spec

  @property
  def _type_spec(self):
    return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
                                            self._element_spec, self._options,
                                            self._canonicalize_devices)

  @property
  def output_classes(self):
    """Returns the class of each component of an element of this iterator.

    The expected values are `tf.Tensor` and `tf.SparseTensor`.

    Returns:
      A nested structure of Python `type` objects corresponding to each
      component of an element of this dataset.
    """
    return nest.map_structure(
        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
        self._element_spec)

  @property
  def output_shapes(self):
    """Returns the shape of each component of an element of this iterator.

    Returns:
      A nested structure of `tf.TensorShape` objects corresponding to each
      component of an element of this dataset.
    """
    return nest.map_structure(
        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
        self._element_spec)

  @property
  def output_types(self):
    """Returns the type of each component of an element of this iterator.

    Returns:
      A nested structure of `tf.DType` objects corresponding to each component
      of an element of this dataset.
    """
    return nest.map_structure(
        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
        self._element_spec)


def _create_iterators_per_worker(worker_datasets,
                                 input_workers,
                                 options=None,
                                 canonicalize_devices=False):
  """Create a multidevice iterator on each of the workers."""
  assert isinstance(input_workers, InputWorkers)
  assert len(worker_datasets) == len(input_workers.worker_devices)
  iterators = []
  for i, worker in enumerate(input_workers.worker_devices):
    with ops.device(worker):
      worker_devices = input_workers.compute_devices_for_worker(i)
      iterator = _SingleWorkerOwnedDatasetIterator(
          dataset=worker_datasets[i],
          worker=worker,
          devices=worker_devices,
          options=options,
          canonicalize_devices=canonicalize_devices)
      iterators.append(iterator)
  return iterators


def _create_datasets_from_function_with_input_context(input_contexts,
                                                      input_workers,
                                                      dataset_fn):
  """Create device datasets per worker given a dataset function."""
  datasets = []
  for i, ctx in enumerate(input_contexts):
    worker = input_workers.worker_devices[i]
    with ops.device(worker):
      dataset = dataset_fn(ctx)
      datasets.append(dataset)
  return datasets, dataset.element_spec


# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
  """Get the batched dataset from `d`."""
  # pylint: disable=protected-access
  if isinstance(d, dataset_ops.DatasetV1Adapter):
    d = d._dataset

  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
    return d
  elif isinstance(d, (dataset_ops.PrefetchDataset,
                      dataset_ops._OptionsDataset)):
    return _get_batched_dataset(d._input_dataset)

  raise ValueError(
      "Unable to get batched dataset from the input dataset. `batch` "
      "`map_and_batch` need to be the last operations on the dataset. "
      "The batch operations can be followed by a prefetch.")


def _get_batched_dataset_attributes(d):
  """Get `batch_size`, `drop_remainder` of dataset."""
  # pylint: disable=protected-access
  assert isinstance(d,
                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
  if isinstance(d, dataset_ops.BatchDataset):
    batch_size = d._batch_size
    drop_remainder = d._drop_remainder
  elif isinstance(d, batching._MapAndBatchDataset):
    batch_size = d._batch_size_t
    drop_remainder = d._drop_remainder_t
  # pylint: enable=protected-access

  if tensor_util.is_tf_type(batch_size):
    batch_size = tensor_util.constant_value(batch_size)

  if tensor_util.is_tf_type(drop_remainder):
    drop_remainder = tensor_util.constant_value(drop_remainder)

  return batch_size, drop_remainder


# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_dataset_attributes(dataset):
  """Get the underlying attributes from the dataset object."""
  # pylint: disable=protected-access

  # First, get batch_size and drop_remainder from the dataset. We need
  # to walk back the dataset creation process and find the batched version in
  # order to get the attributes.
  batched_dataset = _get_batched_dataset(dataset)
  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)

  # Second, prefetch buffer should be get from the original dataset.
  prefetch_buffer = None
  if isinstance(dataset, dataset_ops.PrefetchDataset):
    prefetch_buffer = dataset._buffer_size
  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
    prefetch_buffer = dataset._dataset._buffer_size

  return batch_size, drop_remainder, prefetch_buffer


def _should_use_multi_device_iterator(options):
  """Determine whether to use multi_device_iterator_ops."""
  if (options is None or
      options.experimental_replication_mode == InputReplicationMode.PER_WORKER
      or
      (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
       and options.experimental_fetch_to_device)):
    return True
  return False


class MultiStepContext(object):
  """A context object that can be used to capture things when running steps.

  This context object is useful when running multiple steps at a time using the
  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
  function to specify which outputs to emit at what frequency. Currently it
  supports capturing output from the last step, as well as capturing non tensor
  outputs.  In the future it will be augmented to support other use cases such
  as output each N steps.
  """

  def __init__(self):
    """Initialize an output context.

    Returns:
      A context object.
    """
    self._last_step_outputs = {}
    self._last_step_outputs_reduce_ops = {}
    self._non_tensor_outputs = {}

  @property
  def last_step_outputs(self):
    """A dictionary consisting of outputs to be captured on last step.

    Keys in the dictionary are names of tensors to be captured, as specified
    when `set_last_step_output` is called.
    Values in the dictionary are the tensors themselves. If
    `set_last_step_output` was called with a `reduce_op` for this output,
    then the value is the reduced value.

    Returns:
      A dictionary with last step outputs.
    """
    return self._last_step_outputs

  def _set_last_step_outputs(self, outputs):
    """Replace the entire dictionary of last step outputs."""
    if not isinstance(outputs, dict):
      raise ValueError("Need a dictionary to set last_step_outputs.")
    self._last_step_outputs = outputs

  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribute_lib.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribute_lib.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribute_lib.get_replica_context().merge_call(
          merge_fn, args=(output,))

  @property
  def non_tensor_outputs(self):
    """A dictionary consisting of any non tensor outputs to be captured."""
    return self._non_tensor_outputs

  def set_non_tensor_output(self, name, output):
    """Set `output` with `name` to be captured as a non tensor output."""
    if distribute_lib.in_cross_replica_context():
      self._non_tensor_outputs[name] = output
    else:
      def merge_fn(distribution, value):
        # NOTE(priyag): For non tensor outputs, we simply return all the values
        # in a list as reduction doesn't make sense on non tensors.
        self._non_tensor_outputs[name] = (
            distribution.experimental_local_results(value))
      distribute_lib.get_replica_context().merge_call(
          merge_fn, args=(output,))


def _create_distributed_tensor_spec(strategy, tensor_spec):
  """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.

  Args:
    strategy: The given `tf.distribute` strategy.
    tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
      shape should be None if you have partial batches.

  Returns:
    A `tf.TypeSpec` that matches the values produced by a given strategy. This
    can be a `tf.TensorSpec` or a `PerRelicaSpec`.
  """
  num_replicas = len(strategy.extended.worker_devices)

  # For one device strategy that is not MultiWorkerMirroredStrategy,  return the
  # tensor_spec as is, since we don't wrap the output with PerReplica in this
  # case.
  # TODO(b/166464552): remove after we always wrap for all strategies.
  if not _always_wrap(strategy):
    return tensor_spec

  # For other cases we assume the input to tf.function is a per replica type.
  def _get_value_per_replica(tensor_spec_per_input):
    value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
    return values.PerReplicaSpec(*value_specs)

  return nest.map_structure(_get_value_per_replica, tensor_spec)


def _replace_per_replica_spec(spec, i):
  """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
  if isinstance(spec, values.PerReplicaSpec):
    return spec._value_specs[i]  # pylint: disable=protected-access
  else:
    return spec


def _cardinality(dataset):
  """Returns the cardinality of the dataset."""
  if context.executing_eagerly():
    with ops.device(dataset._variant_tensor.device):  # pylint: disable=protected-access
      return dataset.cardinality().numpy()
  return cardinality_lib.UNKNOWN


def _enable_get_next_as_optional(strategy, dataset, cardinality):
  """Returns whether to enable using partial batch handling."""
  # TODO(b/133073708): we currently need a flag to control the usage because
  # there is a performance difference between get_next() and
  # get_next_as_optional(). And we only enable get_next_as_optional when the
  # output shapes are not static.
  #
  # TODO(rxsang): We want to always enable the get_next_as_optional behavior
  # when user passed input_fn instead of dataset.
  if not getattr(
      strategy.extended, "enable_partial_batch_handling",
      getattr(strategy.extended, "experimental_enable_get_next_as_optional",
              False)):
    return False

  # If the dataset is infinite, we don't need to enable last partial batch
  # support. Note that we can only evaluate the cardinality of the dataset in
  # eager.
  if cardinality == cardinality_lib.INFINITE:
    return False

  return not _is_statically_shaped(
      dataset.element_spec) or strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access


def _create_per_replica(value_list, strategy):
  """Creates a PerReplica.

  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
  type spec is set to the element spec of the dataset. This helps avoid
  retracing for partial batches. Retracing is problematic for multi client when
  different client retraces different time, since retracing changes the
  collective keys in the tf.function, and causes mismatches among clients.

  For single client strategies, this simply calls distribute_utils.regroup().

  Args:
    value_list: a list of values, one for each replica.
    strategy: the `tf.distribute.Strategy`.

  Returns:
    a structure of PerReplica.

  """
  # TODO(b/166464552): always wrap for all one device strategies as well.
  always_wrap = _always_wrap(strategy)
  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
  return per_replicas


def _always_wrap(strategy):
  """Returns whether to always wrap the values in a DistributedValues."""
  return strategy.extended._in_multi_worker_mode() or len(  # pylint: disable=protected-access
      strategy.extended.worker_devices) > 1


def _rebatch_as_dynamic(per_replica_spec):
  """Rebatch the spec to have a dynamic batch dimension."""
  assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec

  # pylint: disable=protected-access
  def _rebatch(spec):
    # Rebatch if possible.
    try:
      return spec._unbatch()._batch(None)
    except ValueError:
      pass
    return spec

  return values.PerReplicaSpec(
      *nest.map_structure(_rebatch, per_replica_spec._value_specs))
  # pylint: enable=protected-access


def _ag_enumerate_not_implemented(s, unused_start):
  msg = (
      f"enumerate not supported with {s.__class__.__name__} types within "
      "tf.functions. Use a for loop over the dataset and keep a separate "
      "counter instead."
  )
  raise NotImplementedError(msg)


py_builtins.enumerate_registry.register(
    DistributedIterator, _ag_enumerate_not_implemented
)
py_builtins.enumerate_registry.register(
    DistributedDataset, _ag_enumerate_not_implemented
)