Cog-Creators/Red-DiscordBot

View on GitHub
.github/workflows/scripts/merge_requirements.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple

from packaging.markers import Marker
from packaging.requirements import Requirement


REQUIREMENTS_FOLDER = Path(__file__).parents[3].absolute() / "requirements"
os.chdir(REQUIREMENTS_FOLDER)


class RequirementData:
    def __init__(self, requirement_string: str) -> None:
        self.req = Requirement(requirement_string)
        self.comments = set()

    def __hash__(self) -> int:
        return hash(self.req)

    def __eq__(self, other: RequirementData) -> bool:
        return self.req == other.req

    @property
    def name(self) -> str:
        return self.req.name

    @property
    def marker(self) -> Marker:
        return self.req.marker

    @marker.setter
    def marker(self, value: Marker) -> None:
        self.req.marker = value


def get_requirements(fp: TextIO) -> List[RequirementData]:
    requirements = []

    current = None
    for line in fp.read().splitlines():
        annotation_prefix = "    # "
        if line.startswith(annotation_prefix) and current is not None:
            source = line[len(annotation_prefix) :].strip()
            if source == "via":
                continue
            via_prefix = "via "
            if source.startswith(via_prefix):
                source = source[len(via_prefix) :]
            current.comments.add(source)
        elif line and not line.startswith(("#", " ")):
            current = RequirementData(line)
            requirements.append(current)

    return requirements


def iter_envs(envs: Iterable[str]) -> Iterable[Tuple[str, str]]:
    for env_name in envs:
        platform, python_version = env_name.split("-", maxsplit=1)
        yield (platform, python_version)


names = ["base"]
names.extend(file.stem for file in REQUIREMENTS_FOLDER.glob("extra-*.in"))
base_requirements: List[RequirementData] = []

for name in names:
    # {req_data: {sys_platform: RequirementData}
    input_data: Dict[RequirementData, Dict[str, RequirementData]] = {}
    all_envs = set()
    all_platforms = set()
    all_python_versions = set()
    for file in REQUIREMENTS_FOLDER.glob(f"*-{name}.txt"):
        platform_name, python_version, _ = file.stem.split("-", maxsplit=2)
        env_name = f"{platform_name}-{python_version}"
        all_envs.add(env_name)
        all_platforms.add(platform_name)
        all_python_versions.add(python_version)
        with file.open(encoding="utf-8") as fp:
            requirements = get_requirements(fp)

        for req in requirements:
            envs = input_data.setdefault(req, {})
            envs[env_name] = req

    output = base_requirements if name == "base" else []
    for req, envs in input_data.items():
        # {platform: [python_versions...]}
        python_versions_per_platform: Dict[str, List[str]] = {}
        # {python_version: [platforms...]}
        platforms_per_python_version: Dict[str, List[str]] = {}
        platforms = python_versions_per_platform.keys()
        python_versions = platforms_per_python_version.keys()
        for env_name, other_req in envs.items():
            platform_name, python_version = env_name.split("-", maxsplit=1)
            python_versions_per_platform.setdefault(platform_name, []).append(python_version)
            platforms_per_python_version.setdefault(python_version, []).append(platform_name)

            req.comments.update(other_req.comments)

        base_req = next(
            (base_req for base_req in base_requirements if base_req.name == req.name), None
        )
        if base_req is not None:
            old_base_marker = base_req.marker
            old_req_marker = req.marker
            req.marker = base_req.marker = None
            if base_req.req != req.req:
                raise RuntimeError(f"Incompatible requirements for {req.name}.")

            base_req.marker = old_base_marker
            req.marker = old_req_marker
            if base_req.marker is None or base_req.marker == req.marker:
                continue

        if len(envs) == len(all_envs):
            output.append(req)
            continue

        # At this point I'm wondering why I didn't just go for
        # a more generic boolean algebra simplification (sympy.simplify_logic())...
        if (
            len(set(map(frozenset, python_versions_per_platform.values()))) == 1
            or len(set(map(frozenset, platforms_per_python_version.values()))) == 1
        ):
            # Either all platforms have the same Python version set
            # or all Python versions have the same platform set.
            # We can generate markers for platform (platform_marker) and Python
            # (python_version_marker) version sets separately and then simply require
            # that both markers are fulfilled at the same time (env_marker).

            python_version_marker = (
                # Requirement present on less Python versions than not.
                " or ".join(
                    f"python_version == '{python_version}'" for python_version in python_versions
                )
                if len(python_versions) < len(all_python_versions - python_versions)
                # Requirement present on more Python versions than not
                # This may generate an empty string when Python version is irrelevant.
                else " and ".join(
                    f"python_version != '{python_version}'"
                    for python_version in all_python_versions - python_versions
                )
            )

            platform_marker = (
                # Requirement present on less platforms than not.
                " or ".join(f"sys_platform == '{platform}'" for platform in platforms)
                if len(platforms) < len(all_platforms - platforms)
                # Requirement present on more platforms than not
                # This may generate an empty string when platform is irrelevant.
                else " and ".join(
                    f"sys_platform != '{platform}'" for platform in all_platforms - platforms
                )
            )

            if python_version_marker and platform_marker:
                env_marker = f"({python_version_marker}) and ({platform_marker})"
            else:
                env_marker = python_version_marker or platform_marker
        else:
            # Fallback to generic case.
            env_marker = (
                # Requirement present on less envs than not.
                " or ".join(
                    f"(sys_platform == '{platform}' and python_version == '{python_version}')"
                    for platform, python_version in iter_envs(envs)
                )
                if len(envs) < len(all_envs - envs.keys())
                else " and ".join(
                    f"(sys_platform != '{platform}' and python_version != '{python_version}')"
                    for platform, python_version in iter_envs(all_envs - envs.keys())
                )
            )

        new_marker = f"({req.marker}) and ({env_marker})" if req.marker is not None else env_marker
        req.marker = Marker(new_marker)
        if base_req is not None and base_req.marker == req.marker:
            continue

        output.append(req)

    output.sort(key=lambda req: (req.marker is not None, req.name))
    with open(f"{name}.txt", "w+", encoding="utf-8") as fp:
        for req in output:
            fp.write(str(req.req))
            fp.write("\n")
            comments = sorted(req.comments)

            if len(comments) == 1:
                source = comments[0]
                fp.write("    # via ")
                fp.write(source)
                fp.write("\n")
            else:
                fp.write("    # via\n")
                for source in comments:
                    fp.write("    #   ")
                    fp.write(source)
                    fp.write("\n")