tensorflow/tensorflow

View on GitHub
tensorflow/python/data/kernel_tests/map_test.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.map()`."""
import collections
import dataclasses
import functools
import threading
import time
from typing import Callable
import warnings

from absl.testing import parameterized
import numpy as np

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_sanitizers
from tensorflow.python import tf2
from tensorflow.python.checkpoint import checkpoint as trackable_utils
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import global_shuffle_op
from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_case
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test

try:
  import attr  # pylint:disable=g-import-not-at-top
except ImportError:
  attr = None


def _test_combinations_with_mode_v1(mode):

  def new_map_fn(dataset, *args, **kwargs):
    return dataset.map(*args, **kwargs)

  def legacy_map_fn(dataset, *args, **kwargs):
    return dataset.map_with_legacy_function(*args, **kwargs)

  new_map_combinations = combinations.combine(
      tf_api_version=1,
      mode=mode,
      apply_map=combinations.NamedObject("map_fn", new_map_fn))

  legacy_map_combinations = combinations.combine(
      tf_api_version=1,
      mode=mode,
      apply_map=combinations.NamedObject("legacy_map_fn", legacy_map_fn))

  return new_map_combinations + legacy_map_combinations


def _test_combinations_with_mode_v2(mode):

  def new_map_fn(dataset, *args, **kwargs):
    return dataset.map(*args, **kwargs)

  return combinations.combine(
      tf_api_version=2,
      mode=mode,
      apply_map=combinations.NamedObject("map_fn", new_map_fn))


def _test_combinations_with_mode(mode):
  return _test_combinations_with_mode_v1(
      mode) + _test_combinations_with_mode_v2(mode)


def _test_combinations():
  return _test_combinations_with_mode("eager") + _test_combinations_with_mode(
      "graph")


def _short_circuit_test_cases():
  cases = [
      ("Identity", None, lambda x: x),
      ("Replicate", None, lambda x: (x, x)),
      ("Swap", (None, None), lambda x, y: (y, x)),
      ("Project", (None, None), lambda x, y: x)
  ]

  def reduce_fn(x, y):
    name, structure, fn = y
    return x + combinations.combine(
        structure=structure, fn=combinations.NamedObject(name, fn))

  return functools.reduce(reduce_fn, cases, [])


class Foo:
  """Dummy class used for invalid return value tests."""

  def __init__(self):
    pass


@dataclasses.dataclass
class MyDataclass:
  value1: tensor.Tensor
  value2: tensor.Tensor

  def __tf_flatten__(self):
    metadata = tuple()
    components = (self.value1, self.value2)
    return metadata, components

  @classmethod
  def __tf_unflatten__(cls, metadata, components):
    del metadata
    return cls(value1=components[0], value2=components[1])


@dataclasses.dataclass
class MaskedTensor:
  mask: bool
  value: tensor.Tensor

  def __tf_flatten__(self):
    metadata = (self.mask,)
    components = (self.value,)
    return metadata, components

  @classmethod
  def __tf_unflatten__(cls, metadata, components):
    mask = metadata[0]
    value = components[0]
    return MaskedTensor(mask=mask, value=value)


@dataclasses.dataclass
class NestedMaskedTensor:
  mask: bool
  value: MaskedTensor

  def __tf_flatten__(self):
    metadata = (self.mask,)
    components = (self.value,)
    return metadata, components

  @classmethod
  def __tf_unflatten__(cls, metadata, components):
    mask = metadata[0]
    value = components[0]
    return NestedMaskedTensor(mask=mask, value=value)

  def __eq__(self, other):
    return self.mask == other.mask and self.value == other.value


class MapTest(test_base.DatasetTestBase, parameterized.TestCase):

  def _map_dataset_factory(self, components, apply_map, count):

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    dataset = apply_map(dataset, _map_fn).repeat(count)
    self.assertEqual(
        [c.shape[1:] for c in components],
        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    return dataset

  @combinations.generate(_test_combinations())
  def testMapDataset(self, apply_map):
    """Test an dataset that maps a TF function across its input elements."""
    # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
    # RepeatDataset(count).
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    # Test single-threaded access to the iterator.
    get_next = self.getNext(
        self._map_dataset_factory(components, apply_map, count=14))
    for _ in range(14):
      for i in range(7):
        result = self.evaluate(get_next())
        for component, result_component in zip(components, result):
          self.assertAllEqual(component[i]**2, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  # TODO(b/117581999): add eager coverage
  @combinations.generate(_test_combinations_with_mode("graph"))
  def testMapDatasetMultiThreaded(self, apply_map):
    # Test multi-threaded access to the same iterator.
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    get_next = self.getNext(
        self._map_dataset_factory(components, apply_map, count=18))
    results = []
    with self.cached_session() as sess:
      def iterator_thread():
        while True:
          try:
            results.append(sess.run(get_next()))
          except errors.OutOfRangeError:
            return
      threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      # `results` will contain the same elements components**2
      # repeated 18 times, but in a non-deterministic order. Sort the
      # results, and assert that each element of components**2 is
      # produced 18 times.
      results.sort(key=lambda x: x[0])
      for i in range(7):
        for j in range(18):
          for component, result_component in zip(components,
                                                 results[i * 18 + j]):
            self.assertAllEqual(component[i]**2, result_component)

  def _parallel_map_dataset_factory(self, components, apply_map, count,
                                    num_parallel_calls, buffer_size):

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls)
    dataset = dataset.prefetch(buffer_size).repeat(count)

    self.assertEqual(
        [c.shape[1:] for c in components],
        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    return dataset

  @combinations.generate(
      combinations.times(
          _test_combinations(),
          combinations.combine(num_parallel_calls=1, buffer_size=1) +
          combinations.combine(num_parallel_calls=1, buffer_size=2) +
          combinations.combine(num_parallel_calls=2, buffer_size=2) +
          combinations.combine(num_parallel_calls=2, buffer_size=4) +
          combinations.combine(num_parallel_calls=8, buffer_size=8) +
          combinations.combine(num_parallel_calls=8, buffer_size=16)))
  def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size):
    """Test an dataset that maps a TF function across its input elements."""

    # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
    # RepeatDataset(count).
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    # Test single-threaded access to the iterator.
    get_next = self.getNext(
        self._parallel_map_dataset_factory(components, apply_map, 14,
                                           num_parallel_calls, buffer_size))
    for _ in range(14):
      for i in range(7):
        result = self.evaluate(get_next())
        for component, result_component in zip(components, result):
          self.assertAllEqual(component[i]**2, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  # TODO(b/117581999): add eager coverage
  @combinations.generate(
      combinations.times(
          _test_combinations_with_mode("graph"),
          combinations.combine(num_parallel_calls=1, buffer_size=1) +
          combinations.combine(num_parallel_calls=1, buffer_size=2) +
          combinations.combine(num_parallel_calls=2, buffer_size=2) +
          combinations.combine(num_parallel_calls=2, buffer_size=4) +
          combinations.combine(num_parallel_calls=8, buffer_size=8) +
          combinations.combine(num_parallel_calls=8, buffer_size=16)))
  def testParallelMapDatasetMultiThreaded(self, apply_map, num_parallel_calls,
                                          buffer_size):

    # Test multi-threaded access to the same iterator.
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    get_next = self.getNext(
        self._parallel_map_dataset_factory(components, apply_map, 18,
                                           num_parallel_calls, buffer_size))
    results = []
    with self.cached_session() as sess:

      def iterator_thread():
        while True:
          try:
            results.append(sess.run(get_next()))
          except errors.OutOfRangeError:
            return

      threads = [self.checkedThread(target=iterator_thread) for _ in range(64)]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      # `results` will contain the same elements components**2
      # repeated 18 times, but in a non-deterministic order. Sort the
      # results, and assert that each element of components**2 is
      # produced 18 times.
      results.sort(key=lambda x: x[0])
      for i in range(7):
        for j in range(18):
          for component, result_component in zip(components,
                                                 results[i * 18 + j]):
            self.assertAllEqual(component[i]**2, result_component)

  @combinations.generate(_test_combinations())
  def testImplicitDisposeParallelMapDataset(self, apply_map):
    # Tests whether a parallel map dataset will be cleaned up correctly when
    # the pipeline does not run it until exhaustion.
    # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
    # RepeatDataset(1000).
    components = (np.arange(1000),
                  np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                  np.array(37.0) * np.arange(1000))

    dataset = self._parallel_map_dataset_factory(components, apply_map, 1000,
                                                 100, 100)
    # NOTE(mrry): Also test that the prefetching thread is cancelled correctly.
    dataset = dataset.prefetch(100)
    get_next = self.getNext(dataset)

    for _ in range(3):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testParallelMapUnspecifiedOutputSize(self, apply_map):
    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    dataset = apply_map(
        dataset,
        lambda x: array_ops.check_numerics(x, "message"),
        num_parallel_calls=2)
    get_next = self.getNext(dataset)

    for _ in range(3):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testParallelMapError(self, apply_map):
    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    dataset = apply_map(
        dataset,
        lambda x: array_ops.check_numerics(x, "message"),
        num_parallel_calls=2)
    get_next = self.getNext(dataset)

    for _ in range(3):
      self.evaluate(get_next())
    # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())
    self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testPrefetchError(self, apply_map):
    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    dataset = apply_map(
        dataset, lambda x: array_ops.check_numerics(x, "message")).prefetch(2)
    get_next = self.getNext(dataset)

    for _ in range(3):
      self.evaluate(get_next())
    # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())
    self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testCaptureIterator(self, apply_map):

    def _build_ds(iterator):

      def _map_fn(x):
        get_next = iterator.get_next()
        return x * get_next

      return apply_map(dataset_ops.Dataset.range(10), _map_fn)

    def _build_graph():
      if context.executing_eagerly():
        captured_iterator = iter(dataset_ops.Dataset.range(10))
      else:
        captured_iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.range(10))
      ds = _build_ds(captured_iterator)
      return captured_iterator, ds

    captured_iter, ds = _build_graph()
    if not context.executing_eagerly():
      self.evaluate(captured_iter.initializer)
    get_next = self.getNext(ds, requires_initialization=True)
    for i in range(10):
      self.assertEqual(i * i, self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testCaptureHashTable(self, apply_map):
    # NOTE(mrry): We must use the V2 variants of `HashTable`
    # etc. because these produce a `tf.resource`-typed output that is
    # compatible with the in-graph function implementation.
    default_val = -1
    keys = constant_op.constant(["brain", "salad", "surgery"])
    values = constant_op.constant([0, 1, 2], dtypes.int64)
    table = lookup_ops.HashTable(
        lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

    input_sentences = dataset_ops.Dataset.from_tensor_slices(
        ["brain brain tank salad surgery", "surgery brain"])

    dataset = apply_map(input_sentences,
                        lambda x: string_ops.string_split([x]).values)
    dataset = apply_map(dataset, table.lookup)

    get_next = self.getNext(dataset, requires_initialization=True)

    self.evaluate(table.initializer)
    self.evaluate(get_next())
    self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations_with_mode("graph"))
  def testCaptureQueue(self, apply_map):
    elements = np.random.randint(100, size=[200])
    queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
    enqueue_op = queue.enqueue_many(elements)
    close_op = queue.close()
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1)
    dataset = apply_map(dataset, lambda _: queue.dequeue())

    get_next = self.getNext(dataset, requires_initialization=True)
    self.evaluate(enqueue_op)
    self.evaluate(close_op)

    for element in elements:
      self.assertEqual(element, self.evaluate(get_next()))
    # When the map function in `MapDataset` raises an OutOfRange error, TF1 and
    # TF2 behave differently. TF1 raises an OutOfRangeError to signal the end of
    # sequence while TF2 raises an InvalidArgumentError. This behavior is
    # controlled by the `preserve_cardinality` argument of `map` transformation
    # which is set to `True` for TF2 and `False` for TF1, which is for backward
    # compatibility.
    if tf2.enabled():
      with self.assertRaises(errors.InvalidArgumentError):
        self.evaluate(get_next())
    else:
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

  # TODO(b/117581999): Possible deadlock in eager mode, debug.
  @combinations.generate(_test_combinations_with_mode_v1("graph"))
  def testCaptureSameResourceMultipleTimes(self, apply_map):
    elements = np.random.randint(100, size=[200])
    queue = data_flow_ops.FIFOQueue(
        200, dtypes.int64, shapes=[], shared_name="shared_queue")
    queue_2 = data_flow_ops.FIFOQueue(
        200, dtypes.int64, shapes=[], shared_name="shared_queue")

    enqueue_op = queue.enqueue_many(elements)
    close_op = queue.close()

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(-1)
    dataset = apply_map(dataset, lambda _: (queue.dequeue(), queue_2.dequeue()))

    self.evaluate(enqueue_op)
    self.evaluate(close_op)
    get_next = self.getNext(dataset, requires_initialization=True)
    for i in range(100):
      self.assertCountEqual([elements[i * 2], elements[i * 2 + 1]],
                            self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testSeededStatefulOperatorIsProperlyStateful(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
    fn = lambda _: random_ops.random_uniform((), seed=11)
    dataset = apply_map(dataset, fn).batch(2)

    get_next = self.getNext(dataset, requires_initialization=True)
    random_values = []
    with self.assertRaises(errors.OutOfRangeError):
      while True:
        random_values.extend(self.evaluate(get_next()))
    self.assertLen(random_values, 10)
    self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)

    get_next = self.getNext(dataset, requires_initialization=True)
    random_values_2 = []
    with self.assertRaises(errors.OutOfRangeError):
      while True:
        random_values_2.extend(self.evaluate(get_next()))

    # Randomness is repeatable given same seed
    self.assertAllClose(random_values, random_values_2)

  @combinations.generate(_test_combinations())
  def testStatefulMapKeepsStateAcrossIterators(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
    fn = lambda _: random_ops.random_uniform((), seed=11)
    dataset = apply_map(dataset, fn).repeat(1000).batch(10)

    get_next = self.getNext(dataset)
    random_values = self.evaluate(get_next())

    # Assert that one of the next 99 batches yielded by the iterator is
    # different from the first.
    i = 0
    while i < 99:
      if np.any(random_values != self.evaluate(get_next())):
        break
      i += 1
    self.assertLess(i, 99)

  @combinations.generate(_test_combinations())
  def testStatefulOperationInShortCircuit(self, apply_map):
    counter_var = variable_scope.get_variable(
        "counter", (), dtypes.int32, use_resource=True)

    def increment_fn(x):
      counter_var.assign_add(1)
      return x

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, increment_fn)
    options = options_lib.Options()
    options.experimental_optimization.inject_prefetch = False
    dataset = dataset.with_options(options)

    get_next = self.getNext(dataset, requires_initialization=True)

    self.evaluate(counter_var.initializer)
    for i in range(10):
      self.assertEqual(i, self.evaluate(counter_var))
      self.assertEqual(i, self.evaluate(get_next()))
    self.assertEqual(10, self.evaluate(counter_var))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(10, self.evaluate(counter_var))

  @combinations.generate(_test_combinations())
  def testMapDict(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: {"foo": x * 2, "bar": x**2})
    dataset = apply_map(dataset, lambda d: d["foo"] + d["bar"])
    self.assertDatasetProduces(
        dataset, expected_output=[i * 2 + i**2 for i in range(10)])

  @combinations.generate(_test_combinations())
  def testMapDataclass(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: MyDataclass(value1=x, value2=2 * x))
    dataset = apply_map(dataset, lambda x: x.value1 + x.value2)
    self.assertDatasetProduces(
        dataset,
        expected_output=[3 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapMaskedTensor(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x))
    dataset = apply_map(dataset, lambda x: 3 * x.value)
    self.assertDatasetProduces(
        dataset,
        expected_output=[3 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapDataclassWithInputAndOutput(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors(MyDataclass(value1=1, value2=2))
    dataset = apply_map(dataset, lambda x: (x.value1 * 5, x.value2))
    dataset = apply_map(
        dataset, lambda x, y: MaskedTensor(mask=True, value=x + y)
    )
    dataset = apply_map(
        dataset, lambda m: NestedMaskedTensor(mask=False, value=m)
    )
    self.assertDatasetProduces(
        dataset,
        expected_output=[
            NestedMaskedTensor(
                mask=False, value=MaskedTensor(mask=True, value=7)
            )
        ],
    )

  @combinations.generate(_test_combinations())
  def testMapListOfDataclassObjects(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)

    # Creates a list of dataclass objects.
    dataset = apply_map(
        dataset,
        lambda x: [  # pylint: disable=g-long-lambda
            MyDataclass(value1=x, value2=1),
            MyDataclass(value1=2, value2=2 * x),
        ],
    )

    # Takes a list of dataclass objects as input.
    dataset = apply_map(dataset, lambda *x: x[0].value1 + x[1].value2)

    self.assertDatasetProduces(
        dataset,
        expected_output=[3 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapDictOfDataclassValues(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)

    # Creates a dict of {str -> dataclass}.
    dataset = apply_map(
        dataset,
        lambda x: {  # pylint: disable=g-long-lambda
            "a": MyDataclass(value1=x, value2=1),
            "b": MyDataclass(value1=2, value2=2 * x),
        },
    )
    # Takes a dict of dataclass values as input.
    dataset = apply_map(dataset, lambda x: x["a"].value1 + x["b"].value2)
    self.assertDatasetProduces(
        dataset,
        expected_output=[3 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapNestedMaskedTensorWithDataclassInput(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: MaskedTensor(mask=True, value=x))
    dataset = apply_map(
        dataset,
        # Takes a MaskedTensor as input.
        lambda x: NestedMaskedTensor(mask=False, value=x),
    )
    dataset = apply_map(dataset, lambda x: 5 * x.value.value)
    self.assertDatasetProduces(
        dataset,
        expected_output=[5 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapNestedMaskedTensorWithDataclassOutput(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(
        dataset,
        lambda x: NestedMaskedTensor(  # pylint: disable=g-long-lambda
            mask=False, value=MaskedTensor(mask=True, value=x)
        ),
    )

    # Return a MaskedTensor as the return value.
    dataset = apply_map(dataset, lambda x: x.value)
    dataset = apply_map(dataset, lambda x: 7 * x.value)
    self.assertDatasetProduces(
        dataset,
        expected_output=[7 * x for x in range(10)],
    )

  @combinations.generate(_test_combinations())
  def testMapNamedtuple(self, apply_map):
    # construct dataset of tuples
    labels = dataset_ops.Dataset.range(10)
    images = apply_map(labels, lambda l: -l)
    dataset_tuple = dataset_ops.Dataset.zip((labels, images))

    # convert dataset of tuples to dataset of namedtuples
    example = collections.namedtuple("Example", ["label", "image"])
    dataset_namedtuple = apply_map(dataset_tuple, example)

    def preprocess_tuple(label, image):
      image = 2 * image
      return label, image

    def preprocess_namedtuple(example):
      return example._replace(image=2 * example.image)

    # preprocess both datasets
    dataset_tuple = apply_map(dataset_tuple, preprocess_tuple)
    dataset_namedtuple = apply_map(dataset_namedtuple, preprocess_namedtuple)

    next_tuple = self.getNext(dataset_tuple)
    next_namedtuple = self.getNext(dataset_namedtuple)

    # make sure both datasets contain the same data
    for i in range(10):
      tuple_, namedtuple_ = self.evaluate([next_tuple(), next_namedtuple()])
      self.assertEqual(tuple_, namedtuple_)
      self.assertEqual(tuple_, (i, -2 * i))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_namedtuple())

  @combinations.generate(_test_combinations())
  def testMapAttrs(self, apply_map):
    if attr is None:
      self.skipTest("attr module is not available.")

    # construct dataset of tuples
    labels = dataset_ops.Dataset.range(10)
    images = apply_map(labels, lambda l: -l)
    dataset = dataset_ops.Dataset.zip((labels, images))

    @attr.s(cmp=True)
    class Example:
      label = attr.ib()
      image = attr.ib()

    dataset = apply_map(dataset, Example)

    def preprocess(example):
      example.image = 2 * example.image
      return example

    dataset = apply_map(dataset, preprocess)
    get_next = self.getNext(dataset)

    for i in range(10):
      data = self.evaluate(get_next())
      self.assertEqual(data, Example(i, -2 * i))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testUseStepContainerInMap(self, apply_map):
    row = np.arange(6)
    dataset = dataset_ops.Dataset.from_tensors(row)
    dataset = apply_map(dataset,
                        lambda elems: map_fn.map_fn(lambda x: x * x, elems))
    self.assertDatasetProduces(dataset, expected_output=[row**2])

  @combinations.generate(_test_combinations())
  def testCaseAndCondInMap(self, apply_map):

    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      def defaults_two():
        return cond.cond(
            math_ops.equal(math_ops.mod(x, 2), 0),
            multiply,
            divide,
            name="cond_mult")

      pred_fn_pairs = [
          (math_ops.logical_or(math_ops.equal(y, 2),
                               math_ops.equal(y, 3)), defaults_two),
      ]

      return control_flow_case.case(
          pred_fn_pairs, default=multiply, exclusive=True)

    def build_dataset(row, num):
      dataset = dataset_ops.Dataset.from_tensor_slices(row)
      return apply_map(dataset, lambda x: control_map_fn(x, num))

    row = np.arange(6)
    for num in [2, 3, 4]:
      get_next = self.getNext(build_dataset(row, num))
      for i in range(6):
        self.assertEqual(
            (i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
            self.evaluate(get_next()))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testCaseInWhileInMap(self, apply_map):

    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      pred_fn_pairs = [
          (math_ops.logical_or(math_ops.equal(y, 2),
                               math_ops.equal(y, 3)), divide),
      ]

      return control_flow_case.case(
          pred_fn_pairs, default=multiply, exclusive=True)

    def build_dataset(row, num):
      dataset = dataset_ops.Dataset.from_tensors(row)
      return apply_map(
          dataset,
          lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems))

    row = np.arange(6)
    for num in [2, 3, 4]:
      get_next = self.getNext(build_dataset(row, num))
      self.assertAllEqual(
          [x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
          self.evaluate(get_next()))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testCaseAndCondInWhileInMap(self, apply_map):

    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      def defaults_two():
        return cond.cond(
            math_ops.equal(math_ops.mod(x, 2), 0),
            multiply,
            divide,
            name="cond_mult")

      pred_fn_pairs = [
          (math_ops.logical_or(math_ops.equal(y, 2),
                               math_ops.equal(y, 3)), defaults_two),
      ]

      return control_flow_case.case(
          pred_fn_pairs, default=multiply, exclusive=True)

    row = np.arange(6)
    num = 2
    dataset = dataset_ops.Dataset.from_tensors(row)
    dataset = apply_map(
        dataset,
        lambda elems: map_fn.map_fn(lambda x: control_map_fn(x, num), elems))
    get_next = self.getNext(dataset)

    self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
                         (num == 2 or num == 3) else x * 2 for x in row],
                        self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testNestedListMapDataset(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors([0, 1, 2]).repeat(10)
    dataset = apply_map(dataset, lambda a: ([a[1], a[0] + a[2]], a[1]))
    expected_output = [(np.array([1, 2]), 1)] * 10
    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(
      combinations.times(_test_combinations(),
                         combinations.combine(buffer_size=[1, 2, 3, 4])))
  def testPrefetch(self, apply_map, buffer_size):
    # We will use this event to test that `_map_py_func()` has been invoked a
    # certain number of times (6 times, to be exact) after consuming fewer
    # elements from the iterator.
    ev = threading.Event()

    set_event_during_invocation = 5

    def _map_py_func(x):
      if x == set_event_during_invocation:
        ev.set()
      return x * x

    def _map_fn(x):
      return script_ops.py_func(_map_py_func, [x], x.dtype)

    # We can indirectly observe that varying the buffer size has the intended
    # effect by observing when `ev` is set (on the 6th invocation of
    # `_map_py_func()`).
    # NOTE(mrry): We do not test with `buffer_size ==
    # set_event_during_invocation`, because we must consume at least one element
    # to start the prefetching.
    dataset = dataset_ops.Dataset.range(100)
    dataset = apply_map(dataset, _map_fn).prefetch(buffer_size)
    get_next = self.getNext(dataset)

    event_will_be_set_after_consuming = (
        set_event_during_invocation - buffer_size + 1)

    ev.clear()
    for i in range(event_will_be_set_after_consuming):
      self.assertFalse(ev.is_set())
      self.assertEqual(i * i, self.evaluate(get_next()))
    ev.wait()
    for i in range(event_will_be_set_after_consuming, 100):
      self.assertEqual(i * i, self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testReturnList(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: [x, constant_op.constant(37.0)])
    self.assertDatasetProduces(
        dataset, expected_output=[(i, 37.0) for i in range(10)])

  @combinations.generate(_test_combinations())
  def testMultiOutputPyFunc(self, apply_map):
    # The `tf.py_func()` op returns a list of tensors for its outputs.
    def _map_fn(x_tensor):
      def _map_py_func(x):
        return x, np.array(37.0, dtype=np.float64)
      return script_ops.py_func(
          _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _map_fn)
    self.assertDatasetProduces(
        dataset, expected_output=[(i, 37.0) for i in range(10)])

  @combinations.generate(_test_combinations())
  def testSparse(self, apply_map):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=np.array([[0, 0]]),
          values=(i * np.array([1])),
          dense_shape=np.array([1, 1]))

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _sparse)
    self.assertDatasetProduces(
        dataset, expected_output=[_sparse(i) for i in range(10)])

  @combinations.generate(_test_combinations())
  def testSparseChain(self, apply_map):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=np.array([[0, 0]]),
          values=(i * np.array([1])),
          dense_shape=np.array([1, 1]))

    def _check(i):
      self.assertTrue(sparse_tensor.is_sparse(i))
      return sparse_ops.sparse_concat(0, [i, i])

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _sparse)
    dataset = apply_map(dataset, _check)

    self.assertDatasetProduces(
        dataset,
        expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])

  @combinations.generate(_test_combinations_with_mode("eager"))
  def testSparseMapShapeInference(self, apply_map):
    row_lengths = np.random.randint(0, 4, size=128)
    values = np.ones(np.sum(row_lengths))
    sparse = ragged_tensor.RaggedTensor.from_row_lengths(
        values, row_lengths).to_sparse()
    dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
    dataset = dataset.batch(32, drop_remainder=True)
    dataset = apply_map(dataset, lambda x: x)
    self.assertEqual((32, 3), dataset.element_spec.shape)

  @combinations.generate(_test_combinations_with_mode("eager"))
  def testSparseMapShapeInferencePartial(self, apply_map):
    row_lengths = np.random.randint(0, 4, size=128)
    values = np.ones(np.sum(row_lengths))
    sparse = ragged_tensor.RaggedTensor.from_row_lengths(
        values, row_lengths).to_sparse()
    dataset = dataset_ops.Dataset.from_tensor_slices(sparse)
    dataset = dataset.batch(32, drop_remainder=False)
    dataset = apply_map(dataset, lambda x: x)
    self.assertEqual([None, 3], dataset.element_spec.shape.as_list())

  @combinations.generate(_test_combinations())
  def testTensorArray(self, apply_map):

    def _tensor_array(i):
      i = math_ops.cast(i, dtypes.int32)
      return (
          tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
          .unstack(math_ops.range(i, dtype=dtypes.int32)))

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _tensor_array)
    self.assertDatasetProduces(
        dataset, expected_output=[list(range(i)) for i in range(10)])

  @combinations.generate(_test_combinations())
  def testTensorArrayChain(self, apply_map):

    def _tensor_array(i):
      i = math_ops.cast(i, dtypes.int32)
      return (
          tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
          .unstack(math_ops.range(i, dtype=dtypes.int32)))

    def _check(x):
      self.assertIsInstance(x, tensor_array_ops.TensorArray)
      return x.identity()

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _tensor_array)
    dataset = apply_map(dataset, _check)

    self.assertDatasetProduces(
        dataset,
        expected_output=[list(range(i)) for i in range(10)])

  @combinations.generate(_test_combinations())
  def testRagged(self, apply_map):

    def _ragged(i):
      return ragged_tensor.RaggedTensor.from_tensor(i * [[1]])

    dataset = dataset_ops.Dataset.range(5)
    dataset = apply_map(dataset, _ragged)
    self.assertDatasetProduces(
        dataset,
        expected_output=[ragged_factory_ops.constant([[i]]) for i in range(5)])

  @combinations.generate(_test_combinations())
  def testRaggedChain(self, apply_map):

    def _ragged(i):
      return ragged_tensor.RaggedTensor.from_tensor(i * [[1]])

    def _concat(i):
      self.assertTrue(ragged_tensor.is_ragged(i))
      return ragged_concat_ops.concat([i, i], 0)

    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, _ragged)
    dataset = apply_map(dataset, _concat)

    self.assertDatasetProduces(
        dataset,
        expected_output=[
            self.evaluate(_concat(ragged_factory_ops.constant([[i]])))
            for i in range(10)
        ])

  @combinations.generate(_test_combinations_with_mode("graph"))
  def testParallelMapOutOfRangeError(self, apply_map):

    def raising_py_func(i):
      if i == 100:
        raise StopIteration()
      else:
        return i

    dataset = dataset_ops.Dataset.range(105)
    dataset = apply_map(
        dataset,
        lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64),
        num_parallel_calls=2)
    get_next = self.getNext(dataset)
    for i in range(100):
      self.assertEqual(i, self.evaluate(get_next()))
    # When the map function in `MapDataset` raises an OutOfRange error, TF1 and
    # TF2 behave differently. TF1 raises an OutOfRangeError to signal the end of
    # sequence while TF2 raises an InvalidArgumentError. This behavior is
    # controlled by the `preserve_cardinality` argument of `map` transformation
    # which is set to `True` for TF2 and `False` for TF1, which is for backward
    # compatibility.
    if tf2.enabled():
      with self.assertRaises(errors.InvalidArgumentError):
        self.evaluate(get_next())
    else:
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testConstantOutput(self, apply_map):
    dataset = dataset_ops.Dataset.range(10)
    dataset = apply_map(dataset, lambda x: [x, "hello", 10])
    self.assertDatasetProduces(dataset, [(i, b"hello", 10) for i in range(10)])

  @combinations.generate(test_base.graph_only_combinations())
  def testWarnOnSeedFromOuterGraph(self):
    with ops.Graph().as_default() as g:
      g.seed = 10
      warnings.simplefilter("always")

      def _check_warning(caught_warnings, expected_result):
        found_warning = False
        for warning in caught_warnings:
          if ("Explicitly set the seed in the function if this is not the "
              "intended behavior" in str(warning)):
            found_warning = True
            break
        self.assertEqual(found_warning, expected_result)

      # map_fun doesn't use seed, so no warning is generated.
      with warnings.catch_warnings(record=True) as w:
        _ = dataset_ops.Dataset.range(10).map(math_ops.square)
      _check_warning(w, False)

      def random_func(x):
        x = math_ops.add(x, 1)
        random_ops.random_shuffle([x, math_ops.square(x)])
        return x

      with warnings.catch_warnings(record=True) as w:
        _ = dataset_ops.Dataset.range(10).map(random_func)
      _check_warning(w, True)

      def random_func_seeded(x):
        ops.get_default_graph().seed = None
        random_ops.random_shuffle(x)
        return x

      with warnings.catch_warnings(record=True) as w:
        _ = dataset_ops.Dataset.range(10).batch(2).map(random_func_seeded)
      _check_warning(w, False)

      with warnings.catch_warnings(record=True) as w:
        _ = dataset_ops.Dataset.range(10).batch(2).map(
            lambda x: random_ops.random_shuffle(x, seed=37))
      _check_warning(w, False)

  @combinations.generate(_test_combinations())
  def testNestedDatasetMap(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
    dataset = apply_map(dataset, dataset_ops.Dataset.from_tensor_slices)
    dataset = apply_map(dataset, lambda ds: ds.batch(3)).flat_map(lambda x: x)

    self.assertDatasetProduces(dataset, expected_output=[[1.0, 2.0, 3.0]])

  @combinations.generate(_test_combinations())
  def testReturnValueError(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
    with self.assertRaisesRegex(
        TypeError, r"Unsupported return value from function passed to "
        r"Dataset.map\(\)"):
      _ = apply_map(dataset, lambda x: Foo)

  @combinations.generate(test_base.default_test_combinations())
  def testBrokenFunctionErrorOnInitialization(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])

    def broken_function(_):
      """A function deliberately designed to fail on instantiation."""
      value = []
      tensor_value = attr_value_pb2.AttrValue()
      tensor_value.tensor.CopyFrom(
          tensor_util.make_tensor_proto(
              value, dtype=dtypes.float32, shape=[0], verify_shape=False))
      dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)

      # Create a "Const" op with a `tf.float32` value and a `tf.int32` type.
      const_tensor = ops.get_default_graph().create_op(
          "Const", [], [dtypes.int32],
          attrs={
              "value": tensor_value,
              "dtype": dtype_value
          },
          name="BrokenConst").outputs[0]
      return const_tensor

    dataset = dataset.map(broken_function)
    self.assertDatasetProduces(
        dataset, expected_error=(errors.InvalidArgumentError, "Type mismatch"))

  @combinations.generate(
      combinations.times(
          _test_combinations_with_mode("graph"),
          combinations.combine(num_parallel_calls=[None, 12])))
  def testNoInterOpParallelism(self, apply_map, num_parallel_calls):
    dataset = dataset_ops.Dataset.from_tensors(0)

    def _get_tid():
      return np.int64(threading.current_thread().ident)

    def _map_fn(_):
      tids = []
      for _ in range(10):
        tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
      return tids

    dataset = apply_map(dataset, _map_fn)
    dataset._variant_tensor.op._set_attr("use_inter_op_parallelism",
                                         attr_value_pb2.AttrValue(b=False))
    get_next = self.getNext(dataset)

    tids = self.evaluate(get_next())
    self.assertTrue(all(tids[0] == tid for tid in tids))

  @combinations.generate(
      combinations.times(_test_combinations(), _short_circuit_test_cases(),
                         combinations.combine(num_parallel_calls=[None, 12])))
  def testShortCircuit(self, apply_map, structure, fn, num_parallel_calls):
    dataset = self.structuredDataset(structure).repeat()
    dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls)
    get_next = self.getNext(dataset)

    if isinstance(structure, tuple):
      expected = fn(*self.evaluate(self.structuredElement(structure)))
    else:
      expected = fn(self.evaluate(self.structuredElement(structure)))
    self.assertEqual(expected, self.evaluate(get_next()))

  @combinations.generate(
      combinations.times(_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 12])))
  def testShortCircuitCapturedInput(self, apply_map, num_parallel_calls):
    captured_t = variables.Variable(42)
    dataset = self.structuredDataset(None).repeat()
    dataset = apply_map(
        dataset, lambda x: captured_t, num_parallel_calls=num_parallel_calls)
    self.evaluate(variables.global_variables_initializer())
    get_next = self.getNext(dataset, requires_initialization=True)

    self.assertEqual(42, self.evaluate(get_next()))

  @combinations.generate(
      combinations.combine(
          tf_api_version=2,
          mode=["eager", "graph"],
          num_parallel_calls=[None, 12]))
  def testPreserveCardinality(self, num_parallel_calls):

    def py_fn(_):
      raise StopIteration()

    dataset = dataset_ops.Dataset.from_tensors(0).map(
        lambda x: script_ops.py_func(py_fn, [x], dtypes.int64),
        num_parallel_calls=num_parallel_calls)
    get_next = self.getNext(dataset)
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations_with_mode("graph"))
  def testCollectionCopy(self, apply_map):
    w = variable_scope.get_variable("w", [])
    self.assertIn(w, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))

    def func(x):
      self.assertIn(w, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
      return x

    dataset = dataset_ops.Dataset.from_tensors(constant_op.constant(1.0))
    _ = apply_map(dataset, func)

  @combinations.generate(
      combinations.times(
          _test_combinations_with_mode_v1("graph"),
          combinations.combine(num_parallel_calls=[None, 12])))
  def testMapCancellation(self, apply_map, num_parallel_calls):
    # Checks that a cancellation of is threaded through to map transformation.
    queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())

    def fn(_):
      return queue.dequeue()

    dataset = dataset_ops.Dataset.range(1)
    dataset = apply_map(dataset, fn, num_parallel_calls=num_parallel_calls)
    get_next = self.getNext(dataset, requires_initialization=True)

    with self.cached_session() as sess:
      thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(),))
      thread.start()
      time.sleep(0.2)
      sess.close()
      thread.join()

  # TODO(b/126553094): map doesnt work with variable defined inside function in
  # eager mode, possible Graph tensors leak out of the function building context
  # from function graph in eager mode as variables are created in init_scope.
  @combinations.generate(test_base.graph_only_combinations())
  def testCreateVariableInsideFunctionWithGetter(self):

    def func(_):
      with variable_scope.variable_scope(
          "variable", reuse=variable_scope.AUTO_REUSE):
        counter_var = variable_scope.get_variable(
            "counter", (), dtypes.int32, use_resource=True)
      return counter_var.assign_add(1)

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)

    if hasattr(dataset, "map_with_legacy_function"):
      # NOTE: In the legacy function, resource is captured by value.
      with self.assertRaisesWithPredicateMatch(
          AttributeError, ".*Tensor.* object has no attribute 'assign_add'"
      ):
        dataset.map_with_legacy_function(func)

    dataset = dataset.map(func)
    self.evaluate(variables.global_variables_initializer())

    get_next = self.getNext(dataset, requires_initialization=True)

    for i in range(10):
      self.assertEqual(i + 1, self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(_test_combinations())
  def testCaptureVariable(self, apply_map):
    counter_var = variable_scope.get_variable(
        "counter", (), dtypes.int32, use_resource=True)
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
    dataset = apply_map(dataset, lambda _: counter_var.assign_add(1))
    options = options_lib.Options()
    options.experimental_optimization.inject_prefetch = False
    dataset = dataset.with_options(options)
    get_next = self.getNext(dataset, requires_initialization=True)

    self.evaluate(counter_var.initializer)

    for i in range(10):
      self.assertEqual(i, self.evaluate(counter_var))
      self.assertEqual(i + 1, self.evaluate(get_next()))
    self.assertEqual(10, self.evaluate(counter_var))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(10, self.evaluate(counter_var))

  @combinations.generate(_test_combinations_with_mode_v1("graph"))
  def testCaptureUninitializedVariableError(self, apply_map):
    counter_var = variable_scope.get_variable(
        "counter", (), dtypes.int32, use_resource=True)
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
    dataset = apply_map(dataset, lambda _: counter_var.assign_add(1))

    get_next = self.getNext(dataset, requires_initialization=True)
    with self.assertRaises(errors.NotFoundError):
      self.evaluate(get_next())

  # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
  @combinations.generate(_test_combinations_with_mode_v1("graph"))
  def testCaptureConstantsWithConflictingDevices(self, apply_map):
    config = config_pb2.ConfigProto(device_count={"CPU": 3})
    with self.cached_session(config=config):
      with ops.device("/device:CPU:0"):
        a = constant_op.constant(3.0)
      with ops.device("/device:CPU:1"):
        b = constant_op.constant(5.0)

      def func(_):
        return math_ops.add(a, b)

      dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
      dataset = apply_map(dataset, func)
      expected_output = [8.0] * 10
      self.assertDatasetProduces(dataset, expected_output=expected_output)

  # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
  @combinations.generate(_test_combinations_with_mode_v1("graph"))
  def testReferenceVariablesWithMultipleDevices(self, apply_map):
    config = config_pb2.ConfigProto(device_count={"CPU": 3})
    with self.cached_session(config=config):

      with ops.device("/device:CPU:0"):
        a = variable_v1.VariableV1(3.0)
      with ops.device("/device:CPU:1"):
        b = variable_v1.VariableV1(5.0)

      def func(_):
        nonlocal a, b
        return math_ops.add(a, b)

      # NOTE: Use the legacy function implementation as eager function will
      # convert RefVariables to ResourceVariables.
      dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
      dataset = apply_map(dataset, func)
      self.evaluate(variables.global_variables_initializer())
      expected_output = [8.0] * 10
      self.assertDatasetProduces(
          dataset,
          expected_output=expected_output,
          requires_initialization=True)

  # TODO(b/121264236): add eager mode coverage when we have multi-device setup.
  @combinations.generate(_test_combinations_with_mode_v1("graph"))
  def testResourceVariablesWithMultipleDevices(self, apply_map):
    config = config_pb2.ConfigProto(device_count={"CPU": 3})

    def func(_):
      with variable_scope.variable_scope(
          "variable", reuse=variable_scope.AUTO_REUSE):
        with ops.device("/device:CPU:0"):
          a_var = variable_scope.get_variable(
              "a", (), dtypes.int32, use_resource=True)
          a_var = math_ops.add(a_var, 1)
        with ops.device("/device:CPU:1"):
          b_var = variable_scope.get_variable(
              "b", (), dtypes.int32, use_resource=True)
      return math_ops.add(a_var, b_var)

    g = ops.Graph()
    with self.session(config=config, graph=g):
      dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
      dataset = apply_map(dataset, func)
      self.evaluate(variables.global_variables_initializer())
      expected_output = [1] * 10
      self.assertDatasetProduces(
          dataset,
          expected_output=expected_output,
          requires_initialization=True)

  @combinations.generate(
      combinations.times(
          _test_combinations(),
          combinations.combine(
              local_determinism=[None, True, False],
              global_determinism=[True, False])))
  def testDeterminismConfiguration(self, apply_map, local_determinism,
                                   global_determinism):
    expect_determinism = local_determinism or (local_determinism is None and
                                               global_determinism)
    elements = list(range(1000))

    def dataset_fn(delay_ms):

      def sleep(x):
        time.sleep(delay_ms / 1000)
        return x

      def map_function(x):
        if math_ops.equal(x, 0):
          return script_ops.py_func(sleep, [x], x.dtype)
        else:
          return x

      dataset = dataset_ops.Dataset.from_tensor_slices(elements)
      dataset = apply_map(
          dataset,
          map_function,
          num_parallel_calls=2,
          deterministic=local_determinism)
      opts = options_lib.Options()
      opts.deterministic = global_determinism
      dataset = dataset.with_options(opts)
      return dataset

    self.checkDeterminism(
        dataset_fn, expect_determinism, expected_elements=elements)

  @combinations.generate(_test_combinations())
  def testNoneComponent(self, apply_map):
    dataset = dataset_ops.Dataset.from_tensors((42, None))

    def map_function(x, y):
      if y is None:
        return x / 2
      return x

    dataset = apply_map(dataset, map_function)
    self.assertDatasetProduces(dataset, expected_output=[21])

  @combinations.generate(test_base.eager_only_combinations())
  def testCheckpointLargeBuffer(self):
    if (pywrap_sanitizers.is_asan_enabled() or
        pywrap_sanitizers.is_tsan_enabled() or
        pywrap_sanitizers.is_msan_enabled()):
      self.skipTest("Skip to avoid OOM when using sanitizers.")
    dataset = dataset_ops.Dataset.range(10).batch(2)
    dataset = dataset.map(
        # Create tensors of size 512M.
        lambda seed: stateless_random_ops.stateless_random_uniform(
            (128, 1024, 1024), seed, dtype=dtypes.float32
        )
    )
    # Set parallelism to 5 to exceed the 2GB protobuf limit
    dataset = dataset.map(lambda x: x * 2, num_parallel_calls=5)
    iterator = iter(dataset)
    next(iterator)  # Request an element to fill the parallel map buffer
    time.sleep(1)  # Give buffers some time to fill
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    manager.save()
    del dataset
    del iterator
    manager.restore_or_initialize()

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 1])))
  def testName(self, num_parallel_calls):
    dataset = dataset_ops.Dataset.from_tensors(21).map(
        lambda x: x * 2, num_parallel_calls=num_parallel_calls, name="map")
    self.assertDatasetProduces(dataset, [42])

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 1])))
  def testStatusMessage(self, num_parallel_calls):
    dataset = dataset_ops.Dataset.from_tensors(21).map(
        lambda x: x // 0, num_parallel_calls=num_parallel_calls, name="map")
    options = options_lib.Options()
    options.experimental_optimization.apply_default_optimizations = False
    dataset = dataset.with_options(options)
    get_next = self.getNext(dataset)
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r".*Error in user-defined function passed to .* transformation with "
        r"iterator: Iterator::Root::.*"):
      self.evaluate(get_next())


class MapCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                        parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          checkpoint_test_base.default_test_combinations(),
          combinations.combine(
              num_parallel_calls=[None, 2], symbolic_checkpoint=[False, True])))
  def testCore(self, verify_fn, num_parallel_calls, symbolic_checkpoint):

    tensor_slice_len = 7
    num_epochs = 2
    multiplier = 37.0

    def _build_ds():

      components = (np.arange(tensor_slice_len), np.array([[1, 2, 3]]) *
                    np.arange(tensor_slice_len)[:, np.newaxis],
                    np.array(multiplier) * np.arange(tensor_slice_len))

      def _map_fn(x, y, z):
        return math_ops.square(x), math_ops.square(y), math_ops.square(z)

      dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
          _map_fn, num_parallel_calls=num_parallel_calls).repeat(num_epochs)
      options = options_lib.Options()
      options.experimental_symbolic_checkpoint = symbolic_checkpoint
      return dataset.with_options(options)

    verify_fn(self, _build_ds, tensor_slice_len * num_epochs)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testSaveStatefulFunction(self, num_parallel_calls):

    def _build_ds():

      def _map_fn(x):
        return random_ops.random_uniform(
            (), 0, 10, dtype=dtypes.int32) * math_ops.cast(x, dtypes.int32)

      return dataset_ops.Dataset.range(100).map(
          _map_fn, num_parallel_calls=num_parallel_calls)

    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testCaptureVariableInMapFn(self, num_parallel_calls):

    def _build_ds():
      counter_var = variable_scope.get_variable(
          "counter", (), dtypes.int32, use_resource=True)
      return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
          lambda _: counter_var.assign_add(1),
          num_parallel_calls=num_parallel_calls))

    self.verify_error_on_save(_build_ds, 15, errors.FailedPreconditionError)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testCaptureConstantInMapFn(self, verify_fn, num_parallel_calls):
    num_outputs = 10

    def _build_ds():
      constant_var = constant_op.constant(5)
      return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
          lambda x: x + constant_var, num_parallel_calls=num_parallel_calls))

    verify_fn(self, _build_ds, num_outputs)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testCaptureDefunInMapFn(self, verify_fn, num_parallel_calls):
    num_outputs = 10

    def _build_ds():

      @function.Defun(dtypes.int64)
      def defun_fn(x):
        return constant_op.constant(1000) + math_ops.cast(x, dtypes.int32)

      return dataset_ops.Dataset.range(num_outputs).map(
          defun_fn, num_parallel_calls=num_parallel_calls)

    verify_fn(self, _build_ds, num_outputs)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testBuildDefunInMapFn(self, verify_fn, num_parallel_calls):
    num_outputs = 10

    def _build_ds():

      @function.Defun(dtypes.int64)
      def defun_fn(x):

        @function.Defun(dtypes.int32)
        def defun_fn_deep(x):
          return constant_op.constant(1000) + math_ops.cast(x, dtypes.int32)

        return constant_op.constant(11000) + defun_fn_deep(
            math_ops.cast(x, dtypes.int32))

      return dataset_ops.Dataset.range(num_outputs).map(
          defun_fn, num_parallel_calls=num_parallel_calls)

    verify_fn(self, _build_ds, num_outputs)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         checkpoint_test_base.default_test_combinations(),
                         combinations.combine(num_parallel_calls=[None, 2])))
  def testSparse(self, verify_fn, num_parallel_calls):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=np.array([[0, 0]]),
          values=(i * np.array([1])),
          dense_shape=np.array([1, 1]))

    def _build_ds(num_outputs):
      return dataset_ops.Dataset.range(num_outputs).map(
          _sparse, num_parallel_calls=num_parallel_calls)

    num_outputs = 10
    verify_fn(self, lambda: _build_ds(num_outputs), num_outputs=num_outputs)


class MapRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(test_base.v2_only_combinations(),
                         combinations.combine(index=[-1, 4, 5])))
  def testInvalidIndex(self, index):
    dataset = dataset_ops.Dataset.from_tensor_slices([-1, 0, 1,
                                                      2]).map(lambda x: x * 2)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(random_access.at(dataset, index=index))

  @combinations.generate(
      combinations.times(test_base.v2_only_combinations(),
                         combinations.combine(index=[-1, 0])))
  def testEmptyDataset(self, index):
    dataset = dataset_ops.Dataset.from_tensor_slices([]).map(lambda x: x // 2)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(random_access.at(dataset, index=index))

  @combinations.generate(combinations.times(test_base.v2_only_combinations()))
  def testBasic(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4,
                                                      5]).map(lambda x: x * 3)
    for i in range(5):
      self.assertEqual(self.evaluate(random_access.at(dataset, index=i)), i * 3)

  @combinations.generate(
      combinations.times(
          test_base.v2_only_combinations(),
          combinations.combine(
              elements=[0, 10, 20, 40], num_parallel_calls=[None, 2])))
  def testMultipleCombinations(self, elements, num_parallel_calls):
    dataset = dataset_ops.Dataset.range(elements).map(
        lambda x: x // 2, num_parallel_calls=num_parallel_calls)
    for i in range(elements):
      self.assertEqual(
          self.evaluate(random_access.at(dataset, index=i)), i // 2)

  @combinations.generate(
      combinations.times(
          test_base.v2_only_combinations(),
          combinations.combine(
              elements=[0, 10, 20, 40], num_parallel_calls=[None, 2])))
  def testMapFnInFunction(self, elements, num_parallel_calls):

    @def_function.function
    def _map_fn(x):
      return math_ops.square(x)

    dataset = dataset_ops.Dataset.range(elements).map(
        _map_fn, num_parallel_calls=num_parallel_calls)
    for i in range(elements):
      self.assertEqual(
          self.evaluate(random_access.at(dataset, index=i)),
          self.evaluate(math_ops.square(i)))


class MapGlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.v2_only_combinations(),
          combinations.combine(
              dataset_range=[100],
              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE],
              deterministic=[True, False])))
  def testMapV2(  # V2 API preserves cardinality by default.
      self, dataset_range: int, num_parallel_calls: int, deterministic: bool):
    dataset = dataset_ops.Dataset.range(dataset_range)
    dataset = dataset.map(
        lambda x: x * 2,
        num_parallel_calls=num_parallel_calls,
        deterministic=deterministic)
    dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
    dataset = global_shuffle_op._global_shuffle(dataset)
    # Disables optimizations (e.g.: `map_parallelization`), to make sure we test
    # both `Map` and `ParallelMap`.
    # TODO(b/325112575): Support warm-start. With warm-start, prefetching uses
    # the unintended IteratorContext here:
    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/data/prefetch_dataset_op.cc#L197-L199.
    options = options_lib.Options()
    options.experimental_optimization.apply_default_optimizations = False
    options.experimental_warm_start = False
    dataset = dataset.with_options(options)

    expected = list(range(0, dataset_range * 2, 2))
    dataset_output = self.getDatasetOutput(
        dataset, requires_initialization=True)
    self.assertCountEqual(dataset_output, expected)
    self.assertNotEqual(dataset_output, expected)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(
              dataset_range=[100],
              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE],
              deterministic=[True, False])))
  def testMapV1AndV2(
      self, dataset_range: int, num_parallel_calls: int, deterministic: bool):
    dataset = dataset_ops.Dataset.range(dataset_range)
    dataset_cardinality = dataset.cardinality()
    dataset = dataset.map(
        lambda x: x * 2,
        num_parallel_calls=num_parallel_calls,
        deterministic=deterministic)
    dataset = dataset.apply(cardinality.assert_cardinality(dataset_cardinality))
    dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
    dataset = global_shuffle_op._global_shuffle(dataset)

    expected = list(range(0, dataset_range * 2, 2))
    dataset_output = self.getDatasetOutput(
        dataset, requires_initialization=True)
    self.assertCountEqual(dataset_output, expected)
    self.assertNotEqual(dataset_output, expected)


class MapGlobalShuffleCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                     parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.v2_only_combinations(),
          checkpoint_test_base.default_test_combinations(),
          combinations.combine(
              dataset_range=[10],
              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE],
              reshuffle_each_iteration=[True, False],
              symbolic_checkpoint=[True, False])))
  def testMapV2(  # V2 API preserves cardinality by default.
      self,
      verify_fn: Callable[..., None],
      dataset_range: int,
      num_parallel_calls: int,
      reshuffle_each_iteration: bool,
      symbolic_checkpoint: bool):

    def _build_dataset() -> dataset_ops.Dataset:
      dataset = dataset_ops.Dataset.range(dataset_range)
      dataset = dataset.map(
          lambda x: x * 2,
          num_parallel_calls=num_parallel_calls,
          deterministic=True)
      dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
      dataset = global_shuffle_op._global_shuffle(
          dataset, seed=42, reshuffle_each_iteration=reshuffle_each_iteration)
      options = options_lib.Options()
      options.experimental_optimization.apply_default_optimizations = False
      options.experimental_warm_start = False
      options.experimental_symbolic_checkpoint = symbolic_checkpoint
      return dataset.with_options(options)

    verify_fn(
        self,
        _build_dataset,
        num_outputs=dataset_range,
        assert_items_equal=reshuffle_each_iteration)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          checkpoint_test_base.default_test_combinations(),
          combinations.combine(
              dataset_range=[10],
              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE],
              reshuffle_each_iteration=[True, False],
              symbolic_checkpoint=[True, False])))
  def testMapV1AndV2(
      self,
      verify_fn: Callable[..., None],
      dataset_range: int,
      num_parallel_calls: int,
      reshuffle_each_iteration: bool,
      symbolic_checkpoint: bool):

    def _build_dataset() -> dataset_ops.Dataset:
      dataset = dataset_ops.Dataset.range(dataset_range)
      dataset_cardinality = dataset.cardinality()
      dataset = dataset.map(
          lambda x: x * 2,
          num_parallel_calls=num_parallel_calls,
          deterministic=True)
      dataset = dataset.apply(
          cardinality.assert_cardinality(dataset_cardinality))
      dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
      dataset = global_shuffle_op._global_shuffle(
          dataset, seed=42, reshuffle_each_iteration=reshuffle_each_iteration)
      options = options_lib.Options()
      options.experimental_optimization.apply_default_optimizations = False
      options.experimental_symbolic_checkpoint = symbolic_checkpoint
      return dataset.with_options(options)

    verify_fn(
        self,
        _build_dataset,
        num_outputs=dataset_range,
        assert_items_equal=reshuffle_each_iteration)


if __name__ == "__main__":
  test.main()