tensorflow/tensorflow

View on GitHub
tensorflow/python/distribute/vars_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed values library."""

import itertools

import uuid
from absl.testing import parameterized
from tensorflow.python.checkpoint import checkpoint as trackable_utils
from tensorflow.python.checkpoint import checkpoint_management as ckpt_manager
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.util import variable_utils


def strategy_and_run_tf_function_combinations():
  # Test the combination of different strategies and whether a tf.function
  # is passed into strategy.run."""
  # TODO(b/197981388): re-enable MWMS test
  # return combinations.combine(
  #     distribution=[
  #         strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
  #     ],
  #     mode=["graph", "eager"],
  #     experimental_run_tf_function=[True, False],
  #     use_var_policy=[True, False]) +
  return combinations.combine(
      distribution=[
          strategy_combinations.tpu_strategy,
          strategy_combinations.tpu_strategy_packed_var,
      ],
      mode=["graph", "eager"],
      experimental_run_tf_function=[True],
      use_var_policy=[True, False])


def strategy_with_var_policy():
  return combinations.combine(
      distribution=[
          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
          # TODO(b/197981388): re-enable MWMS test
          # strategy_combinations.multi_worker_mirrored_2x1_cpu,
          # strategy_combinations.multi_worker_mirrored_2x1_gpu,
          strategy_combinations.tpu_strategy,
          strategy_combinations.tpu_strategy_packed_var,
      ],
      mode=["graph", "eager"],
      use_var_policy=[True, False])


