HazyResearch/fonduer

View on GitHub
src/fonduer/candidates/mentions.py

Summary

Maintainability
D
2 days
Test Coverage
"""Fonduer mention."""
import logging
import re
from builtins import map, range
from typing import Any, Collection, Dict, Iterable, Iterator, List, Optional, Set, Union

from sqlalchemy.orm import Session

from fonduer.candidates.matchers import _Matcher
from fonduer.candidates.models import Candidate, Mention
from fonduer.candidates.models.candidate import candidate_subclasses
from fonduer.candidates.models.caption_mention import TemporaryCaptionMention
from fonduer.candidates.models.cell_mention import TemporaryCellMention
from fonduer.candidates.models.document_mention import TemporaryDocumentMention
from fonduer.candidates.models.figure_mention import TemporaryFigureMention
from fonduer.candidates.models.paragraph_mention import TemporaryParagraphMention
from fonduer.candidates.models.section_mention import TemporarySectionMention
from fonduer.candidates.models.span_mention import TemporarySpanMention
from fonduer.candidates.models.table_mention import TemporaryTableMention
from fonduer.candidates.models.temporary_context import TemporaryContext
from fonduer.parser.models import Context, Document, Sentence
from fonduer.utils.udf import UDF, UDFRunner
from fonduer.utils.utils import get_dict_of_stable_id

logger = logging.getLogger(__name__)


class MentionSpace(object):
    """Define the **space** of Mention objects.

    Calling *apply(x)* given an object *x* returns a generator over mentions in
    *x*.
    """

    def __init__(self) -> None:
        """Initialize mention space."""
        pass

    def apply(self, x: Context) -> Iterator[TemporaryContext]:
        """Apply function takes a Context and return a mention generator.

        :param x: The input Context.
        :yield: The mention generator.
        """
        raise NotImplementedError()


class Ngrams(MentionSpace):
    """Define the space of Mentions as all n-grams in a Sentence.

    Define the space of Mentions as all n-grams (n_min <= n <= n_max) in a
    Sentence *x*, indexing by **character offset**.

    :param n_min: Lower limit for the generated n_grams.
    :param n_max: Upper limit for the generated n_grams.
    :param split_tokens: Tokens, on which unigrams are split into two separate
        unigrams.
    :type split_tokens: tuple, list of str.
    """

    def __init__(
        self, n_min: int = 1, n_max: int = 5, split_tokens: Collection[str] = []
    ) -> None:
        """Initialize Ngrams."""
        MentionSpace.__init__(self)
        self.n_min = n_min
        self.n_max = n_max
        self.split_rgx = (
            r"(" + r"|".join(map(re.escape, sorted(split_tokens, reverse=True))) + r")"
            if split_tokens and len(split_tokens) > 0
            else None
        )

    def apply(self, context: Sentence) -> Iterator[TemporarySpanMention]:
        """Apply function takes a Sentence and return a mention generator.

        :param x: The input Sentence.
        :yield: The mention generator.
        """
        # These are the character offset--**relative to the sentence
        # start**--for each _token_
        offsets = context.char_offsets

        # Loop over all n-grams in **reverse** order (to facilitate
        # longest-match semantics)
        L = len(offsets)
        seen: Set[TemporarySpanMention] = set()
        for j in range(self.n_min, self.n_max + 1)[::-1]:
            for i in range(L - j + 1):
                w = context.words[i + j - 1]
                start = offsets[i]
                end = offsets[i + j - 1] + len(w) - 1
                ts = TemporarySpanMention(
                    char_start=start, char_end=end, sentence=context
                )
                if ts not in seen:
                    seen.add(ts)
                    yield ts

                # Check for split
                if (
                    j == 1
                    and self.n_max >= 1
                    and self.n_min <= 1
                    and self.split_rgx is not None
                    and end - start > 0
                ):
                    text = context.text[start - offsets[0] : end - offsets[0] + 1]
                    start_idxs = [0]
                    end_idxs = []
                    for m in re.finditer(self.split_rgx, text):
                        start_idxs.append(m.end())
                        end_idxs.append(m.start())
                    end_idxs.append(len(text))
                    for start_idx in start_idxs:
                        for end_idx in end_idxs:
                            if start_idx < end_idx:
                                ts = TemporarySpanMention(
                                    char_start=start_idx,
                                    char_end=end_idx - 1,
                                    sentence=context,
                                )
                                if ts not in seen and ts.get_span():
                                    seen.add(ts)
                                    yield ts


