debatelab/deepa2

View on GitHub
deepa2/builder/arg_q_builder.py

Summary

Maintainability
A
45 mins
Test Coverage
D
66%
"""Defines Builder for creating DeepA2 datasets from IBM ArgQ data."""

from __future__ import annotations

import dataclasses
import logging
import pathlib
import random
import shutil
import sys
from typing import List, Dict, Union

import datasets

# import jinja2
import pandas as pd
from tqdm import tqdm  # type: ignore

from deepa2 import (
    QuotedStatement,
    DeepA2Item,
)
from deepa2.builder import (
    RawExample,
    PreprocessedExample,
    DatasetLoader,
    Builder,
    PipedBuilder,
    Pipeline,
    Transformer,
)
from deepa2.config import (
    data_dir,
)

tqdm.pandas()


@dataclasses.dataclass
class RawArgQExample(RawExample):  # pylint: disable=too-many-instance-attributes
    """
    Datatype describing a raw, unprocessed example
    in the Arg_Q_Rank dataset, possibly batched.
    """

    argument: Union[str, List[str]]
    topic: Union[str, List[str]]
    set: Union[str, List[str]]
    stance_WA: Union[str, List[str]]  # pylint: disable=invalid-name


@dataclasses.dataclass
class PreprocessedArgQExample(
    PreprocessedExample
):  # pylint: disable=too-many-instance-attributes
    """
    Datatype describing a preprocessed ArgQ example.
    """

    topic: str
    stance: int
    argument_stance_conf: str
    argument_stance_nonconf: str


class ArgQLoader(DatasetLoader):  # pylint: disable=too-few-public-methods
    """loads ArgQ raw data"""

    _ARG_Q_FILENAME = "arg_quality_rank_30k.csv"

    def __init__(self, **kwargs) -> None:
        super().__init__()
        if "path" not in kwargs:
            logging.info("No ArgQ file-path specified (via --path), assuming path is .")
        self._sourcepath: str = kwargs.get("path", f"./{self._ARG_Q_FILENAME}")

    def load_dataset(self) -> datasets.DatasetDict:

        # copy csv
        arg_q_source = pathlib.Path(self._sourcepath)
        if not arg_q_source.is_file():
            logging.error("No ArgQ file at %s found, exiting.", str(arg_q_source))
            sys.exit(-1)

        arg_q_raw = pathlib.Path(data_dir, "raw", "arg_q", "all.csv")
        arg_q_raw.parents[0].mkdir(parents=True, exist_ok=True)

        shutil.copy(arg_q_source, arg_q_raw)
        logging.info("Copied ArgQ file %s to %s.", str(arg_q_source), str(arg_q_raw))

        # load all argq data from disk and initialize dataset
        df_arg_q = pd.read_csv(arg_q_raw)
        df_arg_q = df_arg_q[["argument", "topic", "set", "stance_WA"]]

        splits_mapping = {"train": "train", "dev": "validation", "test": "test"}
        dataset_dict = {}
        for split_key, target_key in splits_mapping.items():
            dataset_dict[target_key] = datasets.Dataset.from_pandas(
                df_arg_q[df_arg_q["set"] == split_key],
                preserve_index=False,
            )

        dataset_dict = datasets.DatasetDict(dataset_dict)

        return dataset_dict


class AddSourceText(Transformer):
    """add source"""

    _TEMPLATE_STRINGS = {
        "conjecture": "{{ topic_str }}? {{ stance_str }}",
        "source_text": '{{ conjecture }} {{ args | join(" ") }}',
    }
    PRO_EXPRS = ["yes!", "absolutely!", "I agree!"]
    CON_EXPRS = ["no!", "not at all!", "I disagree!"]

    def __init__(self, builder: Builder) -> None:
        super().__init__(builder)
        self._random = random.Random()

    def _generate_source(
        self,
        stance: int,
        topic: str,
        argument_stance_conf: str,
        argument_stance_nonconf: str,
        **kwargs,  # pylint: disable=unused-argument
    ):
        """generates source text"""
        topic_str = topic
        stance_exprs = self.PRO_EXPRS if stance == 1 else self.CON_EXPRS
        stance_str = self._random.choice(stance_exprs)
        args = self._random.sample(
            [
                argument_stance_conf,
                argument_stance_nonconf,
            ],
            k=2,
        )

        conjecture = self.templates["conjecture"].render(
            topic_str=topic_str,
            stance_str=stance_str,
        )
        source = self.templates["source_text"].render(
            conjecture=conjecture,
            args=args,
        )
        return source, conjecture

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgQExample
    ) -> DeepA2Item:
        # print("da2_item: %s" % da2_item)
        # print("prep_example: %s" % prep_example)

        source, conjecture = self._generate_source(**dataclasses.asdict(prep_example))
        da2_item.source_text = source
        da2_item.metadata.append(("conjecture", conjecture))

        return da2_item


class AddReasons(Transformer):
    """add reasons"""

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgQExample
    ) -> DeepA2Item:
        # print("da2_item: %s" % da2_item)
        # print("prep_example: %s" % prep_example)
        text = prep_example.argument_stance_conf
        text = text.strip(" .")
        reasons = [QuotedStatement(text=text, ref_reco=1)]
        da2_item.reasons = reasons
        return da2_item


class AddConjectures(Transformer):
    """adds conjectures"""

    def transform(  # type: ignore[override]
        self, da2_item: DeepA2Item, prep_example: PreprocessedArgQExample
    ) -> DeepA2Item:
        text = dict(da2_item.metadata)["conjecture"]
        conjectures = [QuotedStatement(text=text, ref_reco=3)]
        da2_item.conjectures = conjectures
        return da2_item


class ArgQBuilder(PipedBuilder):
    """builds ArgQ dataset"""

    @staticmethod
    def preprocess(dataset: datasets.Dataset) -> datasets.Dataset:

        df_raw = pd.DataFrame(dataset.to_pandas())
        df_raw.astype({"stance_WA": "int32"})
        # preprocess arguments
        df_raw.argument = df_raw.argument.str.strip(" .").str.lower()
        df_raw.argument = df_raw.argument.apply(
            lambda s: s + "." if s[-1] not in ["?", "!"] else s
        )
        # to each topic-stance, assigns et of cooresponding arguments
        args_by_topicstance = df_raw.groupby(["topic", "stance_WA"]).apply(
            lambda g: g["argument"].tolist()
        )

        df_raw["argument_stance_conf"] = df_raw.argument

        def sample_ca(row) -> str:
            # select all args with same topic but OPPOSITE stance
            cas = args_by_topicstance[(row["topic"], -row["stance_WA"])]
            if not cas:
                return ""
            return random.choice(cas)

        df_raw["argument_stance_nonconf"] = df_raw.apply(sample_ca, axis=1)
        df_raw["stance"] = df_raw.stance_WA
        df_raw["topic"] = df_raw["topic"].str.lower()

        # remove spare columns
        field_names = [
            field.name for field in dataclasses.fields(PreprocessedArgQExample)
        ]
        df_raw = df_raw[field_names]

        # return as Dataset
        return datasets.Dataset.from_pandas(df=df_raw, preserve_index=False)

    def _construct_pipeline(self, **kwargs) -> Pipeline:
        pipeline = Pipeline(
            [
                AddSourceText(self),
                AddReasons(self),
                AddConjectures(self),
            ]
        )
        return pipeline

    def set_input(self, batched_input: Dict[str, List]) -> None:
        prep_example = PreprocessedArgQExample.from_batch(batched_input)
        self._input = prep_example