with context.graph_mode(), distribution.scope():
      result = distribution.extended.call_for_each_replica(model_fn)
      self.assertEqual(2, len(result))
      for v, name in zip(result, ["a", "b"]):
        self.assertIsInstance(v, values.DistributedValues)