debatelab/deepa2

View on GitHub
deepa2/datastructures.py

Summary

Maintainability
A
3 hrs
Test Coverage
A
95%
"""Basic data structures"""

from abc import ABC
import dataclasses
import logging
from typing import Any, Callable, List, Dict, Tuple, Optional

import datasets
from datasets import Value, Sequence


DA2_FEATURES = datasets.Features(
    {
        "argdown_reconstruction": Value(dtype="string", id=None),
        "conclusion": [
            {
                "explicit": Value(dtype="bool", id=None),
                "ref_reco": Value(dtype="int64", id=None),
                "text": Value(dtype="string", id=None),
            }
        ],
        "conclusion_formalized": [
            {
                "form": Value(dtype="string", id=None),
                "ref_reco": Value(dtype="int64", id=None),
            }
        ],
        "conjectures": [
            {
                "ref_reco": Value(dtype="int64", id=None),
                "starts_at": Value(dtype="int64", id=None),
                "text": Value(dtype="string", id=None),
            }
        ],
        "context": Value(dtype="string", id=None),
        "entity_placeholders": Sequence(
            feature=Value(dtype="string", id=None), length=-1, id=None
        ),
        "erroneous_argdown": Value(dtype="string", id=None),
        "gist": Value(dtype="string", id=None),
        "intermediary_conclusions": [
            {
                "explicit": Value(dtype="bool", id=None),
                "ref_reco": Value(dtype="int64", id=None),
                "text": Value(dtype="string", id=None),
            }
        ],
        "intermediary_conclusions_formalized": [
            {
                "form": Value(dtype="string", id=None),
                "ref_reco": Value(dtype="int64", id=None),
            }
        ],
        "metadata": Sequence(
            feature=Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            length=-1,
            id=None,
        ),
        "misc_placeholders": Sequence(
            feature=Value(dtype="string", id=None), length=-1, id=None
        ),
        "plchd_substitutions": Sequence(
            feature=Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            length=-1,
            id=None,
        ),
        "predicate_placeholders": Sequence(
            feature=Value(dtype="string", id=None), length=-1, id=None
        ),
        "premises": [
            {
                "explicit": Value(dtype="bool", id=None),
                "ref_reco": Value(dtype="int64", id=None),
                "text": Value(dtype="string", id=None),
            }
        ],
        "premises_formalized": [
            {
                "form": Value(dtype="string", id=None),
                "ref_reco": Value(dtype="int64", id=None),
            }
        ],
        "reasons": [
            {
                "ref_reco": Value(dtype="int64", id=None),
                "starts_at": Value(dtype="int64", id=None),
                "text": Value(dtype="string", id=None),
            }
        ],
        "source_paraphrase": Value(dtype="string", id=None),
        "source_text": Value(dtype="string", id=None),
        "title": Value(dtype="string", id=None),
    }
)


@dataclasses.dataclass
class DA2_ANGLES_MAP:  # pylint: disable=invalid-name,too-many-instance-attributes
    """maps key to DA2 features (`DeepA2Item`)"""

    s: str = "source_text"
    t: str = "title"
    g: str = "gist"
    h: str = "source_paraphrase"
    x: str = "context"
    a: str = "argdown_reconstruction"
    e: str = "erroneous_argdown"
    r: str = "reasons"
    j: str = "conjectures"
    p: str = "premises"
    i: str = "intermediary_conclusions"
    c: str = "conclusion"
    fp: str = "premises_formalized"
    fi: str = "intermediary_conclusions_formalized"
    fc: str = "conclusion_formalized"
    k: str = "plchd_substitutions"


@dataclasses.dataclass
class BaseExample(ABC):
    """Abstract Base Example dataclass"""

    @classmethod
    def from_batch(cls, batched_data: Dict[str, List]):
        """Unbatches data and returns a BaseExample"""
        unbatched_data = {k: v[0] for k, v in batched_data.items()}
        return cls(**unbatched_data)

    @classmethod
    def check_features(cls, dataset: datasets.DatasetDict):
        """Checks whether dataset has features of examples"""
        example_fields = [field.name for field in dataclasses.fields(cls)]
        fields_without_meta = [name for name in example_fields if name != "metadata"]
        for _, split in dataset.items():
            if set(split.column_names) not in (
                set(example_fields),
                set(fields_without_meta),
            ):
                logging.error(
                    "Features of dataset with examples (%s) "
                    "don't match example_class %s (%s).",
                    dataset.column_names,
                    cls,
                    example_fields,
                )
                raise ValueError(
                    "Features of dataset with examples don't match example_class."
                )


@dataclasses.dataclass
class DeepA2BaseItem(ABC):
    """Abstract BaseItem of da2 datastructures"""


@dataclasses.dataclass
class QuotedStatement(DeepA2BaseItem):
    """dataclass representing verbatim quote in da2item"""

    text: Optional[str] = ""
    starts_at: int = -1
    ref_reco: int = -1


@dataclasses.dataclass
class ArgdownStatement(DeepA2BaseItem):
    """dataclass representing argdown statement in da2item"""

    text: Optional[str] = ""
    explicit: Optional[bool] = None
    ref_reco: int = -1


