fedspendingtransparency/usaspending-api

View on GitHub
usaspending_api/conftest_helpers.py

Summary

Maintainability
A
2 hrs
Test Coverage
A
94%
import os
from builtins import Exception
from datetime import datetime, timezone
from pathlib import Path
from string import Template

from django.conf import settings
from django.core.management import call_command
from django.core.serializers.json import DjangoJSONEncoder, json
from django.db import DEFAULT_DB_ALIAS, connection
from elasticsearch import Elasticsearch
from pytest import Session

from usaspending_api.common.helpers.sql_helpers import ordered_dictionary_fetcher
from usaspending_api.common.helpers.text_helpers import generate_random_string
from usaspending_api.common.sqs.sqs_handler import (
    UNITTEST_FAKE_DEAD_LETTER_QUEUE_NAME,
    UNITTEST_FAKE_QUEUE_NAME,
    _FakeStatelessLoggingSQSDeadLetterQueue,
    _FakeUnitTestFileBackedSQSQueue,
)
from usaspending_api.etl.elasticsearch_loader_helpers import (
    TaskSpec,
    create_award_type_aliases,
    execute_sql_statement,
    transform_award_data,
    transform_covid19_faba_data,
    transform_transaction_data,
)
from usaspending_api.etl.elasticsearch_loader_helpers.index_config import create_load_alias


def is_pytest_xdist_parallel_sessions() -> bool:
    """Return True if the current tests executing are running in a pytest-xdist parallel test session,
    even if only 1 worker (1 master, and 1 worker) are configured"""
    worker_count = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", 0))
    return worker_count > 0


def is_pytest_xdist_master_process(session: Session):
    """Return True if running in pytest-xdist parallel test sessions and the session given is from the 'master'
    process, and not a session of one of the spawned workers.

    This is useful if needing to orchestration cross-session (cross-worker) setup and teardown of fixtures that
    should only run once, which can be done from the master process.
    See: https://pytest-xdist.readthedocs.io/en/stable/how-to.html#making-session-scoped-fixtures-execute-only-once
    And: https://github.com/pytest-dev/pytest-xdist/issues/271#issuecomment-826396320
    """
    workerinput = getattr(session.config, "workerinput", None)
    return is_pytest_xdist_parallel_sessions() and workerinput is None


def transform_xdist_worker_id_to_django_test_db_id(worker_id: str):
    """Align pytest-xdist global worker IDs instead to suffixes Django unittest uses for their test DBs
    e.g. 'gw0' -> '1', 'gw1' -> '2', ...
    If it's the master process, return empty string"""
    if not worker_id or worker_id == "master":
        return ""
    return str(int(worker_id.replace("gw", "")) + 1)


def is_safe_for_xdist_setup_or_teardown(session: Session, worker_id: str):
    """Use this in any session fixture whose setup must run EXACTLY ONCE for ALL workers
    if there are multiple parallel workers (e.g. via a pytest-xdist run with -n X or --numprocesses=X),
    and then only perform that exactly-once setup logic when this evalutes to True.
    For teardown, to ensure this cleanup runs after all workers complete their session, add cleanup logic to the pytest
    ``pytest_sessionfinish(session, exitstatus)`` fixture defined in this module, and guard it with this

    Example:
        Setup
        >>> @pytest.fixture(scope="session")
        >>> def do_some_setup_for_all_tests(session, worker_id):
        >>>     if not is_safe_for_xdist_setup_or_teardown(session, worker_id):
        >>>         yield
        >>>     else:
        >>>         # ... do setup here exactly once

        TearDown:
        >>> def pytest_sessionfinish(session, exitstatus):
        >>>     worker_id = get_xdist_worker_id(session)
        >>>     if is_safe_for_xdist_setup_or_teardown(session, worker_id):
        >>>         # ... do teardown here
    """
    return is_pytest_xdist_master_process(session) or worker_id == "master" or not worker_id


