debatelab/deepa2

View on GitHub
deepa2/builder/arg_kp_builder.py

Summary

Maintainability
A
2 hrs
Test Coverage
D
67%
"""Defines Builder for creating DeepA2 datasets from IBM Argument Key Point data."""

from __future__ import annotations

import dataclasses
import logging
import pathlib
import random
from typing import List, Dict, Union

import datasets

# import jinja2
import pandas as pd
from tqdm import tqdm  # type: ignore

from deepa2 import (
    ArgdownStatement,
    Formalization,
    DeepA2Item,
)
from deepa2.builder import (
    RawExample,
    PreprocessedExample,
    DatasetLoader,
    Builder,
    PipedBuilder,
    Pipeline,
    Transformer,
)

tqdm.pandas()


@dataclasses.dataclass
class RawArgKPExample(RawExample):  # pylint: disable=too-many-instance-attributes
    """
    Datatype describing a raw, unprocessed example
    in the ArgKP dataset as returned by dataset loader,
    possibly batched.
    """

    argument: Union[str, List[str]]
    topic: Union[str, List[str]]
    key_point: Union[str, List[str]]
    stance: Union[str, List[str]]  # pylint: disable=invalid-name


@dataclasses.dataclass
class PreprocessedArgKPExample(
    PreprocessedExample
):  # pylint: disable=too-many-instance-attributes
    """
    Datatype describing a preprocessed ArgKP example.
    """

    topic: str
    stance: int
    argument: str
    key_point: str


class ArgKPLoader(DatasetLoader):  # pylint: disable=too-few-public-methods
    """loads ArgKP raw data from multiple files"""

    _PATH_SPLITS = {
        "train": "kpm_data",
        "dev": "kpm_data",
        "test": "test_data",
    }

    _FILE_TEMPLATES = {
        "arguments": "arguments_{split}.csv",
        "key_points": "key_points_{split}.csv",
        "labels": "labels_{split}.csv",
    }

    def __init__(self, **kwargs) -> None:
        super().__init__()
        if "path" not in kwargs:
            logging.info(
                "No ArgKP root path specified (via --path), assuming path is ."
            )
        self._sourcepath: str = kwargs.get("path", ".")

    def load_dataset(self) -> datasets.DatasetDict:
        """loads ArgQP dat from root directory by merging multiple csv files"""
        splits_mapping = {"train": "train", "dev": "validation", "test": "test"}
        dataset_dict = {}
        for split_key, target_key in splits_mapping.items():
            logging.info("Loading %s split", split_key)
            dataset_dict[target_key] = datasets.Dataset.from_pandas(
                self.load_split(split_key),
                preserve_index=False,
            )
        dataset_dict = datasets.DatasetDict(dataset_dict)
        return dataset_dict

    def load_split(self, split: str) -> pd.DataFrame:
        """loads split as dataframe"""
        data_path = pathlib.Path(self._sourcepath, self._PATH_SPLITS[split])
        files = {
            k: pathlib.Path(v.format(split=split))
            for k, v in self._FILE_TEMPLATES.items()
        }
        # load arguments
        df_arguments = pd.read_csv(data_path / files["arguments"])
        logging.debug("Loadad arguments dataframe: %s", df_arguments.head())
        # load key_points
        df_key_points = pd.read_csv(data_path / files["key_points"])
        logging.debug("Loadad keypoints dataframe: %s", df_key_points.head())
        # load labels, which match arguments with keypoints
        df_labels = pd.read_csv(data_path / files["labels"])
        logging.debug("Loadad labels dataframe: %s", df_labels.head())
        # matching_labels
        df_labels_m = df_labels[df_labels.label == 1]
        logging.debug("Matching labels dataframe: %s", df_labels_m.head())
        # matched arguments
        df_arguments_m = df_arguments[
            df_arguments.arg_id.isin(df_labels_m.arg_id.tolist())
        ]
        if len(df_arguments) != len(df_arguments_m):
            logging.info(
                "Removing %s arguments without keypoints.",
                len(df_arguments) - len(df_arguments_m),
            )
            df_arguments = df_arguments_m

        def append_keypoints(row: pd.Series) -> pd.DataFrame:
            """returns df with all keypoints for given argument"""
            # get all keypoints
            arg_id = row.arg_id.iloc[0]
            key_point_ids = df_labels_m[df_labels_m.arg_id == arg_id][
                "key_point_id"
            ].tolist()
            df_key_points_sel = df_key_points[
                df_key_points.key_point_id.isin(key_point_ids)
            ]
            df_key_points_sel["argument"] = row.argument.iloc[0]
            return df_key_points_sel

        # append keypoints to each argument
        df_merged = pd.DataFrame(
            df_arguments.groupby("arg_id", as_index=False).progress_apply(
                append_keypoints
            )
        )
        # drop unnecessary columns & reindex
        df_merged = df_merged[["argument", "topic", "key_point", "stance"]]
        df_merged = df_merged.reset_index(drop=True)

        return df_merged


