tensorflow/tensorflow

View on GitHub
tensorflow/python/distribute/parallel_device/parallel_device_test.py

Summary

Maintainability
F
4 days
Test Coverage
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
import threading

from absl.testing import parameterized
from tensorflow.python.checkpoint import checkpoint as tracking
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_switch_case
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.util import nest

# When running collectives asynchronously, we need to give each parallel device
# execution a unique ID so the collectives don't interfere. Since the op is
# replicated with group/instance key intact, the replicated nodes will
# communicate.
# TODO(allenl): Switch to using a collective manager.
_COUNTER_LOCK = threading.Lock()
_COUNTER = 100


def _collective_reduce(inputs, operation, num_replicas):

  def _reduce_tensor(tensor):
    with _COUNTER_LOCK:
      global _COUNTER
      keys = _COUNTER
      _COUNTER += 1
    return collective_ops.all_reduce_v2(
        t=tensor,
        group_size=num_replicas,
        merge_op=operation,
        group_key=keys,
        instance_key=keys)

  return nest.map_structure(_reduce_tensor, inputs)


def _collective_sum(inputs, num_replicas):
  return _collective_reduce(
      inputs=inputs, operation="Add", num_replicas=num_replicas)


class _Dense(module.Module):

  def __init__(self, output_size):
    self.output_size = output_size
    self.kernel = None
    self.bias = None

  def __call__(self, x):
    if self.kernel is None:
      self.kernel = variables.Variable(
          array_ops.ones(
              array_ops_stack.stack([self.output_size,
                                     array_ops.shape(x)[-1]])))
      self.bias = variables.Variable(array_ops.ones([self.output_size]))
    return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias


class _VirtualDeviceTestCase(test.TestCase):

  def setUp(self):
    super(_VirtualDeviceTestCase, self).setUp()
    ctx = context.context()
    if ctx.list_physical_devices("TPU"):
      self.device_type = "TPU"
      tpu_cluster_resolver.initialize_tpu_system()
    elif ctx.list_physical_devices("GPU"):
      self.device_type = "GPU"
      gpus = ctx.list_physical_devices(self.device_type)
      ctx.set_logical_device_configuration(gpus[0], [
          context.LogicalDeviceConfiguration(memory_limit=100),
          context.LogicalDeviceConfiguration(memory_limit=100),
      ])
    else:
      self.device_type = "CPU"
      cpus = ctx.list_physical_devices("CPU")
      ctx.set_logical_device_configuration(cpus[0], [
          context.LogicalDeviceConfiguration(),
          context.LogicalDeviceConfiguration(),
      ])

    self.device = parallel_device.ParallelDevice(components=[
        "/job:localhost/device:{}:0".format(self.device_type),
        self.device_type + ":1"
    ])
    self.assertIn(self.device_type + ":0", self.device.components[0])
    self.assertIn(self.device_type + ":1", self.device.components[1])