class MentionNgrams(Ngrams):
    """Defines the **space** of Mentions as n-grams in a Document.

    Defines the space of Mentions as all n-grams (n_min <= n <= n_max) in a
    Document *x*, divided into Sentences inside of html elements (such as table
    cells).

    :param n_min: Lower limit for the generated n_grams.
    :param n_max: Upper limit for the generated n_grams.
    :param split_tokens: Tokens, on which unigrams are split into two separate
        unigrams.
    :type split_tokens: tuple, list of str.
    """

    def __init__(
        self, n_min: int = 1, n_max: int = 5, split_tokens: Collection[str] = []
    ) -> None:
        """Initialize MentionNgrams."""
        Ngrams.__init__(self, n_min=n_min, n_max=n_max, split_tokens=split_tokens)

    def apply(self, doc: Document) -> Iterator[TemporarySpanMention]:
        """Generate MentionNgrams from a Document by parsing all of its Sentences.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionNgrams.apply() must be of type Document"
            )

        for sentence in doc.sentences:
            for ts in Ngrams.apply(self, sentence):
                yield ts


class MentionFigures(MentionSpace):
    """Defines the space of Mentions as all figures in a Document *x*.

    :param types: If specified, only yield TemporaryFigureMentions whose url ends in
        one of the specified types. Example: types=["png", "jpg", "jpeg"].
    :type types: list, tuple of str
    """

    def __init__(self, types: Optional[str] = None) -> None:
        """Initialize MentionFigures."""
        MentionSpace.__init__(self)
        if types is not None:
            self.types = [t.strip().lower() for t in types]
        else:
            self.types = None

    def apply(self, doc: Document) -> Iterator[TemporaryFigureMention]:
        """
        Generate MentionFigures from a Document by parsing all of its Figures.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionFigures.apply() must be of type Document"
            )

        for figure in doc.figures:
            if self.types is None or any(
                figure.url.lower().endswith(type) for type in self.types
            ):
                yield TemporaryFigureMention(figure)


class MentionSentences(MentionSpace):
    """Defines the space of Mentions as all sentences in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionSentences."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporarySpanMention]:
        """
        Generate MentionSentences from a Document by parsing all of its Sentences.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionSentences.apply() must be of type Document"
            )

        for sentence in doc.sentences:
            yield TemporarySpanMention(
                char_start=0, char_end=len(sentence.text) - 1, sentence=sentence
            )


class MentionParagraphs(MentionSpace):
    """Defines the space of Mentions as all paragraphs in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionParagraphs."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporaryParagraphMention]:
        """
        Generate MentionParagraphs from a Document by parsing all of its Paragraphs.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionParagraphs.apply() must be of type Document"
            )

        for paragraph in doc.paragraphs:
            yield TemporaryParagraphMention(paragraph)


class MentionCaptions(MentionSpace):
    """Defines the space of Mentions as all captions in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionCaptions."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporaryCaptionMention]:
        """
        Generate MentionCaptions from a Document by parsing all of its Captions.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionCaptions.apply() must be of type Document"
            )

        for caption in doc.captions:
            yield TemporaryCaptionMention(caption)


class MentionCells(MentionSpace):
    """Defines the space of Mentions as all cells in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionCells."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporaryCellMention]:
        """
        Generate MentionCells from a Document by parsing all of its Cells.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionCells.apply() must be of type Document"
            )

        for cell in doc.cells:
            yield TemporaryCellMention(cell)


class MentionTables(MentionSpace):
    """Defines the space of Mentions as all tables in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionTables."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporaryTableMention]:
        """
        Generate MentionTables from a Document by parsing all of its Tables.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionTables.apply() must be of type Document"
            )

        for table in doc.tables:
            yield TemporaryTableMention(table)


class MentionSections(MentionSpace):
    """Defines the space of Mentions as all sections in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionSections."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporarySectionMention]:
        """
        Generate MentionSections from a Document by parsing all of its Sections.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionSections.apply() must be of type Document"
            )

        for section in doc.sections:
            yield TemporarySectionMention(section)


