def graph_fn():
      predicted_embedding_feature_map = tf.constant(
          predicted_embedding_feature_map_np, dtype=tf.float32)

      gathered_predicted_embeddings = (