debatelab/deepa2

View on GitHub
deepa2/builder/nli_builder.py

Summary

Maintainability
D
1 day
Test Coverage
A
95%
"""Defines Builders for creating DeepA2 datasets from NLI-type data."""

from __future__ import annotations

import dataclasses
import logging
import random
import re
from typing import List, Dict, Union
import uuid

import datasets
import jinja2
import numpy as np
import pandas as pd
from tqdm import tqdm  # type: ignore

from deepa2.builder.core import (
    ArgdownStatement,
    Builder,
    Formalization,
    PreprocessedExample,
    QuotedStatement,
    RawExample,
    DeepA2Item,
)
from deepa2.config import template_dir, package_dir
import deepa2.builder.jinjafilters as jjfilters

tqdm.pandas()


@dataclasses.dataclass
class RawESNLIExample(RawExample):
    """
    Datatype describing a raw, unprocessed example
    in a e-snli dataset, possibly batched.
    """

    premise: Union[str, List[str]]
    hypothesis: Union[str, List[str]]
    label: Union[int, List[int]]
    explanation_1: Union[str, List[str]]
    explanation_2: Union[str, List[str]]
    explanation_3: Union[str, List[str]]


@dataclasses.dataclass
class PreprocessedESNLIExample(PreprocessedExample):
    """
    Datatype describing a preprocessed e-snli example.
    """

    premise: str
    hypothesis_ent: str
    hypothesis_neu: str
    hypothesis_con: str
    explanation_ent: List[str]
    explanation_neu: List[str]
    explanation_con: List[str]


@dataclasses.dataclass
class ESNLIItemConfiguration:
    """
    Datatype describing the build-configuration of a single
    DeepA2Item in a DeepA2 datasets created from eSNLI data.
    """

    label: str
    argdown_template_path: str = "esnli/argdown_generic.txt"
    argdown_err_template_path: str = ""
    source_paraphrase_template_path: str = "esnli/source_paraphrase.txt"
    scheme_name: str = "modus ponens"
    formal_scheme: List = dataclasses.field(
        default_factory=lambda: ["{p}", "{p} -> {q}", "{q}"]
    )
    placeholders: Dict = dataclasses.field(
        default_factory=lambda: {"p": "{premise}", "q": "{hypothesis}"}
    )

    _nl_schemes_dict = {
        "{p}": "{{ {p} | lower }}",
        "{q}": "{{ {q} | lower }}",
        "¬{p}": "{{ {p} | negation }}",
        "¬{q}": "{{ {q} | negation }}",
        "{p} -> {q}": "{{ {p} | conditional({q}) }}",
        "{p} -> ¬{q}": "{{ {p} | conditional({q} | negation) }}",
    }

    # list of errorneous argdown templates
    _argdown_err_templates = [
        "esnli/argdown_err-01.txt",
        "esnli/argdown_err-02.txt",
        "esnli/argdown_err-03.txt",
        "esnli/argdown_err-04.txt",
        "esnli/argdown_err-05.txt",
        "esnli/argdown_err-06.txt",
        "esnli/argdown_err-07.txt",
        "esnli/argdown_err-08.txt",
        "esnli/argdown_err-09.txt",
    ]

    def __post_init__(self):
        # choose and set random template
        self.argdown_err_template_path = random.choice(self._argdown_err_templates)

    @property
    def nl_scheme(self) -> List[str]:
        """nl scheme to use with this configuration"""
        placeholders = {k: v.strip("{}") for k, v in self.placeholders.items()}
        nl_scheme = [
            self._nl_schemes_dict[s].format(**placeholders) for s in self.formal_scheme
        ]
        nl_scheme = [
            "{" + s + "}" for s in nl_scheme
        ]  # postprocess: re-add {} which got lost in previous format() call
        assert all(
            (s[:2] == "{{" and s[-2:] == "}}") for s in nl_scheme
        )  # jinja2 templates?
        return nl_scheme


