fedspendingtransparency/usaspending-api

View on GitHub
usaspending_api/common/csv_stream_s3_to_pg.py

Summary

Maintainability
A
50 mins
Test Coverage
F
46%
"""
WARNING WARNING WARNING
!!!!!!!!!!!!!!!!!!!!!!!
This module must be managed very carefully and stay lean.

The main functions: copy_csv_from_s3_to_pg and copy_csvs_from_s3_to_pg
are used in distributed/parallel/multiprocess execution (by Spark) and is
pickled via cloudpickle. As such it must not have any presumed setup code that would have run (like Django setup,
logging configuration, etc.) and must encapsulate all of those dependencies (like logging config) on its own.

Adding new imports to this module may inadvertently introduce a dependency that can't be pickled.

As it stands, even if new imports are added to the modules it already imports, it could lead to a problem.
"""
import boto3
import gzip
import logging
import psycopg2
import tempfile
import time

from typing import Iterable, List

from botocore.client import BaseClient

from usaspending_api.common.helpers.s3_helpers import download_s3_object
from usaspending_api.common.logging import AbbrevNamespaceUTCFormatter, ensure_logging
from usaspending_api.config import CONFIG
from usaspending_api.settings import LOGGING

logger = logging.getLogger(__name__)


def _get_boto3_s3_client() -> BaseClient:
    if not CONFIG.USE_AWS:
        boto3_session = boto3.session.Session(
            region_name=CONFIG.AWS_REGION,
            aws_access_key_id=CONFIG.AWS_ACCESS_KEY.get_secret_value(),
            aws_secret_access_key=CONFIG.AWS_SECRET_KEY.get_secret_value(),
        )
        s3_client = boto3_session.client(
            service_name="s3",
            region_name=CONFIG.AWS_REGION,
            endpoint_url=f"http://{CONFIG.AWS_S3_ENDPOINT}",
        )
    else:
        s3_client = boto3.client(
            service_name="s3",
            region_name=CONFIG.AWS_REGION,
            endpoint_url=f"https://{CONFIG.AWS_S3_ENDPOINT}",
        )
    return s3_client


def _download_and_copy(
    configured_logger: logging.Logger,
    cursor: psycopg2._psycopg.cursor,
    s3_client: BaseClient,
    s3_bucket_name: str,
    s3_obj_key: str,
    target_pg_table: str,
    ordered_col_names: List[str],
    gzipped: bool,
    partition_prefix: str = "",
):

    """Download a CSV file from S3 then COPY it into a Postgres table using the SQL bulk COPY command. The CSV should
    not include a header row with column names
    """
    start = time.time()
    configured_logger.info(f"{partition_prefix}Starting write of {s3_obj_key}")

    with tempfile.NamedTemporaryFile(delete=True) as temp_file:
        download_s3_object(s3_bucket_name, s3_obj_key, temp_file.name, s3_client=s3_client)

        try:
            if gzipped:
                with gzip.open(temp_file.name, "rb") as csv_file:
                    cursor.copy_expert(
                        sql=f"COPY {target_pg_table} ({','.join(ordered_col_names)}) FROM STDIN (FORMAT CSV)",
                        file=csv_file,
                    )
            else:
                with open(temp_file.name, "r", encoding="utf-8") as csv_file:
                    cursor.copy_expert(
                        sql=f"COPY {target_pg_table} ({','.join(ordered_col_names)}) FROM STDIN (FORMAT CSV)",
                        file=csv_file,
                    )

            elapsed = time.time() - start
            rows_copied = cursor.rowcount
            configured_logger.info(
                f"{partition_prefix}Finished writing {rows_copied} row(s) in {elapsed:.3f}s for {s3_obj_key}"
            )
            yield rows_copied

        except Exception as exc:
            configured_logger.error(f"{partition_prefix}ERROR writing {s3_obj_key}")
            configured_logger.exception(exc)
            raise exc


def copy_csv_from_s3_to_pg(
    s3_bucket_name: str,
    s3_obj_key: str,
    db_dsn: str,
    target_pg_table: str,
    ordered_col_names: List[str],
    gzipped: bool = True,
    work_mem_override: int = None,
):
    """Download a CSV file from S3 then stream it into a Postgres table using the SQL bulk COPY command

    WARNING: See note above in module docstring about this function being pickle-able, and maintaining a lean set of
    outward dependencies on other modules/code that require setup/config.
    """
    ensure_logging(logging_config_dict=LOGGING, formatter_class=AbbrevNamespaceUTCFormatter, logger_to_use=logger)
    try:
        with psycopg2.connect(dsn=db_dsn) as connection:
            connection.autocommit = True
            with connection.cursor() as cursor:
                if work_mem_override:
                    cursor.execute("SET work_mem TO %s", (work_mem_override,))
                s3_client = _get_boto3_s3_client()
                results_generator = _download_and_copy(
                    configured_logger=logger,
                    cursor=cursor,
                    s3_client=s3_client,
                    s3_bucket_name=s3_bucket_name,
                    s3_obj_key=s3_obj_key,
                    target_pg_table=target_pg_table,
                    gzipped=gzipped,
                    ordered_col_names=ordered_col_names,
                )
                return list(results_generator)[0]
    except Exception as exc:
        logger.error(f"ERROR writing {s3_obj_key}")
        logger.exception(exc)
        raise exc


def copy_csvs_from_s3_to_pg(
    batch_num: int,
    s3_bucket_name: str,
    s3_obj_keys: Iterable[str],
    db_dsn: str,
    target_pg_table: str,
    ordered_col_names: List[str],
    gzipped: bool = True,
    work_mem_override: int = None,
):
    """An optimized form of ``copy_csv_from_s3_to_pg`` that can save on runtime by instantiating the psycopg2 DB
    connection and s3_client only once per partition, where a partition could represent processing several files
    """
    ensure_logging(logging_config_dict=LOGGING, formatter_class=AbbrevNamespaceUTCFormatter, logger_to_use=logger)
    s3_obj_keys = list(s3_obj_keys)  # convert from Iterator (generator) to a concrete List
    batch_size = len(s3_obj_keys)
    batch_start = time.time()
    partition_prefix = f"Partition#{batch_num}: "
    logger.info(f"{partition_prefix}Starting write of a batch of {batch_size} on partition {batch_num}")
    try:
        with psycopg2.connect(dsn=db_dsn) as connection:
            connection.autocommit = True
            with connection.cursor() as cursor:
                if work_mem_override:
                    cursor.execute("SET work_mem TO %s", (work_mem_override,))
                s3_client = _get_boto3_s3_client()
                for s3_obj_key in s3_obj_keys:
                    yield from _download_and_copy(
                        configured_logger=logger,
                        cursor=cursor,
                        s3_client=s3_client,
                        s3_bucket_name=s3_bucket_name,
                        s3_obj_key=s3_obj_key,
                        target_pg_table=target_pg_table,
                        ordered_col_names=ordered_col_names,
                        gzipped=gzipped,
                        partition_prefix=partition_prefix,
                    )
                batch_elapsed = time.time() - batch_start
                logger.info(
                    f"{partition_prefix}Finished writing batch of {batch_size} "
                    f"{'file' if batch_size == 1 else 'files'} on partition {batch_num} in {batch_elapsed:.3f}s"
                )
    except Exception as exc_batch:
        logger.error(f"{partition_prefix}ERROR writing batch/partition number {batch_num}")
        logger.exception(exc_batch)
        raise exc_batch