thesadru/apimodel

View on GitHub
apimodel/generator.py

Summary

Maintainability
C
1 day
Test Coverage
A
90%
"""Model generator from JSON data."""
from __future__ import annotations

import sys
import typing

from . import parser, tutils, utility

__all__ = ["generate_models"]

JSONType = typing.Union[None, str, int, float, bool, typing.Sequence["JSONType"], typing.Mapping[str, "JSONType"]]
RawSchema = typing.Union[
    str,
    typing.Tuple["RawSchema", ...],
    typing.List["RawSchema"],
    typing.Mapping[str, "RawSchema"],
]

MaybeUnion = tutils.MaybeSequence[str]
Field = typing.TypedDict("Field", {"alias": str, "type": MaybeUnion, "default": object, "array": bool}, total=False)
Schema = typing.Mapping[str, Field]

VersionInfo = typing.Union[typing.Tuple[int, ...], "sys._version_info"]

T = typing.TypeVar("T")


def to_pascal_case(string: str) -> str:
    """Turn snake case into pascal case."""
    return "".join(x[:1].upper() + x[1:] for x in string.split("_"))


def to_snake_case(string: str) -> str:
    """Turn camel case into snake case."""
    return "".join(
        ("_" if i and string[i].isupper() and not string[i : i + 2].isupper() else "") + x.lower()  # noqa: E203
        for i, x in enumerate(string)
    )


def format_field_type(
    field: typing.Union[str, Field],
    *,
    python: typing.Optional[VersionInfo] = None,
) -> str:
    """Format a field type."""
    if isinstance(field, str):
        return field

    assert "type" in field, "Incomplete Field"
    python = python or sys.version_info

    types: typing.Sequence[str]
    optional: bool = False

    if not isinstance(field["type"], str):  # union
        types, old_types = [x for x in field["type"] if "None" not in x], field["type"]
        if len(types) != len(old_types):
            optional = True
    elif field["type"] == "None":
        types = ("object",)
        optional = True
    else:
        types = (field["type"],)

    if python >= (3, 10):
        annotation = " | ".join(types)
        if optional:
            annotation += " | None"
    else:
        if len(types) == 1:
            annotation = types[0]
        else:
            annotation = f"typing.Union[{', '.join(types)}]"

        if optional:
            annotation = f"typing.Optional[{annotation}]"

    if field.get("array", False):
        if python >= (3, 9):
            annotation = f"list[{annotation}]"
        else:
            annotation = f"typing.Sequence[{annotation}]"

    return annotation


def format_field_default(field: Field) -> str:
    """Format a field default."""
    data = {k: v for k, v in field.items() if k not in ("type", "array", "default")}

    if "default" in field and not data:
        return str(field["default"])

    args: typing.Sequence[str] = []
    if "default" in field:
        args = [str(field["default"])]

    args += [f"{k}={v}" for k, v in data.items()]

    if not args:
        return ""

    return f"apimodel.Field({', '.join(args)})"


def join_mappings(mappings: typing.Sequence[typing.Mapping[str, T]]) -> typing.Mapping[str, tutils.MaybeTuple[T]]:
    """Join a sequence of mappings."""
    output: typing.Mapping[str, tutils.MaybeTuple[T]] = {}
    for index, mapping in enumerate(mappings):
        for k, value in mapping.items():
            # do not set missing values in the first run
            if index == 0:
                output[k] = value
                continue

            new_value = join_union(output.get(k, "None"), value)

            # if the current value should be an array and the old value was an array / missing
            # cast the new value to an array to prevent confusion with a union
            if isinstance(value, list) and (k not in output or isinstance(output[k], list)):
                if not tutils.generic_isinstance(new_value, typing.Sequence[object]):
                    new_value = [new_value]

                new_value = list(new_value)

            output[k] = new_value

    return output


