tensorflow/tensorflow

View on GitHub
tensorflow/python/kernel_tests/metrics_test.py

Summary

Maintainability
F
5 mos
Test Coverage
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metrics."""

import functools
import math

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
from tensorflow.python.platform import test

NAN = float('nan')


def _enqueue_vector(sess, queue, values, shape=None):
  if not shape:
    shape = (1, len(values))
  dtype = queue.dtypes[0]
  sess.run(
      queue.enqueue(constant_op.constant(
          values, dtype=dtype, shape=shape)))


def _binary_2d_label_to_2d_sparse_value(labels):
  """Convert dense 2D binary indicator to sparse ID.

  Only 1 values in `labels` are included in result.

  Args:
    labels: Dense 2D binary indicator, shape [batch_size, num_classes].

  Returns:
    `SparseTensorValue` of shape [batch_size, num_classes], where num_classes
    is the number of `1` values in each row of `labels`. Values are indices
    of `1` values along the last dimension of `labels`.
  """
  indices = []
  values = []
  batch = 0
  for row in labels:
    label = 0
    xi = 0
    for x in row:
      if x == 1:
        indices.append([batch, xi])
        values.append(label)
        xi += 1
      else:
        assert x == 0
      label += 1
    batch += 1
  shape = [len(labels), len(labels[0])]
  return sparse_tensor.SparseTensorValue(
      np.array(indices, np.int64),
      np.array(values, np.int64), np.array(shape, np.int64))


def _binary_2d_label_to_1d_sparse_value(labels):
  """Convert dense 2D binary indicator to sparse ID.

  Only 1 values in `labels` are included in result.

  Args:
    labels: Dense 2D binary indicator, shape [batch_size, num_classes]. Each
    row must contain exactly 1 `1` value.

  Returns:
    `SparseTensorValue` of shape [batch_size]. Values are indices of `1` values
    along the last dimension of `labels`.

  Raises:
    ValueError: if there is not exactly 1 `1` value per row of `labels`.
  """
  indices = []
  values = []
  batch = 0
  for row in labels:
    label = 0
    xi = 0
    for x in row:
      if x == 1:
        indices.append([batch])
        values.append(label)
        xi += 1
      else:
        assert x == 0
      label += 1
    batch += 1
  if indices != [[i] for i in range(len(labels))]:
    raise ValueError('Expected 1 label/example, got %s.' % indices)
  shape = [len(labels)]
  return sparse_tensor.SparseTensorValue(
      np.array(indices, np.int64),
      np.array(values, np.int64), np.array(shape, np.int64))


def _binary_3d_label_to_sparse_value(labels):
  """Convert dense 3D binary indicator tensor to sparse tensor.

  Only 1 values in `labels` are included in result.

  Args:
    labels: Dense 2D binary indicator tensor.

  Returns:
    `SparseTensorValue` whose values are indices along the last dimension of
    `labels`.
  """
  indices = []
  values = []
  for d0, labels_d0 in enumerate(labels):
    for d1, labels_d1 in enumerate(labels_d0):
      d2 = 0
      for class_id, label in enumerate(labels_d1):
        if label == 1:
          values.append(class_id)
          indices.append([d0, d1, d2])
          d2 += 1
        else:
          assert label == 0
  shape = [len(labels), len(labels[0]), len(labels[0][0])]
  return sparse_tensor.SparseTensorValue(
      np.array(indices, np.int64),
      np.array(values, np.int64), np.array(shape, np.int64))


def _assert_nan(test_case, actual):
  test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual)


def _assert_metric_variables(test_case, expected):
  test_case.assertEqual(
      set(expected), set(v.name for v in variables.local_variables()))
  test_case.assertEqual(
      set(expected),
      set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))


def _test_values(shape):
  return np.reshape(np.cumsum(np.ones(shape)), newshape=shape)


class MeanTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean(array_ops.ones([4, 3]))
    _assert_metric_variables(self, ('mean/count:0', 'mean/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean(
        array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean(
        array_ops.ones([4, 3]), updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testBasic(self):
    with self.cached_session() as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      mean, update_op = metrics.mean(values)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAlmostEqual(1.65, self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testUpdateOpsReturnsCurrentValue(self):
    with self.cached_session() as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      mean, update_op = metrics.mean(values)

      self.evaluate(variables.local_variables_initializer())

      self.assertAlmostEqual(0.5, self.evaluate(update_op), 5)
      self.assertAlmostEqual(1.475, self.evaluate(update_op), 5)
      self.assertAlmostEqual(12.4 / 6.0, self.evaluate(update_op), 5)
      self.assertAlmostEqual(1.65, self.evaluate(update_op), 5)

      self.assertAlmostEqual(1.65, self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    values = _test_values((3, 2, 4, 1))
    mean_results = (
        metrics.mean(values),
        metrics.mean(values, weights=1.0),
        metrics.mean(values, weights=np.ones((1, 1, 1))),
        metrics.mean(values, weights=np.ones((1, 1, 1, 1))),
        metrics.mean(values, weights=np.ones((1, 1, 1, 1, 1))),
        metrics.mean(values, weights=np.ones((1, 1, 4))),
        metrics.mean(values, weights=np.ones((1, 1, 4, 1))),
        metrics.mean(values, weights=np.ones((1, 2, 1))),
        metrics.mean(values, weights=np.ones((1, 2, 1, 1))),
        metrics.mean(values, weights=np.ones((1, 2, 4))),
        metrics.mean(values, weights=np.ones((1, 2, 4, 1))),
        metrics.mean(values, weights=np.ones((3, 1, 1))),
        metrics.mean(values, weights=np.ones((3, 1, 1, 1))),
        metrics.mean(values, weights=np.ones((3, 1, 4))),
        metrics.mean(values, weights=np.ones((3, 1, 4, 1))),
        metrics.mean(values, weights=np.ones((3, 2, 1))),
        metrics.mean(values, weights=np.ones((3, 2, 1, 1))),
        metrics.mean(values, weights=np.ones((3, 2, 4))),
        metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
        metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
    expected = np.mean(values)
    with self.cached_session():
      variables.local_variables_initializer().run()
      for mean_result in mean_results:
        mean, update_op = mean_result
        self.assertAlmostEqual(expected, self.evaluate(update_op))
        self.assertAlmostEqual(expected, self.evaluate(mean))

  def _test_3d_weighted(self, values, weights):
    expected = (
        np.sum(np.multiply(weights, values)) /
        np.sum(np.multiply(weights, np.ones_like(values)))
    )
    mean, update_op = metrics.mean(values, weights=weights)
    with self.cached_session():
      variables.local_variables_initializer().run()
      self.assertAlmostEqual(expected, self.evaluate(update_op), places=5)
      self.assertAlmostEqual(expected, self.evaluate(mean), places=5)

  @test_util.run_deprecated_v1
  def test1x1x1Weighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((5,)).reshape((1, 1, 1)))

  @test_util.run_deprecated_v1
  def test1x1xNWeighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4)))

  @test_util.run_deprecated_v1
  def test1xNx1Weighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((5, 11)).reshape((1, 2, 1)))

  @test_util.run_deprecated_v1
  def test1xNxNWeighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4)))

  @test_util.run_deprecated_v1
  def testNx1x1Weighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((5, 7, 11)).reshape((3, 1, 1)))

  @test_util.run_deprecated_v1
  def testNx1xNWeighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((
            5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4)))

  @test_util.run_deprecated_v1
  def testNxNxNWeighted(self):
    self._test_3d_weighted(
        _test_values((3, 2, 4)),
        weights=np.asarray((
            5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3,
            2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4)))

  @test_util.run_deprecated_v1
  def testInvalidWeights(self):
    values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32)
    values = _test_values((3, 2, 4, 1))
    invalid_weights = (
        (1,),
        (1, 1),
        (3, 2),
        (2, 4, 1),
        (4, 2, 4, 1),
        (3, 3, 4, 1),
        (3, 2, 5, 1),
        (3, 2, 4, 2),
        (1, 1, 1, 1, 1))
    expected_error_msg = 'weights can not be broadcast to values'
    for invalid_weight in invalid_weights:
      # Static shapes.
      with self.assertRaisesRegex(ValueError, expected_error_msg):
        metrics.mean(values, invalid_weight)

      # Dynamic shapes.
      with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg):
        with self.cached_session():
          _, update_op = metrics.mean(values_placeholder, invalid_weight)
          variables.local_variables_initializer().run()
          update_op.eval(feed_dict={values_placeholder: values})


class MeanTensorTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_tensor(array_ops.ones([4, 3]))
    _assert_metric_variables(self,
                             ('mean/total_tensor:0', 'mean/count_tensor:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean_tensor(
        array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_tensor(
        array_ops.ones([4, 3]), updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testBasic(self):
    with self.cached_session() as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean))

  @test_util.run_deprecated_v1
  def testMultiDimensional(self):
    with self.cached_session() as sess:
      values_queue = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
      _enqueue_vector(
          sess,
          values_queue, [[[1, 2], [1, 2]], [[1, 2], [1, 2]]],
          shape=(2, 2, 2))
      _enqueue_vector(
          sess,
          values_queue, [[[1, 2], [1, 2]], [[3, 4], [9, 10]]],
          shape=(2, 2, 2))
      values = values_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(2):
        self.evaluate(update_op)
      self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]],
                          self.evaluate(mean))

  @test_util.run_deprecated_v1
  def testUpdateOpsReturnsCurrentValue(self):
    with self.cached_session() as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values)

      self.evaluate(variables.local_variables_initializer())

      self.assertAllClose([[0, 1]], self.evaluate(update_op), 5)
      self.assertAllClose([[-2.1, 5.05]], self.evaluate(update_op), 5)
      self.assertAllClose([[2.3 / 3., 10.1 / 3.]], self.evaluate(update_op), 5)
      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(update_op), 5)

      self.assertAllClose([[-0.9 / 4., 3.525]], self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testBinaryWeighted1d(self):
    with self.cached_session() as sess:
      # Create the queue that populates the values.
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, weights_queue, [[1]])
      _enqueue_vector(sess, weights_queue, [[0]])
      _enqueue_vector(sess, weights_queue, [[1]])
      _enqueue_vector(sess, weights_queue, [[0]])
      weights = weights_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values, weights)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAllClose([[3.25, 0.5]], self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testWeighted1d(self):
    with self.cached_session() as sess:
      # Create the queue that populates the values.
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, weights_queue, [[0.0025]])
      _enqueue_vector(sess, weights_queue, [[0.005]])
      _enqueue_vector(sess, weights_queue, [[0.01]])
      _enqueue_vector(sess, weights_queue, [[0.0075]])
      weights = weights_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values, weights)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAllClose([[0.8, 3.52]], self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testWeighted2d_1(self):
    with self.cached_session() as sess:
      # Create the queue that populates the values.
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, weights_queue, [1, 1])
      _enqueue_vector(sess, weights_queue, [1, 0])
      _enqueue_vector(sess, weights_queue, [0, 1])
      _enqueue_vector(sess, weights_queue, [0, 0])
      weights = weights_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values, weights)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAllClose([[-2.1, 0.5]], self.evaluate(mean), 5)

  @test_util.run_deprecated_v1
  def testWeighted2d_2(self):
    with self.cached_session() as sess:
      # Create the queue that populates the values.
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
      _enqueue_vector(sess, weights_queue, [0, 1])
      _enqueue_vector(sess, weights_queue, [0, 0])
      _enqueue_vector(sess, weights_queue, [0, 1])
      _enqueue_vector(sess, weights_queue, [0, 0])
      weights = weights_queue.dequeue()

      mean, update_op = metrics.mean_tensor(values, weights)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(4):
        self.evaluate(update_op)
      self.assertAllClose([[0, 0.5]], self.evaluate(mean), 5)


class AccuracyTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.accuracy(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        name='my_accuracy')
    _assert_metric_variables(self,
                             ('my_accuracy/count:0', 'my_accuracy/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.accuracy(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.accuracy(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones((10, 3))
    labels = array_ops.ones((10, 4))
    with self.assertRaises(ValueError):
      metrics.accuracy(labels, predictions)

  @test_util.run_deprecated_v1
  def testPredictionsAndWeightsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones((10, 3))
    labels = array_ops.ones((10, 3))
    weights = array_ops.ones((9, 3))
    with self.assertRaises(ValueError):
      metrics.accuracy(labels, predictions, weights)

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
    accuracy, update_op = metrics.accuracy(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_accuracy = self.evaluate(accuracy)
      for _ in range(10):
        self.assertEqual(initial_accuracy, self.evaluate(accuracy))

  @test_util.run_deprecated_v1
  def testMultipleUpdates(self):
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [2])
      _enqueue_vector(sess, preds_queue, [1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [2])
      labels = labels_queue.dequeue()

      accuracy, update_op = metrics.accuracy(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(3):
        self.evaluate(update_op)
      self.assertEqual(0.5, self.evaluate(update_op))
      self.assertEqual(0.5, self.evaluate(accuracy))

  @test_util.run_deprecated_v1
  def testEffectivelyEquivalentSizes(self):
    predictions = array_ops.ones((40, 1))
    labels = array_ops.ones((40,))
    with self.cached_session():
      accuracy, update_op = metrics.accuracy(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1.0, self.evaluate(update_op))
      self.assertEqual(1.0, self.evaluate(accuracy))

  @test_util.run_deprecated_v1
  def testEffectivelyEquivalentSizesWithScalarWeight(self):
    predictions = array_ops.ones((40, 1))
    labels = array_ops.ones((40,))
    with self.cached_session():
      accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)

      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1.0, self.evaluate(update_op))
      self.assertEqual(1.0, self.evaluate(accuracy))

  @test_util.run_deprecated_v1
  def testEffectivelyEquivalentSizesWithStaticShapedWeight(self):
    predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
    labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
                                   1)  # shape 3, 1
    weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
                                    1)  # shape 3, 1

    with self.cached_session():
      accuracy, update_op = metrics.accuracy(labels, predictions, weights)

      self.evaluate(variables.local_variables_initializer())
      # if streaming_accuracy does not flatten the weight, accuracy would be
      # 0.33333334 due to an intended broadcast of weight. Due to flattening,
      # it will be higher than .95
      self.assertGreater(self.evaluate(update_op), .95)
      self.assertGreater(self.evaluate(accuracy), .95)

  @test_util.run_deprecated_v1
  def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self):
    predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
    labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
                                   1)  # shape 3, 1

    weights = [[100], [1], [1]]  # shape 3, 1
    weights_placeholder = array_ops.placeholder(
        dtype=dtypes_lib.int32, name='weights')
    feed_dict = {weights_placeholder: weights}

    with self.cached_session():
      accuracy, update_op = metrics.accuracy(labels, predictions,
                                             weights_placeholder)

      self.evaluate(variables.local_variables_initializer())
      # if streaming_accuracy does not flatten the weight, accuracy would be
      # 0.33333334 due to an intended broadcast of weight. Due to flattening,
      # it will be higher than .95
      self.assertGreater(update_op.eval(feed_dict=feed_dict), .95)
      self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)

  @test_util.run_deprecated_v1
  def testMultipleUpdatesWithWeightedValues(self):
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [2])
      _enqueue_vector(sess, preds_queue, [1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [2])
      labels = labels_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          4, dtypes=dtypes_lib.int64, shapes=(1, 1))
      _enqueue_vector(sess, weights_queue, [1])
      _enqueue_vector(sess, weights_queue, [1])
      _enqueue_vector(sess, weights_queue, [0])
      _enqueue_vector(sess, weights_queue, [0])
      weights = weights_queue.dequeue()

      accuracy, update_op = metrics.accuracy(labels, predictions, weights)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(3):
        self.evaluate(update_op)
      self.assertEqual(1.0, self.evaluate(update_op))
      self.assertEqual(1.0, self.evaluate(accuracy))


class PrecisionTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.precision(
        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    _assert_metric_variables(self, ('precision/false_positives/count:0',
                                    'precision/true_positives/count:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.precision(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.precision(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    precision, update_op = metrics.precision(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_precision = self.evaluate(precision)
      for _ in range(10):
        self.assertEqual(initial_precision, self.evaluate(precision))

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(inputs)
    labels = constant_op.constant(inputs)
    precision, update_op = metrics.precision(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
      self.assertAlmostEqual(1.0, self.evaluate(precision), 6)

  @test_util.run_deprecated_v1
  def testSomeCorrect_multipleInputDtypes(self):
    for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
      predictions = math_ops.cast(
          constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
      labels = math_ops.cast(
          constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
      precision, update_op = metrics.precision(labels, predictions)

      with self.cached_session():
        self.evaluate(variables.local_variables_initializer())
        self.assertAlmostEqual(0.5, self.evaluate(update_op))
        self.assertAlmostEqual(0.5, self.evaluate(precision))

  @test_util.run_deprecated_v1
  def testWeighted1d(self):
    predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    precision, update_op = metrics.precision(
        labels, predictions, weights=constant_op.constant([[2], [5]]))

    with self.cached_session():
      variables.local_variables_initializer().run()
      weighted_tp = 2.0 + 5.0
      weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
      expected_precision = weighted_tp / weighted_positives
      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
      self.assertAlmostEqual(expected_precision, self.evaluate(precision))

  @test_util.run_deprecated_v1
  def testWeightedScalar_placeholders(self):
    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    feed_dict = {
        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    }
    precision, update_op = metrics.precision(labels, predictions, weights=2)

    with self.cached_session():
      variables.local_variables_initializer().run()
      weighted_tp = 2.0 + 2.0
      weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
      expected_precision = weighted_tp / weighted_positives
      self.assertAlmostEqual(
          expected_precision, update_op.eval(feed_dict=feed_dict))
      self.assertAlmostEqual(
          expected_precision, precision.eval(feed_dict=feed_dict))

  @test_util.run_deprecated_v1
  def testWeighted1d_placeholders(self):
    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    feed_dict = {
        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    }
    precision, update_op = metrics.precision(
        labels, predictions, weights=constant_op.constant([[2], [5]]))

    with self.cached_session():
      variables.local_variables_initializer().run()
      weighted_tp = 2.0 + 5.0
      weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
      expected_precision = weighted_tp / weighted_positives
      self.assertAlmostEqual(
          expected_precision, update_op.eval(feed_dict=feed_dict))
      self.assertAlmostEqual(
          expected_precision, precision.eval(feed_dict=feed_dict))

  @test_util.run_deprecated_v1
  def testWeighted2d(self):
    predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    precision, update_op = metrics.precision(
        labels,
        predictions,
        weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))

    with self.cached_session():
      variables.local_variables_initializer().run()
      weighted_tp = 3.0 + 4.0
      weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
      expected_precision = weighted_tp / weighted_positives
      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
      self.assertAlmostEqual(expected_precision, self.evaluate(precision))

  @test_util.run_deprecated_v1
  def testWeighted2d_placeholders(self):
    predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    feed_dict = {
        predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
        labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    }
    precision, update_op = metrics.precision(
        labels,
        predictions,
        weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))

    with self.cached_session():
      variables.local_variables_initializer().run()
      weighted_tp = 3.0 + 4.0
      weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
      expected_precision = weighted_tp / weighted_positives
      self.assertAlmostEqual(
          expected_precision, update_op.eval(feed_dict=feed_dict))
      self.assertAlmostEqual(
          expected_precision, precision.eval(feed_dict=feed_dict))

  @test_util.run_deprecated_v1
  def testAllIncorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(inputs)
    labels = constant_op.constant(1 - inputs)
    precision, update_op = metrics.precision(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertAlmostEqual(0, self.evaluate(precision))

  @test_util.run_deprecated_v1
  def testZeroTrueAndFalsePositivesGivesZeroPrecision(self):
    predictions = constant_op.constant([0, 0, 0, 0])
    labels = constant_op.constant([0, 0, 0, 0])
    precision, update_op = metrics.precision(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertEqual(0.0, self.evaluate(precision))


class RecallTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.recall(
        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    _assert_metric_variables(
        self,
        ('recall/false_negatives/count:0', 'recall/true_positives/count:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.recall(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.recall(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    recall, update_op = metrics.recall(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_recall = self.evaluate(recall)
      for _ in range(10):
        self.assertEqual(initial_recall, self.evaluate(recall))

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    np_inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(np_inputs)
    labels = constant_op.constant(np_inputs)
    recall, update_op = metrics.recall(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertAlmostEqual(1.0, self.evaluate(recall), 6)

  @test_util.run_deprecated_v1
  def testSomeCorrect_multipleInputDtypes(self):
    for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
      predictions = math_ops.cast(
          constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
      labels = math_ops.cast(
          constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
      recall, update_op = metrics.recall(labels, predictions)

      with self.cached_session():
        self.evaluate(variables.local_variables_initializer())
        self.assertAlmostEqual(0.5, self.evaluate(update_op))
        self.assertAlmostEqual(0.5, self.evaluate(recall))

  @test_util.run_deprecated_v1
  def testWeighted1d(self):
    predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    weights = constant_op.constant([[2], [5]])
    recall, update_op = metrics.recall(labels, predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      weighted_tp = 2.0 + 5.0
      weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
      expected_precision = weighted_tp / weighted_t
      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
      self.assertAlmostEqual(expected_precision, self.evaluate(recall))

  @test_util.run_deprecated_v1
  def testWeighted2d(self):
    predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
    labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
    recall, update_op = metrics.recall(labels, predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      weighted_tp = 3.0 + 1.0
      weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
      expected_precision = weighted_tp / weighted_t
      self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
      self.assertAlmostEqual(expected_precision, self.evaluate(recall))

  @test_util.run_deprecated_v1
  def testAllIncorrect(self):
    np_inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(np_inputs)
    labels = constant_op.constant(1 - np_inputs)
    recall, update_op = metrics.recall(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertEqual(0, self.evaluate(recall))

  @test_util.run_deprecated_v1
  def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self):
    predictions = array_ops.zeros((1, 4))
    labels = array_ops.zeros((1, 4))
    recall, update_op = metrics.recall(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertEqual(0, self.evaluate(recall))


class AUCTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.auc(predictions=array_ops.ones((10, 1)),
                labels=array_ops.ones((10, 1)))
    _assert_metric_variables(self,
                             ('auc/true_positives:0', 'auc/false_negatives:0',
                              'auc/false_positives:0', 'auc/true_negatives:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.auc(predictions=array_ops.ones((10, 1)),
                          labels=array_ops.ones((10, 1)),
                          metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.auc(predictions=array_ops.ones((10, 1)),
                               labels=array_ops.ones((10, 1)),
                               updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    auc, update_op = metrics.auc(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_auc = self.evaluate(auc)
      for _ in range(10):
        self.assertAlmostEqual(initial_auc, self.evaluate(auc), 5)

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    self.allCorrectAsExpected('ROC')

  def allCorrectAsExpected(self, curve):
    inputs = np.random.randint(0, 2, size=(100, 1))

    with self.cached_session():
      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
      labels = constant_op.constant(inputs)
      auc, update_op = metrics.auc(labels, predictions, curve=curve)

      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1, self.evaluate(update_op))

      self.assertEqual(1, self.evaluate(auc))

  @test_util.run_deprecated_v1
  def testSomeCorrect_multipleLabelDtypes(self):
    with self.cached_session():
      for label_dtype in (
          dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
        predictions = constant_op.constant(
            [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
        labels = math_ops.cast(
            constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
        auc, update_op = metrics.auc(labels, predictions)

        self.evaluate(variables.local_variables_initializer())
        self.assertAlmostEqual(0.5, self.evaluate(update_op))

        self.assertAlmostEqual(0.5, self.evaluate(auc))

  @test_util.run_deprecated_v1
  def testWeighted1d(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
      weights = constant_op.constant([2], shape=(1, 1))
      auc, update_op = metrics.auc(labels, predictions, weights=weights)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.5, self.evaluate(update_op), 5)

      self.assertAlmostEqual(0.5, self.evaluate(auc), 5)

  @test_util.run_deprecated_v1
  def testWeighted2d(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
      weights = constant_op.constant([1, 2, 3, 4], shape=(1, 4))
      auc, update_op = metrics.auc(labels, predictions, weights=weights)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.7, self.evaluate(update_op), 5)

      self.assertAlmostEqual(0.7, self.evaluate(auc), 5)

  @test_util.run_deprecated_v1
  def testManualThresholds(self):
    with self.cached_session():
      # Verifies that thresholds passed in to the `thresholds` parameter are
      # used correctly.
      # The default thresholds do not split the second and third predictions.
      # Thus, when we provide manual thresholds which correctly split it, we get
      # an accurate AUC value.
      predictions = constant_op.constant(
          [0.12, 0.3001, 0.3003, 0.72], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
      weights = constant_op.constant([1, 1, 1, 1], shape=(1, 4))
      thresholds = [0.0, 0.2, 0.3002, 0.6, 1.0]
      default_auc, default_update_op = metrics.auc(labels,
                                                   predictions,
                                                   weights=weights)
      manual_auc, manual_update_op = metrics.auc(labels,
                                                 predictions,
                                                 weights=weights,
                                                 thresholds=thresholds)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3)
      self.assertAlmostEqual(0.875, self.evaluate(default_auc), 3)

      self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3)
      self.assertAlmostEqual(0.75, self.evaluate(manual_auc), 3)

  # Regarding the AUC-PR tests: note that the preferred method when
  # calculating AUC-PR is summation_method='careful_interpolation'.
  @test_util.run_deprecated_v1
  def testCorrectAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='careful_interpolation')

      self.evaluate(variables.local_variables_initializer())
      # expected ~= 0.79726744594
      expected = 1 - math.log(1.5) / 2
      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testCorrectAnotherAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
          shape=(1, 7),
          dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='careful_interpolation')

      self.evaluate(variables.local_variables_initializer())
      # expected ~= 0.61350593198
      expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testThirdCorrectAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
          shape=(1, 7),
          dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='careful_interpolation')

      self.evaluate(variables.local_variables_initializer())
      # expected ~= 0.90410597584
      expected = 1 - math.log(4./3) / 3
      self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
      self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testIncorrectAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='trapezoidal')

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.79166, self.evaluate(update_op), delta=1e-3)

      self.assertAlmostEqual(0.79166, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testAnotherIncorrectAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
          shape=(1, 7),
          dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='trapezoidal')

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.610317, self.evaluate(update_op), delta=1e-3)

      self.assertAlmostEqual(0.610317, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testThirdIncorrectAUCPRSpecialCase(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
          shape=(1, 7),
          dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
      auc, update_op = metrics.auc(labels, predictions, curve='PR',
                                   summation_method='trapezoidal')

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.90277, self.evaluate(update_op), delta=1e-3)

      self.assertAlmostEqual(0.90277, self.evaluate(auc), delta=1e-3)

  @test_util.run_deprecated_v1
  def testAllIncorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    with self.cached_session():
      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
      labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
      auc, update_op = metrics.auc(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0, self.evaluate(update_op))

      self.assertAlmostEqual(0, self.evaluate(auc))

  @test_util.run_deprecated_v1
  def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
    with self.cached_session():
      predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
      labels = array_ops.zeros([4])
      auc, update_op = metrics.auc(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1, self.evaluate(update_op), 6)

      self.assertAlmostEqual(1, self.evaluate(auc), 6)

  @test_util.run_deprecated_v1
  def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
    with self.cached_session():
      predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
      labels = array_ops.ones([4])
      auc, update_op = metrics.auc(labels, predictions, curve='PR')

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1, self.evaluate(update_op), 6)

      self.assertAlmostEqual(1, self.evaluate(auc), 6)

  def np_auc(self, predictions, labels, weights):
    """Computes the AUC explicitly using Numpy.

    Args:
      predictions: an ndarray with shape [N].
      labels: an ndarray with shape [N].
      weights: an ndarray with shape [N].

    Returns:
      the area under the ROC curve.
    """
    if weights is None:
      weights = np.ones(np.size(predictions))
    is_positive = labels > 0
    num_positives = np.sum(weights[is_positive])
    num_negatives = np.sum(weights[~is_positive])

    # Sort descending:
    inds = np.argsort(-predictions)

    sorted_labels = labels[inds]
    sorted_weights = weights[inds]
    is_positive = sorted_labels > 0

    tp = np.cumsum(sorted_weights * is_positive) / num_positives
    return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives

  @test_util.run_deprecated_v1
  def testWithMultipleUpdates(self):
    num_samples = 1000
    batch_size = 10
    num_batches = int(num_samples / batch_size)

    # Create the labels and data.
    labels = np.random.randint(0, 2, size=num_samples)
    noise = np.random.normal(0.0, scale=0.2, size=num_samples)
    predictions = 0.4 + 0.2 * labels + noise
    predictions[predictions > 1] = 1
    predictions[predictions < 0] = 0

    def _enqueue_as_batches(x, enqueue_ops):
      x_batches = x.astype(np.float32).reshape((num_batches, batch_size))
      x_queue = data_flow_ops.FIFOQueue(
          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
      for i in range(num_batches):
        enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :]))
      return x_queue.dequeue()

    for weights in (None, np.ones(num_samples), np.random.exponential(
        scale=1.0, size=num_samples)):
      expected_auc = self.np_auc(predictions, labels, weights)

      with self.cached_session() as sess:
        enqueue_ops = [[] for i in range(num_batches)]
        tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
        tf_labels = _enqueue_as_batches(labels, enqueue_ops)
        tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if
                      weights is not None else None)

        for i in range(num_batches):
          sess.run(enqueue_ops[i])

        auc, update_op = metrics.auc(tf_labels,
                                     tf_predictions,
                                     curve='ROC',
                                     num_thresholds=500,
                                     weights=tf_weights)

        self.evaluate(variables.local_variables_initializer())
        for i in range(num_batches):
          self.evaluate(update_op)

        # Since this is only approximate, we can't expect a 6 digits match.
        # Although with higher number of samples/thresholds we should see the
        # accuracy improving
        self.assertAlmostEqual(expected_auc, self.evaluate(auc), 2)


class SpecificityAtSensitivityTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.specificity_at_sensitivity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        sensitivity=0.7)
    _assert_metric_variables(self,
                             ('specificity_at_sensitivity/true_positives:0',
                              'specificity_at_sensitivity/false_negatives:0',
                              'specificity_at_sensitivity/false_positives:0',
                              'specificity_at_sensitivity/true_negatives:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.specificity_at_sensitivity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        sensitivity=0.7,
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.specificity_at_sensitivity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        sensitivity=0.7,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
    specificity, update_op = metrics.specificity_at_sensitivity(
        labels, predictions, sensitivity=0.7)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_specificity = self.evaluate(specificity)
      for _ in range(10):
        self.assertAlmostEqual(initial_specificity, self.evaluate(specificity),
                               5)

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
    labels = constant_op.constant(inputs)
    specificity, update_op = metrics.specificity_at_sensitivity(
        labels, predictions, sensitivity=0.7)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1, self.evaluate(update_op))
      self.assertEqual(1, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testSomeCorrectHighSensitivity(self):
    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9]
    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    predictions = constant_op.constant(
        predictions_values, dtype=dtypes_lib.float32)
    labels = constant_op.constant(labels_values)
    specificity, update_op = metrics.specificity_at_sensitivity(
        labels, predictions, sensitivity=0.8)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1.0, self.evaluate(update_op))
      self.assertAlmostEqual(1.0, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testSomeCorrectLowSensitivity(self):
    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    predictions = constant_op.constant(
        predictions_values, dtype=dtypes_lib.float32)
    labels = constant_op.constant(labels_values)
    specificity, update_op = metrics.specificity_at_sensitivity(
        labels, predictions, sensitivity=0.4)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      self.assertAlmostEqual(0.6, self.evaluate(update_op))
      self.assertAlmostEqual(0.6, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testWeighted1d_multipleLabelDtypes(self):
    for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
      predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
      labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
      weights_values = [3]

      predictions = constant_op.constant(
          predictions_values, dtype=dtypes_lib.float32)
      labels = math_ops.cast(labels_values, dtype=label_dtype)
      weights = constant_op.constant(weights_values)
      specificity, update_op = metrics.specificity_at_sensitivity(
          labels, predictions, weights=weights, sensitivity=0.4)

      with self.cached_session():
        self.evaluate(variables.local_variables_initializer())

        self.assertAlmostEqual(0.6, self.evaluate(update_op))
        self.assertAlmostEqual(0.6, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testWeighted2d(self):
    predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
    weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

    predictions = constant_op.constant(
        predictions_values, dtype=dtypes_lib.float32)
    labels = constant_op.constant(labels_values)
    weights = constant_op.constant(weights_values)
    specificity, update_op = metrics.specificity_at_sensitivity(
        labels, predictions, weights=weights, sensitivity=0.4)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      self.assertAlmostEqual(8.0 / 15.0, self.evaluate(update_op))
      self.assertAlmostEqual(8.0 / 15.0, self.evaluate(specificity))


class SensitivityAtSpecificityTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.sensitivity_at_specificity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        specificity=0.7)
    _assert_metric_variables(self,
                             ('sensitivity_at_specificity/true_positives:0',
                              'sensitivity_at_specificity/false_negatives:0',
                              'sensitivity_at_specificity/false_positives:0',
                              'sensitivity_at_specificity/true_negatives:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.sensitivity_at_specificity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        specificity=0.7,
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.sensitivity_at_specificity(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        specificity=0.7,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
    sensitivity, update_op = metrics.sensitivity_at_specificity(
        labels, predictions, specificity=0.7)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_sensitivity = self.evaluate(sensitivity)
      for _ in range(10):
        self.assertAlmostEqual(initial_sensitivity, self.evaluate(sensitivity),
                               5)

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
    labels = constant_op.constant(inputs)
    specificity, update_op = metrics.sensitivity_at_specificity(
        labels, predictions, specificity=0.7)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
      self.assertAlmostEqual(1.0, self.evaluate(specificity), 6)

  @test_util.run_deprecated_v1
  def testSomeCorrectHighSpecificity(self):
    predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    predictions = constant_op.constant(
        predictions_values, dtype=dtypes_lib.float32)
    labels = constant_op.constant(labels_values)
    specificity, update_op = metrics.sensitivity_at_specificity(
        labels, predictions, specificity=0.8)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.8, self.evaluate(update_op))
      self.assertAlmostEqual(0.8, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testSomeCorrectLowSpecificity(self):
    predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
    labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    predictions = constant_op.constant(
        predictions_values, dtype=dtypes_lib.float32)
    labels = constant_op.constant(labels_values)
    specificity, update_op = metrics.sensitivity_at_specificity(
        labels, predictions, specificity=0.4)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(0.6, self.evaluate(update_op))
      self.assertAlmostEqual(0.6, self.evaluate(specificity))

  @test_util.run_deprecated_v1
  def testWeighted_multipleLabelDtypes(self):
    for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
      predictions_values = [
          0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
      labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
      weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

      predictions = constant_op.constant(
          predictions_values, dtype=dtypes_lib.float32)
      labels = math_ops.cast(labels_values, dtype=label_dtype)
      weights = constant_op.constant(weights_values)
      specificity, update_op = metrics.sensitivity_at_specificity(
          labels, predictions, weights=weights, specificity=0.4)

      with self.cached_session():
        self.evaluate(variables.local_variables_initializer())
        self.assertAlmostEqual(0.675, self.evaluate(update_op))
        self.assertAlmostEqual(0.675, self.evaluate(specificity))


# TODO(nsilberman): Break this up into two sets of tests.
class PrecisionRecallThresholdsTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.precision_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0, 0.5, 1.0])
    _assert_metric_variables(self, (
        'precision_at_thresholds/true_positives:0',
        'precision_at_thresholds/false_positives:0',
    ))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    prec, _ = metrics.precision_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0, 0.5, 1.0],
        metrics_collections=[my_collection_name])
    rec, _ = metrics.recall_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0, 0.5, 1.0],
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, precision_op = metrics.precision_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0, 0.5, 1.0],
        updates_collections=[my_collection_name])
    _, recall_op = metrics.recall_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0, 0.5, 1.0],
        updates_collections=[my_collection_name])
    self.assertListEqual(
        ops.get_collection(my_collection_name), [precision_op, recall_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
    labels = random_ops.random_uniform(
        (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    thresholds = [0, 0.5, 1.0]
    prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                    thresholds)
    rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates, then verify idempotency.
      self.evaluate([prec_op, rec_op])
      initial_prec = self.evaluate(prec)
      initial_rec = self.evaluate(rec)
      for _ in range(10):
        self.evaluate([prec_op, rec_op])
        self.assertAllClose(initial_prec, prec)
        self.assertAllClose(initial_rec, rec)

  # TODO(nsilberman): fix tests (passing but incorrect).
  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    with self.cached_session():
      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
      labels = constant_op.constant(inputs)
      thresholds = [0.5]
      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                      thresholds)
      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
                                                 thresholds)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertEqual(1, self.evaluate(prec))
      self.assertEqual(1, self.evaluate(rec))

  @test_util.run_deprecated_v1
  def testSomeCorrect_multipleLabelDtypes(self):
    with self.cached_session():
      for label_dtype in (
          dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
        predictions = constant_op.constant(
            [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
        labels = math_ops.cast(
            constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
        thresholds = [0.5]
        prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                        thresholds)
        rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
                                                   thresholds)

        self.evaluate(variables.local_variables_initializer())
        self.evaluate([prec_op, rec_op])

        self.assertAlmostEqual(0.5, self.evaluate(prec))
        self.assertAlmostEqual(0.5, self.evaluate(rec))

  @test_util.run_deprecated_v1
  def testAllIncorrect(self):
    inputs = np.random.randint(0, 2, size=(100, 1))

    with self.cached_session():
      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
      labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
      thresholds = [0.5]
      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                      thresholds)
      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
                                                 thresholds)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertAlmostEqual(0, self.evaluate(prec))
      self.assertAlmostEqual(0, self.evaluate(rec))

  @test_util.run_deprecated_v1
  def testWeights1d(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
      labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
      weights = constant_op.constant(
          [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32)
      thresholds = [0.5, 1.1]
      prec, prec_op = metrics.precision_at_thresholds(
          labels, predictions, thresholds, weights=weights)
      rec, rec_op = metrics.recall_at_thresholds(
          labels, predictions, thresholds, weights=weights)

      [prec_low, prec_high] = array_ops.split(
          value=prec, num_or_size_splits=2, axis=0)
      prec_low = array_ops.reshape(prec_low, shape=())
      prec_high = array_ops.reshape(prec_high, shape=())
      [rec_low, rec_high] = array_ops.split(
          value=rec, num_or_size_splits=2, axis=0)
      rec_low = array_ops.reshape(rec_low, shape=())
      rec_high = array_ops.reshape(rec_high, shape=())

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
      self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
      self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
      self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)

  @test_util.run_deprecated_v1
  def testWeights2d(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
      labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
      weights = constant_op.constant(
          [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32)
      thresholds = [0.5, 1.1]
      prec, prec_op = metrics.precision_at_thresholds(
          labels, predictions, thresholds, weights=weights)
      rec, rec_op = metrics.recall_at_thresholds(
          labels, predictions, thresholds, weights=weights)

      [prec_low, prec_high] = array_ops.split(
          value=prec, num_or_size_splits=2, axis=0)
      prec_low = array_ops.reshape(prec_low, shape=())
      prec_high = array_ops.reshape(prec_high, shape=())
      [rec_low, rec_high] = array_ops.split(
          value=rec, num_or_size_splits=2, axis=0)
      rec_low = array_ops.reshape(rec_low, shape=())
      rec_high = array_ops.reshape(rec_high, shape=())

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
      self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
      self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
      self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)

  @test_util.run_deprecated_v1
  def testExtremeThresholds(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
      thresholds = [-1.0, 2.0]  # lower/higher than any values
      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                      thresholds)
      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
                                                 thresholds)

      [prec_low, prec_high] = array_ops.split(
          value=prec, num_or_size_splits=2, axis=0)
      [rec_low, rec_high] = array_ops.split(
          value=rec, num_or_size_splits=2, axis=0)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertAlmostEqual(0.75, self.evaluate(prec_low))
      self.assertAlmostEqual(0.0, self.evaluate(prec_high))
      self.assertAlmostEqual(1.0, self.evaluate(rec_low))
      self.assertAlmostEqual(0.0, self.evaluate(rec_high))

  @test_util.run_deprecated_v1
  def testZeroLabelsPredictions(self):
    with self.cached_session():
      predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
      labels = array_ops.zeros([4])
      thresholds = [0.5]
      prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
                                                      thresholds)
      rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
                                                 thresholds)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([prec_op, rec_op])

      self.assertAlmostEqual(0, self.evaluate(prec), 6)
      self.assertAlmostEqual(0, self.evaluate(rec), 6)

  @test_util.run_deprecated_v1
  def testWithMultipleUpdates(self):
    num_samples = 1000
    batch_size = 10
    num_batches = int(num_samples / batch_size)

    # Create the labels and data.
    labels = np.random.randint(0, 2, size=(num_samples, 1))
    noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
    predictions = 0.4 + 0.2 * labels + noise
    predictions[predictions > 1] = 1
    predictions[predictions < 0] = 0
    thresholds = [0.3]

    tp = 0
    fp = 0
    fn = 0
    tn = 0
    for i in range(num_samples):
      if predictions[i] > thresholds[0]:
        if labels[i] == 1:
          tp += 1
        else:
          fp += 1
      else:
        if labels[i] == 1:
          fn += 1
        else:
          tn += 1
    epsilon = 1e-7
    expected_prec = tp / (epsilon + tp + fp)
    expected_rec = tp / (epsilon + tp + fn)

    labels = labels.astype(np.float32)
    predictions = predictions.astype(np.float32)

    with self.cached_session() as sess:
      # Reshape the data so its easy to queue up:
      predictions_batches = predictions.reshape((batch_size, num_batches))
      labels_batches = labels.reshape((batch_size, num_batches))

      # Enqueue the data:
      predictions_queue = data_flow_ops.FIFOQueue(
          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
      labels_queue = data_flow_ops.FIFOQueue(
          num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))

      for i in range(int(num_batches)):
        tf_prediction = constant_op.constant(predictions_batches[:, i])
        tf_label = constant_op.constant(labels_batches[:, i])
        sess.run([
            predictions_queue.enqueue(tf_prediction),
            labels_queue.enqueue(tf_label)
        ])

      tf_predictions = predictions_queue.dequeue()
      tf_labels = labels_queue.dequeue()

      prec, prec_op = metrics.precision_at_thresholds(tf_labels, tf_predictions,
                                                      thresholds)
      rec, rec_op = metrics.recall_at_thresholds(tf_labels, tf_predictions,
                                                 thresholds)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(int(num_samples / batch_size)):
        self.evaluate([prec_op, rec_op])
      # Since this is only approximate, we can't expect a 6 digits match.
      # Although with higher number of samples/thresholds we should see the
      # accuracy improving
      self.assertAlmostEqual(expected_prec, self.evaluate(prec), 2)
      self.assertAlmostEqual(expected_rec, self.evaluate(rec), 2)


def _test_precision_at_k(predictions,
                         labels,
                         k,
                         expected,
                         class_id=None,
                         weights=None,
                         test_case=None):
  with ops.Graph().as_default() as g, test_case.test_session(g):
    if weights is not None:
      weights = constant_op.constant(weights, dtypes_lib.float32)
    metric, update = metrics.precision_at_k(
        predictions=constant_op.constant(predictions, dtypes_lib.float32),
        labels=labels,
        k=k,
        class_id=class_id,
        weights=weights)

    # Fails without initialized vars.
    test_case.assertRaises(errors_impl.OpError, metric.eval)
    test_case.assertRaises(errors_impl.OpError, update.eval)
    variables.variables_initializer(variables.local_variables()).run()

    # Run per-step op and assert expected values.
    if math.isnan(expected):
      _assert_nan(test_case, update.eval())
      _assert_nan(test_case, metric.eval())
    else:
      test_case.assertEqual(expected, update.eval())
      test_case.assertEqual(expected, metric.eval())


def _test_precision_at_top_k(
    predictions_idx,
    labels,
    expected,
    k=None,
    class_id=None,
    weights=None,
    test_case=None):
  with ops.Graph().as_default() as g, test_case.test_session(g):
    if weights is not None:
      weights = constant_op.constant(weights, dtypes_lib.float32)
    metric, update = metrics.precision_at_top_k(
        predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
        labels=labels,
        k=k,
        class_id=class_id,
        weights=weights)

    # Fails without initialized vars.
    test_case.assertRaises(errors_impl.OpError, metric.eval)
    test_case.assertRaises(errors_impl.OpError, update.eval)
    variables.variables_initializer(variables.local_variables()).run()

    # Run per-step op and assert expected values.
    if math.isnan(expected):
      test_case.assertTrue(math.isnan(update.eval()))
      test_case.assertTrue(math.isnan(metric.eval()))
    else:
      test_case.assertEqual(expected, update.eval())
      test_case.assertEqual(expected, metric.eval())


def _test_average_precision_at_k(predictions,
                                 labels,
                                 k,
                                 expected,
                                 weights=None,
                                 test_case=None):
  with ops.Graph().as_default() as g, test_case.test_session(g):
    if weights is not None:
      weights = constant_op.constant(weights, dtypes_lib.float32)
    predictions = constant_op.constant(predictions, dtypes_lib.float32)
    metric, update = metrics.average_precision_at_k(
        labels, predictions, k, weights=weights)

    # Fails without initialized vars.
    test_case.assertRaises(errors_impl.OpError, metric.eval)
    test_case.assertRaises(errors_impl.OpError, update.eval)
    variables.variables_initializer(variables.local_variables()).run()

    # Run per-step op and assert expected values.
    if math.isnan(expected):
      _assert_nan(test_case, update.eval())
      _assert_nan(test_case, metric.eval())
    else:
      test_case.assertAlmostEqual(expected, update.eval())
      test_case.assertAlmostEqual(expected, metric.eval())


class SingleLabelPrecisionAtKTest(test.TestCase):

  def setUp(self):
    self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
    self._predictions_idx = [[3], [3]]
    indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
    class_labels = (3, 2)
    # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
    self._labels = (
        _binary_2d_label_to_1d_sparse_value(indicator_labels),
        _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
            class_labels, dtype=np.int64), np.array(
                [[class_id] for class_id in class_labels], dtype=np.int64))
    self._test_precision_at_k = functools.partial(
        _test_precision_at_k, test_case=self)
    self._test_precision_at_top_k = functools.partial(
        _test_precision_at_top_k, test_case=self)
    self._test_average_precision_at_k = functools.partial(
        _test_average_precision_at_k, test_case=self)

  @test_util.run_deprecated_v1
  def test_at_k1_nan(self):
    for labels in self._labels:
      # Classes 0,1,2 have 0 predictions, classes -1 and 4 are out of range.
      for class_id in (-1, 0, 1, 2, 4):
        self._test_precision_at_k(
            self._predictions, labels, k=1, expected=NAN, class_id=class_id)
        self._test_precision_at_top_k(
            self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_at_k1(self):
    for labels in self._labels:
      # Class 3: 1 label, 2 predictions, 1 correct.
      self._test_precision_at_k(
          self._predictions, labels, k=1, expected=1.0 / 2, class_id=3)
      self._test_precision_at_top_k(
          self._predictions_idx, labels, k=1, expected=1.0 / 2, class_id=3)

      # All classes: 2 labels, 2 predictions, 1 correct.
      self._test_precision_at_k(
          self._predictions, labels, k=1, expected=1.0 / 2)
      self._test_precision_at_top_k(
          self._predictions_idx, labels, k=1, expected=1.0 / 2)
      self._test_average_precision_at_k(
          self._predictions, labels, k=1, expected=1.0 / 2)


class MultiLabelPrecisionAtKTest(test.TestCase):

  def setUp(self):
    self._test_precision_at_k = functools.partial(
        _test_precision_at_k, test_case=self)
    self._test_precision_at_top_k = functools.partial(
        _test_precision_at_top_k, test_case=self)
    self._test_average_precision_at_k = functools.partial(
        _test_average_precision_at_k, test_case=self)

  @test_util.run_deprecated_v1
  def test_average_precision(self):
    # Example 1.
    # Matches example here:
    # fastml.com/what-you-wanted-to-know-about-mean-average-precision
    labels_ex1 = (0, 1, 2, 3, 4)
    labels = np.array([labels_ex1], dtype=np.int64)
    predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
    predictions = (predictions_ex1,)
    predictions_idx_ex1 = (5, 3, 6, 0, 1)
    precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
    avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
                         (precision_ex1[1] + precision_ex1[3]) / 4)
    for i in range(4):
      k = i + 1
      self._test_precision_at_k(
          predictions, labels, k, expected=precision_ex1[i])
      self._test_precision_at_top_k(
          (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
      self._test_average_precision_at_k(
          predictions, labels, k, expected=avg_precision_ex1[i])

    # Example 2.
    labels_ex2 = (0, 2, 4, 5, 6)
    labels = np.array([labels_ex2], dtype=np.int64)
    predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2)
    predictions = (predictions_ex2,)
    predictions_idx_ex2 = (1, 3, 0, 6, 5)
    precision_ex2 = (0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4)
    avg_precision_ex2 = (0.0 / 1, 0.0 / 2, precision_ex2[2] / 3,
                         (precision_ex2[2] + precision_ex2[3]) / 4)
    for i in range(4):
      k = i + 1
      self._test_precision_at_k(
          predictions, labels, k, expected=precision_ex2[i])
      self._test_precision_at_top_k(
          (predictions_idx_ex2[:k],), labels, k=k, expected=precision_ex2[i])
      self._test_average_precision_at_k(
          predictions, labels, k, expected=avg_precision_ex2[i])

    # Both examples, we expect both precision and average precision to be the
    # average of the 2 examples.
    labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
    predictions = (predictions_ex1, predictions_ex2)
    streaming_precision = [(ex1 + ex2) / 2
                           for ex1, ex2 in zip(precision_ex1, precision_ex2)]
    streaming_average_precision = [
        (ex1 + ex2) / 2
        for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
    ]
    for i in range(4):
      k = i + 1
      predictions_idx = (predictions_idx_ex1[:k], predictions_idx_ex2[:k])
      self._test_precision_at_k(
          predictions, labels, k, expected=streaming_precision[i])
      self._test_precision_at_top_k(
          predictions_idx, labels, k=k, expected=streaming_precision[i])
      self._test_average_precision_at_k(
          predictions, labels, k, expected=streaming_average_precision[i])

    # Weighted examples, we expect streaming average precision to be the
    # weighted average of the 2 examples.
    weights = (0.3, 0.6)
    streaming_average_precision = [
        (weights[0] * ex1 + weights[1] * ex2) / (weights[0] + weights[1])
        for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
    ]
    for i in range(4):
      k = i + 1
      self._test_average_precision_at_k(
          predictions,
          labels,
          k,
          expected=streaming_average_precision[i],
          weights=weights)

  @test_util.run_deprecated_v1
  def test_average_precision_some_labels_out_of_range(self):
    """Tests that labels outside the [0, n_classes) range are ignored."""
    labels_ex1 = (-1, 0, 1, 2, 3, 4, 7)
    labels = np.array([labels_ex1], dtype=np.int64)
    predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
    predictions = (predictions_ex1,)
    predictions_idx_ex1 = (5, 3, 6, 0, 1)
    precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
    avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
                         (precision_ex1[1] + precision_ex1[3]) / 4)
    for i in range(4):
      k = i + 1
      self._test_precision_at_k(
          predictions, labels, k, expected=precision_ex1[i])
      self._test_precision_at_top_k(
          (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
      self._test_average_precision_at_k(
          predictions, labels, k, expected=avg_precision_ex1[i])

  @test_util.run_deprecated_v1
  def test_average_precision_different_num_labels(self):
    """Tests the case where the numbers of labels differ across examples."""
    predictions = [[0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4]]
    sparse_labels = _binary_2d_label_to_2d_sparse_value(
        [[0, 0, 1, 1], [0, 0, 0, 1]])
    dense_labels = np.array([[2, 3], [3, -1]], dtype=np.int64)
    predictions_idx_ex1 = np.array(((0, 1, 2, 3), (3, 2, 1, 0)))
    precision_ex1 = ((0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4),
                     (1.0 / 1, 1.0 / 2, 1.0 / 3, 1.0 / 4))
    mean_precision_ex1 = np.mean(precision_ex1, axis=0)
    avg_precision_ex1 = (
        (0.0 / 1, 0.0 / 2, 1.0 / 3 / 2, (1.0 / 3 + 2.0 / 4) / 2),
        (1.0 / 1, 1.0 / 1, 1.0 / 1, 1.0 / 1))
    mean_avg_precision_ex1 = np.mean(avg_precision_ex1, axis=0)
    for labels in (sparse_labels, dense_labels):
      for i in range(4):
        k = i + 1
        self._test_precision_at_k(
            predictions, labels, k, expected=mean_precision_ex1[i])
        self._test_precision_at_top_k(
            predictions_idx_ex1[:, :k], labels, k=k,
            expected=mean_precision_ex1[i])
        self._test_average_precision_at_k(
            predictions, labels, k, expected=mean_avg_precision_ex1[i])

  @test_util.run_deprecated_v1
  def test_three_labels_at_k5_no_predictions(self):
    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
    sparse_labels = _binary_2d_label_to_2d_sparse_value(
        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)

    for labels in (sparse_labels, dense_labels):
      # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
      for class_id in (-1, 1, 3, 8, 10):
        self._test_precision_at_k(
            predictions, labels, k=5, expected=NAN, class_id=class_id)
        self._test_precision_at_top_k(
            predictions_idx, labels, k=5, expected=NAN, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_three_labels_at_k5_no_labels(self):
    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
    sparse_labels = _binary_2d_label_to_2d_sparse_value(
        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)

    for labels in (sparse_labels, dense_labels):
      # Classes 0,4,6,9: 0 labels, >=1 prediction.
      for class_id in (0, 4, 6, 9):
        self._test_precision_at_k(
            predictions, labels, k=5, expected=0.0, class_id=class_id)
        self._test_precision_at_top_k(
            predictions_idx, labels, k=5, expected=0.0, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_three_labels_at_k5(self):
    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
    sparse_labels = _binary_2d_label_to_2d_sparse_value(
        [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
    dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)

    for labels in (sparse_labels, dense_labels):
      # Class 2: 2 labels, 2 correct predictions.
      self._test_precision_at_k(
          predictions, labels, k=5, expected=2.0 / 2, class_id=2)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)

      # Class 5: 1 label, 1 correct prediction.
      self._test_precision_at_k(
          predictions, labels, k=5, expected=1.0 / 1, class_id=5)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)

      # Class 7: 1 label, 1 incorrect prediction.
      self._test_precision_at_k(
          predictions, labels, k=5, expected=0.0 / 1, class_id=7)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)

      # All classes: 10 predictions, 3 correct.
      self._test_precision_at_k(
          predictions, labels, k=5, expected=3.0 / 10)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=3.0 / 10)

  @test_util.run_deprecated_v1
  def test_three_labels_at_k5_some_out_of_range(self):
    """Tests that labels outside the [0, n_classes) range are ignored."""
    predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                   [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
    predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
    sp_labels = sparse_tensor.SparseTensorValue(
        indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
                 [1, 3]],
        # values -1 and 10 are outside the [0, n_classes) range and are ignored.
        values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
        dense_shape=[2, 4])

    # Class 2: 2 labels, 2 correct predictions.
    self._test_precision_at_k(
        predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2)
    self._test_precision_at_top_k(
        predictions_idx, sp_labels, k=5, expected=2.0 / 2, class_id=2)

    # Class 5: 1 label, 1 correct prediction.
    self._test_precision_at_k(
        predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5)
    self._test_precision_at_top_k(
        predictions_idx, sp_labels, k=5, expected=1.0 / 1, class_id=5)

    # Class 7: 1 label, 1 incorrect prediction.
    self._test_precision_at_k(
        predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7)
    self._test_precision_at_top_k(
        predictions_idx, sp_labels, k=5, expected=0.0 / 1, class_id=7)

    # All classes: 10 predictions, 3 correct.
    self._test_precision_at_k(
        predictions, sp_labels, k=5, expected=3.0 / 10)
    self._test_precision_at_top_k(
        predictions_idx, sp_labels, k=5, expected=3.0 / 10)

  @test_util.run_deprecated_v1
  def test_3d_nan(self):
    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
    labels = _binary_3d_label_to_sparse_value(
        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])

    # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
    for class_id in (-1, 1, 3, 8, 10):
      self._test_precision_at_k(
          predictions, labels, k=5, expected=NAN, class_id=class_id)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=NAN, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_3d_no_labels(self):
    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
    labels = _binary_3d_label_to_sparse_value(
        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])

    # Classes 0,4,6,9: 0 labels, >=1 prediction.
    for class_id in (0, 4, 6, 9):
      self._test_precision_at_k(
          predictions, labels, k=5, expected=0.0, class_id=class_id)
      self._test_precision_at_top_k(
          predictions_idx, labels, k=5, expected=0.0, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_3d(self):
    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
    labels = _binary_3d_label_to_sparse_value(
        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])

    # Class 2: 4 predictions, all correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=4.0 / 4, class_id=2)
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=4.0 / 4, class_id=2)

    # Class 5: 2 predictions, both correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=2.0 / 2, class_id=5)
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=2.0 / 2, class_id=5)

    # Class 7: 2 predictions, 1 correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=1.0 / 2, class_id=7)
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=1.0 / 2, class_id=7)

    # All classes: 20 predictions, 7 correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=7.0 / 20)
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=7.0 / 20)

  @test_util.run_deprecated_v1
  def test_3d_ignore_some(self):
    predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
                   [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
                    [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
    predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
                       [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
    labels = _binary_3d_label_to_sparse_value(
        [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
         [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])

    # Class 2: 2 predictions, both correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[1], [0]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[1], [0]])

    # Class 2: 2 predictions, both correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[0], [1]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[0], [1]])

    # Class 7: 1 incorrect prediction.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
        weights=[[1], [0]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=0.0 / 1.0, class_id=7,
        weights=[[1], [0]])

    # Class 7: 1 correct prediction.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
        weights=[[0], [1]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=1.0 / 1.0, class_id=7,
        weights=[[0], [1]])

    # Class 7: no predictions.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=NAN, class_id=7,
        weights=[[1, 0], [0, 1]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=NAN, class_id=7,
        weights=[[1, 0], [0, 1]])

    # Class 7: 2 predictions, 1 correct.
    self._test_precision_at_k(
        predictions, labels, k=5, expected=1.0 / 2.0, class_id=7,
        weights=[[0, 1], [1, 0]])
    self._test_precision_at_top_k(
        predictions_idx, labels, k=5, expected=1.0 / 2.0, class_id=7,
        weights=[[0, 1], [1, 0]])


def _test_recall_at_k(predictions,
                      labels,
                      k,
                      expected,
                      class_id=None,
                      weights=None,
                      test_case=None):
  with ops.Graph().as_default() as g, test_case.test_session(g):
    if weights is not None:
      weights = constant_op.constant(weights, dtypes_lib.float32)
    metric, update = metrics.recall_at_k(
        predictions=constant_op.constant(predictions, dtypes_lib.float32),
        labels=labels,
        k=k,
        class_id=class_id,
        weights=weights)

    # Fails without initialized vars.
    test_case.assertRaises(errors_impl.OpError, metric.eval)
    test_case.assertRaises(errors_impl.OpError, update.eval)
    variables.variables_initializer(variables.local_variables()).run()

    # Run per-step op and assert expected values.
    if math.isnan(expected):
      _assert_nan(test_case, update.eval())
      _assert_nan(test_case, metric.eval())
    else:
      test_case.assertEqual(expected, update.eval())
      test_case.assertEqual(expected, metric.eval())


def _test_recall_at_top_k(
    predictions_idx,
    labels,
    expected,
    k=None,
    class_id=None,
    weights=None,
    test_case=None):
  with ops.Graph().as_default() as g, test_case.test_session(g):
    if weights is not None:
      weights = constant_op.constant(weights, dtypes_lib.float32)
    metric, update = metrics.recall_at_top_k(
        predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
        labels=labels,
        k=k,
        class_id=class_id,
        weights=weights)

    # Fails without initialized vars.
    test_case.assertRaises(errors_impl.OpError, metric.eval)
    test_case.assertRaises(errors_impl.OpError, update.eval)
    variables.variables_initializer(variables.local_variables()).run()

    # Run per-step op and assert expected values.
    if math.isnan(expected):
      _assert_nan(test_case, update.eval())
      _assert_nan(test_case, metric.eval())
    else:
      test_case.assertEqual(expected, update.eval())
      test_case.assertEqual(expected, metric.eval())


class SingleLabelRecallAtKTest(test.TestCase):

  def setUp(self):
    self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
    self._predictions_idx = [[3], [3]]
    indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
    class_labels = (3, 2)
    # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
    self._labels = (
        _binary_2d_label_to_1d_sparse_value(indicator_labels),
        _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
            class_labels, dtype=np.int64), np.array(
                [[class_id] for class_id in class_labels], dtype=np.int64))
    self._test_recall_at_k = functools.partial(
        _test_recall_at_k, test_case=self)
    self._test_recall_at_top_k = functools.partial(
        _test_recall_at_top_k, test_case=self)

  @test_util.run_deprecated_v1
  def test_at_k1_nan(self):
    # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
    # range.
    for labels in self._labels:
      for class_id in (-1, 0, 1, 4):
        self._test_recall_at_k(
            self._predictions, labels, k=1, expected=NAN, class_id=class_id)
        self._test_recall_at_top_k(
            self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_at_k1_no_predictions(self):
    for labels in self._labels:
      # Class 2: 0 predictions.
      self._test_recall_at_k(
          self._predictions, labels, k=1, expected=0.0, class_id=2)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=1, expected=0.0, class_id=2)

  @test_util.run_deprecated_v1
  def test_one_label_at_k1(self):
    for labels in self._labels:
      # Class 3: 1 label, 2 predictions, 1 correct.
      self._test_recall_at_k(
          self._predictions, labels, k=1, expected=1.0 / 1, class_id=3)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3)

      # All classes: 2 labels, 2 predictions, 1 correct.
      self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=1, expected=1.0 / 2)

  @test_util.run_deprecated_v1
  def test_one_label_at_k1_weighted_class_id3(self):
    predictions = self._predictions
    predictions_idx = self._predictions_idx
    for labels in self._labels:
      # Class 3: 1 label, 2 predictions, 1 correct.
      self._test_recall_at_k(
          predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=NAN, class_id=3,
          weights=(0.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(1.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(1.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(2.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(2.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=NAN, class_id=3,
          weights=(0.0, 1.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=NAN, class_id=3,
          weights=(0.0, 1.0))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(1.0, 0.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
          weights=(1.0, 0.0))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=2.0 / 2, class_id=3,
          weights=(2.0, 3.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3,
          weights=(2.0, 3.0))

  @test_util.run_deprecated_v1
  def test_one_label_at_k1_weighted(self):
    predictions = self._predictions
    predictions_idx = self._predictions_idx
    for labels in self._labels:
      # All classes: 2 labels, 2 predictions, 1 correct.
      self._test_recall_at_k(
          predictions, labels, k=1, expected=NAN, weights=(0.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=NAN, weights=(0.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
      self._test_recall_at_k(
          predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
      self._test_recall_at_top_k(
          predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))


class MultiLabel2dRecallAtKTest(test.TestCase):

  def setUp(self):
    self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
                         (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6))
    self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6))
    indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0),
                        (0, 1, 1, 0, 0, 1, 0, 0, 0, 0))
    class_labels = ((2, 7, 8), (1, 2, 5))
    # Sparse vs dense labels should be handled the same.
    self._labels = (_binary_2d_label_to_2d_sparse_value(indicator_labels),
                    np.array(
                        class_labels, dtype=np.int64))
    self._test_recall_at_k = functools.partial(
        _test_recall_at_k, test_case=self)
    self._test_recall_at_top_k = functools.partial(
        _test_recall_at_top_k, test_case=self)

  @test_util.run_deprecated_v1
  def test_at_k5_nan(self):
    for labels in self._labels:
      # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
      for class_id in (0, 3, 4, 6, 9, 10):
        self._test_recall_at_k(
            self._predictions, labels, k=5, expected=NAN, class_id=class_id)
        self._test_recall_at_top_k(
            self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id)

  @test_util.run_deprecated_v1
  def test_at_k5_no_predictions(self):
    for labels in self._labels:
      # Class 8: 1 label, no predictions.
      self._test_recall_at_k(
          self._predictions, labels, k=5, expected=0.0 / 1, class_id=8)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8)

  @test_util.run_deprecated_v1
  def test_at_k5(self):
    for labels in self._labels:
      # Class 2: 2 labels, both correct.
      self._test_recall_at_k(
          self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)

      # Class 5: 1 label, incorrect.
      self._test_recall_at_k(
          self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)

      # Class 7: 1 label, incorrect.
      self._test_recall_at_k(
          self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)

      # All classes: 6 labels, 3 correct.
      self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6)
      self._test_recall_at_top_k(
          self._predictions_idx, labels, k=5, expected=3.0 / 6)

  @test_util.run_deprecated_v1
  def test_at_k5_some_out_of_range(self):
    """Tests that labels outside the [0, n_classes) count in denominator."""
    labels = sparse_tensor.SparseTensorValue(
        indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
                 [1, 3]],
        # values -1 and 10 are outside the [0, n_classes) range.
        values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
        dense_shape=[2, 4])

    # Class 2: 2 labels, both correct.
    self._test_recall_at_k(
        self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
    self._test_recall_at_top_k(
        self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)

    # Class 5: 1 label, incorrect.
    self._test_recall_at_k(
        self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
    self._test_recall_at_top_k(
        self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)

    # Class 7: 1 label, incorrect.
    self._test_recall_at_k(
        self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
    self._test_recall_at_top_k(
        self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)

    # All classes: 8 labels, 3 correct.
    self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8)
    self._test_recall_at_top_k(
        self._predictions_idx, labels, k=5, expected=3.0 / 8)


class MultiLabel3dRecallAtKTest(test.TestCase):

  def setUp(self):
    self._predictions = (((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
                          (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)),
                         ((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6),
                          (0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9)))
    self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)),
                             ((5, 7, 2, 9, 6), (9, 4, 6, 2, 0)))
    # Note: We don't test dense labels here, since examples have different
    # numbers of labels.
    self._labels = _binary_3d_label_to_sparse_value(((
        (0, 0, 1, 0, 0, 0, 0, 1, 1, 0), (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)), (
            (0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0))))
    self._test_recall_at_k = functools.partial(
        _test_recall_at_k, test_case=self)
    self._test_recall_at_top_k = functools.partial(
        _test_recall_at_top_k, test_case=self)

  @test_util.run_deprecated_v1
  def test_3d_nan(self):
    # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
    for class_id in (0, 3, 4, 6, 9, 10):
      self._test_recall_at_k(
          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id)
      self._test_recall_at_top_k(
          self._predictions_idx, self._labels, k=5, expected=NAN,
          class_id=class_id)

  @test_util.run_deprecated_v1
  def test_3d_no_predictions(self):
    # Classes 1,8 have 0 predictions, >=1 label.
    for class_id in (1, 8):
      self._test_recall_at_k(
          self._predictions, self._labels, k=5, expected=0.0, class_id=class_id)
      self._test_recall_at_top_k(
          self._predictions_idx, self._labels, k=5, expected=0.0,
          class_id=class_id)

  @test_util.run_deprecated_v1
  def test_3d(self):
    # Class 2: 4 labels, all correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2)
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=4.0 / 4,
        class_id=2)

    # Class 5: 2 labels, both correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5)
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=2.0 / 2,
        class_id=5)

    # Class 7: 2 labels, 1 incorrect.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7)
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=1.0 / 2,
        class_id=7)

    # All classes: 12 labels, 7 correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=7.0 / 12)
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=7.0 / 12)

  @test_util.run_deprecated_v1
  def test_3d_ignore_all(self):
    for class_id in range(10):
      self._test_recall_at_k(
          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
          weights=[[0], [0]])
      self._test_recall_at_top_k(
          self._predictions_idx, self._labels, k=5, expected=NAN,
          class_id=class_id, weights=[[0], [0]])
      self._test_recall_at_k(
          self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
          weights=[[0, 0], [0, 0]])
      self._test_recall_at_top_k(
          self._predictions_idx, self._labels, k=5, expected=NAN,
          class_id=class_id, weights=[[0, 0], [0, 0]])
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=NAN,
        weights=[[0], [0]])
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=NAN,
        weights=[[0, 0], [0, 0]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=NAN,
        weights=[[0, 0], [0, 0]])

  @test_util.run_deprecated_v1
  def test_3d_ignore_some(self):
    # Class 2: 2 labels, both correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[1], [0]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
        class_id=2, weights=[[1], [0]])

    # Class 2: 2 labels, both correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
        weights=[[0], [1]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
        class_id=2, weights=[[0], [1]])

    # Class 7: 1 label, correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7,
        weights=[[0], [1]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0,
        class_id=7, weights=[[0], [1]])

    # Class 7: 1 label, incorrect.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7,
        weights=[[1], [0]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0,
        class_id=7, weights=[[1], [0]])

    # Class 7: 2 labels, 1 correct.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7,
        weights=[[1, 0], [1, 0]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0,
        class_id=7, weights=[[1, 0], [1, 0]])

    # Class 7: No labels.
    self._test_recall_at_k(
        self._predictions, self._labels, k=5, expected=NAN, class_id=7,
        weights=[[0, 1], [0, 1]])
    self._test_recall_at_top_k(
        self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7,
        weights=[[0, 1], [0, 1]])


class MeanAbsoluteErrorTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_absolute_error(
        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    _assert_metric_variables(
        self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean_absolute_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_absolute_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_normal((10, 3), seed=1)
    labels = random_ops.random_normal((10, 3), seed=2)
    error, update_op = metrics.mean_absolute_error(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_error = self.evaluate(error)
      for _ in range(10):
        self.assertEqual(initial_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithErrorAndWeights(self):
    predictions = constant_op.constant(
        [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
    weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))

    error, update_op = metrics.mean_absolute_error(labels, predictions, weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(3, self.evaluate(update_op))
      self.assertEqual(3, self.evaluate(error))


class MeanRelativeErrorTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_relative_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        normalizer=array_ops.ones((10, 1)))
    _assert_metric_variables(
        self, ('mean_relative_error/count:0', 'mean_relative_error/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean_relative_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        normalizer=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_relative_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        normalizer=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_normal((10, 3), seed=1)
    labels = random_ops.random_normal((10, 3), seed=2)
    normalizer = random_ops.random_normal((10, 3), seed=3)
    error, update_op = metrics.mean_relative_error(labels, predictions,
                                                   normalizer)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_error = self.evaluate(error)
      for _ in range(10):
        self.assertEqual(initial_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateNormalizedByLabels(self):
    np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
    np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32)
    expected_error = np.mean(
        np.divide(np.absolute(np_predictions - np_labels), np_labels))

    predictions = constant_op.constant(
        np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
    labels = constant_op.constant(np_labels, shape=(1, 4))

    error, update_op = metrics.mean_relative_error(
        labels, predictions, normalizer=labels)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(expected_error, self.evaluate(update_op))
      self.assertEqual(expected_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateNormalizedByZeros(self):
    np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)

    predictions = constant_op.constant(
        np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_relative_error(
        labels, predictions, normalizer=array_ops.zeros_like(labels))

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(0.0, self.evaluate(update_op))
      self.assertEqual(0.0, self.evaluate(error))


class MeanSquaredErrorTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_squared_error(
        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    _assert_metric_variables(
        self, ('mean_squared_error/count:0', 'mean_squared_error/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean_squared_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_squared_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_normal((10, 3), seed=1)
    labels = random_ops.random_normal((10, 3), seed=2)
    error, update_op = metrics.mean_squared_error(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_error = self.evaluate(error)
      for _ in range(10):
        self.assertEqual(initial_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateZeroError(self):
    predictions = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)
    labels = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_squared_error(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(0, self.evaluate(update_op))
      self.assertEqual(0, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithError(self):
    predictions = constant_op.constant(
        [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_squared_error(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(6, self.evaluate(update_op))
      self.assertEqual(6, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithErrorAndWeights(self):
    predictions = constant_op.constant(
        [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
    weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))

    error, update_op = metrics.mean_squared_error(labels, predictions, weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(13, self.evaluate(update_op))
      self.assertEqual(13, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testMultipleBatchesOfSizeOne(self):
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, preds_queue, [10, 8, 6])
      _enqueue_vector(sess, preds_queue, [-4, 3, -1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, labels_queue, [1, 3, 2])
      _enqueue_vector(sess, labels_queue, [2, 4, 6])
      labels = labels_queue.dequeue()

      error, update_op = metrics.mean_squared_error(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate(update_op)
      self.assertAlmostEqual(208.0 / 6, self.evaluate(update_op), 5)

      self.assertAlmostEqual(208.0 / 6, self.evaluate(error), 5)

  @test_util.run_deprecated_v1
  def testMetricsComputedConcurrently(self):
    with self.cached_session() as sess:
      # Create the queue that populates one set of predictions.
      preds_queue0 = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, preds_queue0, [10, 8, 6])
      _enqueue_vector(sess, preds_queue0, [-4, 3, -1])
      predictions0 = preds_queue0.dequeue()

      # Create the queue that populates one set of predictions.
      preds_queue1 = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, preds_queue1, [0, 1, 1])
      _enqueue_vector(sess, preds_queue1, [1, 1, 0])
      predictions1 = preds_queue1.dequeue()

      # Create the queue that populates one set of labels.
      labels_queue0 = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, labels_queue0, [1, 3, 2])
      _enqueue_vector(sess, labels_queue0, [2, 4, 6])
      labels0 = labels_queue0.dequeue()

      # Create the queue that populates another set of labels.
      labels_queue1 = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, labels_queue1, [-5, -3, -1])
      _enqueue_vector(sess, labels_queue1, [5, 4, 3])
      labels1 = labels_queue1.dequeue()

      mse0, update_op0 = metrics.mean_squared_error(
          labels0, predictions0, name='msd0')
      mse1, update_op1 = metrics.mean_squared_error(
          labels1, predictions1, name='msd1')

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([update_op0, update_op1])
      self.evaluate([update_op0, update_op1])

      mse0, mse1 = self.evaluate([mse0, mse1])
      self.assertAlmostEqual(208.0 / 6, mse0, 5)
      self.assertAlmostEqual(79.0 / 6, mse1, 5)

  @test_util.run_deprecated_v1
  def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, preds_queue, [10, 8, 6])
      _enqueue_vector(sess, preds_queue, [-4, 3, -1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          2, dtypes=dtypes_lib.float32, shapes=(1, 3))
      _enqueue_vector(sess, labels_queue, [1, 3, 2])
      _enqueue_vector(sess, labels_queue, [2, 4, 6])
      labels = labels_queue.dequeue()

      mae, ma_update_op = metrics.mean_absolute_error(labels, predictions)
      mse, ms_update_op = metrics.mean_squared_error(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([ma_update_op, ms_update_op])
      self.evaluate([ma_update_op, ms_update_op])

      self.assertAlmostEqual(32.0 / 6, self.evaluate(mae), 5)
      self.assertAlmostEqual(208.0 / 6, self.evaluate(mse), 5)


class RootMeanSquaredErrorTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.root_mean_squared_error(
        predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    _assert_metric_variables(
        self,
        ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.root_mean_squared_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.root_mean_squared_error(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_normal((10, 3), seed=1)
    labels = random_ops.random_normal((10, 3), seed=2)
    error, update_op = metrics.root_mean_squared_error(labels, predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_error = self.evaluate(error)
      for _ in range(10):
        self.assertEqual(initial_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateZeroError(self):
    with self.cached_session():
      predictions = constant_op.constant(
          0.0, shape=(1, 3), dtype=dtypes_lib.float32)
      labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)

      rmse, update_op = metrics.root_mean_squared_error(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(0, self.evaluate(update_op))

      self.assertEqual(0, self.evaluate(rmse))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithError(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
      labels = constant_op.constant(
          [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)

      rmse, update_op = metrics.root_mean_squared_error(labels, predictions)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(math.sqrt(6), self.evaluate(update_op), 5)
      self.assertAlmostEqual(math.sqrt(6), self.evaluate(rmse), 5)

  @test_util.run_deprecated_v1
  def testSingleUpdateWithErrorAndWeights(self):
    with self.cached_session():
      predictions = constant_op.constant(
          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
      labels = constant_op.constant(
          [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
      weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))

      rmse, update_op = metrics.root_mean_squared_error(labels, predictions,
                                                        weights)

      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(math.sqrt(13), self.evaluate(update_op))

      self.assertAlmostEqual(math.sqrt(13), self.evaluate(rmse), 5)


def _reweight(predictions, labels, weights):
  return (np.concatenate([[p] * int(w) for p, w in zip(predictions, weights)]),
          np.concatenate([[l] * int(w) for l, w in zip(labels, weights)]))


class MeanCosineDistanceTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_cosine_distance(
        predictions=array_ops.ones((10, 3)),
        labels=array_ops.ones((10, 3)),
        dim=1)
    _assert_metric_variables(self, (
        'mean_cosine_distance/count:0',
        'mean_cosine_distance/total:0',
    ))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.mean_cosine_distance(
        predictions=array_ops.ones((10, 3)),
        labels=array_ops.ones((10, 3)),
        dim=1,
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_cosine_distance(
        predictions=array_ops.ones((10, 3)),
        labels=array_ops.ones((10, 3)),
        dim=1,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    predictions = random_ops.random_normal((10, 3), seed=1)
    labels = random_ops.random_normal((10, 3), seed=2)
    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_error = self.evaluate(error)
      for _ in range(10):
        self.assertEqual(initial_error, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateZeroError(self):
    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))

    predictions = constant_op.constant(
        np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(0, self.evaluate(update_op))
      self.assertEqual(0, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithError1(self):
    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))

    predictions = constant_op.constant(
        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1, self.evaluate(update_op), 5)
      self.assertAlmostEqual(1, self.evaluate(error), 5)

  @test_util.run_deprecated_v1
  def testSingleUpdateWithError2(self):
    np_predictions = np.matrix(
        ('0.819031913261206 0.567041924552012 0.087465312324590;'
         '-0.665139432070255 -0.739487441769973 -0.103671883216994;'
         '0.707106781186548 -0.707106781186548 0'))
    np_labels = np.matrix(
        ('0.819031913261206 0.567041924552012 0.087465312324590;'
         '0.665139432070255 0.739487441769973 0.103671883216994;'
         '0.707106781186548 0.707106781186548 0'))

    predictions = constant_op.constant(
        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAlmostEqual(1.0, self.evaluate(update_op), 5)
      self.assertAlmostEqual(1.0, self.evaluate(error), 5)

  @test_util.run_deprecated_v1
  def testSingleUpdateWithErrorAndWeights1(self):
    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))

    predictions = constant_op.constant(
        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    weights = constant_op.constant(
        [1, 0, 0], shape=(3, 1, 1), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_cosine_distance(
        labels, predictions, dim=2, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(0, self.evaluate(update_op))
      self.assertEqual(0, self.evaluate(error))

  @test_util.run_deprecated_v1
  def testSingleUpdateWithErrorAndWeights2(self):
    np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
    np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))

    predictions = constant_op.constant(
        np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    labels = constant_op.constant(
        np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
    weights = constant_op.constant(
        [0, 1, 1], shape=(3, 1, 1), dtype=dtypes_lib.float32)

    error, update_op = metrics.mean_cosine_distance(
        labels, predictions, dim=2, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1.5, self.evaluate(update_op))
      self.assertEqual(1.5, self.evaluate(error))


class PcntBelowThreshTest(test.TestCase):

  def setUp(self):
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.percentage_below(values=array_ops.ones((10,)), threshold=2)
    _assert_metric_variables(self, (
        'percentage_below_threshold/count:0',
        'percentage_below_threshold/total:0',
    ))

  @test_util.run_deprecated_v1
  def testMetricsCollection(self):
    my_collection_name = '__metrics__'
    mean, _ = metrics.percentage_below(
        values=array_ops.ones((10,)),
        threshold=2,
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.percentage_below(
        values=array_ops.ones((10,)),
        threshold=2,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testOneUpdate(self):
    with self.cached_session():
      values = constant_op.constant(
          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)

      pcnt0, update_op0 = metrics.percentage_below(values, 100, name='high')
      pcnt1, update_op1 = metrics.percentage_below(values, 7, name='medium')
      pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low')

      self.evaluate(variables.local_variables_initializer())
      self.evaluate([update_op0, update_op1, update_op2])

      pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
      self.assertAlmostEqual(1.0, pcnt0, 5)
      self.assertAlmostEqual(0.75, pcnt1, 5)
      self.assertAlmostEqual(0.0, pcnt2, 5)

  @test_util.run_deprecated_v1
  def testSomePresentOneUpdate(self):
    with self.cached_session():
      values = constant_op.constant(
          [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
      weights = constant_op.constant(
          [1, 0, 0, 1], shape=(1, 4), dtype=dtypes_lib.float32)

      pcnt0, update_op0 = metrics.percentage_below(
          values, 100, weights=weights, name='high')
      pcnt1, update_op1 = metrics.percentage_below(
          values, 7, weights=weights, name='medium')
      pcnt2, update_op2 = metrics.percentage_below(
          values, 1, weights=weights, name='low')

      self.evaluate(variables.local_variables_initializer())
      self.assertListEqual([1.0, 0.5, 0.0],
                           self.evaluate([update_op0, update_op1, update_op2]))

      pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
      self.assertAlmostEqual(1.0, pcnt0, 5)
      self.assertAlmostEqual(0.5, pcnt1, 5)
      self.assertAlmostEqual(0.0, pcnt2, 5)


class MeanIOUTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_iou(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2)
    _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',))

  @test_util.run_deprecated_v1
  def testMetricsCollections(self):
    my_collection_name = '__metrics__'
    mean_iou, _ = metrics.mean_iou(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2,
        metrics_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [mean_iou])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_iou(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones([10, 3])
    labels = array_ops.ones([10, 4])
    with self.assertRaises(ValueError):
      metrics.mean_iou(labels, predictions, num_classes=2)

  @test_util.run_deprecated_v1
  def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones([10])
    labels = array_ops.ones([10])
    weights = array_ops.zeros([9])
    with self.assertRaises(ValueError):
      metrics.mean_iou(labels, predictions, num_classes=2, weights=weights)

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    num_classes = 3
    predictions = random_ops.random_uniform(
        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
    labels = random_ops.random_uniform(
        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
    mean_iou, update_op = metrics.mean_iou(
        labels, predictions, num_classes=num_classes)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_mean_iou = self.evaluate(mean_iou)
      for _ in range(10):
        self.assertEqual(initial_mean_iou, self.evaluate(mean_iou))

  @test_util.run_deprecated_v1
  def testMultipleUpdates(self):
    num_classes = 3
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [2])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [2])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(5):
        self.evaluate(update_op)
      desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.])
      self.assertEqual(desired_output, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testMultipleUpdatesWithWeights(self):
    num_classes = 2
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [0.0])
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [0.0])
      weights = weights_queue.dequeue()

      mean_iou, update_op = metrics.mean_iou(
          labels, predictions, num_classes, weights=weights)

      variables.local_variables_initializer().run()
      for _ in range(6):
        self.evaluate(update_op)
      desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0])
      self.assertAlmostEqual(desired_output, self.evaluate(mean_iou))

  @test_util.run_deprecated_v1
  def testMultipleUpdatesWithMissingClass(self):
    # Test the case where there are no predictions and labels for
    # one class, and thus there is one row and one column with
    # zero entries in the confusion matrix.
    num_classes = 3
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      # There is no prediction for class 2.
      preds_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      # There is label for class 2.
      labels_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(5):
        self.evaluate(update_op)
      desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0])
      self.assertAlmostEqual(desired_output, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
    predictions = array_ops.concat(
        [
            constant_op.constant(
                0, shape=[5]), constant_op.constant(
                    1, shape=[5])
        ],
        0)
    labels = array_ops.concat(
        [
            constant_op.constant(
                0, shape=[3]), constant_op.constant(
                    1, shape=[7])
        ],
        0)
    num_classes = 2
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      confusion_matrix = self.evaluate(update_op)
      self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
      desired_miou = np.mean([3. / 5., 5. / 7.])
      self.assertAlmostEqual(desired_miou, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    predictions = array_ops.zeros([40])
    labels = array_ops.zeros([40])
    num_classes = 1
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(40, self.evaluate(update_op)[0])
      self.assertEqual(1.0, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testAllWrong(self):
    predictions = array_ops.zeros([40])
    labels = array_ops.ones([40])
    num_classes = 2
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([[0, 0], [40, 0]], update_op)
      self.assertEqual(0., self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testResultsWithSomeMissing(self):
    predictions = array_ops.concat(
        [
            constant_op.constant(
                0, shape=[5]), constant_op.constant(
                    1, shape=[5])
        ],
        0)
    labels = array_ops.concat(
        [
            constant_op.constant(
                0, shape=[3]), constant_op.constant(
                    1, shape=[7])
        ],
        0)
    num_classes = 2
    weights = array_ops.concat(
        [
            constant_op.constant(
                0, shape=[1]), constant_op.constant(
                    1, shape=[8]), constant_op.constant(
                        0, shape=[1])
        ],
        0)
    with self.cached_session():
      miou, update_op = metrics.mean_iou(
          labels, predictions, num_classes, weights=weights)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([[2, 0], [2, 4]], update_op)
      desired_miou = np.mean([2. / 4., 4. / 6.])
      self.assertAlmostEqual(desired_miou, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testMissingClassInLabels(self):
    labels = constant_op.constant([
        [[0, 0, 1, 1, 0, 0],
         [1, 0, 0, 0, 0, 1]],
        [[1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0]]])
    predictions = constant_op.constant([
        [[0, 0, 2, 1, 1, 0],
         [0, 1, 2, 2, 0, 1]],
        [[0, 0, 2, 1, 1, 1],
         [1, 1, 2, 0, 0, 0]]])
    num_classes = 3
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op)
      self.assertAlmostEqual(
          1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
          self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testMissingClassOverallSmall(self):
    labels = constant_op.constant([0])
    predictions = constant_op.constant([0])
    num_classes = 2
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([[1, 0], [0, 0]], update_op)
      self.assertAlmostEqual(1, self.evaluate(miou))

  @test_util.run_deprecated_v1
  def testMissingClassOverallLarge(self):
    labels = constant_op.constant([
        [[0, 0, 1, 1, 0, 0],
         [1, 0, 0, 0, 0, 1]],
        [[1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0]]])
    predictions = constant_op.constant([
        [[0, 0, 1, 1, 0, 0],
         [1, 1, 0, 0, 1, 1]],
        [[0, 0, 0, 1, 1, 1],
         [1, 1, 1, 0, 0, 0]]])
    num_classes = 3
    with self.cached_session():
      miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op)
      self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)),
                             self.evaluate(miou))


class MeanPerClassAccuracyTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.mean_per_class_accuracy(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2)
    _assert_metric_variables(self, ('mean_accuracy/count:0',
                                    'mean_accuracy/total:0'))

  @test_util.run_deprecated_v1
  def testMetricsCollections(self):
    my_collection_name = '__metrics__'
    mean_accuracy, _ = metrics.mean_per_class_accuracy(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2,
        metrics_collections=[my_collection_name])
    self.assertListEqual(
        ops.get_collection(my_collection_name), [mean_accuracy])

  @test_util.run_deprecated_v1
  def testUpdatesCollection(self):
    my_collection_name = '__updates__'
    _, update_op = metrics.mean_per_class_accuracy(
        predictions=array_ops.ones([10, 1]),
        labels=array_ops.ones([10, 1]),
        num_classes=2,
        updates_collections=[my_collection_name])
    self.assertListEqual(ops.get_collection(my_collection_name), [update_op])

  @test_util.run_deprecated_v1
  def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones([10, 3])
    labels = array_ops.ones([10, 4])
    with self.assertRaises(ValueError):
      metrics.mean_per_class_accuracy(labels, predictions, num_classes=2)

  @test_util.run_deprecated_v1
  def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
    predictions = array_ops.ones([10])
    labels = array_ops.ones([10])
    weights = array_ops.zeros([9])
    with self.assertRaises(ValueError):
      metrics.mean_per_class_accuracy(
          labels, predictions, num_classes=2, weights=weights)

  @test_util.run_deprecated_v1
  def testValueTensorIsIdempotent(self):
    num_classes = 3
    predictions = random_ops.random_uniform(
        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
    labels = random_ops.random_uniform(
        [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
    mean_accuracy, update_op = metrics.mean_per_class_accuracy(
        labels, predictions, num_classes=num_classes)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())

      # Run several updates.
      for _ in range(10):
        self.evaluate(update_op)

      # Then verify idempotency.
      initial_mean_accuracy = self.evaluate(mean_accuracy)
      for _ in range(10):
        self.assertEqual(initial_mean_accuracy, self.evaluate(mean_accuracy))

    num_classes = 3
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [2])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [2])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(5):
        self.evaluate(update_op)
      desired_output = np.mean([1.0, 1.0 / 3.0, 0.0])
      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))

  @test_util.run_deprecated_v1
  def testMultipleUpdatesWithWeights(self):
    num_classes = 2
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      preds_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      labels_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      # Create the queue that populates the weights.
      weights_queue = data_flow_ops.FIFOQueue(
          6, dtypes=dtypes_lib.float32, shapes=(1, 1))
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [0.5])
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [0.0])
      _enqueue_vector(sess, weights_queue, [1.0])
      _enqueue_vector(sess, weights_queue, [0.0])
      weights = weights_queue.dequeue()

      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes, weights=weights)

      variables.local_variables_initializer().run()
      for _ in range(6):
        self.evaluate(update_op)
      desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5])
      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))

  @test_util.run_deprecated_v1
  def testMultipleUpdatesWithMissingClass(self):
    # Test the case where there are no predictions and labels for
    # one class, and thus there is one row and one column with
    # zero entries in the confusion matrix.
    num_classes = 3
    with self.cached_session() as sess:
      # Create the queue that populates the predictions.
      # There is no prediction for class 2.
      preds_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, preds_queue, [0])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [1])
      _enqueue_vector(sess, preds_queue, [0])
      predictions = preds_queue.dequeue()

      # Create the queue that populates the labels.
      # There is label for class 2.
      labels_queue = data_flow_ops.FIFOQueue(
          5, dtypes=dtypes_lib.int32, shapes=(1, 1))
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [1])
      _enqueue_vector(sess, labels_queue, [0])
      _enqueue_vector(sess, labels_queue, [1])
      labels = labels_queue.dequeue()

      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes)

      self.evaluate(variables.local_variables_initializer())
      for _ in range(5):
        self.evaluate(update_op)
      desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
      self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))

  @test_util.run_deprecated_v1
  def testAllCorrect(self):
    predictions = array_ops.zeros([40])
    labels = array_ops.zeros([40])
    num_classes = 1
    with self.cached_session():
      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual(1.0, self.evaluate(update_op)[0])
      self.assertEqual(1.0, self.evaluate(mean_accuracy))

  @test_util.run_deprecated_v1
  def testAllWrong(self):
    predictions = array_ops.zeros([40])
    labels = array_ops.ones([40])
    num_classes = 2
    with self.cached_session():
      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes)
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual([0.0, 0.0], update_op)
      self.assertEqual(0., self.evaluate(mean_accuracy))

  @test_util.run_deprecated_v1
  def testResultsWithSomeMissing(self):
    predictions = array_ops.concat([
        constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5])
    ], 0)
    labels = array_ops.concat([
        constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7])
    ], 0)
    num_classes = 2
    weights = array_ops.concat([
        constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
        constant_op.constant(0, shape=[1])
    ], 0)
    with self.cached_session():
      mean_accuracy, update_op = metrics.mean_per_class_accuracy(
          labels, predictions, num_classes, weights=weights)
      self.evaluate(variables.local_variables_initializer())
      desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
      self.assertAllEqual(desired_accuracy, update_op)
      desired_mean_accuracy = np.mean(desired_accuracy)
      self.assertAlmostEqual(desired_mean_accuracy,
                             self.evaluate(mean_accuracy))


class FalseNegativesTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.false_negatives(
        labels=(0, 1, 0, 1),
        predictions=(0, 0, 1, 1))
    _assert_metric_variables(self, ('false_negatives/count:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    tn, tn_update_op = metrics.false_negatives(
        labels=labels, predictions=predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(3., tn_update_op)
      self.assertAllClose(3., tn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    weights = constant_op.constant((1., 1.5, 2., 2.5))
    tn, tn_update_op = metrics.false_negatives(
        labels=labels, predictions=predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(5., tn_update_op)
      self.assertAllClose(5., tn)


class FalseNegativesAtThresholdsTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.false_negatives_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0.15, 0.5, 0.85])
    _assert_metric_variables(self, ('false_negatives/false_negatives:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    fn, fn_update_op = metrics.false_negatives_at_thresholds(
        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0, 0, 0), fn)
      self.assertAllEqual((0, 2, 3), fn_update_op)
      self.assertAllEqual((0, 2, 3), fn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    fn, fn_update_op = metrics.false_negatives_at_thresholds(
        predictions=predictions,
        labels=labels,
        weights=((3.0,), (5.0,), (7.0,)),
        thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0.0, 0.0, 0.0), fn)
      self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op)
      self.assertAllEqual((0.0, 8.0, 11.0), fn)


class FalsePositivesTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.false_positives(
        labels=(0, 1, 0, 1),
        predictions=(0, 0, 1, 1))
    _assert_metric_variables(self, ('false_positives/count:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    tn, tn_update_op = metrics.false_positives(
        labels=labels, predictions=predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(7., tn_update_op)
      self.assertAllClose(7., tn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    weights = constant_op.constant((1., 1.5, 2., 2.5))
    tn, tn_update_op = metrics.false_positives(
        labels=labels, predictions=predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(14., tn_update_op)
      self.assertAllClose(14., tn)


class FalsePositivesAtThresholdsTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.false_positives_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0.15, 0.5, 0.85])
    _assert_metric_variables(self, ('false_positives/false_positives:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    fp, fp_update_op = metrics.false_positives_at_thresholds(
        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0, 0, 0), fp)
      self.assertAllEqual((7, 4, 2), fp_update_op)
      self.assertAllEqual((7, 4, 2), fp)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    fp, fp_update_op = metrics.false_positives_at_thresholds(
        predictions=predictions,
        labels=labels,
        weights=((1.0, 2.0, 3.0, 5.0),
                 (7.0, 11.0, 13.0, 17.0),
                 (19.0, 23.0, 29.0, 31.0)),
        thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0.0, 0.0, 0.0), fp)
      self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op)
      self.assertAllEqual((125.0, 42.0, 12.0), fp)


class TrueNegativesTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.true_negatives(
        labels=(0, 1, 0, 1),
        predictions=(0, 0, 1, 1))
    _assert_metric_variables(self, ('true_negatives/count:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    tn, tn_update_op = metrics.true_negatives(
        labels=labels, predictions=predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(3., tn_update_op)
      self.assertAllClose(3., tn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    weights = constant_op.constant((1., 1.5, 2., 2.5))
    tn, tn_update_op = metrics.true_negatives(
        labels=labels, predictions=predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(4., tn_update_op)
      self.assertAllClose(4., tn)


class TrueNegativesAtThresholdsTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.true_negatives_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0.15, 0.5, 0.85])
    _assert_metric_variables(self, ('true_negatives/true_negatives:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    tn, tn_update_op = metrics.true_negatives_at_thresholds(
        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0, 0, 0), tn)
      self.assertAllEqual((2, 5, 7), tn_update_op)
      self.assertAllEqual((2, 5, 7), tn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    tn, tn_update_op = metrics.true_negatives_at_thresholds(
        predictions=predictions,
        labels=labels,
        weights=((0.0, 2.0, 3.0, 5.0),),
        thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0.0, 0.0, 0.0), tn)
      self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op)
      self.assertAllEqual((5.0, 15.0, 23.0), tn)


class TruePositivesTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.true_positives(
        labels=(0, 1, 0, 1),
        predictions=(0, 0, 1, 1))
    _assert_metric_variables(self, ('true_positives/count:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    tn, tn_update_op = metrics.true_positives(
        labels=labels, predictions=predictions)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(7., tn_update_op)
      self.assertAllClose(7., tn)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    labels = constant_op.constant(((0, 1, 0, 1, 0),
                                   (0, 0, 1, 1, 1),
                                   (1, 1, 1, 1, 0),
                                   (0, 0, 0, 0, 1)))
    predictions = constant_op.constant(((0, 0, 1, 1, 0),
                                        (1, 1, 1, 1, 1),
                                        (0, 1, 0, 1, 0),
                                        (1, 1, 1, 1, 1)))
    weights = constant_op.constant((1., 1.5, 2., 2.5))
    tn, tn_update_op = metrics.true_positives(
        labels=labels, predictions=predictions, weights=weights)

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllClose(0., tn)
      self.assertAllClose(12., tn_update_op)
      self.assertAllClose(12., tn)


class TruePositivesAtThresholdsTest(test.TestCase):

  def setUp(self):
    np.random.seed(1)
    ops.reset_default_graph()

  @test_util.run_deprecated_v1
  def testVars(self):
    metrics.true_positives_at_thresholds(
        predictions=array_ops.ones((10, 1)),
        labels=array_ops.ones((10, 1)),
        thresholds=[0.15, 0.5, 0.85])
    _assert_metric_variables(self, ('true_positives/true_positives:0',))

  @test_util.run_deprecated_v1
  def testUnweighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    tp, tp_update_op = metrics.true_positives_at_thresholds(
        predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0, 0, 0), tp)
      self.assertAllEqual((3, 1, 0), tp_update_op)
      self.assertAllEqual((3, 1, 0), tp)

  @test_util.run_deprecated_v1
  def testWeighted(self):
    predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
                                        (0.2, 0.9, 0.7, 0.6),
                                        (0.1, 0.2, 0.4, 0.3)))
    labels = constant_op.constant(((0, 1, 1, 0),
                                   (1, 0, 0, 0),
                                   (0, 0, 0, 0)))
    tp, tp_update_op = metrics.true_positives_at_thresholds(
        predictions=predictions, labels=labels, weights=37.0,
        thresholds=[0.15, 0.5, 0.85])

    with self.cached_session():
      self.evaluate(variables.local_variables_initializer())
      self.assertAllEqual((0.0, 0.0, 0.0), tp)
      self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op)
      self.assertAllEqual((111.0, 37.0, 0.0), tp)


if __name__ == '__main__':
  test.main()