class AddContext(Transformer):
    """adds context and gist"""

    _TEMPLATE_STRINGS = {
        "context": "{{ topic_str }}? {{ stance_str }}",
    }
    PRO_EXPRS = ["yes!", "absolutely!", "I agree!"]
    CON_EXPRS = ["no!", "not at all!", "I disagree!"]

    def __init__(self, builder: Builder) -> None:
        super().__init__(builder)
        self._random = random.Random()

    def _generate_context(
        self,
        stance: int,
        topic: str,
        **kwargs,  # pylint: disable=unused-argument
    ):
        """generates context"""
        topic_str = topic
        stance_exprs = self.PRO_EXPRS if stance == 1 else self.CON_EXPRS
        stance_str = self._random.choice(stance_exprs)
        context = self.templates["context"].render(
            topic_str=topic_str,
            stance_str=stance_str,
        )
        return context

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgKPExample
    ) -> DeepA2Item:
        # print("da2_item: %s" % da2_item)
        # print("prep_example: %s" % prep_example)
        context = self._generate_context(**dataclasses.asdict(prep_example))
        da2_item.context = context
        da2_item.gist = prep_example.key_point
        return da2_item


class AddSourceText(Transformer):
    """add source_text"""

    _TEMPLATE_STRINGS = {
        "source_text": "{{ topic_str }}? {{ argument_str }}",
    }

    def _generate_source(
        self,
        argument: str,
        topic: str,
        **kwargs,  # pylint: disable=unused-argument
    ):
        """generates source"""
        source = self.templates["source_text"].render(
            topic_str=topic,
            argument_str=argument,
        )
        return source

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgKPExample
    ) -> DeepA2Item:
        # print("da2_item: %s" % da2_item)
        # print("prep_example: %s" % prep_example)
        source = self._generate_source(**dataclasses.asdict(prep_example))
        da2_item.source_text = source
        return da2_item


class AddArgument(Transformer):
    """adds premisses, conclusion, argdown, formalizations"""

    _TEMPLATE_STRINGS = {
        "negation": "it is not the case that {{ topic_str }}.",
        "connecting_premise": "if {{ key_point_str }} then {{ conclusion_str }}.",
        "argdown": "(1) {{ main_premise }} (2) {{ connecting_premise }} "
        "-- with modus ponens from (1) (2) -- (3) {{ conclusion }}",
    }

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgKPExample
    ) -> DeepA2Item:
        # use topic (possibly negated if stance == -1) as conclusion
        topic_str = prep_example.topic.lower().strip(". ")
        if prep_example.stance == 1:
            ctext = f"{topic_str}."
        else:
            ctext = self.templates["negation"].render(topic_str=topic_str)
        conclusion = [ArgdownStatement(text=ctext, ref_reco=3)]

        # use keypoint as first premise, add connecting premise
        premises = [
            ArgdownStatement(text=prep_example.key_point, ref_reco=1),
            ArgdownStatement(
                text=self.templates["connecting_premise"].render(
                    key_point_str=prep_example.key_point.strip("."),
                    conclusion_str=ctext.strip("."),
                ),
                ref_reco=2,
            ),
        ]

        # add argdown
        argdown = self.templates["argdown"].render(
            main_premise=premises[0].text,
            connecting_premise=premises[1].text,
            conclusion=conclusion[0].text,
        )

        da2_item.conclusion = conclusion
        da2_item.premises = premises
        da2_item.argdown_reconstruction = argdown

        return da2_item


class AddFormalization(Transformer):
    """adds premisses, conclusion, argdown, formalizations"""

    _TEMPLATE_STRINGS = {
        "negation": "not {{ s }}",
        "connecting_premise": "p -> {{ s }}",
    }

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgKPExample
    ) -> DeepA2Item:
        # formalize main premise
        p = "p"  # pylint: disable=invalid-name
        # formalize conclusion
        q = "q"  # pylint: disable=invalid-name
        if prep_example.stance == -1:
            q = self.templates["negation"].render(s=q)  # pylint: disable=invalid-name

        da2_item.premises_formalized = [
            Formalization(form=p, ref_reco=1),
            Formalization(
                form=self.templates["connecting_premise"].render(s=q),
                ref_reco=2,
            ),
        ]
        da2_item.conclusion_formalized = [Formalization(form=q, ref_reco=3)]
        da2_item.misc_placeholders = ["p", "q"]
        da2_item.plchd_substitutions = [
            ("p", prep_example.key_point.strip(".")),
            ("q", prep_example.topic.strip(".")),
        ]

        return da2_item


class ArgKPBuilder(PipedBuilder):
    """builds ArgKP dataset"""

    @staticmethod
    def preprocess(dataset: datasets.Dataset) -> datasets.Dataset:

        df_raw = pd.DataFrame(dataset.to_pandas())
        # stance
        df_raw.astype({"stance": "int32"})
        # arguments
        df_raw.argument = df_raw.argument.str.strip(" .").str.lower()
        df_raw.argument = df_raw.argument.apply(
            lambda s: s + "." if s[-1] not in ["?", "!"] else s
        )
        # arguments
        df_raw.key_point = df_raw.key_point.str.strip(" .").str.lower()
        df_raw.key_point = df_raw.key_point.apply(
            lambda s: s + "." if s[-1] not in ["?", "!"] else s
        )
        # topics
        df_raw["topic"] = df_raw["topic"].str.lower()

        # remove spare columns
        field_names = [
            field.name for field in dataclasses.fields(PreprocessedArgKPExample)
        ]
        df_raw = df_raw[field_names]

        # return as Dataset
        return datasets.Dataset.from_pandas(df=df_raw, preserve_index=False)

    def _construct_pipeline(self, **kwargs) -> Pipeline:
        pipeline = Pipeline(
            [
                AddSourceText(self),
                AddContext(self),
                AddArgument(self),
                AddFormalization(self),
            ]
        )
        return pipeline

    def set_input(self, batched_input: Dict[str, List]) -> None:
        prep_example = PreprocessedArgKPExample.from_batch(batched_input)
        self._input = prep_example