class TestElasticSearchIndex:
    def __init__(self, index_type):
        """
        We will be prefixing all aliases with the index name to ensure
        uniquity, otherwise, we may end up with aliases representing more
        than one index which may throw off our search results.
        """
        self.index_type = index_type
        self.index_name = self._generate_index_name()
        self.alias_prefix = self.index_name
        self.client = Elasticsearch([settings.ES_HOSTNAME], timeout=settings.ES_TIMEOUT)
        self.etl_config = {
            "load_type": self.index_type,
            "index_name": self.index_name,
            "query_alias_prefix": self.alias_prefix,
            "verbose": False,
            "verbosity": 0,
            "write_alias": self.index_name + "-load-alias",
            "process_deletes": True,
        }
        self.worker = TaskSpec(
            base_table=None,
            base_table_id=None,
            execute_sql_func=execute_sql_statement,
            field_for_es_id="award_id" if self.index_type == "award" else "transaction_id",
            index=self.index_name,
            is_incremental=None,
            name=f"{self.index_type} test worker",
            partition_number=None,
            primary_key="award_id" if self.index_type == "award" else "transaction_id",
            sql=None,
            transform_func=None,
            view=None,
        )

    def delete_index(self):
        self.client.indices.delete(self.index_name, ignore_unavailable=True)

    def update_index(self, load_index: bool = True, **options):
        """
        To ensure a fresh Elasticsearch index, delete the old one, update the
        materialized views, re-create the Elasticsearch index, create aliases
        for the index, and add contents.
        """
        self.delete_index()
        call_command("es_configure", "--template-only", f"--load-type={self.index_type}")
        self.client.indices.create(index=self.index_name)
        create_award_type_aliases(self.client, self.etl_config)
        create_load_alias(self.client, self.etl_config)
        self.etl_config["max_query_size"] = self._get_max_query_size()
        if load_index:
            self._add_contents(**options)

    def _get_max_query_size(self):
        upper_name = ""
        if self.index_type == "award":
            upper_name = "AWARDS"
        elif self.index_type == "covid19-faba":
            upper_name = "COVID19_FABA"
        elif self.index_type == "transaction":
            upper_name = "TRANSACTIONS"
        elif self.index_type == "recipient":
            upper_name = "RECIPIENTS"
        elif self.index_type == "location":
            upper_name = "LOCATIONS"
        return getattr(settings, f"ES_{upper_name}_MAX_RESULT_WINDOW")

    def _add_contents(self, **options):
        """
        Get all of the transactions presented in the view and stuff them into the Elasticsearch index.
        The view is only needed to load the transactions into Elasticsearch so it is dropped after each use.
        """
        if self.index_type == "award":
            view_sql_file = f"{settings.ES_AWARDS_ETL_VIEW_NAME}.sql"
            view_name = settings.ES_AWARDS_ETL_VIEW_NAME
            es_id = f"{self.index_type}_id"
        elif self.index_type == "covid19-faba":
            view_sql_file = f"{settings.ES_COVID19_FABA_ETL_VIEW_NAME}.sql"
            view_name = settings.ES_COVID19_FABA_ETL_VIEW_NAME
            es_id = "financial_account_distinct_award_key"
        elif self.index_type == "transaction":
            view_sql_file = f"{settings.ES_TRANSACTIONS_ETL_VIEW_NAME}.sql"
            view_name = settings.ES_TRANSACTIONS_ETL_VIEW_NAME
            es_id = f"{self.index_type}_id"
        elif self.index_type == "recipient":
            view_sql_file = f"{settings.ES_RECIPIENTS_ETL_VIEW_NAME}.sql"
            view_name = settings.ES_RECIPIENTS_ETL_VIEW_NAME
            es_id = "id"
        elif self.index_type == "location":
            view_sql_file = f"{settings.ES_LOCATIONS_ETL_VIEW_NAME}.sql"
            view_name = settings.ES_LOCATIONS_ETL_VIEW_NAME
            es_id = "id"
        else:
            raise Exception("Invalid index type")

        view_sql = open(str(settings.APP_DIR / "database_scripts" / "etl" / view_sql_file), "r").read()
        with connection.cursor() as cursor:
            cursor.execute(view_sql)
            cursor.execute(f"SELECT * FROM {view_name};")
            records = ordered_dictionary_fetcher(cursor)
            cursor.execute(f"DROP VIEW {view_name};")
            if self.index_type == "award":
                records = transform_award_data(self.worker, records)
            elif self.index_type == "transaction":
                records = transform_transaction_data(self.worker, records)
            elif self.index_type == "covid19-faba":
                records = transform_covid19_faba_data(
                    TaskSpec(
                        name="worker",
                        index=self.index_name,
                        sql=view_sql_file,
                        view=view_name,
                        base_table="financial_accounts_by_awards",
                        base_table_id="financial_accounts_by_awards_id",
                        field_for_es_id="financial_account_distinct_award_key",
                        primary_key="award_id",
                        partition_number=1,
                        is_incremental=False,
                    ),
                    records,
                )

        for record in records:
            # Special cases where we convert array of JSON to an array of strings to avoid nested types
            routing_key = options.get("routing", settings.ES_ROUTING_FIELD)
            routing_value = record.get(routing_key)

            if "_id" in record:
                es_id_value = record.pop("_id")
            else:
                es_id_value = record.get(es_id)

            self.client.index(
                index=self.index_name,
                body=json.dumps(record, cls=DjangoJSONEncoder),
                id=es_id_value,
                routing=routing_value,
            )
        # Force newly added documents to become searchable.
        self.client.indices.refresh(self.index_name)

    def _generate_index_name(self):
        required_suffix = ""
        if self.index_type == "award":
            required_suffix = "-" + settings.ES_AWARDS_NAME_SUFFIX
        elif self.index_type == "transaction":
            required_suffix = "-" + settings.ES_TRANSACTIONS_NAME_SUFFIX
        elif self.index_type == "covid19-faba":
            required_suffix = "-" + settings.ES_COVID19_FABA_NAME_SUFFIX
        elif self.index_type == "recipient":
            required_suffix = "-" + settings.ES_RECIPIENTS_NAME_SUFFIX
        return (
            f"test-{datetime.now(timezone.utc).strftime('%Y-%m-%d-%H-%M-%S-%f')}"
            f"-{generate_random_string()}"
            f"{required_suffix}"
        )


