def graph_fn():
      regressed_keypoint_feature_map = tf.constant(
          regressed_keypoint_feature_map_np, dtype=tf.float32)

      gathered_regressed_keypoints = (