def join_union(*raw_values: T) -> tutils.MaybeTuple[T]:
    """Join a union with an emphasis on mappings.

    Return a tuple if there are multiple values.
    """
    values = utility.flatten_sequences(raw_values)
    values = tuple({repr(tp): tp for tp in values}.values())
    if "float" in values:  # pyright: ignore[reportUnnecessaryContains]  # pyright bug
        values = tuple(x for x in values if x != "int")

    if values and all(isinstance(value, typing.Mapping) for value in values):
        values = typing.cast("typing.Sequence[typing.Mapping[str, T]]", values)
        mapping = join_mappings(values)
        return typing.cast("T", mapping)

    if len(values) == 1:
        return values[0]

    return values


def recognize_json_type(value: JSONType) -> RawSchema:
    """Recognize JSON type of value."""
    if value in tutils.NoneTypes:
        return "None"

    if isinstance(value, str):
        try:
            parser.datetime_validator.synchronous(NotImplemented, value)
        except (ValueError, TypeError, OSError):
            return "str"
        else:
            return "datetime.datetime"

    if isinstance(value, typing.Sequence):
        values = [recognize_json_type(item) for item in value]
        clean = join_union(*values)
        return list(clean) if isinstance(clean, tuple) else [clean]

    if isinstance(value, typing.Mapping):
        return {name: recognize_json_type(item) for name, item in value.items()}

    return type(value).__name__


def add_schema(
    schema_name: str,
    raw_schema: typing.Mapping[str, RawSchema],
    schemas: typing.MutableMapping[str, Schema],
) -> Schema:
    """Add schema to a collection of previous schemas."""
    schema: Schema = {}

    for name, value in raw_schema.items():
        field: Field = {}

        name, old_name = to_snake_case(name), name
        if name != old_name:
            field["alias"] = '"' + old_name + '"'

        # tuple = union, list = array of union
        if isinstance(value, list):
            field["array"] = True
            if len(value) == 1:
                value = value[0]

        if isinstance(value, typing.Sequence) and not isinstance(value, str):
            union: typing.Sequence[str] = []
            for x in value:
                if isinstance(x, typing.Mapping):
                    unique_name = to_pascal_case(schema_name + "_" + name)
                    add_schema(unique_name, x, schemas)
                    union.append(unique_name)
                else:
                    if not isinstance(x, str):
                        raise ValueError("Found nesting in the raw schema.")

                    union.append(x)

            field["type"] = join_union(*union)
            schema[name] = field

        elif isinstance(value, typing.Mapping):
            unique_name = to_pascal_case(schema_name + "_" + name)
            add_schema(unique_name, value, schemas)

            field["type"] = unique_name
            schema[name] = field

        else:
            field["type"] = value
            schema[name] = field

        if "None" in field.get("type", ""):
            field["default"] = "None"

    if not schema_name:
        schema_name = "Root"

    schemas[schema_name] = schema
    return schema


def create_schemas(data: JSONType) -> typing.Mapping[str, Schema]:
    """Create raw schemas from json data."""
    raw_schema = recognize_json_type(data)
    if isinstance(raw_schema, typing.Sequence):
        raw_schema = {"field": raw_schema}

    schemas: typing.MutableMapping[str, Schema] = {}
    add_schema("", raw_schema, schemas)

    return schemas


def generate_models(
    data: JSONType,
    *,
    python: typing.Optional[VersionInfo] = None,
) -> str:
    """Generate model code from data."""
    schemas = create_schemas(data)

    code = "import typing\n\nimport apimodel\n\n"

    for schema_name, schema in schemas.items():
        code += f"class {schema_name}(apimodel.APIModel):\n"
        if len(schema) == 0:
            code += "    pass\n"

        for name, field in schema.items():
            value = format_field_type(field, python=python)
            default = format_field_default(field)

            code += f"    {name}: {value}"
            if default:
                code += " = " + default

            code += "\n"

        code += "\n\n"

    return code.strip() + "\n"