if max_total_size:
      max_total_size = tf.minimum(max_total_size, sorted_boxes.num_boxes())
      sorted_boxes = box_list_ops.gather(sorted_boxes, tf.range(max_total_size))
      num_valid_nms_boxes_cumulative = tf.where(
          max_total_size > num_valid_nms_boxes_cumulative,