def ensure_broker_server_dblink_exists():
    """Ensure that all the database extensions exist, and the broker database is setup as a foreign data server

    This leverages SQL script files and connection strings in ``settings.DATABASES`` in order to setup these database
    objects. The connection strings for BOTH the USAspending and the Broker databases are needed,
    as the postgres ``SERVER`` setup needs to reference tokens in both connection strings.

    NOTE: An alternative way to run this through bash scripting is to first ``EXPORT`` the referenced environment
    variables, then run::

        psql --username ${USASPENDING_DB_USER} --dbname ${USASPENDING_DB_NAME} --echo-all --set ON_ERROR_STOP=on --set VERBOSITY=verbose --file usaspending_api/database_scripts/extensions/extensions.sql
        eval "cat <<< \"$(<usaspending_api/database_scripts/servers/broker_server.sql)\"" > usaspending_api/database_scripts/servers/broker_server.sql
        psql --username ${USASPENDING_DB_USER} --dbname ${USASPENDING_DB_NAME} --echo-all --set ON_ERROR_STOP=on --set VERBOSITY=verbose --file usaspending_api/database_scripts/servers/broker_server.sql

    """
    # Gather tokens from database connection strings
    if DEFAULT_DB_ALIAS not in settings.DATABASES:
        raise Exception("'{}' database not configured in django settings.DATABASES".format(DEFAULT_DB_ALIAS))
    if "data_broker" not in settings.DATABASES:
        raise Exception("'data_broker' database not configured in django settings.DATABASES")
    db_conn_tokens_dict = {
        **{"USASPENDING_DB_" + k: v for k, v in settings.DATABASES[DEFAULT_DB_ALIAS].items()},
        **{"BROKER_DB_" + k: v for k, v in settings.DATABASES["data_broker"].items()},
    }

    extensions_script_path = str(settings.APP_DIR / "database_scripts" / "extensions" / "extensions.sql")
    broker_server_script_path = str(settings.APP_DIR / "database_scripts" / "servers" / "broker_server.sql")

    with open(extensions_script_path) as f1, open(broker_server_script_path) as f2:
        extensions_script = f1.read()
        broker_server_script = f2.read()

    with connection.cursor() as cursor:
        # Ensure required extensions added to USAspending db
        cursor.execute(extensions_script)

        # Ensure foreign server setup to point to the broker DB for dblink to work
        broker_server_script_template = Template(broker_server_script)

        cursor.execute(broker_server_script_template.substitute(**db_conn_tokens_dict))


def get_unittest_fake_sqs_queue(*args, **kwargs):
    """Mocks sqs_handler.get_sqs_queue to instead return a fake queue used for unit testing"""
    if "queue_name" in kwargs and kwargs["queue_name"] == UNITTEST_FAKE_DEAD_LETTER_QUEUE_NAME:
        return _FakeStatelessLoggingSQSDeadLetterQueue()
    else:
        return _FakeUnitTestFileBackedSQSQueue.instance()


def remove_unittest_queue_data_files(queue_to_tear_down):
    """Delete any local files created to persist or access file-backed queue data from
    ``_FakeUnittestFileBackedSQSQueue``
    """
    q = queue_to_tear_down

    # Check that it's the unit test queue before removing
    assert q.url.split("/")[-1] == UNITTEST_FAKE_QUEUE_NAME
    queue_data_file = q._QUEUE_DATA_FILE
    lock_file_path = Path(queue_data_file + ".lock")
    if lock_file_path.exists():
        lock_file_path.unlink()
    queue_data_file_path = Path(queue_data_file)
    if queue_data_file_path.exists():
        queue_data_file_path.unlink()