scripts/stubgen.py
#!/usr/bin/env python3
"""Generate stub files for API classes with async and sync interface."""
from __future__ import annotations
import re
import sys
from contextlib import suppress
from copy import copy
from pathlib import Path
from mypy.nodes import ConditionalExpr, Expression, ListExpr
from mypy.stubgen import (
ASTStubGenerator,
Options,
StubSource,
collect_build_targets,
generate_asts_for_modules,
generate_guarded,
mypy_options,
)
HEADER = '''"""
@generated by stubgen. Do not edit manually!
isort:skip_file
"""
'''
class ApiStubGenerator(ASTStubGenerator):
"""Generate stub text from a mypy AST."""
def get_str_type_of_node(self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True) -> str:
"""Get type of node as string."""
if isinstance(rvalue, ConditionalExpr):
if_type = self.get_str_type_of_node(rvalue.if_expr, can_infer_optional=can_infer_optional, can_be_any=False)
else_type = self.get_str_type_of_node(rvalue.else_expr, can_infer_optional=can_infer_optional, can_be_any=False)
if if_type and else_type and if_type != else_type:
return f"{if_type} | {else_type}"
return if_type or else_type or "Any" if can_be_any else ""
if isinstance(rvalue, ListExpr):
list_item_type = {
self.get_str_type_of_node(item, can_infer_optional=can_infer_optional, can_be_any=can_be_any)
for item in rvalue.items
}
return f"list[{' | '.join(list_item_type)}]"
return super().get_str_type_of_node(rvalue, can_infer_optional=can_infer_optional, can_be_any=can_be_any)
def add_sync(self) -> None:
"""Add sync methods."""
output = copy(self._output)
for i in range(len(output)):
if "async" in output[i]:
self.add(output[i].replace("async_", "").replace("async ", ""))
def fix_union_annotations(self) -> None:
"""Fix Union annotations."""
for i, output in enumerate(self._output):
if match := re.search(r"Union\[([a-z, ]+)\]", output):
types = match[1].replace(",", " |")
self._output[i] = output.replace(match[0], types)
def generate_stubs() -> None:
"""Generate stubs - main entry point for the program."""
options = Options(
pyversion=sys.version_info[:2],
no_import=True,
inspect=False,
doc_dir="",
search_path=[],
interpreter=sys.executable,
parse_only=False,
ignore_errors=False,
include_private=False,
output_dir="",
modules=[],
packages=[],
files=["devolo_plc_api/device_api/deviceapi.py", "devolo_plc_api/plcnet_api/plcnetapi.py"],
verbose=False,
quiet=True,
export_less=True,
include_docstrings=False,
)
mypy_opts = mypy_options(options)
py_modules, _, _ = collect_build_targets(options, mypy_opts)
generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose)
files = []
for mod in py_modules:
target = mod.module.replace(".", "/")
target += ".pyi"
target = str(Path(options.output_dir) / target)
files.append(target)
with generate_guarded(mod.module, target, options.ignore_errors, options.verbose):
generate_stub_from_ast(mod, target, options.parse_only, options.include_private, options.export_less)
def generate_stub_from_ast(mod: StubSource, target: str, parse_only: bool, include_private: bool, export_less: bool) -> None:
"""Use analyzed (or just parsed) AST to generate type stub for single file."""
gen = ApiStubGenerator(mod.runtime_all, include_private=include_private, analyzed=not parse_only, export_less=export_less)
if mod.ast is None:
return
mod.ast.accept(gen)
if "annotations" in mod.ast.future_import_flags:
gen.add_import_line("from __future__ import annotations\n")
gen.fix_union_annotations()
gen.add_sync()
old_output = ""
new_output = HEADER + "".join(gen.output())
with suppress(FileNotFoundError), Path(target).open() as file:
old_output = file.read()
if new_output != old_output:
with Path(target).open("w") as file:
file.write(new_output)
sys.exit(1)
if __name__ == "__main__":
generate_stubs()