src/skjold/sources/github.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
import json
import os
import urllib.request
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple

import click
from packaging import specifiers
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version

from skjold.core import Dependency, SecurityAdvisory, SecurityAdvisorySource
from skjold.tasks import register_source


class GithubSecurityAdvisory(SecurityAdvisory):
    _json: Dict

    @classmethod
    def using(cls, json_: dict) -> "GithubSecurityAdvisory":
        obj = cls()
        obj._json = json_
        return obj

    @property
    def identifier(self) -> str:
        return str(self.__advisory["ghsaId"])

    @property
    def source(self) -> str:
        return "github"

    @property
    def references(self) -> List[str]:
        return [reference["url"] for reference in self.__advisory["references"]]

    @property
    def package_name(self) -> str:
        return str(self._json["node"]["package"]["name"])

    @property
    def canonical_name(self) -> NormalizedName:
        return canonicalize_name(self.package_name)

    @property
    def ecosystem(self) -> str:
        return str(self._json["node"]["package"]["ecosystem"])

    @property
    def severity(self) -> str:
        return str(self._json["node"]["severity"])

    @property
    def __advisory(self) -> Any:
        return self._json["node"]["advisory"]

    @property
    def summary(self) -> str:
        return str(self.__advisory["summary"])

    @property
    def first_patched_version(self) -> str:
        return str(self._json["node"]["firstPatchedVersion"]["identifier"])

    @property
    def vulnerable_version_range(self) -> specifiers.SpecifierSet:
        items = self._json["node"]["vulnerableVersionRange"].split(",")
        if len(items) > 2:
            raise ValueError(f"Found more than 2 version specifiers!")

        vulnerable_ranges = []
        for value in items:
            value = value.strip()
            if value.startswith("= "):
                vulnerable_ranges.append(value.replace("= ", "=="))
            else:
                vulnerable_ranges.append(value)

        return specifiers.SpecifierSet(",".join(vulnerable_ranges), prereleases=True)

    @property
    def vulnerable_versions(self) -> str:
        return str(self.vulnerable_version_range)

    def is_affected(self, version: str) -> bool:
        version_ = Version(version)
        return version_ in self.vulnerable_version_range

    @property
    def url(self) -> str:
        return f"https://github.com/advisories/{self.identifier}"


def _query_github_graphql(
    first: int = 10, after: Optional[str] = None
) -> Tuple[int, str, bool, dict]:
    _after = after and f'"{after}"' or "null"
    _limit = first and int(first) or 1

    try:
        _api_token = os.environ["SKJOLD_GITHUB_API_TOKEN"]
    except KeyError:
        raise click.UsageError(
            "It looks like you're trying to use the github source without providing an access token."
            "Skjold needs a token to access the Github GraphQL API."
            "You can generate a token at https://github.com/settings/tokens without giving it any permissions "
            "and make it available to skjold by setting SKJOLD_GITHUB_API_TOKEN in your environment."
        )

    query = f"""
    {{
        securityVulnerabilities(first: {_limit}, after: {_after}, ecosystem: PIP, orderBy: {{ field: UPDATED_AT, direction: DESC }}) {{
            pageInfo {{
                startCursor
                hasNextPage
                endCursor
            }}
            totalCount
            edges {{
                node {{
                    advisory {{
                        ghsaId
                        publishedAt
                        references {{
                            url
                        }}
                        summary
                    }}
                    firstPatchedVersion {{
                        identifier
                    }}
                    package {{
                        ecosystem
                        name
                    }}
                    severity
                    updatedAt
                    vulnerableVersionRange
                }}
            }}
        }}
    }}
    """
    payload = json.dumps({"query": query}).encode("utf-8")
    request_ = urllib.request.Request(
        url="https://api.github.com/graphql",
        data=payload,
        headers={
            "Accept": "application/json",
            "Authorization": f"Bearer {_api_token}",
            "Content-Type": "application/json; charset=utf-8",
        },
    )
    with urllib.request.urlopen(request_) as response:
        _data = json.loads(response.read())

    has_next = _data["data"]["securityVulnerabilities"]["pageInfo"]["hasNextPage"]
    cursor = _data["data"]["securityVulnerabilities"]["pageInfo"]["endCursor"]
    data = _data["data"]["securityVulnerabilities"]["edges"]
    total_count = int(_data["data"]["securityVulnerabilities"]["totalCount"])

    return total_count, cursor, has_next, data


def _fetch_github_security_advisories(
    limit: int = 100,
) -> Iterator[Tuple[int, str, bool, dict]]:
    total_count, cursor, has_next, data = _query_github_graphql(limit, None)
    yield from data

    while has_next:
        total_count, cursor, has_next, data = _query_github_graphql(limit, cursor)
        yield from data


class Github(SecurityAdvisorySource):
    _name = "github"

    @property
    def name(self) -> str:
        return self._name

    @property
    def total_count(self) -> int:
        return len(self._advisories.keys())

    def populate_from_cache(self) -> None:
        self._advisories = defaultdict(list)

        with open(self.path, "rb") as fh:
            for item in json.load(fh):
                obj = GithubSecurityAdvisory.using(item)
                self._advisories[obj.canonical_name].append(obj)

    @property
    def path(self) -> str:
        return os.path.join(self._cache_dir, "github.cache")

    def update(self) -> None:
        data = list(_fetch_github_security_advisories())
        with open(self.path, "w") as fh:
            json.dump(data, fh)

    def has_security_advisory_for(self, dependency: Dependency) -> bool:
        return dependency.canonical_name in self.advisories.keys()

    def is_vulnerable_package(
        self, dependency: Dependency
    ) -> Tuple[bool, List[SecurityAdvisory]]:
        advisories = []
        for candidate in self.advisories[dependency.canonical_name]:
            if candidate.is_affected(dependency.version):
                advisories.append(candidate)

        return len(advisories) > 0, advisories


register_source("github", Github)