class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):

  def test_register_parallel_device(self):
    with self.device:
      c = constant_op.constant(1.)
      d = constant_op.constant(2.)
      e = c + d
      outputs = self.device.unpack(e)
    self.assertAllClose([3., 3.], outputs)

    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

  def test_no_implicit_copyon(self):
    a1 = constant_op.constant(1.)
    a2 = constant_op.constant(2.)

    with self.device:
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          "First pack non-parallel tensors for each device"):
        a1 + a2  # pylint:disable=pointless-statement

  def test_error_message_length(self):
    x = array_ops.ones([3, 3, 3, 3, 3, 3])

    with self.device:
      with self.assertRaisesRegex(
          errors.InvalidArgumentError,
          r"TensorHandle\((.|\n){1,150}\[...\], shape="):
        array_ops.identity(x)

  def test_one_replica_eager_control_flow(self):
    device = parallel_device.ParallelDevice(components=[
        "/job:localhost/device:{}:0".format(self.device_type),
    ])
    x = constant_op.constant([2, 3, 4])
    with device:
      x = device.pack([x])
      if math_ops.reduce_any(math_ops.equal(x, constant_op.constant(4))):
        y = constant_op.constant(1)
      else:
        y = constant_op.constant(2)
    self.assertAllEqual([1], device.unpack(y))

  @parameterized.named_parameters(
      ("variable", variables.Variable),
      ("tensor", lambda x: x))
  def test_string_representation(self, transform):
    x = self.device.pack(
        [constant_op.constant([5., 6.]),
         constant_op.constant([6., 7.])])
    with self.device:
      x = transform(x)
    parallel_str = str(x)
    self.assertIn("5", parallel_str)
    self.assertIn("7", parallel_str)
    self.assertIn(self.device_type + ":0", parallel_str)
    self.assertIn(self.device_type + ":1", parallel_str)
    parallel_repr = repr(x)
    self.assertIn("5", parallel_repr)
    self.assertIn("7", parallel_repr)
    self.assertIn(self.device_type + ":0", parallel_repr)
    self.assertIn(self.device_type + ":1", parallel_repr)

  def test_device_id(self):
    device_ids = self.device.unpack(self.device.device_ids)
    self.assertAllClose([0, 1], device_ids)
    # TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
    # Currently backing_device is CPU.
    self.assertIn(self.device.components[0], device_ids[0].device)
    self.assertIn(self.device.components[1], device_ids[1].device)

  def test_zeros(self):
    with self.device:
      x = array_ops.zeros([array_ops.identity(constant_op.constant(10))])
    for component in self.device.unpack(x):
      self.assertAllClose([0.] * 10, component)

  def test_generator(self):
    with self.device:
      g_same = stateful_random_ops.Generator.from_seed(0)
      g_different = stateful_random_ops.Generator.from_seed(
          self.device.device_ids)
      same = g_same.normal([10])
      different = g_different.normal([10])
    same_unpacked = self.device.unpack(same)
    different_unpacked = self.device.unpack(different)
    for same_component, different_component in zip(same_unpacked[1:],
                                                   different_unpacked[1:]):
      self.assertAllClose(same_component, same_unpacked[0])
      self.assertNotAllClose(different_component, different_unpacked[0])

  def test_collective_reduce(self):
    x = self.device.pack(
        [constant_op.constant(-1.5),
         constant_op.constant(3.5)])
    with self.device:
      reduced = _collective_sum(x, num_replicas=2)
      outputs = self.device.unpack(reduced)
    self.assertAllClose([2., 2.], outputs)
    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

  def test_collective_reduce_in_function(self):
    x = self.device.pack(
        [constant_op.constant(-1.5),
         constant_op.constant(3.5)])
    with self.device:

      @def_function.function
      def reduce(t):
        return _collective_sum(t, num_replicas=2)

      reduced = reduce(x)
      outputs = self.device.unpack(reduced)
    self.assertAllClose([2., 2.], outputs)
    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

  def test_collective_reduce_async_scope(self):
    # Note that ops on the parallel device currently don't execute
    # asynchronously. The test is just that we don't get deadlocks.
    x = self.device.pack(
        [constant_op.constant(-1.5),
         constant_op.constant(3.5)])
    with context.async_scope(), self.device:
      reduced = _collective_sum(x, num_replicas=2)
      outputs = self.device.unpack(reduced)
    self.assertAllClose([2., 2.], outputs)
    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

  def test_collective_reduce_async_context(self):
    previous = config.get_synchronous_execution()
    try:
      context._reset_context()
      config.set_synchronous_execution(False)
      self.setUp()
      # Note that ops on the parallel device currently don't execute
      # asynchronously. The test is just that we don't get deadlocks.
      x = self.device.pack(
          [constant_op.constant(-1.5),
           constant_op.constant(3.5)])
      with self.device:
        reduced = _collective_sum(x, num_replicas=2)
        outputs = self.device.unpack(reduced)
      self.assertAllClose([2., 2.], outputs)
      self.assertIn(self.device.components[0], outputs[0].backing_device)
      self.assertIn(self.device.components[1], outputs[1].backing_device)
    finally:
      context._reset_context()
      config.set_synchronous_execution(previous)

  def test_collective_broadcast_in_function(self):
    if self.device_type == "TPU":
      self.skipTest("ParallelDevice broadcast collectives on TPUs need work")

    @def_function.function
    def broadcast_send_recv(device_id):
      c = constant_op.constant([2])

      @def_function.function
      def send():
        s0 = collective_ops.broadcast_send(
            c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
        with ops.control_dependencies([s0.op]):
          return array_ops.identity(c)

      @def_function.function
      def recv():
        r0 = collective_ops.broadcast_recv(
            c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
        return r0

      return control_flow_switch_case.switch_case(
          device_id, branch_fns={
              0: send,
              1: recv
          })

    with self.device:
      result = broadcast_send_recv(self.device.device_ids)
    self.assertAllClose([[2], [6]], self.device.unpack(result))

  def test_use_in_graph_error_is_informative(self):
    @def_function.function
    def uses_parallel():
      with self.device:
        return self.device.unpack(array_ops.ones([]))

    with self.assertRaisesRegex(NotImplementedError, "inside `tf.function`"):
      uses_parallel()

  def test_checkpointing(self):
    self.skipTest("b/216201668: revisit parallel device and checkpointing.")

    prefix = os.path.join(self.get_temp_dir(), "ckpt")
    different_values = self.device.pack(
        [constant_op.constant(-1.),
         constant_op.constant(3.)])
    with self.device:
      v = variables.Variable(different_values)
      checkpoint = tracking.Checkpoint(v=v)
    save_path = checkpoint.save(prefix)
    with self.device:
      v.assign(constant_op.constant(0.))
    checkpoint.restore(save_path).assert_consumed()
    with self.device:
      outputs = self.device.unpack(v)
    self.assertAllClose([-1., 3.], outputs)

    with self.device:
      restore_on_create = tracking.Checkpoint()
      restore_on_create.restore(save_path)
      restore_on_create.v = variables.Variable(0.)
      outputs = self.device.unpack(restore_on_create.v)
    self.assertAllClose([-1., 3.], outputs)

    # Changing the number of devices / restoring into a single-device copy is OK
    single_device = tracking.Checkpoint(v=variables.Variable(0.))
    status = single_device.restore(save_path)
    status.assert_existing_objects_matched()
    self.assertAllClose(-1., single_device.v)
    with self.assertRaisesRegex(AssertionError, "parallel_component_1"):
      # There are parts of the variable that aren't restored into a
      # single-device copy.
      status.assert_consumed()

  def test_pack_composite(self):
    if self.device_type != "CPU":
      self.skipTest("Iterator GetNext doesn't work on accelerators.")
    datasets = [
        dataset_ops.Dataset.from_tensor_slices(
            [i + 1, (i + 1) * 2, (i + 1) * 3])
        for i in range(len(self.device.components))]
    parallel_dataset = self.device.pack(datasets)
    with self.device:
      iterator = iter(parallel_dataset)
      parallel_sample = next(iterator)
    component_iterators = self.device.unpack(iterator)
    self.assertEqual(2, next(component_iterators[0]).numpy())
    self.assertEqual(1, self.device.unpack(parallel_sample)[0].numpy())
    self.assertEqual(4, next(component_iterators[1]).numpy())
    self.assertEqual(2, self.device.unpack(parallel_sample)[1].numpy())

  def test_pack_structure(self):
    x_parts = [{"a": constant_op.constant(float(i))}
               for i in range(len(self.device.components))]
    x = self.device.pack(x_parts)
    self.assertAllClose([{"a": 0.}, {"a": 1.}], self.device.unpack(x))

  def test_pack_variable_value(self):
    x_parts = [variables.Variable(i)
               for i in range(len(self.device.components))]
    x = self.device.pack(x_parts)
    with self.device:
      x1 = self.device.pack(x_parts)
    for v in x_parts:
      v.assign(-10)  # Mutating the variable does not affect previous reads.
    self.assertAllClose([0, 1], self.device.unpack(x))
    self.assertAllClose([0, 1], self.device.unpack(x1))

  def test_unpack_variable_value(self):
    x_parts = [constant_op.constant(i)
               for i in range(len(self.device.components))]
    x = self.device.pack(x_parts)
    with self.device:
      v = variables.Variable(x)
      v_unpacked = self.device.unpack(v)
      v.assign(-10)  # Mutating the variable does not affect previous reads.
    self.assertAllClose([0, 1], v_unpacked)

  def test_saved_model(self):
    self.skipTest("b/216201668: revisit parallel device and saved model")

    different_values = self.device.pack(
        [constant_op.constant(-1.),
         constant_op.constant(3.)])
    with self.device:
      m = module.Module()
      m.v = variables.Variable(different_values)
      m.f = def_function.function(lambda: m.v * 2.)
      self.assertAllClose([-2., 6.], self.device.unpack(m.f()))
    saved_model_path = os.path.join(self.get_temp_dir(), "saved_model")
    save.save(m, saved_model_path)

    context._reset_context()
    self.setUp()

    single_device_loaded = load.load(saved_model_path)
    self.assertAllClose(-2., single_device_loaded.f())
    assign_value = self.device.pack(
        [constant_op.constant(.1), constant_op.constant(.2)])
    with self.device:
      parallel_loaded = load.load(saved_model_path)
      self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f()))
      self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v))
      parallel_loaded.v.assign(assign_value)
      self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))

  def _assert_close_to_non_parallel(self, computation):
    """Asserts that replication of `computation` works and is equivalent."""
    with self.device:
      parallel_result = computation()
    non_parallel_result = computation()
    # The computations should have the same number and structure of Tensor
    # objects, even though the tensors themselves will be on different devices
    # and represent different numbers of values.
    nest.assert_same_structure(parallel_result, non_parallel_result)
    non_parallel_flat = nest.flatten(non_parallel_result)
    parallel_flat = nest.flatten(parallel_result)
    self.assertGreater(len(parallel_flat), 0)
    for non_parallel, parallel in zip(non_parallel_flat, parallel_flat):
      self.assertEqual(self.device._name, parallel.device)
      self.assertNotEqual(self.device._name, non_parallel.device)
      for parallel_component in self.device.unpack(parallel):
        self.assertAllClose(non_parallel, parallel_component)

  def test_capturing(self):
    with self.device:
      x = constant_op.constant([1., 2.])
      x = array_ops.identity(x)

      @def_function.function
      def f(y):
        return x + y

      y = array_ops.ones([2])
      parallel_result = f(y)
    self.assertAllClose([[2., 3.]] * 2, self.device.unpack(parallel_result))

  def test_euclidean_norm(self):
    def _test_fn():
      with backprop.GradientTape() as tape:
        x = array_ops.ones([5, 5])
        tape.watch(x)
        y = math_ops.reduce_euclidean_norm(x, axis=constant_op.constant(1))
      return y, tape.gradient(y, x)
    self._assert_close_to_non_parallel(_test_fn)

  def test_reduce_sum(self):
    def _test_fn():
      with backprop.GradientTape() as tape:
        x = array_ops.ones([5, 5])
        tape.watch(x)
        y = math_ops.reduce_sum(x, axis=constant_op.constant(1))
      return y, tape.gradient(y, x)
    self._assert_close_to_non_parallel(_test_fn)

  def test_variable_created_in_function(self):
    captured_value = constant_op.constant(2.)

    class M(module.Module):

      def __init__(self):
        self.v = None
        self.w = None
        self.x = None
        self.z = None

      @def_function.function(autograph=False)
      def __call__(self, x):
        if self.v is None:
          with ops.init_scope():
            initial_value = constant_op.constant(2.)
            self.z = variables.Variable(initial_value)
          self.x = variables.Variable(captured_value)
          self.w = variables.Variable(lambda: constant_op.constant(2.))
          self.v = variables.Variable(constant_op.constant(2.))
        return x * self.v * self.w * self.x * self.z

    with self.device:
      m = M()
      packed_outputs = m(array_ops.ones([]))
      outputs = self.device.unpack(packed_outputs)
    self.assertAllClose([16., 16.], outputs)

  def test_different_shapes(self):
    x = self.device.pack(
        [constant_op.constant([1., 2.]),
         constant_op.constant([5.])])
    with self.device:
      y = x * 2.
    self.assertEqual([None], y.shape.as_list())
    self.assertAllClose([[2., 4.], [10.]], self.device.unpack(y))

    different_axes = self.device.pack(
        [constant_op.constant([1., 2.]),
         constant_op.constant([[5.]])])
    with self.assertRaisesRegex(Exception,
                                "components do not all have the same rank"):
      different_axes.shape  # pylint: disable=pointless-statement