@dataclasses.dataclass
class Formalization(DeepA2BaseItem):
    """dataclass representing formalization in da2item"""

    form: str = ""
    ref_reco: int = -1


@dataclasses.dataclass
class DeepA2Item(
    BaseExample, DeepA2BaseItem
):  # pylint: disable=too-many-instance-attributes
    """
    Dataclass defining the structure of a DeepA2 example.

    Attributes:
        source_text: source text that informally presents the reconstructed argument
        title: telling title of the reconstructed argument
        gist: very succinct summary of the argument, main point of the argument
        source_paraphrase: a maximally clear, though conservative summary of the argument
            (i.e., leavex out distractors, focuses on one argument, leaves out redundant parts,
            syntactic streamlining, adds inference indicators, but does generally not explicate
            implicit premises)
        context: provides informal or semi-formal description of the argument's context,
            ideally an argdown snippet (without detailed reconstruction) sketching the dialectic
            neighbourhood
        argdown_reconstruction: argdown snippet with reconstruction of the argument
        erroneous_argdown: a flawed reconstruction, similar to the correct one
        reasons: a list of reason statements (verbatim quotes from `source_text`)
        conjectures: a list of conjectures (verbatim quotes from `source_text`)
        premises: the premises of `argdown_reconstruction`
        intermediary_conclusions: the intermediary conclusions of `argdown_reconstruction`
        conclusion: the conclusion of `argdown_reconstruction`
        premises_formalized: formalizations of the `premises`
        intermediary_conclusions_formalized: formalizations of the `intermediary_conclusions`
        conclusion_formalized: formalizations of the `conclusion`
        predicate_placeholders: placeholders in formalizations
        entity_placeholders: placeholders in formalizations
        misc_placeholders: placeholders in formalizations
        plchd_substitutions: substitutions for placeholders
        metadata: metadata

    """

    source_text: Optional[str] = None

    title: Optional[str] = None
    gist: Optional[str] = None
    source_paraphrase: Optional[str] = None
    context: Optional[str] = None

    argdown_reconstruction: Optional[str] = None
    erroneous_argdown: Optional[str] = None

    reasons: Optional[List[QuotedStatement]] = None
    conjectures: Optional[List[QuotedStatement]] = None

    premises: Optional[List[ArgdownStatement]] = None
    intermediary_conclusions: Optional[List[ArgdownStatement]] = None
    conclusion: Optional[List[ArgdownStatement]] = None

    premises_formalized: Optional[List[Formalization]] = None
    intermediary_conclusions_formalized: Optional[List[Formalization]] = None
    conclusion_formalized: Optional[List[Formalization]] = None
    predicate_placeholders: Optional[List[str]] = None
    entity_placeholders: Optional[List[str]] = None
    misc_placeholders: Optional[List[str]] = None
    plchd_substitutions: Optional[List[Tuple[str, str]]] = None

    metadata: List[Tuple[str, Any]] = dataclasses.field(default_factory=lambda: [])

    @classmethod
    def from_batch(cls, batched_data: Dict[str, List]):
        """Unbatches data and returns a DeepA2Item"""
        unbatched_data = {k: v[0] for k, v in batched_data.items()}

        for field in dataclasses.fields(cls):
            item_class: Optional[Callable] = None
            if field.name in ["reasons", "conjectures"]:
                item_class = QuotedStatement
            elif field.name in ["premises", "intermediary_conclusions", "conclusion"]:
                item_class = ArgdownStatement
            elif field.name in [
                "premises_formalized",
                "intermediary_conclusions_formalized",
                "conclusion_formalized",
            ]:
                item_class = Formalization
            if item_class:
                if unbatched_data.get(field.name):
                    # field requires re-initialization
                    unbatched_data[field.name] = [
                        item_class(**item) for item in unbatched_data[field.name]
                    ]

        return cls(**unbatched_data)

    @staticmethod
    def angles() -> Dict[str, str]:
        """maps keys to field names of DeepA2 Item"""
        return dataclasses.asdict(DA2_ANGLES_MAP())


@dataclasses.dataclass
class GenerativeMode:
    """Generative mode da2 datastructures"""

    name: str
    target: str
    input: List[str]
    weight: float = 1.0

    @staticmethod
    def from_keys(name: str):
        """
        Tries to create GenerativeMode from keys formula `name`
        of the form

          `a+b+... => o`
        """

        if " => " not in name:
            return None

        input_str, target = name.split(" => ")[:2]

        if "+" in input_str:
            input_list = input_str.split("+")
            input_list = [s.strip() for s in input_list]
        else:
            input_list = [input_str.strip()]

        target = target.strip()

        if target not in DeepA2Item.angles() or any(
            key not in DeepA2Item.angles() for key in input_list
        ):
            return None

        mode = GenerativeMode(
            name=name,
            input=[DeepA2Item.angles()[key] for key in input_list],
            target=DeepA2Item.angles()[target],
        )

        return mode