if fields.InputDataFields.groundtruth_track_ids in tensor_dict:
      tensor_dict[fields.InputDataFields.groundtruth_track_ids] = tf.cast(
          tensor_dict[fields.InputDataFields.groundtruth_track_ids],
          dtype=tf.int32)