fedspendingtransparency/usaspending-api

View on GitHub
usaspending_api/etl/elasticsearch_loader_helpers/controller_for_spark.py

Summary

Maintainability
A
0 mins
Test Coverage
F
0%
import logging
from math import ceil
from time import perf_counter
from typing import List, Tuple

from django.conf import settings
from pyspark.sql import SparkSession

from usaspending_api.broker.helpers.last_load_date import get_earliest_load_date, update_last_load_date
from usaspending_api.common.elasticsearch.client import instantiate_elasticsearch_client
from usaspending_api.common.etl.spark import build_ref_table_name_list, create_ref_temp_views
from usaspending_api.common.helpers.spark_helpers import clean_postgres_sql_for_spark_sql
from usaspending_api.etl.elasticsearch_loader_helpers import (
    TaskSpec,
    count_of_records_to_process_in_delta,
    delete_awards,
    delete_transactions,
    format_log,
    load_data,
    obtain_extract_all_partitions_sql,
)
from usaspending_api.etl.elasticsearch_loader_helpers.controller import AbstractElasticsearchIndexerController

logger = logging.getLogger("script")


class DeltaLakeElasticsearchIndexerController(AbstractElasticsearchIndexerController):
    """Controller for Spark-based Elasticsearch ETL that extracts data from Delta Lake"""

    def __init__(self, config: dict, spark: SparkSession, spark_created_by_command: bool = False):
        super(DeltaLakeElasticsearchIndexerController, self).__init__(config)
        self.spark = spark
        self.spark_created_by_command = spark_created_by_command

    def ensure_view_exists(self, sql_view_name: str, force_recreate=True) -> None:
        view_exists = len(list(self.spark.sql(f"show views like '{sql_view_name}'").collect())) == 1
        if view_exists and not force_recreate:
            return

        # Ensure reference tables the TEMP VIEW may depend on exist
        create_ref_temp_views(self.spark)

        view_file_path = settings.APP_DIR / "database_scripts" / "etl" / f"{sql_view_name}.sql"

        view_sql = view_file_path.read_text()

        # Find/replace SQL strings in Postgres-based SQL to make it Spark SQL compliant
        # WARNING: If the SQL changes, it must be tested to still be Spark SQL compliant, and changes here may be needed
        temp_view_select_sql = view_sql.replace(f"DROP VIEW IF EXISTS {sql_view_name};", "")

        identifier_replacements = {}
        if self.config["load_type"] == "transaction":
            identifier_replacements["transaction_search"] = "rpt.transaction_search"
        elif self.config["load_type"] == "award":
            identifier_replacements["award_search"] = "rpt.award_search"
        elif self.config["load_type"] == "covid19-faba":
            identifier_replacements["financial_accounts_by_awards"] = "int.financial_accounts_by_awards"
            identifier_replacements["vw_awards"] = "int.awards"
        elif self.config["load_type"] == "recipient":
            identifier_replacements["recipient_profile"] = "rpt.recipient_profile"
        elif self.config["load_type"] == "location":
            # Replace the Postgres regex operator with the Databricks regex operator
            identifier_replacements["~"] = "rlike"
            identifier_replacements["state_data"] = "global_temp.state_data"
        else:
            raise ValueError(
                f"Unrecognized load_type {self.config['load_type']}, or this function does not yet support it"
            )

        temp_view_select_sql = clean_postgres_sql_for_spark_sql(
            temp_view_select_sql, build_ref_table_name_list(), identifier_replacements
        )

        self.spark.sql(
            f"""
        {temp_view_select_sql}
        """
        )

    def _count_of_records_to_process(self, config) -> Tuple[int, int, int]:
        return count_of_records_to_process_in_delta(config, self.spark)

    def determine_partitions(self) -> int:
        """Simple strategy to divide total record count by ideal partition size"""
        return ceil(self.record_count / self.config["partition_size"])

    def get_id_range_for_partition(self, partition_number: int) -> Tuple[int, int]:
        raise NotImplementedError(
            "Delta Lake data is extracted into a Spark DataFrame that is partitioned. No need to get each partition "
            "by ID ranges."
        )

    def prepare_for_etl(self) -> None:
        spark_executors = self.spark.sparkContext.defaultParallelism
        logger.info(
            format_log(
                f"Overriding specified --processes and setting to configured executors "
                f"on the Spark cluster = {spark_executors}"
            )
        )
        self.config["processes"] = self.spark.sparkContext.defaultParallelism
        super().prepare_for_etl()

    def configure_task(self, partition_number: int, task_name: str, is_null_partition: bool = False) -> TaskSpec:
        # Spark-based approach maps indexing functions to partitions of data, rather than extracting and processing
        # the data, so extract sql is not used
        return self._construct_task_spec(partition_number, task_name, extract_sql_str=None)

    def dispatch_tasks(self) -> None:
        extract_sql = obtain_extract_all_partitions_sql(self.config)
        extract_sql = clean_postgres_sql_for_spark_sql(extract_sql)
        logger.info(format_log(f"Using extract_sql:\n{extract_sql}", action="Extract"))
        df = self.spark.sql(extract_sql)
        df_record_count = df.count()  # safe to doublecheck the count of the *actual* data being processed

        if self.config["extra_null_partition"]:
            # Data which may have a "NULL Partition" is parent-child grouped data, where child records are grouped by
            # the config["primary_key"], which is the PK field of the parent records.
            # It is imperative that only 1 indexing operation per parent document (per primary_key value) happen,
            # and encompass all of its nested child documents, otherwise subsequent indexing of that parent document
            # will overwrite the prior documents.
            # For this to happen, ALL data for a parent document (primar_key) must exist in the same partition of
            # data being transformed and loaded. This can be achieved by providing the grouping to repartition,
            # so all records with that same value will end up in the same partition
            partition_field = self.config["primary_key"]
            msg = (
                f"Repartitioning {df_record_count} records from {df.rdd.getNumPartitions()} partitions into "
                f"{self.config['partitions']} partitions. Repartitioning by the {partition_field} field so that all "
                f"records that share the same value for this field will be in the same partition to support nesting "
                f"all child documents under the parent record. Partitions should generally be less than "
                f"{self.config['partition_size']} records. Exceptions are if a parent has more than "
                f"{self.config['partition_size']} child documents, and the 'NULL partition', which is all child "
                f"documents that have NULL for the {partition_field} field. Then handing each partition to available "
                f"executors for processing."
            )
            logger.info(format_log(msg))
            logger.info(
                format_log(
                    "Partition-processing task logs will be embedded in executor stderr logs, and not appear here."
                )
            )
            df = df.repartition(self.config["partitions"], partition_field)
        else:
            msg = (
                f"Repartitioning {df_record_count} records from {df.rdd.getNumPartitions()} partitions into "
                f"{self.config['partitions']} partitions to evenly balance no more than "
                f"{self.config['partition_size']} "
                f"records per partition. Then handing each partition to available executors for processing."
            )
            logger.info(format_log(msg))
            logger.info(
                format_log(
                    "Partition-processing task logs will be embedded in executor stderr logs, and not appear here."
                )
            )
            df = df.repartition(self.config["partitions"])

        # Create a clean/detached copy of this dict. Referencing self within the lambda will attempt to pickle the
        # self object, which has a reference to the SparkContext. SparkContext references CANNOT be pickled
        task_dict = {**self.tasks}

        # Map the indexing function to each of the created partitions of the DataFrame
        success_fail_stats = df.rdd.mapPartitionsWithIndex(
            lambda partition_idx, partition_data: transform_and_load_partition(
                partition_data=partition_data,
                task=task_dict[partition_idx],
            ),
            preservesPartitioning=True,
        ).collect()

        successes, failures = 0, 0
        for sf in success_fail_stats:
            successes += sf[0]
            failures += sf[1]
        msg = f"Total documents indexed: {successes}, total document fails: {failures}"
        logger.info(format_log(msg))

    def _run_award_deletes(self):
        client = instantiate_elasticsearch_client()
        delete_awards(
            client=client,
            config=self.config,
            fabs_external_data_load_date_key="transaction_fabs",
            fpds_external_data_load_date_key="transaction_fpds",
            spark=self.spark,
        )

    def _run_transaction_deletes(self):
        client = instantiate_elasticsearch_client()
        delete_transactions(
            client=client,
            config=self.config,
            fabs_external_data_load_date_key="transaction_fabs",
            fpds_external_data_load_date_key="transaction_fpds",
            spark=self.spark,
        )
        # Use the lesser of the fabs/fpds load dates as the es_deletes load date. This
        # ensures all records deleted since either job was run are taken into account
        # Using the loaded-from-DELTA-tables dates here, not the postgres table load dates
        last_db_delete_time = get_earliest_load_date(["transaction_fabs", "transaction_fpds"])
        update_last_load_date("es_deletes", last_db_delete_time)

    def cleanup(self) -> None:
        if self.spark_created_by_command:
            self.spark.stop()


def transform_and_load_partition(task: TaskSpec, partition_data) -> List[Tuple[int, int]]:
    start = perf_counter()
    msg = f"Started processing on partition #{task.partition_number}: {task.name}"
    logger.info(format_log(msg, name=task.name))

    client = instantiate_elasticsearch_client()
    try:
        if task.transform_func is not None:
            records = task.transform_func(task, [row.asDict() for row in partition_data])
        else:
            records = [row.asDict() for row in partition_data]
        if len(records) > 0:
            success, fail = load_data(task, records, client)
        else:
            logger.info(format_log("No records to index", name=task.name))
            success, fail = 0, 0
    except Exception as exc:
        logger.exception(format_log(f"{task.name} failed!", name=task.name), exc)
        raise exc
    else:
        msg = f"Partition #{task.partition_number} was successfully processed in {perf_counter() - start:.2f}s"
        logger.info(format_log(msg, name=task.name))
    return [(success, fail)]