elliotwutingfeng/Inversion-DNSBL-Generator

View on GitHub
modules/safebrowsing.py

Summary

Maintainability
B
4 hrs
Test Coverage
"""
Safe Browsing API helper class
"""

import asyncio
import base64
import itertools
import json
from collections.abc import Iterator

from dotenv import dotenv_values
from more_itertools import flatten
from more_itertools.more import chunked
from tqdm import tqdm  # type: ignore

from modules.utils.http_requests import get_async, post_async
from modules.utils.log import init_logger
from modules.utils.types import Vendors

SAFEBROWSING_API_KEYS = {
    "Google": dotenv_values(".env").get("GOOGLE_API_KEY", ""),
    "Yandex": dotenv_values(".env").get("YANDEX_API_KEY", ""),
}

logger = init_logger()


class SafeBrowsing:
    """
    Safe Browsing API helper class
    """

    def __init__(self, vendor: Vendors) -> None:
        """Initialize Safe Browsing API helper class
        for a given `vendor` (e.g. "Google", "Yandex" etc.)

        Args:
            vendor (Vendors): Safe Browsing API vendor name (e.g. "Google", "Yandex" etc.)

        Raises:
            ValueError: `vendor` must be "Google" or "Yandex"
        """
        self.vendor = vendor

        if vendor not in ("Google", "Yandex"):
            raise ValueError('vendor must be "Google" or "Yandex"')

        endpoint_prefixes = {
            "Google": "https://safebrowsing.googleapis.com/v4/",
            "Yandex": "https://sba.yandex.net/v4/",
        }

        self.threatMatchesEndpoint = f"{endpoint_prefixes[vendor]}threatMatches:find?key={SAFEBROWSING_API_KEYS[vendor]}"
        self.threatListsEndpoint = f"{endpoint_prefixes[vendor]}threatLists?key={SAFEBROWSING_API_KEYS[vendor]}"
        self.threatListUpdatesEndpoint = f"{endpoint_prefixes[vendor]}threatListUpdates:fetch?key={SAFEBROWSING_API_KEYS[vendor]}"
        self.fullHashesEndpoint = f"{endpoint_prefixes[vendor]}fullHashes:find?key={SAFEBROWSING_API_KEYS[vendor]}"
        self.maximum_url_batch_size = {"Google": 500, "Yandex": 200}[vendor]
        # Even though Yandex API docs states maximum batch size limit as 500
        # Tested absolute maximum is batch size 300 (but fails often)
        # Somewhat stable: batch size 200
        # ¯\_(ツ)_/¯

    # Safe Browsing Lookup API
    def _threat_matches_payload(
        self,
        url_list: list[str],
    ) -> dict:
        """For a given list of URLs,
        generate a POST request payload for Safe Browsing API threatMatches endpoint.

        Google API Reference
        https://developers.google.com/safe-browsing/v4/lookup-api

        Yandex API Reference
        https://yandex.com/dev/safebrowsing/doc/quickstart/concepts/lookup.html

        Args:
            url_list (list[str]): URLs to add to Safe Browsing API threatMatches payload

        Returns:
            dict: Safe Browsing API threatMatches payload
        """
        return {
            "client": {
                "clientId": "yourcompanyname",
                "clientVersion": "1.5.2",
            },
            "threatInfo": {
                "threatTypes": [
                    "THREAT_TYPE_UNSPECIFIED",
                    "MALWARE",
                    "SOCIAL_ENGINEERING",
                    "UNWANTED_SOFTWARE",
                    "POTENTIALLY_HARMFUL_APPLICATION",
                ],
                "platformTypes": [
                    "PLATFORM_TYPE_UNSPECIFIED",
                    "WINDOWS",
                    "LINUX",
                    "ANDROID",
                    "OSX",
                    "IOS",
                    "ANY_PLATFORM",
                    "ALL_PLATFORMS",
                    "CHROME",
                ],
                "threatEntryTypes": [
                    "THREAT_ENTRY_TYPE_UNSPECIFIED",
                    "URL",
                    "EXECUTABLE",
                ],
                "threatEntries": [{"url": f"http://{url}"} for url in url_list],
            },
        }

    async def _threat_matches_lookup(
        self, url_batches: Iterator[list[str]]
    ) -> list[dict]:
        """Submit list of URLs to Safe Browsing API threatMatches endpoint
        and return the API response.

        Args:
            url_batches (Iterator[list[str]]): Batches of URLs to submit
            to Safe Browsing API threatMatches endpoint for inspection

        Returns:
            list[dict]: List of each URL batch's
            Safe Browsing API threatMatches response
        """
        endpoints: list[str] = []
        payloads: list[bytes] = []
        for url_batch in url_batches:
            # Make POST request for each sublist of URLs
            endpoints.append(self.threatMatchesEndpoint)
            payloads.append(
                json.dumps(self._threat_matches_payload(url_batch)).encode()
            )
        responses = await post_async(endpoints, payloads, max_concurrent_requests=10)

        return [json.loads(body) for _, body in responses]

    def lookup_malicious_urls(self, urls: set[str]) -> list[str]:
        """Identify all URLs in a given set of `urls` deemed by Safe Browsing API to be malicious.

        Args:
            urls (set[str]): URLs to be submitted to Safe Browsing API

        Returns:
            list[str]: URLs deemed by Safe Browsing API to be malicious
        """
        logger.info("Verifying suspected %s malicious URLs", self.vendor)
        # Split list of URLs into sublists of length == maximum_url_batch_size
        url_batches = chunked(urls, self.maximum_url_batch_size)
        logger.info("%d batches", -(-len(urls) // self.maximum_url_batch_size))

        results = asyncio.get_event_loop().run_until_complete(
            self._threat_matches_lookup(url_batches)
        )

        malicious = itertools.chain(
            *(res["matches"] for res in results if "matches" in res)
        )

        # Removes `https` and `http` prefixes
        malicious_urls = list(
            set(
                (
                    x.get("threat", {})
                    .get("url", "")
                    .replace("https://", "")
                    .replace("http://", "")
                    for x in malicious
                )
            )
        )

        logger.info(
            "%d URLs confirmed to be marked malicious by %s Safe Browsing API.",
            len(malicious_urls),
            self.vendor,
        )

        return malicious_urls

    # Safe Browsing Update API
    def retrieve_url_threatlist_combinations(self) -> list[dict]:
        """GET names of currently available Safe Browsing lists from threatLists endpoint.

        threatlists with substrings "ALLOWLIST" or "WHITELIST" in their threatType are omitted.

        See https://github.com/googleapis/google-api-go-client/blob/b16a2d3
        1763144ab92c4eba73aa5fcc5b418789d/safebrowsing/v4/safebrowsing-api.json#L591

        Returns:
            list[dict]: Names of currently available Safe Browsing lists from threatLists endpoint
        """
        threat_lists_endpoint_resp = asyncio.get_event_loop().run_until_complete(
            get_async([self.threatListsEndpoint])
        )[self.threatListsEndpoint]
        url_threatlist_combinations = (
            []
        )  # Empty list if self.threatListsEndpoint is unreachable
        if threat_lists_endpoint_resp != b"{}":
            threatlist_combinations = json.loads(threat_lists_endpoint_resp)[
                "threatLists"
            ]
            if self.vendor == "Google":
                url_threatlist_combinations = [
                    threatlist
                    for threatlist in threatlist_combinations
                    if all(
                        allowlist_term not in threatlist.get("threatType", "")
                        for allowlist_term in ("ALLOWLIST", "WHITELIST")
                    )
                    and threatlist.get("threatEntryType", "") in ("URL", "IP_RANGE")
                ]
            else:
                # Yandex API will return status code 204 with no content
                # if url_threatlist_combinations is too large
                url_threatlist_combinations = [
                    {
                        "threatType": "ANY",
                        "platformType": "ANY_PLATFORM",
                        "threatEntryType": "URL",
                        "state": "",
                    },
                    {
                        "threatType": "UNWANTED_SOFTWARE",
                        "threatEntryType": "URL",
                        "platformType": "PLATFORM_TYPE_UNSPECIFIED",
                        "state": "",
                    },
                    {
                        "threatType": "MALWARE",
                        "threatEntryType": "URL",
                        "platformType": "PLATFORM_TYPE_UNSPECIFIED",
                        "state": "",
                    },
                    {
                        "threatType": "SOCIAL_ENGINEERING",
                        "threatEntryType": "URL",
                        "platformType": "PLATFORM_TYPE_UNSPECIFIED",
                        "state": "",
                    },
                ]
        return url_threatlist_combinations

    def retrieve_threat_list_updates(
        self, url_threatlist_combinations: list[dict]
    ) -> dict:
        """Return threatListUpdates endpoint JSON response
        in Dictionary-form for all available lists.

        Google API Reference
        https://developers.google.com/safe-browsing/v4/update-api
        Yandex API Reference
        https://yandex.com/dev/safebrowsing/doc/quickstart/concepts/update-threatlist.html

        Args:
            url_threatlist_combinations (list[dict]): Names of currently available
            Safe Browsing lists from threatLists endpoint

        Returns:
            dict: Dictionary-form of Safe Browsing API threatListUpdates.fetch JSON response
            https://developers.google.com/safe-browsing/v4/reference/rest/v4/threatListUpdates/fetch
        """

        if url_threatlist_combinations:
            req_body = {
                "client": {
                    "clientId": "yourcompanyname",
                    "clientVersion": "1.5.2",
                },
                "listUpdateRequests": url_threatlist_combinations,
            }
            payload: bytes = json.dumps(req_body).encode()
            res = asyncio.get_event_loop().run_until_complete(
                post_async([self.threatListUpdatesEndpoint], [payload])
            )[0][1]

            res_json = json.loads(
                res
            )  # dict_keys(['listUpdateResponses', 'minimumWaitDuration'])
            if "listUpdateResponses" not in res_json:
                return {}
            logger.info("Minimum wait duration: %s", res_json["minimumWaitDuration"])
            return res_json

        return {}  # Empty dict() if url_threatlist_combinations is empty

    def get_malicious_url_hash_prefixes(self, threat_list_updates: dict) -> set[str]:
        """Download latest b64 encoded malicious URL hash prefixes from Safe Browsing API.

        The uncompressed threat entries in hash format of a particular prefix length.
        Hashes can be anywhere from 4 to 32 bytes in size. A large majority are 4 bytes,
        but some hashes are lengthened if they collide with the hash of a popular URL.

        Args:
            threat_list_updates (dict): Dictionary-form of Safe Browsing API
            threatListUpdates.fetch JSON response

        Returns:
            set[str]: b64 encoded malicious URL hash prefixes from Safe Browsing API
        """
        logger.info("Downloading %s malicious URL hash prefixes", self.vendor)
        if threat_list_updates == {}:
            logger.info(
                "Downloading %s malicious URL hash prefixes...[DONE:NO THREAT LISTS FOUND]",
                self.vendor,
            )
            return set()
        list_update_responses = threat_list_updates["listUpdateResponses"]

        hash_prefixes = set()

        for list_update_response in tqdm(list_update_responses):
            for addition in list_update_response.get("additions", []):
                raw_hash_prefixes_ = addition.get("rawHashes", dict())
                prefix_size: int = raw_hash_prefixes_.get("prefixSize", 0)
                if (not isinstance(prefix_size, int)) or prefix_size <= 0:
                    continue
                # decode b64 encoded string
                raw_hash_prefixes = base64.b64decode(
                    raw_hash_prefixes_.get("rawHashes", "")
                )
                # split them up into b64 encoded hash prefixes
                hashes_list = [
                    base64.b64encode(
                        raw_hash_prefixes[i : i + prefix_size]
                    ).decode()  # noqa: E203
                    for i in range(0, len(raw_hash_prefixes), prefix_size)
                ]

                hash_prefixes.update(hashes_list)
        logger.info("Downloading %s malicious URL hash prefixes...[DONE]", self.vendor)
        return hash_prefixes

    def get_malicious_url_full_hashes(
        self,
        hash_prefixes: set[str],
        url_threatlist_combinations: list[dict],
    ) -> Iterator[str]:
        """Download latest malicious URL full hashes from Safe Browsing API.

        Args:
            hash_prefixes (set[str]): b64 encoded malicious URL hash prefixes from Safe Browsing API
            url_threatlist_combinations (list[dict]): Names of currently available
            Safe Browsing lists from threatLists endpoint

        Returns:
            Iterator[str]: b64 encoded malicious URL full hashes from Safe Browsing API
        """
        logger.info("Downloading %s malicious URL full hashes", self.vendor)
        b64_encoded_hash_prefixes: list[str] = list(hash_prefixes)

        payloads: list[bytes] = [
            json.dumps(
                {
                    "client": {
                        "clientId": "yourcompanyname",
                        "clientVersion": "1.5.2",
                    },
                    "clientStates": [""],
                    "threatInfo": {
                        "threatTypes": list(
                            set(x["threatType"] for x in url_threatlist_combinations)
                        ),
                        "platformTypes": list(
                            set(x["platformType"] for x in url_threatlist_combinations)
                        ),
                        "threatEntryTypes": list(
                            set(
                                x["threatEntryType"]
                                for x in url_threatlist_combinations
                            )
                        ),
                        "threatEntries": [
                            {"hash": hashPrefix} for hashPrefix in hashPrefixesBatch
                        ],
                    },
                }
            ).encode()
            for hashPrefixesBatch in chunked(
                b64_encoded_hash_prefixes, self.maximum_url_batch_size
            )
        ]

        endpoints: list[str] = [self.fullHashesEndpoint] * len(payloads)
        responses: list[tuple] = asyncio.get_event_loop().run_until_complete(
            post_async(endpoints, payloads, max_concurrent_requests=10)  # type:ignore
        )
        logger.info("Downloading %s malicious URL full hashes...[DONE]", self.vendor)

        threat_matches: Iterator[dict] = flatten(
            json.loads(x[1]).get("matches", dict()) for x in responses
        )
        fullHashes: Iterator[str] = (
            x.get("threat", {}).get("hash", "")
            for x in threat_matches
            if x.get("threat", {}).get("hash", "") != ""
        )

        return fullHashes