class OnWriteVariableSync(test.TestCase, parameterized.TestCase):

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssign(self, distribution, experimental_run_tf_function):

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # assign in replica context with SUM does not make sense cause you can
      # just do value * num replicas error is 1. is not a distributed value and
      # is unsupported for aggregation SUM
      if (not cross_replica and aggregation ==
          variables_lib.VariableAggregation.SUM):
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      for component in v._values:
        self.assertAllEqual(self.evaluate(component.read_value()),
                            self.evaluate(array_ops.ones_like(component)))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):

    with distribution.scope():
      v_to_assign = variable_v1.VariableV1(
          2., aggregation=variables_lib.VariableAggregation.MEAN)
      v_to_assign_sub = variable_v1.VariableV1(
          -2., aggregation=variables_lib.VariableAggregation.MEAN)

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
               ("assign_sub", v_to_assign_sub)]
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # assign in replica context with SUM does not make sense cause you can
      # just do value * num replicas error is 1. is not a distributed value and
      # is unsupported for aggregation SUM
      if aggregation == variables_lib.VariableAggregation.SUM:
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      for component in v._values:
        self.assertAllEqual(2.0, self.evaluate(component.read_value()))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):

    if strategy_test_lib.is_tpu_strategy(distribution):
      self.skipTest("Assigning PerReplica values is not supported. See"
                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")

    with distribution.scope():
      per_replica_value = values.PerReplica(
          [constant_op.constant(2.0),
           constant_op.constant(2.0)])
      per_replica_sub_value = values.PerReplica(
          [constant_op.constant(-2.0),
           constant_op.constant(-2.0)])

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", per_replica_value), ("assign_add", per_replica_value),
               ("assign_sub", per_replica_sub_value)]
    # We don't support assigning PerReplica valus to vars in replica context
    # with aggregation=NONE.
    aggregations = [
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # assign in replica context with SUM does not make sense cause you can
      # just do value * num replicas error is 1. is not a distributed value and
      # is unsupported for aggregation SUM
      if cross_replica:
        # We don't support assigning PerReplica values to MirroredVariables in
        # cross replica context
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      if aggregation == variables_lib.VariableAggregation.SUM:
        expected = 4.0
      else:
        expected = 2.0
      for component in v._values:
        self.assertAllEqual(expected, self.evaluate(component.read_value()))

  @combinations.generate(strategy_with_var_policy())
  def testValueInReplicaContext(self, distribution):
    with distribution.scope():
      v = variables_lib.Variable(
          1., aggregation=variables_lib.VariableAggregation.MEAN)
      self.evaluate(variables_lib.global_variables_initializer())

      @def_function.function
      def f():
        with ops.control_dependencies([v.assign_add(1.)]):
          return v.value()

      results = self.evaluate(
          test_util.gather(distribution, distribution.run(f)))
      for value in results:
        self.assertEqual(2., value)

  @combinations.generate(strategy_with_var_policy())
  def testValueInReplicaContextAssignDirectValue(self, distribution,
                                                 use_var_policy):
    with distribution.scope():
      v = variables_lib.Variable(
          1., aggregation=variables_lib.VariableAggregation.MEAN)
      self.evaluate(variables_lib.global_variables_initializer())

      @def_function.function
      def f():
        with ops.control_dependencies([v.assign_add(1.)]):
          return v.value()

      results = self.evaluate(
          test_util.gather(distribution, distribution.run(f)))
      for value in results:
        self.assertEqual(2., value)

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testReadValueInReplicaContext(self, distribution,
                                    experimental_run_tf_function):
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    for aggregation in aggregations:
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      if experimental_run_tf_function:
        read_var_fn = def_function.function(v.read_value)
      else:
        read_var_fn = v.read_value
      results = self.evaluate(
          test_util.gather(distribution, distribution.run(read_var_fn)))
      for component, value in zip(v._values, results):
        self.assertAllEqual(self.evaluate(component.read_value()), value)

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testReadValueInCrossReplicaContext(self, distribution,
                                         experimental_run_tf_function):
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    for aggregation in aggregations:
      with distribution.scope():
        v = variable_v1.VariableV1(
            2.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())

      if experimental_run_tf_function:
        read_var_fn = def_function.function(v.read_value)
      else:
        read_var_fn = v.read_value

      results = read_var_fn()
      for component in v._values:
        self.assertEqual(self.evaluate(component.read_value()),
                         self.evaluate(results))

  @combinations.generate(strategy_with_var_policy())
  def testAssignOutOfScope(self, distribution):
    with distribution.scope():
      mirrored = variables_lib.Variable(1.)
    self.evaluate(mirrored.assign(3.))
    self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
    for component in mirrored.values:
      self.assertEqual(self.evaluate(component.read_value()), 3.)

  @combinations.generate(strategy_with_var_policy())
  def testInitializedToSameValueInsideEagerRun(self, distribution):
    if not context.executing_eagerly(): self.skipTest("eager only test")
    if isinstance(distribution.extended,
                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
      self.skipTest("Test for more than 1 device per worker only.")
    v = [None]

    @def_function.function
    def step():

      def f():
        if v[0] is None:
          v[0] = variables_lib.Variable(random_ops.random_normal([]))

      distribution.run(f)

    context.set_global_seed(None)
    step()
    vals = self.evaluate(v[0].values)
    self.assertAllEqual(vals[0], vals[1])

  @combinations.generate(strategy_with_var_policy())
  def testAggregationOnlyFirstReplica(self, distribution):
    if isinstance(distribution.extended,
                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
      self.skipTest("b/212945803")
    with distribution.scope():
      v = variable_v1.VariableV1(
          15.,
          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
    self.evaluate(variables_lib.global_variables_initializer())

    @def_function.function
    def assign():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      return v.assign(math_ops.cast(replica_id, dtypes.float32))

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(assign)))
    # The per-replica values should always match the first replicas value.
    self.assertAllEqual(
        array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
        per_replica_results)

  @combinations.generate(strategy_with_var_policy())
  def testInitScope(self, distribution):
    if not context.executing_eagerly(): self.skipTest("eager only")

    class C(object):
      pass

    obj = C()
    obj.w = None
    obj.v = None

    @def_function.function
    def assign():
      with ops.init_scope():
        if obj.w is None:
          obj.w = variables_lib.Variable(
              0., aggregation=variables_lib.VariableAggregation.MEAN)
          obj.v = variables_lib.Variable(
              obj.w.read_value(),
              aggregation=variables_lib.VariableAggregation.MEAN)
          self.evaluate(variables_lib.global_variables_initializer())

      return obj.v.assign_add(2.)

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(assign)))
    self.assertAllEqual([2., 2.], per_replica_results)

  @combinations.generate(strategy_with_var_policy())
  def testOperatorOverride(self, distribution):

    if not context.executing_eagerly() and isinstance(
        distribution.extended,
        collective_all_reduce_strategy.CollectiveAllReduceExtended):
      self.skipTest("b/212954197")

    with distribution.scope():
      v = variable_v1.VariableV1(
          1, aggregation=variables_lib.VariableAggregation.SUM)
      self.evaluate(variables_lib.global_variables_initializer())

    self.assertEqual(2, self.evaluate(v + 1))

    @def_function.function
    def add():
      return v + 1

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(add)))
    self.assertAllEqual([2, 2], per_replica_results)

  @combinations.generate(
      combinations.combine(
          strategy=[
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
              strategy_combinations.tpu_strategy,
              strategy_combinations.tpu_strategy_packed_var,
              strategy_combinations.multi_worker_mirrored_2x1_cpu,
              strategy_combinations.multi_worker_mirrored_2x1_gpu,
          ],
          mode=["eager"],
          use_var_policy=[True, False]))
  def testSaveAndRestoreOnWrite(self, strategy):
    aggregation = [
        variable_scope.VariableAggregation.NONE,
        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA,
        variable_scope.VariableAggregation.SUM,
        variable_scope.VariableAggregation.MEAN
    ]
    for agg in aggregation:
      v_normal_restore = variables_lib.Variable(1.0)
      v_normal_save = variables_lib.Variable(3.0)
      with strategy.scope():
        v_on_write = variables_lib.Variable(2.0, aggregation=agg)

        # Save ONWRITE Restore ONWRITE
        # Save
        ckpt = trackable_utils.Checkpoint(var=v_on_write)
        manager = ckpt_manager.CheckpointManager(
            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
        manager.save()
        # Restore
        ckpt.restore(manager.latest_checkpoint)
        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
        self.assertEqual(2.0, self.evaluate(v_on_write.read_value()))

        # Save Mirrored Restore Normal
        # We've already saved Mirrored, so we only need to restore normal
        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
        ckpt_normal.restore(manager.latest_checkpoint)
        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
        self.assertEqual(2.0, self.evaluate(v_normal_restore.read_value()))

        # Save Normal Restore Mirrored
        # Save
        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
        manager_2 = ckpt_manager.CheckpointManager(
            ckpt, "/tmp/ckptckpt_" + str(uuid.uuid4()), max_to_keep=None)
        manager_2.save()
        # Restore
        ckpt_on_write = trackable_utils.Checkpoint(var=v_on_write)
        ckpt_on_write.restore(manager_2.latest_checkpoint)
        self.assertEqual(3.0, self.evaluate(v_on_write._values[0]))
        self.assertEqual(3.0, self.evaluate(v_on_write.read_value()))


ms_combination = combinations.combine(
    distribution=[strategy_combinations.mirrored_strategy_with_gpu_and_cpu],
    mode=["graph", "eager"])
tpu_combination = combinations.combine(
    distribution=[strategy_combinations.tpu_strategy_packed_var],
    mode=["graph", "eager"])


class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):

  @combinations.generate(ms_combination)
  def testScatterSub(self, distribution):
    with distribution.scope():
      v = variables_lib.Variable(
          [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
    self.evaluate(v.initializer)

    @def_function.function
    def scatter_sub():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      value = indexed_slices.IndexedSlices(
          values=array_ops_stack.stack([
              math_ops.cast(replica_id, dtypes.float32),
              math_ops.cast(replica_id + 1, dtypes.float32)
          ]),
          indices=array_ops_stack.stack([replica_id, replica_id + 1]),
          dense_shape=(3,))
      return v.scatter_sub(value)

    per_replica_results = self.evaluate(
        distribution.experimental_local_results(
            distribution.run(scatter_sub)))
    self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterAdd(self, distribution):
    with distribution.scope():
      v = variables_lib.Variable(
          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
    self.evaluate(v.initializer)

    @def_function.function
    def scatter_add():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      value = indexed_slices.IndexedSlices(
          values=array_ops_stack.stack([replica_id, replica_id + 1]),
          indices=array_ops_stack.stack([replica_id, replica_id + 1]),
          dense_shape=(3,))
      return v.scatter_add(value)

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(scatter_add)))
    self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterDiv(self, distribution):
    with distribution.scope():
      v = variables_lib.Variable(
          [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
    self.evaluate(v.initializer)

    @def_function.function
    def scatter_div():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      value = indexed_slices.IndexedSlices(
          values=array_ops.reshape(replica_id + 2, [1]),
          indices=array_ops.reshape(replica_id, [1]),
          dense_shape=(3,))
      return v.scatter_div(value)

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(scatter_div)))
    self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterMul(self, distribution):
    with distribution.scope():
      v = variables_lib.Variable(
          [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
    self.evaluate(v.initializer)

    @def_function.function
    def scatter_mul():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      value = indexed_slices.IndexedSlices(
          values=array_ops.reshape(
              math_ops.cast(replica_id + 2, dtypes.float32), [1]),
          indices=array_ops.reshape(replica_id, [1]),
          dense_shape=(3,))
      return v.scatter_mul(value)

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(scatter_mul)))
    self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterMin(self, distribution):
    with distribution.scope():
      v1 = variables_lib.Variable(
          [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
      v2 = variables_lib.Variable(
          [0, 2, 0],
          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
    self.evaluate(variables_lib.global_variables_initializer())

    @def_function.function
    def scatter_min(v):
      value = indexed_slices.IndexedSlices(
          values=array_ops.identity([1]),
          indices=array_ops.identity([1]),
          dense_shape=(3,))
      return v.scatter_min(value)

    with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
      self.evaluate(
          test_util.gather(distribution,
                           distribution.run(scatter_min, args=(v1,))))

    per_replica_results = self.evaluate(
        test_util.gather(distribution,
                         distribution.run(scatter_min, args=(v2,))))
    self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterMax(self, distribution):
    with distribution.scope():
      v1 = variables_lib.Variable(
          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
      v2 = variables_lib.Variable(
          [0, 0, 0],
          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
    self.evaluate(variables_lib.global_variables_initializer())

    @def_function.function
    def scatter_max(v):
      value = indexed_slices.IndexedSlices(
          values=array_ops.identity([1]),
          indices=array_ops.identity([0]),
          dense_shape=(3,))
      return v.scatter_max(value)

    with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
      self.evaluate(
          test_util.gather(distribution,
                           distribution.run(scatter_max, args=(v1,))))

    per_replica_results = self.evaluate(
        test_util.gather(distribution,
                         distribution.run(scatter_max, args=(v2,))))
    self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)

  @combinations.generate(ms_combination)
  def testScatterUpdate(self, distribution):
    with distribution.scope():
      v1 = variables_lib.Variable(
          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
      v2 = variables_lib.Variable(
          [0, 0, 0],
          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
    self.evaluate(variables_lib.global_variables_initializer())

    @def_function.function
    def scatter_update(v):
      value = indexed_slices.IndexedSlices(
          values=array_ops.identity([3]),
          indices=array_ops.identity([1]),
          dense_shape=(3,))
      return v.scatter_update(value)

    with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
      self.evaluate(
          test_util.gather(distribution,
                           distribution.run(scatter_update, args=(v1,))))

    per_replica_results = self.evaluate(
        test_util.gather(distribution,
                         distribution.run(scatter_update, args=(v2,))))
    self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)

  @combinations.generate(ms_combination + tpu_combination)
  def testScatterOpsWithNoneAggregation(self, distribution):

    def assert_close(v, op, delta, expect):
      scatter_op = getattr(v, op)

      @def_function.function
      def scatter_xxx():
        return scatter_op(delta)

      per_replica_results = self.evaluate(
          variable_utils.convert_variables_to_tensors(
              distribution.experimental_local_results(
                  distribution.run(scatter_xxx))))
      self.assertAllClose([expect, expect], per_replica_results)

    with distribution.scope():
      v = variables_lib.Variable(
          [4.], aggregation=variables_lib.VariableAggregation.NONE)
    self.evaluate(variables_lib.global_variables_initializer())

    delta = indexed_slices.IndexedSlices(
        values=array_ops.identity([2.]),
        indices=array_ops.identity([0]),
        dense_shape=(1,))

    assert_close(v, "scatter_sub", delta, [2.])
    assert_close(v, "scatter_add", delta, [4.])
    assert_close(v, "scatter_max", delta, [4.])
    assert_close(v, "scatter_min", delta, [2.])
    assert_close(v, "scatter_mul", delta, [4.])
    assert_close(v, "scatter_div", delta, [2.])
    assert_close(v, "scatter_update", delta, [2.])

  @combinations.generate(ms_combination + tpu_combination)
  def testScatterOpsInCrossReplicaContext(self, distribution):
    with distribution.scope():
      v1 = variables_lib.Variable(
          [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
      v2 = variables_lib.Variable([1, 1, 1])
    self.evaluate(variables_lib.global_variables_initializer())

    value = indexed_slices.IndexedSlices(
        values=array_ops.identity([2]),
        indices=array_ops.identity([0]),
        dense_shape=(3,))
    with distribution.scope():
      self.evaluate(v1.scatter_add(value))
      self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))

      self.evaluate(v2.scatter_min(value))
      self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))


class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssign(self, distribution, experimental_run_tf_function):

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # VariableAggregation.SUM in cross-replica mode is tested below,
      # VariableAggregation.NONE in cross-replica mode is not supported.
      if cross_replica and aggregation in [
          variables_lib.VariableAggregation.SUM,
          variables_lib.VariableAggregation.NONE,
      ]:
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      for component in v._values:
        self.assertAllEqual(self.evaluate(component.read_value()),
                            self.evaluate(array_ops.ones_like(component)))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignOnReadVar(self, distribution, experimental_run_tf_function):

    with distribution.scope():
      v_to_assign = variable_v1.VariableV1(
          2., aggregation=variables_lib.VariableAggregation.MEAN)
      v_to_assign_sub = variable_v1.VariableV1(
          -2., aggregation=variables_lib.VariableAggregation.MEAN)

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
               ("assign_sub", v_to_assign_sub)]
    expected_cross_replica = {
        variables_lib.VariableAggregation.SUM: 1.0,
        variables_lib.VariableAggregation.MEAN: 2.0,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
    }
    expected_replica = {
        variables_lib.VariableAggregation.SUM: 2.0,
        variables_lib.VariableAggregation.MEAN: 2.0,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
    }
    # aggregation=NONE is not supported for OnReadVariables.
    aggregations = [
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # assign in replica context with SUM does not make sense cause you can
      # just do value * num replicas error is 1. is not a distributed value and
      # is unsupported for aggregation SUM
      if aggregation == variables_lib.VariableAggregation.SUM:
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      if cross_replica:
        for component in v._values:
          self.assertAllEqual(expected_cross_replica.get(aggregation),
                              self.evaluate(component.read_value()))
      else:
        for component in v._values:
          self.assertAllEqual(expected_replica.get(aggregation),
                              self.evaluate(component.read_value()))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):

    if strategy_test_lib.is_tpu_strategy(distribution):
      self.skipTest("Assigning PerReplica values is not supported. See"
                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")

    self.skipTest("We don't support assiging PerReplica values in cross "
                  "replica context or replica context. see error in "
                  "sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.")

    with distribution.scope():
      per_replica_value = values.PerReplica(
          [constant_op.constant(2.0),
           constant_op.constant(2.0)])

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", per_replica_value)]
    # We don't support assigning PerReplica valus to vars in replica context
    # with aggregation=NONE.
    aggregations = [
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # assign in replica context with SUM does not make sense cause you can
      # just do value * num replicas error is 1. is not a distributed value and
      # is unsupported for aggregation SUM
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      # with self.assertRaisesRegex(ValueError, "Attempt to convert a value "):
      self.evaluate(assign(fn, v, update_value, cross_replica))
      if aggregation == variables_lib.VariableAggregation.SUM:
        expected = 4.0
      else:
        expected = 2.0
      for component in v._values:
        self.assertAllEqual(expected, self.evaluate(component.read_value()))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignDtypeConversion(self, distribution,
                                experimental_run_tf_function):

    def assign(fn, v, update_value, cross_replica):
      update_fn = lambda: getattr(v, fn)(update_value)
      if cross_replica:
        return update_fn()
      else:
        if experimental_run_tf_function:
          update_fn = def_function.function(update_fn)
        return test_util.gather(distribution, distribution.run(update_fn))

    updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    options = list(
        x for x in itertools.product(updates, aggregations, [True, False]))
    for update, aggregation, cross_replica in options:
      # VariableAggregation.SUM in cross-replica mode is tested below,
      # VariableAggregation.NONE in cross-replica mode is not supported.
      if cross_replica and aggregation in [
          variables_lib.VariableAggregation.SUM,
          variables_lib.VariableAggregation.NONE,
      ]:
        continue
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      fn, update_value = update
      self.evaluate(assign(fn, v, update_value, cross_replica))
      for component in v._values:
        self.assertAllEqual(self.evaluate(component.read_value()),
                            self.evaluate(array_ops.ones_like(component)))

  @combinations.generate(strategy_with_var_policy())
  def testAssignWithAggregationSum(self, distribution):
    with distribution.scope():
      v = variable_v1.VariableV1(
          0.,
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=variables_lib.VariableAggregation.SUM)
    self.evaluate(variables_lib.global_variables_initializer())
    self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
    for component in v._values:
      self.assertAllEqual(self.evaluate(component.read_value()),
                          self.evaluate(array_ops.ones_like(component)))

  @combinations.generate(strategy_with_var_policy())
  def testAssignAddSubWithAggregationSum(self, distribution):
    with distribution.scope():
      v = variable_v1.VariableV1(
          0.,
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=variables_lib.VariableAggregation.SUM)
    self.evaluate(variables_lib.global_variables_initializer())
    with self.assertRaisesRegex(
        ValueError, "SyncOnReadVariable does not support "):
      self.evaluate(v.assign_add(1.))
    with self.assertRaisesRegex(
        ValueError, "SyncOnReadVariable does not support "):
      self.evaluate(v.assign_sub(1.))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testReadValueInReplicaContext(self, distribution,
                                    experimental_run_tf_function):
    aggregations = [
        variables_lib.VariableAggregation.NONE,
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    for aggregation in aggregations:
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())
      if experimental_run_tf_function:
        read_var_fn = def_function.function(v.read_value)
      else:
        read_var_fn = v.read_value
      results = self.evaluate(
          test_util.gather(distribution, distribution.run(read_var_fn)))
      for component, value in zip(v._values, results):
        self.assertAllEqual(self.evaluate(component.read_value()), value)

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testReadValueInCrossReplicaContext(self, distribution,
                                         experimental_run_tf_function):
    aggregations = [
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    for aggregation in aggregations:
      if strategy_test_lib.is_tpu_strategy(distribution):
        resolver = tpu_cluster_resolver.TPUClusterResolver("")
        tpu_cluster_resolver.initialize_tpu_system(resolver)
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())

      def assign(v=v):
        ctx = distribute_lib.get_replica_context()
        replica_id = ctx.replica_id_in_sync_group
        return v.assign(math_ops.cast(replica_id, dtypes.float32))

      if experimental_run_tf_function:
        assign = def_function.function(assign)

      self.evaluate(test_util.gather(distribution, distribution.run(assign)))
      num_replicas = distribution.num_replicas_in_sync
      sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
      if aggregation == variables_lib.VariableAggregation.SUM:
        expected = sum_of_replica_values
      elif aggregation == variables_lib.VariableAggregation.MEAN:
        expected = sum_of_replica_values / num_replicas
      else:
        expected = 0
      self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
      self.assertEqual(expected, self.evaluate(v.value()), aggregation)
      self.assertEqual(expected, self.evaluate(v), aggregation)
      self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
                       aggregation)

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAllReduce(self, distribution, experimental_run_tf_function):
    with distribution.scope():
      v = variable_v1.VariableV1(
          2.,
          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
          aggregation=variables_lib.VariableAggregation.MEAN)
    self.evaluate(variables_lib.global_variables_initializer())

    def all_reduce():
      ctx = distribute_lib.get_replica_context()
      replica_id = ctx.replica_id_in_sync_group
      return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
                                                      dtypes.float32)

    if experimental_run_tf_function:
      all_reduce = def_function.function(all_reduce)

    per_replica_results = self.evaluate(
        test_util.gather(distribution, distribution.run(all_reduce)))
    expected_result = []
    for i in range(distribution.num_replicas_in_sync):
      expected_result.append(2.0 * distribution.num_replicas_in_sync +
                             1.0 * i)
    self.assertAllEqual(per_replica_results, tuple(expected_result))

  @combinations.generate(strategy_and_run_tf_function_combinations())
  def testAssignPerReplicaBeforeRead(self, distribution,
                                     experimental_run_tf_function):
    aggregations = [
        variables_lib.VariableAggregation.SUM,
        variables_lib.VariableAggregation.MEAN,
        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
    ]
    for aggregation in aggregations:
      with distribution.scope():
        v = variable_v1.VariableV1(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_READ,
            aggregation=aggregation)
      self.evaluate(variables_lib.global_variables_initializer())

      def assign(var=v):
        ctx = distribute_lib.get_replica_context()
        replica_id = ctx.replica_id_in_sync_group
        return var.assign(math_ops.cast(replica_id, dtypes.float32))

      if experimental_run_tf_function:
        assign = def_function.function(assign)

      per_replica_results = self.evaluate(
          test_util.gather(distribution, distribution.run(assign)))
      expected_result = []
      for i in range(distribution.num_replicas_in_sync):
        expected_result.append(1.0 * i)
      self.assertAllEqual(per_replica_results, tuple(expected_result))

  @combinations.generate(strategy_with_var_policy())
  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
    with distribution.scope():
      v = variable_v1.VariableV1(
          0.,
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=variables_lib.VariableAggregation.NONE)
    self.evaluate(variables_lib.global_variables_initializer())
    with self.assertRaisesRegex(
        ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
      self.evaluate(v.read_value())

  @combinations.generate(strategy_with_var_policy())
  def testInitializedToSameValueInsideEagerRun(self, distribution):
    if not context.executing_eagerly(): self.skipTest("eager only")
    if isinstance(distribution.extended,
                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
      self.skipTest("Test for more than 1 device per worker only.")

    v = [None]
    @def_function.function
    def step():
      def f():
        if v[0] is None:
          v[0] = variables_lib.Variable(
              random_ops.random_normal([]),
              synchronization=variables_lib.VariableSynchronization.ON_READ)

      distribution.run(f)

    context.set_global_seed(None)
    step()
    vals = self.evaluate(v[0].values)
    self.assertAllEqual(vals[0], vals[1])

  @combinations.generate(strategy_with_var_policy())
  def testOperatorOverride(self, distribution):

    with distribution.scope():
      v = variable_v1.VariableV1(
          0.0,
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=variables_lib.VariableAggregation.MEAN)
      self.evaluate(variables_lib.global_variables_initializer())

      @def_function.function
      def assign():
        ctx = distribute_lib.get_replica_context()
        replica_id = ctx.replica_id_in_sync_group
        return v.assign(math_ops.cast(replica_id, dtypes.float32))

      # Assign different replicas with different values.
      self.evaluate(test_util.gather(distribution, distribution.run(assign)))
      self.assertEqual(1.5, self.evaluate(v + 1))

      @def_function.function
      def add():
        return v + 1

      per_replica_results = self.evaluate(
          test_util.gather(distribution, distribution.run(add)))
      self.assertAllEqual([1, 2], per_replica_results)

  @combinations.generate(
      combinations.combine(
          strategy=[
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
              strategy_combinations.tpu_strategy,
              strategy_combinations.tpu_strategy_packed_var,
              strategy_combinations.multi_worker_mirrored_2x1_cpu,
              strategy_combinations.multi_worker_mirrored_2x1_gpu,
          ],
          mode=["eager"],
          use_var_policy=[True, False]))
  def testSaveAndRestoreOnRead(self, strategy):
    aggregation = [variable_scope.VariableAggregation.SUM,
                   variable_scope.VariableAggregation.MEAN]
    for agg in aggregation:
      v_normal_restore = variables_lib.Variable(1.0)
      v_normal_save = variables_lib.Variable(2.0)

      with strategy.scope():
        v_on_read = variables_lib.Variable(
            1.0, synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=agg)

        @def_function.function
        def assign_fn():
          cluster_resolver = strategy.cluster_resolver
          replica_ctx = distribute_lib.get_replica_context()
          if ((cluster_resolver and cluster_resolver.task_type == "worker") or
              math_ops.equal(replica_ctx.replica_id_in_sync_group,
                             constant_op.constant(1))):
            v_on_read.assign(3.)  # pylint:disable=cell-var-from-loop
          else:
            v_on_read.assign(4.)  # pylint:disable=cell-var-from-loop

        strategy.run(assign_fn)

        # Save ONREAD, restore ONREAD
        # Saves v[0] + v[1] = 7 for SUM and 3.5 for MEAN.
        ckpt = trackable_utils.Checkpoint(var=v_on_read)
        manager = ckpt_manager.CheckpointManager(
            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
        manager.save()
        # Restores a value of 7/2 = 3.5 for SUM and 3.5 for MEAN.
        ckpt.restore(manager.latest_checkpoint)
        self.assertEqual(3.5, self.evaluate(v_on_read._values[0]))

        # Save ONREAD, restore normal
        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
        ckpt_normal.restore(manager.latest_checkpoint)
        if agg == variable_scope.VariableAggregation.SUM:
          self.assertEqual(7.0, self.evaluate(v_normal_restore.read_value()))
        else:
          self.assertEqual(3.5, self.evaluate(v_normal_restore.read_value()))

        # Save normal, restore ONREAD
        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
        manager = ckpt_manager.CheckpointManager(
            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
        manager.save()
        # Restores a value of 2/2 = 1.0 for SUM and 2.0 for MEAN.
        ckpt_on_read = trackable_utils.Checkpoint(var=v_on_read)
        ckpt_on_read.restore(manager.latest_checkpoint)
        if agg == variable_scope.VariableAggregation.SUM:
          self.assertEqual(1.0, self.evaluate(v_on_read._values[0]))
        else:
          self.assertEqual(2.0, self.evaluate(v_on_read._values[0]))


@combinations.generate(
    combinations.combine(
        distribution=[
            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
            strategy_combinations.multi_worker_mirrored_2x1_cpu,
            strategy_combinations.multi_worker_mirrored_2x1_gpu,
        ],
        aggregation=[
            variables_lib.VariableAggregation.MEAN,
            variables_lib.VariableAggregation.SUM,
            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
        ],
        mode=["graph", "eager"],
        use_var_policy=[True, False]))
class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):

  def testScatterSub(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [1., 1., 1.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))

  def testScatterAdd(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [1., 1., 1.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_add, args=(delta,)))

  def testScatterDiv(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [2., 6., 1.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_div, args=(delta,)))

  def testScatterMul(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [2., 1., 1.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))

  def testScatterMin(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [3., 4., 5.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))

  def testScatterMax(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [3., 4., 5.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_max, args=(delta,)))

  def testScatterUpdate(self, distribution, aggregation):
    with distribution.scope():
      v = variables_lib.Variable(
          [0., 0., 0.],
          synchronization=variables_lib.VariableSynchronization.ON_READ,
          aggregation=aggregation)
    self.evaluate(v.initializer)

    delta = values.PerReplica([
        indexed_slices.IndexedSlices(
            values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
        indexed_slices.IndexedSlices(
            values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
    ])

    with self.assertRaises(NotImplementedError):
      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))


if __name__ == "__main__":
  test_util.main()