track_match_flags = [
      tf.constant([1.0], dtype=tf.float32),
      tf.constant([1.0], dtype=tf.float32),