def graph_fn():
      box_masks = tf.constant([[[4, 4],
                                [4, 4]]], dtype=mask_dtype)
      boxes = tf.constant([[0.25, 0.25, 0.75, 0.75]], dtype=tf.float32)
      image_masks = ops.reframe_box_masks_to_image_masks(