tensorflow/models

View on GitHub
official/projects/triviaqa/preprocess.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for preprocessing TriviaQA data."""
import bisect
import json
import operator
import os
import re
import string
from typing import Any, Dict, Generator, List, Optional, Set, Text, Tuple

from absl import logging
import apache_beam as beam
from apache_beam import metrics
import dataclasses
import nltk
import numpy as np
import tensorflow.io.gfile as gfile

import sentencepiece as spm
from official.projects.triviaqa import evaluation
from official.projects.triviaqa import sentencepiece_pb2


@dataclasses.dataclass
class Question(object):
  id: Text
  value: Text


@dataclasses.dataclass
class EvidenceInfo(object):
  id: Text
  source: Text
  title: Text


@dataclasses.dataclass
class Evidence(object):
  info: EvidenceInfo
  text: Text


@dataclasses.dataclass
class Answer(object):
  value: Text
  aliases: List[Text]
  normalized_aliases: List[Text]


@dataclasses.dataclass
class QuestionAnswer(object):
  question: Question
  evidence_info: List[EvidenceInfo]
  answer: Optional[Answer] = None


@dataclasses.dataclass
class QuestionAnswerEvidence(object):
  question: Question
  evidence: Evidence
  answer: Optional[Answer] = None


@dataclasses.dataclass
class Features(object):
  id: Text
  stride_index: int
  question_id: Text
  question: Text
  context: bytes
  token_ids: List[int]
  token_offsets: List[int]
  global_token_ids: List[int]
  segment_ids: List[int]


@dataclasses.dataclass
class Paragraph(object):
  sentences: List[sentencepiece_pb2.SentencePieceText]
  size: int


@dataclasses.dataclass
class AnswerSpan(object):
  begin: int  # inclusive
  end: int  # inclusive
  text: Text


def make_paragraph(
    sentence_tokenizer: nltk.tokenize.api.TokenizerI,
    processor: spm.SentencePieceProcessor,
    text: Text,
    paragraph_metric: Optional[metrics.Metrics.DelegatingDistribution] = None,
    sentence_metric: Optional[metrics.Metrics.DelegatingDistribution] = None
) -> Paragraph:
  """Tokenizes paragraphs."""
  paragraph_size = 0
  sentences = []
  for sentence in sentence_tokenizer.tokenize(text):
    sentencepiece_text = sentencepiece_pb2.SentencePieceText.FromString(
        processor.EncodeAsSerializedProto(sentence))
    paragraph_size += len(sentencepiece_text.pieces)
    sentences.append(sentencepiece_text)
    if sentence_metric:
      sentence_metric.update(len(sentencepiece_text.pieces))
  if paragraph_metric:
    paragraph_metric.update(paragraph_size)
  return Paragraph(sentences=sentences, size=paragraph_size)


def read_question_answers(json_path: Text) -> List[QuestionAnswer]:
  """Read question answers."""
  with gfile.GFile(json_path) as f:
    data = json.load(f)['Data']
  question_answers = []
  for datum in data:
    question = Question(id=datum['QuestionId'], value=datum['Question'])
    if 'Answer' in datum:
      answer = Answer(
          value=datum['Answer']['Value'],
          aliases=datum['Answer']['Aliases'],
          normalized_aliases=datum['Answer']['NormalizedAliases'])
    else:
      answer = None
    evidence_info = []
    for key in ['EntityPages', 'SearchResults']:
      for document in datum.get(key, []):
        evidence_info.append(
            EvidenceInfo(
                id=document['Filename'], title=document['Title'], source=key))
    question_answers.append(
        QuestionAnswer(
            question=question, evidence_info=evidence_info, answer=answer))
  return question_answers


def alias_answer(answer: Text, include=None):
  alias = answer.replace('_', ' ').lower()
  exclude = set(string.punctuation + ''.join(['‘', '’', '´', '`']))
  include = include or []
  alias = ''.join(c if c not in exclude or c in include else ' ' for c in alias)
  return ' '.join(alias.split()).strip()


def make_answer_set(answer: Answer) -> Set[Text]:
  """Apply less aggressive normalization to the answer aliases."""
  answers = []
  for alias in [answer.value] + answer.aliases:
    answers.append(alias_answer(alias))
    answers.append(alias_answer(alias, [',', '.']))
    answers.append(alias_answer(alias, ['-']))
    answers.append(alias_answer(alias, [',', '.', '-']))
    answers.append(alias_answer(alias, string.punctuation))
  return set(answers + answer.normalized_aliases)


def find_answer_spans(text: bytes, answer_set: Set[Text]) -> List[AnswerSpan]:
  """Find answer spans."""
  spans = []
  for answer in answer_set:
    answer_regex = re.compile(
        re.escape(answer).encode('utf-8').replace(b'\\ ', b'[ -]'),
        flags=re.IGNORECASE)
    for match in re.finditer(answer_regex, text):
      spans.append(
          AnswerSpan(
              begin=match.start(),
              end=match.end(),
              text=match.group(0).decode('utf-8')))
  return sorted(spans, key=operator.attrgetter('begin'))


def realign_answer_span(features: Features, answer_set: Optional[Set[Text]],
                        processor: spm.SentencePieceProcessor,
                        span: AnswerSpan) -> Optional[AnswerSpan]:
  """Align answer span to text with given tokens."""
  i = bisect.bisect_left(features.token_offsets, span.begin)
  if i == len(features.token_offsets) or span.begin < features.token_offsets[i]:
    i -= 1
  j = i + 1
  answer_end = span.begin + len(span.text.encode('utf-8'))
  while (j < len(features.token_offsets) and
         features.token_offsets[j] < answer_end):
    j += 1
  j -= 1
  sp_answer = (
      features.context[features.token_offsets[i]:features.token_offsets[j + 1]]
      if j + 1 < len(features.token_offsets) else
      features.context[features.token_offsets[i]:])
  if (processor.IdToPiece(features.token_ids[i]).startswith('▁') and
      features.token_offsets[i] > 0):
    sp_answer = sp_answer[1:]
  sp_answer = evaluation.normalize_answer(sp_answer.decode('utf-8'))
  if answer_set is not None and sp_answer not in answer_set:
    # No need to warn if the cause was breaking word boundaries.
    if len(sp_answer) and not len(sp_answer) > len(
        evaluation.normalize_answer(span.text)):
      logging.warning('%s: "%s" not in %s.', features.question_id, sp_answer,
                      answer_set)
    return None
  return AnswerSpan(begin=i, end=j, text=span.text)


def read_sentencepiece_model(path):
  with gfile.GFile(path, 'rb') as file:
    processor = spm.SentencePieceProcessor()
    processor.LoadFromSerializedProto(file.read())
  return processor


class ReadEvidence(beam.DoFn):
  """Function to read evidence."""

  def __init__(self, wikipedia_dir: Text, web_dir: Text):
    self._wikipedia_dir = wikipedia_dir
    self._web_dir = web_dir

  def process(
      self, question_answer: QuestionAnswer
  ) -> Generator[QuestionAnswerEvidence, None, None]:
    for info in question_answer.evidence_info:
      if info.source == 'EntityPages':
        evidence_path = os.path.join(self._wikipedia_dir, info.id)
      elif info.source == 'SearchResult':
        evidence_path = os.path.join(self._web_dir, info.id)
      else:
        raise ValueError(f'Unknown evidence source: {info.source}.')
      with gfile.GFile(evidence_path, 'rb') as f:
        text = f.read().decode('utf-8')
      metrics.Metrics.counter('_', 'documents').inc()
      yield QuestionAnswerEvidence(
          question=question_answer.question,
          evidence=Evidence(info=info, text=text),
          answer=question_answer.answer)


_CLS_PIECE = '<ans>'
_EOS_PIECE = '</s>'
_SEP_PIECE = '<sep_0>'
# _PARAGRAPH_SEP_PIECE = '<sep_1>'
_NULL_PIECE = '<empty>'
_QUESTION_PIECE = '<unused_34>'


class MakeFeatures(beam.DoFn):
  """Function to make features."""

  def __init__(self, sentencepiece_model_path: Text, max_num_tokens: int,
               max_num_global_tokens: int, stride: int):
    self._sentencepiece_model_path = sentencepiece_model_path
    self._max_num_tokens = max_num_tokens
    self._max_num_global_tokens = max_num_global_tokens
    self._stride = stride

  def setup(self):
    self._sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
    self._sentencepiece_processor = read_sentencepiece_model(
        self._sentencepiece_model_path)

  def _make_features(self, stride_index: int, paragraph_texts: List[Text],
                     paragraphs: List[Paragraph],
                     question_answer_evidence: QuestionAnswerEvidence,
                     ids: List[int],
                     paragraph_offset: int) -> Tuple[int, Features]:
    global_ids = (
        [self._sentencepiece_processor.PieceToId(_CLS_PIECE)] +
        [self._sentencepiece_processor.PieceToId(_QUESTION_PIECE)] * len(ids))
    segment_ids = [i + 1 for i in range(len(ids))]  # offset for CLS token
    token_ids, sentences = [], []
    offsets, offset, full_text = [-1] * len(ids), 0, True
    for i in range(paragraph_offset, len(paragraph_texts)):
      if i < len(paragraphs):
        paragraph = paragraphs[i]
      else:
        paragraphs.append(
            make_paragraph(
                self._sentence_tokenizer,
                self._sentencepiece_processor,
                paragraph_texts[i],
                paragraph_metric=metrics.Metrics.distribution(
                    '_', 'paragraphs'),
                sentence_metric=metrics.Metrics.distribution('_', 'sentences')))
        paragraph = paragraphs[-1]
      for sentence in paragraph.sentences:
        if (len(ids) + len(token_ids) + len(sentence.pieces) + 1 >=
            self._max_num_tokens or
            len(global_ids) >= self._max_num_global_tokens):
          full_text = False
          break
        for j, piece in enumerate(sentence.pieces):
          token_ids.append(piece.id)
          segment_ids.append(len(global_ids))
          offsets.append(offset + piece.begin)
          if j == 0 and sentences:
            offsets[-1] -= 1
        offset += len(sentence.text.encode('utf-8')) + 1
        global_ids.append(self._sentencepiece_processor.PieceToId(_EOS_PIECE))
        sentences.append(sentence.text)
      if not full_text:
        break
    context = ' '.join(sentences).encode('utf-8')
    token_ids.append(self._sentencepiece_processor.PieceToId(_NULL_PIECE))
    offsets.append(len(context))
    segment_ids.append(0)
    next_paragraph_index = len(paragraph_texts)
    if not full_text and self._stride > 0:
      shift = paragraphs[paragraph_offset].size
      next_paragraph_index = paragraph_offset + 1
      while (next_paragraph_index < len(paragraphs) and
             shift + paragraphs[next_paragraph_index].size <= self._stride):
        shift += paragraphs[next_paragraph_index].size
        next_paragraph_index += 1
    return next_paragraph_index, Features(
        id='{}--{}'.format(question_answer_evidence.question.id,
                           question_answer_evidence.evidence.info.id),
        stride_index=stride_index,
        question_id=question_answer_evidence.question.id,
        question=question_answer_evidence.question.value,
        context=context,
        token_ids=ids + token_ids,
        global_token_ids=global_ids,
        segment_ids=segment_ids,
        token_offsets=offsets)

  def process(
      self, question_answer_evidence: QuestionAnswerEvidence
  ) -> Generator[Features, None, None]:
    # Tokenize question which is shared among all examples.
    ids = (
        self._sentencepiece_processor.EncodeAsIds(
            question_answer_evidence.question.value) +
        [self._sentencepiece_processor.PieceToId(_SEP_PIECE)])
    paragraph_texts = list(
        filter(
            lambda p: p,
            map(lambda p: p.strip(),
                question_answer_evidence.evidence.text.split('\n'))))
    stride_index, paragraphs, paragraph_index = 0, [], 0
    while paragraph_index < len(paragraph_texts):
      paragraph_index, features = self._make_features(stride_index,
                                                      paragraph_texts,
                                                      paragraphs,
                                                      question_answer_evidence,
                                                      ids, paragraph_index)
      stride_index += 1
      yield features


def _handle_exceptional_examples(
    features: Features,
    processor: spm.SentencePieceProcessor) -> List[AnswerSpan]:
  """Special cases in data."""
  if features.id == 'qw_6687--Viola.txt':
    pattern = 'three strings in common—G, D, and A'.encode('utf-8')
    i = features.context.find(pattern)
    if i != -1:
      span = AnswerSpan(i + len(pattern) - 1, i + len(pattern), 'A')
      span = realign_answer_span(features, None, processor, span)
      assert span is not None, 'Span should exist.'
      return [span]
  if features.id == 'sfq_26183--Vitamin_A.txt':
    pattern = ('Vitamin A is a group of unsaturated nutritional organic '
               'compounds that includes retinol').encode('utf-8')
    i = features.context.find(pattern)
    if i != -1:
      span = AnswerSpan(i + pattern.find(b'A'), i + pattern.find(b'A') + 1, 'A')
      span = realign_answer_span(features, None, processor, span)
      assert span is not None, 'Span should exist.'
      spans = [span]
      span = AnswerSpan(i, i + pattern.find(b'A') + 1, 'Vitamin A')
      span = realign_answer_span(features, None, processor, span)
      return spans + [span]
  if features.id == 'odql_292--Colombia.txt':
    pattern = b'Colombia is the third-most populous country in Latin America'
    i = features.context.find(pattern)
    if i != -1:
      span = AnswerSpan(i, i + len(b'Colombia'), 'Colombia')
      span = realign_answer_span(features, None, processor, span)
      assert span is not None, 'Span should exist.'
      return [span]
  if features.id == 'tc_1648--Vietnam.txt':
    pattern = 'Bảo Đại'.encode('utf-8')
    i = features.context.find(pattern)
    if i != -1:
      span = AnswerSpan(i, i + len(pattern), 'Bảo Đại')
      span = realign_answer_span(features, None, processor, span)
      assert span is not None, 'Span should exist.'
      return [span]
  if features.id == 'sfq_22225--Irish_mythology.txt':
    pattern = 'Tír na nÓg'.encode('utf-8')
    spans = []
    i = 0
    while features.context.find(pattern, i) != -1:
      i = features.context.find(pattern)
      span = AnswerSpan(i, i + len(pattern), 'Tír na nÓg')
      span = realign_answer_span(features, None, processor, span)
      assert span is not None, 'Span should exist.'
      spans.append(span)
      i += len(pattern)
    return spans
  return []


class FindAnswerSpans(beam.DoFn):
  """Find answer spans in document."""

  def __init__(self, sentencepiece_model_path: Text):
    self._sentencepiece_model_path = sentencepiece_model_path

  def setup(self):
    self._sentencepiece_processor = read_sentencepiece_model(
        self._sentencepiece_model_path)

  def process(
      self,
      element: Tuple[Text, List[Features]],
      answer_sets: Dict[Text, Set[Text]],
  ) -> Generator[Tuple[Features, List[AnswerSpan]], None, None]:
    question_id, features = element
    answer_set = answer_sets[question_id]
    has_answer = False
    for feature in features:
      answer_spans = []
      for answer_span in find_answer_spans(feature.context, answer_set):
        realigned_answer_span = realign_answer_span(
            feature, answer_set, self._sentencepiece_processor, answer_span)
        if realigned_answer_span:
          answer_spans.append(realigned_answer_span)
      if not answer_spans:
        answer_spans = _handle_exceptional_examples(
            feature, self._sentencepiece_processor)
      if answer_spans:
        has_answer = True
      else:
        metrics.Metrics.counter('_', 'answerless_examples').inc()
      yield feature, answer_spans
    if not has_answer:
      metrics.Metrics.counter('_', 'answerless_questions').inc()
      logging.error('Question %s has no answer.', question_id)


def make_example(
    features: Features,
    labels: Optional[List[AnswerSpan]] = None) -> Tuple[Text, Dict[Text, Any]]:
  """Make an example."""
  feature = {
      'id': features.id,
      'qid': features.question_id,
      'question': features.question,
      'context': features.context,
      'token_ids': features.token_ids,
      'token_offsets': features.token_offsets,
      'segment_ids': features.segment_ids,
      'global_token_ids': features.global_token_ids,
  }
  if labels:
    answers = set((label.begin, label.end) for label in labels)
    feature['answers'] = np.array([list(answer) for answer in answers],
                                  np.int64)
  else:
    feature['answers'] = np.zeros([0, 2], np.int64)
  metrics.Metrics.counter('_', 'examples').inc()
  return f'{features.id}--{features.stride_index}', feature


def make_pipeline(root: beam.Pipeline, question_answers: List[QuestionAnswer],
                  answer: bool, max_num_tokens: int, max_num_global_tokens: int,
                  stride: int, sentencepiece_model_path: Text,
                  wikipedia_dir: Text, web_dir: Text):
  """Makes a Beam pipeline."""
  question_answers = (
      root | 'CreateQuestionAnswers' >> beam.Create(question_answers))
  features = (
      question_answers
      | 'ReadEvidence' >> beam.ParDo(
          ReadEvidence(wikipedia_dir=wikipedia_dir, web_dir=web_dir))
      | 'MakeFeatures' >> beam.ParDo(
          MakeFeatures(
              sentencepiece_model_path=sentencepiece_model_path,
              max_num_tokens=max_num_tokens,
              max_num_global_tokens=max_num_global_tokens,
              stride=stride)))
  if answer:
    features = features | 'KeyFeature' >> beam.Map(
        lambda feature: (feature.question_id, feature))
    # pylint: disable=g-long-lambda
    answer_sets = (
        question_answers
        | 'MakeAnswerSet' >>
        beam.Map(lambda qa: (qa.question.id, make_answer_set(qa.answer))))
    # pylint: enable=g-long-lambda
    examples = (
        features
        | beam.GroupByKey()
        | 'FindAnswerSpans' >> beam.ParDo(
            FindAnswerSpans(sentencepiece_model_path),
            answer_sets=beam.pvalue.AsDict(answer_sets))
        | 'MakeExamplesWithLabels' >> beam.MapTuple(make_example))
  else:
    examples = features | 'MakeExamples' >> beam.Map(make_example)
  return examples