class LayerTests(_VirtualDeviceTestCase):

  def test_layer_forward(self):
    with self.device:
      layer = _Dense(5)
      x = constant_op.constant([[2.]])
      y = layer(x)
      outputs = self.device.unpack(y)
    self.assertAllClose([[3.] * 5], outputs[0])
    self.assertAllClose([[3.] * 5], outputs[1])
    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

    # With different Layer inputs we get different outputs
    x = self.device.pack(
        [constant_op.constant([[-0.5]]),
         constant_op.constant([[0.5]])])
    with self.device:
      y = layer(x)
      outputs = self.device.unpack(y)
    self.assertGreater(
        math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
    self.assertIn(self.device.components[0], outputs[0].backing_device)
    self.assertIn(self.device.components[1], outputs[1].backing_device)

  def test_layer_sync_training(self):
    x = self.device.pack(
        [constant_op.constant([[-0.5]]),
         constant_op.constant([[0.5]])])
    with self.device:
      layer = _Dense(5)

      with backprop.GradientTape() as tape:
        y = layer(x)
        loss = (y - math_ops.range(5.))**2.
      parameters = layer.trainable_variables
      unreduced_gradients = tape.gradient(loss, parameters)
      reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
      for grad, param in zip(reduced_gradients, parameters):
        param.assign_sub(0.01 * grad)
    final_kernels = self.device.unpack(layer.kernel)
    self.assertAllClose(final_kernels[0], final_kernels[1])
    final_bias = self.device.unpack(layer.bias)
    expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
                     0.01 * 2. * (1. - .5 - math_ops.range(5.)))
    self.assertAllClose(expected_bias, final_bias[0], rtol=1e-4, atol=1e-4)
    self.assertAllClose(expected_bias, final_bias[1], rtol=1e-4, atol=1e-4)
    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
    self.assertIn(self.device.components[1], final_kernels[1].backing_device)

  def test_layer_divergent_buffer_training(self):
    x = self.device.pack(
        [constant_op.constant([[-0.5]]),
         constant_op.constant([[0.5]])])
    with self.device:
      layer = _Dense(5)

      with backprop.GradientTape() as tape:
        y = layer(x)
        loss = (y - math_ops.range(5.))**2.
      parameters = layer.trainable_variables
      unreduced_gradients = tape.gradient(loss, parameters)
      for grad, param in zip(unreduced_gradients, parameters):
        param.assign_sub(0.01 * grad)
    final_kernels = self.device.unpack(layer.kernel)
    self.assertNotAllClose(final_kernels[0], final_kernels[1])
    final_bias = self.device.unpack(layer.bias)
    self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
                        final_bias[0])
    self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
                        final_bias[1])
    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
    self.assertIn(self.device.components[1], final_kernels[1].backing_device)

  def test_training_loop(self):
    self.skipTest("b/216201668: revisit parallel device and checkpointing")
    for _ in range(5):
      layer = _Dense(5)
      checkpoint = tracking.Checkpoint(layer=layer)
      manager = checkpoint_management.CheckpointManager(
          checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
      manager.restore_or_initialize()

      for _ in range(10):
        x = self.device.pack(
            [constant_op.constant([[-0.5]]),
             constant_op.constant([[0.5]])])
        with self.device:
          with backprop.GradientTape() as tape:
            y = layer(x)
            loss = (y - math_ops.range(5.))**2.
          parameters = layer.trainable_variables
          unreduced_gradients = tape.gradient(loss, parameters)
          reduced_gradients = _collective_sum(
              unreduced_gradients, num_replicas=len(self.device.components))
          for grad, param in zip(reduced_gradients, parameters):
            param.assign_sub(0.01 * grad)

        manager.save()


if __name__ == "__main__":
  ops.enable_eager_execution()
  test.main()