class MentionDocuments(MentionSpace):
    """Defines the space of Mentions as a document in a Document *x*."""

    def __init__(self) -> None:
        """Initialize MentionDocuments."""
        MentionSpace.__init__(self)

    def apply(self, doc: Document) -> Iterator[TemporaryDocumentMention]:
        """
        Generate MentionDocuments from a Document by using document.

        :param doc: The ``Document`` to parse.
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionDocuments.apply() must be of type Document"
            )

        yield TemporaryDocumentMention(doc)


class MentionExtractor(UDFRunner):
    """An operator to extract Mention objects from a Context.

    :Example:

        Assuming we want to extract two types of ``Mentions``, a Part and a
        Temperature, and we have already defined Matchers to use::

            part_ngrams = MentionNgrams(n_max=3)
            temp_ngrams = MentionNgrams(n_max=2)

            Part = mention_subclass("Part")
            Temp = mention_subclass("Temp")

            mention_extractor = MentionExtractor(
                session,
                [Part, Temp],
                [part_ngrams, temp_ngrams],
                [part_matcher, temp_matcher]
            )

    :param session: An initialized database session.
    :param mention_classes: The type of relation to extract, defined using
        :func: fonduer.mentions.mention_subclass.
    :param mention_spaces: one or list of :class:`MentionSpace` objects, one for
        each relation argument. Defines space of Contexts to consider
    :param matchers: one or list of :class:`fonduer.matchers.Matcher` objects,
        one for each relation argument. Only tuples of Contexts for which each
        element is accepted by the corresponding Matcher will be returned as
        Mentions
    :param parallelism: The number of processes to use in parallel for calls
        to apply().
    :raises ValueError: If mention classes, spaces, and matchers are not the
        same length.
    """

    def __init__(
        self,
        session: Session,
        mention_classes: List[Mention],
        mention_spaces: List[MentionSpace],
        matchers: List[_Matcher],
        parallelism: int = 1,
    ):
        """Initialize the MentionExtractor."""
        super().__init__(
            session,
            MentionExtractorUDF,
            parallelism=parallelism,
            mention_classes=mention_classes,
            mention_spaces=mention_spaces,
            matchers=matchers,
        )
        # Check that arity is same
        arity = len(mention_classes)
        if not all(
            len(x) == arity  # type: ignore
            for x in [mention_classes, mention_spaces, matchers]
        ):
            raise ValueError(
                "Mismatched arity of mention classes, spaces, and matchers."
            )

        self.mention_classes = mention_classes

    def apply(  # type: ignore
        self,
        docs: Collection[Document],
        clear: bool = True,
        parallelism: Optional[int] = None,
        progress_bar: bool = True,
    ) -> None:
        """Run the MentionExtractor.

        :Example: To extract mentions from a set of training documents using
            4 cores::

                mention_extractor.apply(train_docs, parallelism=4)

        :param docs: Set of documents to extract from.
        :param clear: Whether or not to clear the existing Mentions
            beforehand.
        :param parallelism: How many threads to use for extraction. This will
            override the parallelism value used to initialize the
            MentionExtractor if it is provided.
        :param progress_bar: Whether or not to display a progress bar. The
            progress bar is measured per document.
        """
        super().apply(
            docs, clear=clear, parallelism=parallelism, progress_bar=progress_bar
        )

    def clear(self) -> None:  # type: ignore
        """Delete Mentions of each class in the extractor from the given split."""
        # Create set of candidate_subclasses associated with each mention_subclass
        cand_subclasses = set()
        for mentions, tablename in [
            (_[1][0], _[1][1]) for _ in candidate_subclasses.values()
        ]:
            for mention in mentions:
                if mention in self.mention_classes:
                    cand_subclasses.add(tablename)

        # First, clear all the Mentions. This will cascade and remove the
        # mention_subclasses and corresponding candidate_subclasses.
        for mention_class in self.mention_classes:
            logger.info(f"Clearing table: {mention_class.__tablename__}")
            self.session.query(Mention).filter_by(
                type=mention_class.__tablename__
            ).delete(synchronize_session="fetch")

        # Next, clear the Candidates. This is done manually because we have
        # no cascading relationship from candidate_subclass to Candidate.
        for cand_subclass in cand_subclasses:
            logger.info(f"Cascading to clear table: {cand_subclass}")
            self.session.query(Candidate).filter_by(type=cand_subclass).delete(
                synchronize_session="fetch"
            )

    def clear_all(self) -> None:
        """Delete all Mentions from given split the database."""
        logger.info("Clearing ALL Mentions.")
        self.session.query(Mention).delete(synchronize_session="fetch")

        # With no Mentions, there should be no Candidates also
        self.session.query(Candidate).delete(synchronize_session="fetch")
        logger.info("Cleared ALL Mentions (and Candidates).")

    def get_mentions(
        self, docs: Union[Document, Iterable[Document], None] = None, sort: bool = False
    ) -> List[List[Mention]]:
        """Return a list of lists of the mentions associated with this extractor.

        Each list of the return will contain the Mentions for one of the
        mention classes associated with the MentionExtractor.

        :param docs: If provided, return Mentions from these documents. Else,
            return all Mentions.
        :param sort: If sort is True, then return all Mentions sorted by stable_id.
        :return: Mentions for each mention_class.
        """
        result = []
        if docs:
            docs = docs if isinstance(docs, Iterable) else [docs]
            # Get cands from all splits
            for mention_class in self.mention_classes:
                mentions = (
                    self.session.query(mention_class)
                    .filter(mention_class.document_id.in_([doc.id for doc in docs]))
                    .order_by(mention_class.id)
                    .all()
                )
                if sort:
                    mentions = sorted(mentions, key=lambda x: x[0].get_stable_id())
                result.append(mentions)
        else:
            for mention_class in self.mention_classes:
                mentions = (
                    self.session.query(mention_class).order_by(mention_class.id).all()
                )
                if sort:
                    mentions = sorted(mentions, key=lambda x: x[0].get_stable_id())
                result.append(mentions)
        return result


class MentionExtractorUDF(UDF):
    """UDF for performing mention extraction."""

    def __init__(
        self,
        mention_classes: Union[Mention, List[Mention]],
        mention_spaces: Union[MentionSpace, List[MentionSpace]],
        matchers: Union[_Matcher, List[_Matcher]],
        **kwargs: Any,
    ):
        """Initialize the MentionExtractorUDF."""
        self.mention_classes = (
            mention_classes
            if isinstance(mention_classes, (list, tuple))
            else [mention_classes]
        )
        self.mention_spaces = (
            mention_spaces
            if isinstance(mention_spaces, (list, tuple))
            else [mention_spaces]
        )
        self.matchers = matchers if isinstance(matchers, (list, tuple)) else [matchers]

        # Preallocates internal data structure
        self.child_context_set: Set[TemporaryContext] = set()

        super().__init__(**kwargs)

    def apply(self, doc: Document, **kwargs: Any) -> Document:
        """Extract mentions from the given Document.

        :param doc: A document to process.
        """
        # Get a dict of stable_id of contexts.
        dict_of_stable_id: Dict[str, Context] = get_dict_of_stable_id(doc)

        # Iterate over each mention class
        for i, mention_class in enumerate(self.mention_classes):
            # Generate TemporaryContexts that are children of the context using
            # the mention_space and filtered by the Matcher
            for child_context in self.matchers[i].apply(
                self.mention_spaces[i].apply(doc)
            ):
                # Skip if this temporary context is used by this mention class.
                stable_id = child_context.get_stable_id()
                if hasattr(doc, mention_class.__tablename__ + "s") and any(
                    [
                        m.context.stable_id == stable_id
                        for m in getattr(doc, mention_class.__tablename__ + "s")
                    ]
                ):
                    continue
                # Re-use a persisted context if exists.
                if stable_id in dict_of_stable_id:
                    context = dict_of_stable_id[stable_id]
                # Persist a temporary context.
                else:
                    context_type = child_context._get_table()
                    context = context_type(child_context)
                    dict_of_stable_id[stable_id] = context

                mention_args = {"document": doc, "context": context}

                # Add Mention to session
                mention_class(**mention_args)
        return doc