def graph_fn():
      box_batch = [tf.constant([self._box_center, self._box_center_small])]

      classes = [
          tf.one_hot([0, 1], depth=4),