.github/workflows/scripts/merge_requirements.py
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple
from packaging.markers import Marker
from packaging.requirements import Requirement
REQUIREMENTS_FOLDER = Path(__file__).parents[3].absolute() / "requirements"
os.chdir(REQUIREMENTS_FOLDER)
class RequirementData:
def __init__(self, requirement_string: str) -> None:
self.req = Requirement(requirement_string)
self.comments = set()
def __hash__(self) -> int:
return hash(self.req)
def __eq__(self, other: RequirementData) -> bool:
return self.req == other.req
@property
def name(self) -> str:
return self.req.name
@property
def marker(self) -> Marker:
return self.req.marker
@marker.setter
def marker(self, value: Marker) -> None:
self.req.marker = value
def get_requirements(fp: TextIO) -> List[RequirementData]:
requirements = []
current = None
for line in fp.read().splitlines():
annotation_prefix = " # "
if line.startswith(annotation_prefix) and current is not None:
source = line[len(annotation_prefix) :].strip()
if source == "via":
continue
via_prefix = "via "
if source.startswith(via_prefix):
source = source[len(via_prefix) :]
current.comments.add(source)
elif line and not line.startswith(("#", " ")):
current = RequirementData(line)
requirements.append(current)
return requirements
def iter_envs(envs: Iterable[str]) -> Iterable[Tuple[str, str]]:
for env_name in envs:
platform, python_version = env_name.split("-", maxsplit=1)
yield (platform, python_version)
names = ["base"]
names.extend(file.stem for file in REQUIREMENTS_FOLDER.glob("extra-*.in"))
base_requirements: List[RequirementData] = []
for name in names:
# {req_data: {sys_platform: RequirementData}
input_data: Dict[RequirementData, Dict[str, RequirementData]] = {}
all_envs = set()
all_platforms = set()
all_python_versions = set()
for file in REQUIREMENTS_FOLDER.glob(f"*-{name}.txt"):
platform_name, python_version, _ = file.stem.split("-", maxsplit=2)
env_name = f"{platform_name}-{python_version}"
all_envs.add(env_name)
all_platforms.add(platform_name)
all_python_versions.add(python_version)
with file.open(encoding="utf-8") as fp:
requirements = get_requirements(fp)
for req in requirements:
envs = input_data.setdefault(req, {})
envs[env_name] = req
output = base_requirements if name == "base" else []
for req, envs in input_data.items():
# {platform: [python_versions...]}
python_versions_per_platform: Dict[str, List[str]] = {}
# {python_version: [platforms...]}
platforms_per_python_version: Dict[str, List[str]] = {}
platforms = python_versions_per_platform.keys()
python_versions = platforms_per_python_version.keys()
for env_name, other_req in envs.items():
platform_name, python_version = env_name.split("-", maxsplit=1)
python_versions_per_platform.setdefault(platform_name, []).append(python_version)
platforms_per_python_version.setdefault(python_version, []).append(platform_name)
req.comments.update(other_req.comments)
base_req = next(
(base_req for base_req in base_requirements if base_req.name == req.name), None
)
if base_req is not None:
old_base_marker = base_req.marker
old_req_marker = req.marker
req.marker = base_req.marker = None
if base_req.req != req.req:
raise RuntimeError(f"Incompatible requirements for {req.name}.")
base_req.marker = old_base_marker
req.marker = old_req_marker
if base_req.marker is None or base_req.marker == req.marker:
continue
if len(envs) == len(all_envs):
output.append(req)
continue
# At this point I'm wondering why I didn't just go for
# a more generic boolean algebra simplification (sympy.simplify_logic())...
if (
len(set(map(frozenset, python_versions_per_platform.values()))) == 1
or len(set(map(frozenset, platforms_per_python_version.values()))) == 1
):
# Either all platforms have the same Python version set
# or all Python versions have the same platform set.
# We can generate markers for platform (platform_marker) and Python
# (python_version_marker) version sets separately and then simply require
# that both markers are fulfilled at the same time (env_marker).
python_version_marker = (
# Requirement present on less Python versions than not.
" or ".join(
f"python_version == '{python_version}'" for python_version in python_versions
)
if len(python_versions) < len(all_python_versions - python_versions)
# Requirement present on more Python versions than not
# This may generate an empty string when Python version is irrelevant.
else " and ".join(
f"python_version != '{python_version}'"
for python_version in all_python_versions - python_versions
)
)
platform_marker = (
# Requirement present on less platforms than not.
" or ".join(f"sys_platform == '{platform}'" for platform in platforms)
if len(platforms) < len(all_platforms - platforms)
# Requirement present on more platforms than not
# This may generate an empty string when platform is irrelevant.
else " and ".join(
f"sys_platform != '{platform}'" for platform in all_platforms - platforms
)
)
if python_version_marker and platform_marker:
env_marker = f"({python_version_marker}) and ({platform_marker})"
else:
env_marker = python_version_marker or platform_marker
else:
# Fallback to generic case.
env_marker = (
# Requirement present on less envs than not.
" or ".join(
f"(sys_platform == '{platform}' and python_version == '{python_version}')"
for platform, python_version in iter_envs(envs)
)
if len(envs) < len(all_envs - envs.keys())
else " and ".join(
f"(sys_platform != '{platform}' and python_version != '{python_version}')"
for platform, python_version in iter_envs(all_envs - envs.keys())
)
)
new_marker = f"({req.marker}) and ({env_marker})" if req.marker is not None else env_marker
req.marker = Marker(new_marker)
if base_req is not None and base_req.marker == req.marker:
continue
output.append(req)
output.sort(key=lambda req: (req.marker is not None, req.name))
with open(f"{name}.txt", "w+", encoding="utf-8") as fp:
for req in output:
fp.write(str(req.req))
fp.write("\n")
comments = sorted(req.comments)
if len(comments) == 1:
source = comments[0]
fp.write(" # via ")
fp.write(source)
fp.write("\n")
else:
fp.write(" # via\n")
for source in comments:
fp.write(" # ")
fp.write(source)
fp.write("\n")