class ESNLIBuilder(Builder):
    """
    eSNLI Builer preprocesses and transforms e-SNLI records into DeepA2 items.
    """

    @staticmethod
    def preprocess(dataset: datasets.Dataset) -> datasets.Dataset:
        df_esnli = pd.DataFrame(dataset.to_pandas())
        df_esnli = df_esnli.drop_duplicates()
        # count explanations per row
        df_esnli["n_explanations"] = 3 - df_esnli[
            ["explanation_1", "explanation_2", "explanation_3"]
        ].eq("").sum(axis=1)
        # keep records with at least one explanation
        df_esnli = pd.DataFrame(df_esnli[df_esnli.n_explanations.ge(1)])
        # count how frequently premise occurs in the dataset (default = three times)
        counts = df_esnli.groupby(["premise"]).size()
        tqdm.write("Preprocessing 1/8")
        df_esnli["premise_counts"] = df_esnli.premise.progress_apply(
            lambda x: counts[x]
        )
        # drop records whose premise occurs less than 3 times
        df_esnli = pd.DataFrame(df_esnli[df_esnli.premise_counts.ge(3)])

        # we split df in two parts which will be processed separately and are finally merged

        # Split 1
        # get all rows whose premise occurs more than 3 times
        df_esnli_tmp = df_esnli[df_esnli.premise_counts.gt(3)].copy()
        df_esnli_tmp.reset_index(inplace=True)
        # for each premise, what is the minimum number of labels?
        df2 = df_esnli_tmp.groupby(["premise", "label"]).size().unstack()
        logging.debug("Premises associated with no label: %s", sum(df2.eq(0)))
        df2.fillna(0, inplace=True)
        tqdm.write("Preprocessing 2/8")
        df_esnli_tmp["min_label_counts"] = df_esnli_tmp.premise.progress_apply(
            lambda x: int(df2.min(axis=1)[x])
        )  # df2.min(axis=1) tells us how many records for each premise will go into
        # preprocessed esnli dataset
        # make sure that for each premise, we have the same number of records for labels 0,1,2
        tqdm.write("Preprocessing 3/8")
        if len(df_esnli_tmp) > 0:
            df_esnli_tmp = df_esnli_tmp.groupby(
                ["premise", "label"], as_index=False
            ).progress_apply(lambda x: x.iloc[: x.min_label_counts.iloc[0]])

        # reorder row so as to obtain alternating labels
        def reorder_premise_group(premise_group):
            return (
                premise_group.groupby("label")
                .apply(lambda g: g.reset_index(drop=True))
                .sort_index(level=1)
            )

        tqdm.write("Preprocessing 4/8")
        df_esnli_tmp = df_esnli_tmp.groupby(["premise"], as_index=False).progress_apply(
            reorder_premise_group
        )

        # Split 2
        # get all rows whose premise occurs exactly 3 times
        df_esnli_tmp2 = df_esnli[df_esnli.premise_counts.eq(3)].copy()
        # determine premises with incomplete labels (at least one label is missing)
        tqdm.write("Preprocessing 5/8")
        labels_complete = df_esnli_tmp2.groupby(["premise"]).progress_apply(
            lambda g: len(set(g["label"])) == 3
        )
        tqdm.write("Preprocessing 6/8")
        df_esnli_tmp2["complete"] = df_esnli_tmp2.premise.progress_apply(
            lambda x: labels_complete[x]
        )
        # retain only complete records
        df_esnli_tmp2 = df_esnli_tmp2[df_esnli_tmp2.complete]

        # Merge
        df_esnli_final: pd.DataFrame = pd.concat(
            [
                df_esnli_tmp2[
                    [field.name for field in dataclasses.fields(RawESNLIExample)]
                ],
                df_esnli_tmp[
                    [field.name for field in dataclasses.fields(RawESNLIExample)]
                ],
            ]
        )
        df_esnli_final.reset_index(drop=True, inplace=True)

        # Sanity check
        tqdm.write("Preprocessing 7/8")
        for start in tqdm(range(0, df_esnli_final.shape[0], 3)):
            triple = df_esnli_final.iloc[start : start + 3]
            assert len(set(triple.premise)) == 1
            assert len(set(triple.label)) == 3

        # Merge triples, creating a PreprocessedESNLIExample from each
        tqdm.write("Preprocessing 8/8")

        def merge_triple(triple: pd.DataFrame):
            preprocessed_example = PreprocessedESNLIExample(
                premise=triple.iloc[0].premise,
                hypothesis_ent=triple[triple.label.eq(0)].iloc[0].hypothesis,
                hypothesis_neu=triple[triple.label.eq(1)].iloc[0].hypothesis,
                hypothesis_con=triple[triple.label.eq(2)].iloc[0].hypothesis,
                explanation_ent=[
                    triple[triple.label.eq(0)].iloc[0].explanation_1,
                    triple[triple.label.eq(0)].iloc[0].explanation_2,
                    triple[triple.label.eq(0)].iloc[0].explanation_3,
                ],
                explanation_neu=[
                    triple[triple.label.eq(1)].iloc[0].explanation_1,
                    triple[triple.label.eq(1)].iloc[0].explanation_2,
                    triple[triple.label.eq(1)].iloc[0].explanation_3,
                ],
                explanation_con=[
                    triple[triple.label.eq(2)].iloc[0].explanation_1,
                    triple[triple.label.eq(2)].iloc[0].explanation_2,
                    triple[triple.label.eq(2)].iloc[0].explanation_3,
                ],
            )
            # fill in explanations in case they are missing
            # we assume that "explanation_1" is given
            for key in ["explanation_ent", "explanation_neu", "explanation_con"]:
                explanations = getattr(preprocessed_example, key)
                has_been_changed = False
                for i in [1, 2]:
                    if explanations[i] == "":
                        explanations[i] = explanations[0]
                        has_been_changed = True
                if has_been_changed:
                    setattr(preprocessed_example, key, explanations)

            return pd.Series(dataclasses.asdict(preprocessed_example))

        df_esnli_final["triple_id"] = np.repeat(
            np.arange(int(len(df_esnli_final) / 3)), 3
        )
        df_esnli_final = df_esnli_final.groupby("triple_id").progress_apply(
            merge_triple
        )
        df_esnli_final.reset_index(drop=True, inplace=True)
        logging.debug(
            "Head of preprocessed esnli dataframe:\n %s", df_esnli_final.head()
        )

        # create dataset
        dataset = datasets.Dataset.from_pandas(df_esnli_final)

        return dataset

    # stores argument configurations used for creating DeepA2 data records
    CONFIGURATIONS = {
        "entailment": [
            ESNLIItemConfiguration(
                label="entailment",
                scheme_name="modus ponens",
                formal_scheme=["{p}", "{p} -> {q}", "{q}"],
                placeholders={"p": "{premise}", "q": "{hypothesis}"},
            ),
        ],
        "contradiction": [
            ESNLIItemConfiguration(
                label="contradiction",
                scheme_name="modus ponens",
                formal_scheme=["{p}", "{p} -> ¬{q}", "¬{q}"],
                placeholders={"p": "{premise}", "q": "{hypothesis}"},
            ),
            ESNLIItemConfiguration(
                label="contradiction",
                scheme_name="modus ponens",
                formal_scheme=["{p}", "{p} -> ¬{q}", "¬{q}"],
                placeholders={"p": "{hypothesis}", "q": "{premise}"},
            ),
            ESNLIItemConfiguration(
                label="contradiction",
                scheme_name="modus tollens",
                formal_scheme=["{q}", "{p} -> ¬{q}", "¬{p}"],
                placeholders={"q": "{premise}", "p": "{hypothesis}"},
            ),
            ESNLIItemConfiguration(
                label="contradiction",
                scheme_name="modus tollens",
                formal_scheme=["{q}", "{p} -> ¬{q}", "¬{p}"],
                placeholders={"q": "{hypothesis}", "p": "{premise}"},
            ),
        ],
    }

    def __init__(self, **kwargs) -> None:
        """
        Initialize eSNLI Builder.
        """
        super().__init__(**kwargs)
        self._input: PreprocessedESNLIExample

        # check whether template files are accessible
        if not (template_dir / "esnli").exists():
            logging.debug("Package dir: %s", package_dir)
            logging.debug("Resolve template dir: %s", template_dir)
            logging.debug("List template dir: %s", list(template_dir.glob("*")))
            err_m = f'No "esnli" subdirectory in template_dir {template_dir.resolve()}'
            raise ValueError(err_m)
        self._env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(template_dir),
            autoescape=jinja2.select_autoescape(),
        )

        # register filters
        self._env.filters["lowerall"] = jjfilters.lowerall
        self._env.filters["negation"] = jjfilters.negation
        self._env.filters["conditional"] = jjfilters.conditional

    @property
    def input(self) -> PreprocessedESNLIExample:
        return self._input

    def set_input(self, batched_input: Dict[str, List]) -> None:
        self._input = PreprocessedESNLIExample.from_batch(batched_input)

    def configure_product(self) -> None:
        # populate product with configs
        i = 0
        for label in ["entailment", "contradiction"]:
            # argument_mask specifies which part of argument will be dropped in source text:
            for argument_mask in [[1, 1, 1], [0, 1, 1], [1, 0, 1], [1, 1, 0]]:
                # distractor_mask specifies which distractors will be dropped in source text:
                for distractor_mask in [[1, 1], [1, 0], [0, 1], [0, 0]]:
                    config = self.CONFIGURATIONS[label][
                        i % len(self.CONFIGURATIONS[label])
                    ]
                    metadata = [
                        ("id", str(uuid.uuid4())),
                        ("config", config),
                        ("argument_mask", argument_mask),
                        ("distractor_mask", distractor_mask),
                        ("label", label),
                    ]
                    deepa2record = DeepA2Item(metadata=metadata)
                    self._product.append(deepa2record)
                    i += 1

    def produce_da2item(self) -> None:
        for i, _ in enumerate(self._product):
            self.populate_record(i)

    def _map_data_to_roles(self, record: DeepA2Item, idx: int = 0) -> Dict:
        """initializes population of record"""
        data = {}
        metadata = dict(record.metadata)
        if metadata["label"] == "entailment":
            data = {
                "premise": self.input.premise,
                "hypothesis": self.input.hypothesis_ent,
                "premise_cond": self.input.explanation_ent[
                    idx % 3
                ],  # used in source text
                "distractors": [
                    self.input.hypothesis_neu,
                    self.input.explanation_con[idx % 3],
                ],  # used in source text
            }
        else:  # label == contradiction
            data = {
                "premise": self.input.premise,
                "hypothesis": self.input.hypothesis_con,
                "premise_cond": self.input.explanation_con[
                    idx % 3
                ],  # used in source text
                "distractors": [
                    self.input.hypothesis_neu,
                    self.input.explanation_ent[idx % 3],
                ],  # used in source text
            }

        return data

    def populate_record(  # pylint: disable=too-many-statements, too-many-locals
        self, idx: int
    ) -> None:
        """populates record at product index `int`"""

        record = self._product[idx]
        metadata = dict(record.metadata)
        config = metadata["config"]

        # Initialize: mapping input data to argumentative roles
        data = self._map_data_to_roles(record=record, idx=idx)

        # Step 1: construct argdown
        # argument list
        argument_list = [
            self._env.from_string(t).render(data) for t in config.nl_scheme
        ]
        # argdown
        argdown_template = self._env.get_template(config.argdown_template_path)
        record.argdown_reconstruction = argdown_template.render(
            premise1=argument_list[0],
            premise2=argument_list[1],
            conclusion=argument_list[-1],
            scheme=config.scheme_name,
        )
        # erroneous argdown
        argdown_err_template = self._env.get_template(config.argdown_err_template_path)
        record.erroneous_argdown = argdown_err_template.render(
            premise1=argument_list[0],
            premise2=argument_list[1],
            conclusion=argument_list[-1],
            scheme=config.scheme_name,
        )

        # Step 2: premises and conclusion lists
        # premises
        record.premises = []
        for i in range(2):
            explicit = bool(metadata["argument_mask"][i])
            argdown_statement = ArgdownStatement(
                text=argument_list[i], explicit=explicit, ref_reco=i + 1
            )
            record.premises.append(argdown_statement)
        # conclusion
        i = 2
        explicit = bool(metadata["argument_mask"][i])
        argdown_statement = ArgdownStatement(
            text=argument_list[i], explicit=explicit, ref_reco=i + 1
        )
        record.conclusion = [argdown_statement]

        # Step 3: formalizations
        # premises
        record.premises_formalized = []
        for i in range(2):
            form = re.sub(r"{|}", "", config.formal_scheme[i])  # remove brackets
            form = form.replace("¬", " not ")  # replace negator
            form = form.replace("  ", " ")
            formalization = Formalization(form=form, ref_reco=i + 1)
            record.premises_formalized.append(formalization)
        # conclusion
        i = 2
        form = re.sub(r"{|}", "", config.formal_scheme[i])
        form = form.replace("¬", " not ")  # replace negator
        form = form.replace("  ", " ")
        formalization = Formalization(form=form, ref_reco=i + 1)
        record.conclusion_formalized = [formalization]
        # placeholders
        record.misc_placeholders = [k for k, _ in config.placeholders.items()]
        record.plchd_substitutions = [
            (k, v.format(**data)) for k, v in config.placeholders.items()
        ]

        # Step 4: source text, reasons, conjectures

        # 4.a) compile list with all sentences in source text
        source_text_list = []
        # add distractors
        for i, sentence in enumerate(data["distractors"]):
            if metadata["distractor_mask"][i]:
                source_text_list.append(["distractor", sentence])
        # add reasons
        argument_list2 = argument_list.copy()
        argument_list2[1] = data["premise_cond"]  # replace conditional
        for i, sentence in enumerate(argument_list2[:-1]):
            if metadata["argument_mask"][i]:
                source_text_list.append(
                    [
                        "reason",
                        QuotedStatement(text=sentence, ref_reco=i + 1, starts_at=-1),
                    ]
                )
        # add conclusion
        i = 2
        if metadata["argument_mask"][i]:
            sentence = argument_list2[i]
            source_text_list.append(
                [
                    "conjecture",
                    QuotedStatement(text=sentence, ref_reco=i + 1, starts_at=-1),
                ]
            )
        # shuffle
        random.shuffle(source_text_list)

        # 4.b) walk through list and compile source text as well as reason, conclusions, distractors
        source_text = ""
        record.reasons = []
        record.conjectures = []
        distractors = []
        for item in source_text_list:
            pointer = len(source_text)
            if item[0] == "distractor":
                source_text += item[1]
                distractors.append(item[1])
            elif item[0] in ["reason", "conjecture"]:
                source_text += item[1].text
                item[1].starts_at = pointer
                if item[0] == "reason":
                    record.reasons.append(item[1])
                else:
                    record.conjectures.append(item[1])
            source_text += " "

        record.source_text = source_text.strip(" ")
        record.metadata.append(("distractors", distractors))

        # Step 5: gist, source_paraphrase, context, title
        # use premise2 as gist
        record.gist = data["premise_cond"]
        # source paraphrase
        sp_template = self._env.get_template(config.source_paraphrase_template_path)
        record.source_paraphrase = sp_template.render(
            premises=[d.text for d in record.reasons],
            conclusion=[d.text for d in record.conjectures],
        )

        # title, context
        #   - so far missing

    def postprocess_da2item(self) -> None:
        pass

    def add_metadata_da2item(self) -> None:
        pass