fedspendingtransparency/usaspending-api

View on GitHub
usaspending_api/references/v2/views/filter_tree/psc_filter_tree.py

Summary

Maintainability
A
1 hr
Test Coverage
A
92%
import re

from django.db.models import Q
from string import ascii_uppercase, digits
from usaspending_api.references.models import PSC
from usaspending_api.references.v2.views.filter_tree.filter_tree import FilterTree


PSC_GROUPS = {
    # A
    "Research and Development": {
        "pattern": r"^A.$",
        "count_pattern": r"^A...$",
        "terms": ["A"],
        "expanded_terms": [["A"]],
    },
    # B - Z
    "Service": {
        "pattern": r"^[B-Z]$",
        "count_pattern": r"^[B-Z][A-Z0-9][A-Z0-9][A-Z0-9]$",
        "expanded_terms": [[letter] for letter in ascii_uppercase if letter != "A"],
        "terms": [letter for letter in ascii_uppercase if letter != "A"],
    },
    # 0 - 9
    "Product": {
        "pattern": r"^\d\d$",
        "count_pattern": r"^\d\d\d\d$",
        "expanded_terms": [[digit] for digit in digits],
        "terms": [digit for digit in digits],
    },
}


class PSCFilterTree(FilterTree):
    def raw_search(self, tiered_keys, child_layers, filter_string):
        if not self._path_is_valid(tiered_keys):
            return []
        top = len(tiered_keys)
        bottom = (child_layers if child_layers != -1 else 3) + top
        retval = tier3_nodes = tier2_nodes = tier1_nodes = []
        if bottom >= 3 or (top == 2 and (tiered_keys[0] == "Product" or tiered_keys[1] == "AU")):
            tier3_nodes = self.tier_3_search(tiered_keys, filter_string)
        if bottom >= 2:
            tier2_nodes = self.tier_2_search(tiered_keys, filter_string, tier3_nodes)
        if bottom >= 1:
            tier1_nodes = self.tier_1_search(tiered_keys, filter_string, tier2_nodes)

        if top == 3:
            retval = tier3_nodes
        if top <= 2:
            tier2_nodes = self._combine_nodes(tier2_nodes, tier3_nodes)
            if top == 2:
                if tiered_keys[0] == "Product" or tiered_keys[1] == "AU":
                    retval = tier3_nodes
                else:
                    retval = tier2_nodes
        if top <= 1:
            tier1_nodes = self._combine_nodes(tier1_nodes, tier2_nodes)
            if top == 1:
                retval = tier1_nodes
        if top == 0:
            toptier_nodes = self.toptier_search(filter_string, tier1_nodes + tier2_nodes)
            toptier_nodes = self._combine_nodes(toptier_nodes, tier2_nodes)
            toptier_nodes = self._combine_nodes(toptier_nodes, tier1_nodes)
            retval = toptier_nodes

        return retval

    def tier_3_search(self, ancestor_array, filter_string) -> list:
        filters = [Q(length=4)]
        if ancestor_array:
            parent = ancestor_array[-1]
            if len(parent) > 3:
                filters.append(Q(code__iregex=PSC_GROUPS.get(parent, {}).get("count_pattern") or "(?!)"))
            else:
                filters.append(Q(code__startswith=parent))
        if filter_string:
            filters.append(Q(Q(code__icontains=filter_string) | Q(description__icontains=filter_string)))
        retval = []
        results = PSC.objects.filter(*filters)
        for object in results:
            ancestors = []
            if object.code.isdigit():
                ancestors.append("Product")
                ancestors.append(object.code[:2])
            elif object.code[0] in PSC_GROUPS["Research and Development"]["terms"]:
                ancestors.append("Research and Development")
                ancestors.append(object.code[:2])
                # `AU` is a special case, it skips the length=3 codes, unlike other R&D PSCs
                if object.code[:2] != "AU":
                    ancestors.append(object.code[:3])
            else:
                ancestors.append("Service")
                ancestors.append(object.code[:1])
                ancestors.append(object.code[:2])
            retval.append(
                {
                    "id": object.code,
                    "ancestors": ancestors,
                    "description": object.description,
                    "count": 0,
                    "children": None,
                }
            )
        return sorted(retval, key=lambda x: x["id"])

    def tier_2_search(self, ancestor_array, filter_string, lower_tier_nodes=None) -> list:
        filters = [
            Q(
                Q(Q(length=2) & ~Q(code__startswith=PSC_GROUPS["Research and Development"]["terms"][0]))
                | Q(Q(length=3) & Q(code__startswith=PSC_GROUPS["Research and Development"]["terms"][0]))
            )
        ]
        query = Q()
        if ancestor_array:
            parent = ancestor_array[-1]
            if len(parent) > 3:
                query |= Q(code__iregex=PSC_GROUPS.get(parent, {}).get("pattern") or "(?!)")
            else:
                query |= Q(code__startswith=parent)
        if lower_tier_nodes:
            lower_tier_codes = [
                node["id"][:2]
                if node["id"][:2] == "AU"  # `AU` is a special case, it skips the length=3 codes, unlike other R&D PSCs
                or node["id"][0] not in PSC_GROUPS["Research and Development"]["terms"]
                else node["id"][:3]
                for node in lower_tier_nodes
            ]
            lower_tier_codes = list(dict.fromkeys(lower_tier_codes))
            for code in lower_tier_codes:
                query |= Q(code=code)
        if filter_string:
            query |= Q(Q(code__icontains=filter_string) | Q(description__icontains=filter_string))
        filters.append(query)
        retval = []
        for object in PSC.objects.filter(*filters):
            ancestors = []
            if object.code.isdigit():
                ancestors.append("Product")
            elif object.code[0] in PSC_GROUPS["Research and Development"]["terms"]:
                ancestors.append("Research and Development")
                ancestors.append(object.code[:2])
            else:
                ancestors.append("Service")
                ancestors.append(object.code[:1])
            retval.append(
                {
                    "id": object.code,
                    "ancestors": ancestors,
                    "description": object.description,
                    "count": self.get_count([object.code], object.code),
                    "children": None,
                }
            )
        return sorted(retval, key=lambda x: x["id"])

    def tier_1_search(self, ancestor_array, filter_string, lower_tier_nodes=None) -> list:
        filters = [Q(Q(Q(length=1) & Q(code__in=PSC_GROUPS["Service"]["terms"])) | Q(length=2))]
        query = Q()
        if ancestor_array:
            parent = ancestor_array[0]
            filters.append(Q(code__iregex=PSC_GROUPS.get(parent, {}).get("pattern") or "(?!)"))
        if lower_tier_nodes:
            lower_tier_codes = [node["id"][:-1] for node in lower_tier_nodes]
            lower_tier_codes = list(dict.fromkeys(lower_tier_codes))
            for code in lower_tier_codes:
                query |= Q(code=code)
        if filter_string:
            query |= Q(Q(code__icontains=filter_string) | Q(description__icontains=filter_string))
        filters.append(query)
        retval = []
        for object in PSC.objects.filter(*filters):
            ancestors = []
            if object.code.isdigit():
                ancestors.append("Product")
            elif object.code[0] in PSC_GROUPS["Research and Development"]["terms"]:
                ancestors.append("Research and Development")
            else:
                ancestors.append("Service")
            retval.append(
                {
                    "id": object.code,
                    "ancestors": ancestors,
                    "description": object.description,
                    "count": self.get_count([object.code], object.code),
                    "children": None,
                }
            )
        return sorted(retval, key=lambda x: x["id"])

    def toptier_search(self, filter_string, tier1_nodes=None):
        retval = []
        if not filter_string and not tier1_nodes:
            return [
                {"id": key, "ancestors": [], "description": "", "count": self.get_count([], key), "children": None}
                for key in PSC_GROUPS.keys()
            ]
        if tier1_nodes:
            toptier_codes = [node["id"][:1] for node in tier1_nodes]
            for key in PSC_GROUPS.keys():
                if set(toptier_codes).intersection(set(PSC_GROUPS[key]["terms"])):
                    retval.append(
                        {
                            "id": key,
                            "ancestors": [],
                            "description": "",
                            "count": self.get_count([], key),
                            "children": None,
                        }
                    )
        return sorted(retval, key=lambda x: x["id"])

    def _combine_nodes(self, upper_tier, lower_tier):
        for upper_node in upper_tier:
            children = []
            node_ids = [x["id"] for x in upper_node["children"]] if upper_node["children"] is not None else []
            for lower_node in lower_tier:
                if upper_node["id"] in lower_node["ancestors"] and lower_node["id"] not in node_ids:
                    children.append(lower_node)
            if len(children) > 0:
                upper_node["children"] = sorted(children, key=lambda x: x["id"])
        return upper_tier

    def _path_is_valid(self, path: list) -> bool:
        if len(path) > 1:
            if PSC_GROUPS.get(path[0]) is None or not re.match(PSC_GROUPS[path[0]]["pattern"], path[1]):
                return False
            for x in range(1, len(path) - 1):
                if not path[x + 1].startswith(path[x]):
                    return False
        return True

    def get_count(self, tiered_keys: list, id) -> int:
        if len(tiered_keys) == 0:
            filters = [Q(code__iregex=PSC_GROUPS.get(id, {}).get("count_pattern") or "(?!)")]
            return PSC.objects.filter(*filters).count()
        else:
            filters = [
                Q(length=4),
                Q(code__startswith=id),
            ]
            return PSC.objects.filter(*filters).count()