official/projects/triviaqa/dataset.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA: A Reading Comprehension Dataset."""
import functools
import json
import os
from absl import logging
import apache_beam as beam
import six
import tensorflow as tf, tf_keras
import tensorflow_datasets.public_api as tfds
from official.projects.triviaqa import preprocess
_CITATION = """
@article{2017arXivtriviaqa,
author = {{Joshi}, Mandar and {Choi}, Eunsol and {Weld},
Daniel and {Zettlemoyer}, Luke},
title = "{triviaqa: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension}",
journal = {arXiv e-prints},
year = 2017,
eid = {arXiv:1705.03551},
pages = {arXiv:1705.03551},
archivePrefix = {arXiv},
eprint = {1705.03551},
}
"""
_DOWNLOAD_URL_TMPL = (
"http://nlp.cs.washington.edu/triviaqa/data/triviaqa-{}.tar.gz")
_TRAIN_FILE_FORMAT = "*-train.json"
_VALIDATION_FILE_FORMAT = "*-dev.json"
_TEST_FILE_FORMAT = "*test-without-answers.json"
_WEB_EVIDENCE_DIR = "evidence/web"
_WIKI_EVIDENCE_DIR = "evidence/wikipedia"
_DESCRIPTION = """\
TriviaqQA is a reading comprehension dataset containing over 650K
question-answer-evidence triples. TriviaqQA includes 95K question-answer
pairs authored by trivia enthusiasts and independently gathered evidence
documents, six per question on average, that provide high quality distant
supervision for answering the questions.
"""
_RC_DESCRIPTION = """\
Question-answer pairs where all documents for a given question contain the
answer string(s).
"""
_UNFILTERED_DESCRIPTION = """\
110k question-answer pairs for open domain QA where not all documents for a
given question contain the answer string(s). This makes the unfiltered dataset
more appropriate for IR-style QA.
"""
_CONTEXT_ADDENDUM = "Includes context from Wikipedia and search results."
def _web_evidence_dir(tmp_dir):
return tf.io.gfile.glob(os.path.join(tmp_dir, _WEB_EVIDENCE_DIR))
def _wiki_evidence_dir(tmp_dir):
return tf.io.gfile.glob(os.path.join(tmp_dir, _WIKI_EVIDENCE_DIR))
class TriviaQAConfig(tfds.core.BuilderConfig):
"""BuilderConfig for TriviaQA."""
def __init__(self, *, unfiltered=False, exclude_context=False, **kwargs):
"""BuilderConfig for TriviaQA.
Args:
unfiltered: bool, whether to use the unfiltered version of the dataset,
intended for open-domain QA.
exclude_context: bool, whether to exclude Wikipedia and search context for
reduced size.
**kwargs: keyword arguments forwarded to super.
"""
name = "unfiltered" if unfiltered else "rc"
if exclude_context:
name += ".nocontext"
description = _UNFILTERED_DESCRIPTION if unfiltered else _RC_DESCRIPTION
if not exclude_context:
description += _CONTEXT_ADDENDUM
super(TriviaQAConfig, self).__init__(
name=name,
description=description,
version=tfds.core.Version("1.1.1"),
**kwargs)
self.unfiltered = unfiltered
self.exclude_context = exclude_context
class BigBirdTriviaQAConfig(tfds.core.BuilderConfig):
"""BuilderConfig for TriviaQA."""
def __init__(self, **kwargs):
"""BuilderConfig for TriviaQA.
Args:
**kwargs: keyword arguments forwarded to super.
"""
name = "rc_wiki.preprocessed"
description = _RC_DESCRIPTION
super(BigBirdTriviaQAConfig, self).__init__(
name=name,
description=description,
version=tfds.core.Version("1.1.1"),
**kwargs)
self.unfiltered = False
self.exclude_context = False
def configure(self,
sentencepiece_model_path,
sequence_length,
stride,
global_sequence_length=None):
"""Configures additional user-specified arguments."""
self.sentencepiece_model_path = sentencepiece_model_path
self.sequence_length = sequence_length
self.stride = stride
if global_sequence_length is None and sequence_length is not None:
self.global_sequence_length = sequence_length // 16 + 64
else:
self.global_sequence_length = global_sequence_length
logging.info(
"""
global_sequence_length: %s
sequence_length: %s
stride: %s
sentencepiece_model_path: %s""",
self.global_sequence_length, self.sequence_length,
self.stride, self.sentencepiece_model_path)
def validate(self):
"""Validates that user specifies valid arguments."""
if self.sequence_length is None:
raise ValueError("sequence_length must be specified for BigBird.")
if self.stride is None:
raise ValueError("stride must be specified for BigBird.")
if self.sentencepiece_model_path is None:
raise ValueError(
"sentencepiece_model_path must be specified for BigBird.")
def filter_files_for_big_bird(files):
filtered_files = [f for f in files if os.path.basename(f).startswith("wiki")]
assert len(filtered_files) == 1, "There should only be one wikipedia file."
return filtered_files
class TriviaQA(tfds.core.BeamBasedBuilder):
"""TriviaQA is a reading comprehension dataset.
It containss over 650K question-answer-evidence triples.
"""
name = "bigbird_trivia_qa"
BUILDER_CONFIGS = [
BigBirdTriviaQAConfig(),
TriviaQAConfig(unfiltered=False, exclude_context=False), # rc
TriviaQAConfig(unfiltered=False, exclude_context=True), # rc.nocontext
TriviaQAConfig(unfiltered=True, exclude_context=False), # unfiltered
TriviaQAConfig(unfiltered=True, exclude_context=True),
# unfilered.nocontext
]
def __init__(self,
*,
sentencepiece_model_path=None,
sequence_length=None,
stride=None,
global_sequence_length=None,
**kwargs):
super(TriviaQA, self).__init__(**kwargs)
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
self.builder_config.configure(
sentencepiece_model_path=sentencepiece_model_path,
sequence_length=sequence_length,
stride=stride,
global_sequence_length=global_sequence_length)
def _info(self):
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
supervised_keys=None,
homepage="http://nlp.cs.washington.edu/triviaqa/",
citation=_CITATION,
features=tfds.features.FeaturesDict({
"id": tfds.features.Text(),
"qid": tfds.features.Text(),
"question": tfds.features.Text(),
"context": tfds.features.Text(),
# Sequence features.
"token_ids": tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"token_offsets":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"segment_ids":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"global_token_ids":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
# Start and end indices (inclusive).
"answers":
tfds.features.Tensor(shape=(None, 2), dtype=tf.int64),
}))
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict({
"question":
tfds.features.Text(),
"question_id":
tfds.features.Text(),
"question_source":
tfds.features.Text(),
"entity_pages":
tfds.features.Sequence({
"doc_source":
tfds.features.Text(),
"filename":
tfds.features.Text(),
"title":
tfds.features.Text(),
"wiki_context":
tfds.features.Text(),
}),
"search_results":
tfds.features.Sequence({
"description":
tfds.features.Text(),
"filename":
tfds.features.Text(),
"rank":
tf.int32,
"title":
tfds.features.Text(),
"url":
tfds.features.Text(),
"search_context":
tfds.features.Text(),
}),
"answer":
tfds.features.FeaturesDict({
"aliases":
tfds.features.Sequence(tfds.features.Text()),
"normalized_aliases":
tfds.features.Sequence(tfds.features.Text()),
"matched_wiki_entity_name":
tfds.features.Text(),
"normalized_matched_wiki_entity_name":
tfds.features.Text(),
"normalized_value":
tfds.features.Text(),
"type":
tfds.features.Text(),
"value":
tfds.features.Text(),
}),
}),
supervised_keys=None,
homepage="http://nlp.cs.washington.edu/triviaqa/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
cfg = self.builder_config
download_urls = dict()
if not (cfg.unfiltered and cfg.exclude_context):
download_urls["rc"] = _DOWNLOAD_URL_TMPL.format("rc")
if cfg.unfiltered:
download_urls["unfiltered"] = _DOWNLOAD_URL_TMPL.format("unfiltered")
file_paths = dl_manager.download_and_extract(download_urls)
qa_dir = (
os.path.join(file_paths["unfiltered"], "triviaqa-unfiltered")
if cfg.unfiltered else
os.path.join(file_paths["rc"], "qa"))
train_files = tf.io.gfile.glob(os.path.join(qa_dir, _TRAIN_FILE_FORMAT))
valid_files = tf.io.gfile.glob(
os.path.join(qa_dir, _VALIDATION_FILE_FORMAT))
test_files = tf.io.gfile.glob(os.path.join(qa_dir, _TEST_FILE_FORMAT))
if cfg.exclude_context:
web_evidence_dir = None
wiki_evidence_dir = None
else:
web_evidence_dir = os.path.join(file_paths["rc"], _WEB_EVIDENCE_DIR)
wiki_evidence_dir = os.path.join(file_paths["rc"], _WIKI_EVIDENCE_DIR)
if isinstance(cfg, BigBirdTriviaQAConfig):
train_files = filter_files_for_big_bird(train_files)
valid_files = filter_files_for_big_bird(valid_files)
test_files = filter_files_for_big_bird(test_files)
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs={"files": train_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": True}),
tfds.core.SplitGenerator(
name=tfds.Split.VALIDATION,
gen_kwargs={"files": valid_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": True}),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={"files": test_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": False}),
]
def _build_pcollection(self, pipeline, files, web_dir, wiki_dir, answer):
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
self.builder_config.validate()
question_answers = preprocess.read_question_answers(files[0])
return preprocess.make_pipeline(
pipeline,
question_answers=question_answers,
answer=answer,
max_num_tokens=self.builder_config.sequence_length,
max_num_global_tokens=self.builder_config.global_sequence_length,
stride=self.builder_config.stride,
sentencepiece_model_path=self.builder_config.sentencepiece_model_path,
wikipedia_dir=wiki_dir,
web_dir=web_dir)
parse_example_fn = functools.partial(parse_example,
self.builder_config.exclude_context,
web_dir, wiki_dir)
return (pipeline
| beam.Create(files)
| beam.ParDo(ReadQuestions())
| beam.Reshuffle()
| beam.Map(parse_example_fn))
class ReadQuestions(beam.DoFn):
"""Read questions from JSON."""
def process(self, file):
with tf.io.gfile.GFile(file) as f:
data = json.load(f)
for question in data["Data"]:
example = {"SourceFile": os.path.basename(file)}
example.update(question)
yield example
def parse_example(exclude_context, web_dir, wiki_dir, article):
"""Return a single example from an article JSON record."""
def _strip(collection):
return [item.strip() for item in collection]
if "Answer" in article:
answer = article["Answer"]
answer_dict = {
"aliases":
_strip(answer["Aliases"]),
"normalized_aliases":
_strip(answer["NormalizedAliases"]),
"matched_wiki_entity_name":
answer.get("MatchedWikiEntryName", "").strip(),
"normalized_matched_wiki_entity_name":
answer.get("NormalizedMatchedWikiEntryName", "").strip(),
"normalized_value":
answer["NormalizedValue"].strip(),
"type":
answer["Type"].strip(),
"value":
answer["Value"].strip(),
}
else:
answer_dict = {
"aliases": [],
"normalized_aliases": [],
"matched_wiki_entity_name": "<unk>",
"normalized_matched_wiki_entity_name": "<unk>",
"normalized_value": "<unk>",
"type": "",
"value": "<unk>",
}
if exclude_context:
article["SearchResults"] = []
article["EntityPages"] = []
def _add_context(collection, context_field, file_dir):
"""Adds context from file, or skips if file does not exist."""
new_items = []
for item in collection:
if "Filename" not in item:
logging.info("Missing context 'Filename', skipping.")
continue
new_item = item.copy()
fname = item["Filename"]
try:
with tf.io.gfile.GFile(os.path.join(file_dir, fname)) as f:
new_item[context_field] = f.read()
except (IOError, tf.errors.NotFoundError):
logging.info("File does not exist, skipping: %s", fname)
continue
new_items.append(new_item)
return new_items
def _strip_if_str(v):
return v.strip() if isinstance(v, six.string_types) else v
def _transpose_and_strip_dicts(dicts, field_names):
return {
tfds.core.naming.camelcase_to_snakecase(k):
[_strip_if_str(d[k]) for d in dicts] for k in field_names
}
search_results = _transpose_and_strip_dicts(
_add_context(article.get("SearchResults", []), "SearchContext", web_dir),
["Description", "Filename", "Rank", "Title", "Url", "SearchContext"])
entity_pages = _transpose_and_strip_dicts(
_add_context(article.get("EntityPages", []), "WikiContext", wiki_dir),
["DocSource", "Filename", "Title", "WikiContext"])
question = article["Question"].strip()
question_id = article["QuestionId"]
question_source = article["QuestionSource"].strip()
return f"{article['SourceFile']}_{question_id}", {
"entity_pages": entity_pages,
"search_results": search_results,
"question": question,
"question_id": question_id,
"question_source": question_source,
"answer": answer_dict,
}