track_ids = [
      tf.constant([2], dtype=tf.int32),
      tf.constant([1], dtype=tf.int32),