tensorflow/models

View on GitHub
official/nlp/tools/squad_evaluate_v1_1.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation of SQuAD predictions (version 1.1).

The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.

The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""

import collections
import re
import string

# pylint: disable=g-bad-import-order

from absl import logging
# pylint: enable=g-bad-import-order


def _normalize_answer(s):
  """Lowers text and remove punctuation, articles and extra whitespace."""

  def remove_articles(text):
    return re.sub(r"\b(a|an|the)\b", " ", text)

  def white_space_fix(text):
    return " ".join(text.split())

  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)

  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))


def _f1_score(prediction, ground_truth):
  """Computes F1 score by comparing prediction to ground truth."""
  prediction_tokens = _normalize_answer(prediction).split()
  ground_truth_tokens = _normalize_answer(ground_truth).split()
  prediction_counter = collections.Counter(prediction_tokens)
  ground_truth_counter = collections.Counter(ground_truth_tokens)
  common = prediction_counter & ground_truth_counter
  num_same = sum(common.values())
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(prediction_tokens)
  recall = 1.0 * num_same / len(ground_truth_tokens)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1


def _exact_match_score(prediction, ground_truth):
  """Checks if predicted answer exactly matches ground truth answer."""
  return _normalize_answer(prediction) == _normalize_answer(ground_truth)


def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  """Computes the max over all metric scores."""
  scores_for_ground_truths = []
  for ground_truth in ground_truths:
    score = metric_fn(prediction, ground_truth)
    scores_for_ground_truths.append(score)
  return max(scores_for_ground_truths)


def evaluate(dataset, predictions):
  """Evaluates predictions for a dataset."""
  f1 = exact_match = total = 0
  for article in dataset:
    for paragraph in article["paragraphs"]:
      for qa in paragraph["qas"]:
        total += 1
        if qa["id"] not in predictions:
          message = "Unanswered question " + qa["id"] + " will receive score 0."
          logging.error(message)
          continue
        ground_truths = [entry["text"] for entry in qa["answers"]]
        prediction = predictions[qa["id"]]
        exact_match += _metric_max_over_ground_truths(_exact_match_score,
                                                      prediction, ground_truths)
        f1 += _metric_max_over_ground_truths(_f1_score, prediction,
                                             ground_truths)

  exact_match = exact_match / total
  f1 = f1 / total

  return {"exact_match": exact_match, "final_f1": f1}