libratom/libratom

View on GitHub
libratom/lib/entities.py

Summary

Maintainability
A
0 mins
Test Coverage
# pylint: disable=broad-except,invalid-name,protected-access,consider-using-ternary
"""
Set of utility functions that use spaCy to perform named entity recognition
"""

import logging
import multiprocessing
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple

from sqlalchemy.orm.session import Session

from libratom.lib.base import AttachmentMetadata
from libratom.lib.concurrency import get_messages, imap_job, worker_init
from libratom.lib.constants import (
    RATOM_DB_COMMIT_BATCH_SIZE,
    RATOM_MSG_BATCH_SIZE,
    RATOM_MSG_PROGRESS_STEP,
    RATOM_SPACY_MODEL_MAX_LENGTH,
    BodyType,
)
from libratom.lib.core import get_cached_spacy_model
from libratom.lib.headers import (
    get_header_field_type_mapping,
    populate_header_field_types,
)
from libratom.lib.utils import cleanup_message_body
from libratom.models import Attachment, Entity, FileReport, HeaderField, Message

logger = logging.getLogger(__name__)


@imap_job
def process_message(
    filepath: str,
    message_id: int,
    body: str,
    body_type: BodyType,
    date: datetime,
    attachments: List[AttachmentMetadata],
    spacy_model_name: str,
    headers: Optional[str] = None,
    include_message_contents: bool = False,
) -> Tuple[Dict, Optional[str]]:
    """
    Job function for the worker processes
    """

    # Return basic types to avoid serialization issues
    res = {
        "filepath": filepath,
        "message_id": message_id,
        "date": date,
        "processing_start_time": datetime.utcnow(),
        "attachments": attachments,
    }

    try:
        # Extract entities from the message
        message_body = cleanup_message_body(
            body, body_type, RATOM_SPACY_MODEL_MAX_LENGTH
        )

        spacy_model = get_cached_spacy_model(spacy_model_name)
        doc = spacy_model(message_body)
        res["entities"] = [(ent.text, ent.label_) for ent in doc.ents]

        res["processing_end_time"] = datetime.utcnow()

        if include_message_contents:
            res["body"] = message_body
            res["headers"] = headers

        return res, None

    except Exception as exc:
        return res, str(exc)


def extract_entities(
    files: Iterable[Path],
    session: Session,
    spacy_model_name: str,
    include_message_contents: bool = False,
    jobs: int = None,
    processing_progress_callback: Callable = None,
    reporting_progress_callback: Callable = None,
    **kwargs,
) -> int:
    """
    Main entity extraction function that extracts named entities from a given iterable of files

    Spawns multiple processes via multiprocessing.Pool
    """

    # Confirm environment settings
    for setting_name, setting_value in globals().items():
        if setting_name.startswith("RATOM_"):
            logger.debug(f"{setting_name}: {setting_value}")

    # Default progress callbacks to no-op
    processing_update_progress = processing_progress_callback or (lambda *_, **__: None)
    reporting_update_progress = reporting_progress_callback or (lambda *_, **__: None)

    # Load the file_report table for local lookup
    _file_reports = session.query(FileReport).all()  # noqa: F841

    # Add header field type table
    if include_message_contents:
        populate_header_field_types(session)

    # Cache header field types into local mapping,
    # empty if header field type table was not created
    header_field_type_mapping = get_header_field_type_mapping(session)

    # Start of multiprocessing
    ctx = multiprocessing.get_context(
        "spawn" if spacy_model_name.endswith("_trf") else None
    )  # https://github.com/explosion/spaCy/issues/6662

    with ctx.Pool(processes=jobs, initializer=worker_init) as pool:

        logger.debug(f"Starting pool with {pool._processes} processes")

        new_entities = []
        msg_count = 0

        try:
            for msg_count, worker_output in enumerate(
                pool.imap_unordered(
                    process_message,
                    get_messages(
                        files,
                        spacy_model_name=spacy_model_name,
                        progress_callback=processing_update_progress,
                        include_message_contents=include_message_contents,
                        with_headers=include_message_contents,
                        **kwargs,
                    ),
                    chunksize=RATOM_MSG_BATCH_SIZE,
                ),
                start=1,
            ):

                # Unpack worker job output
                res, error = worker_output

                if error:
                    logger.info(
                        # pylint: disable=consider-using-f-string
                        "Skipping message {message_id} from {filepath}".format(**res)
                    )
                    logger.debug(error)

                    continue

                # Extract results
                entities = res.pop("entities")
                message_id = res.pop("message_id")
                filepath = res.pop("filepath")
                attachments = res.pop("attachments")

                # Create new message instance
                message = Message(pff_identifier=message_id, **res)

                # Link message to a file_report
                try:
                    file_report = (
                        session.query(FileReport).filter_by(path=filepath).one()
                    )
                except Exception as exc:
                    file_report = None
                    logger.info(
                        f"Unable to link message id {message_id} to a file. Error: {exc}"
                    )

                message.file_report = file_report
                session.add(message)

                # Record attachment info
                session.add_all(
                    [
                        Attachment(
                            **asdict(attachment),
                            message=message,
                            file_report=file_report,
                        )
                        for attachment in attachments
                    ]
                )

                # Record header fields
                if include_message_contents:
                    header_fields = []

                    for line in (res.get("headers") or "").splitlines():
                        try:
                            header_name, header_value = line.split(":", maxsplit=1)
                        except ValueError:
                            continue
                        if header_field_type := header_field_type_mapping.get(
                            header_name.lower()
                        ):
                            header_fields.append(
                                HeaderField(
                                    header_field_type=header_field_type,
                                    value=header_value,
                                    message=message,
                                )
                            )

                    session.add_all(header_fields)

                # Record entities info
                for entity in entities:
                    new_entities.append(
                        Entity(
                            text=entity[0],
                            label_=entity[1],
                            filepath=filepath,
                            message=message,
                            file_report=file_report,
                        )
                    )

                # Commit if we reach a certain amount of new entities
                if len(new_entities) >= RATOM_DB_COMMIT_BATCH_SIZE:
                    session.add_all(new_entities)
                    new_entities = []
                    try:
                        session.commit()
                    except Exception as exc:
                        logger.exception(exc)
                        session.rollback()

                # Update progress every N messages
                if not msg_count % RATOM_MSG_PROGRESS_STEP:
                    reporting_update_progress(RATOM_MSG_PROGRESS_STEP)

            # Add remaining new entities
            session.add_all(new_entities)

            # Update progress with remaining message count
            reporting_update_progress(msg_count % RATOM_MSG_PROGRESS_STEP)

        except KeyboardInterrupt:
            logger.warning("Cancelling running task")
            logger.info("Partial results written to database")
            logger.info("Terminating workers")

            # Clean up process pool
            pool.terminate()
            pool.join()

            return 1

    return 0