airbnb/caravel

View on GitHub
scripts/erd/erd.py

Summary

Maintainability
A
1 hr
Test Coverage
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.
"""
This module contains utilities to auto-generate an
Entity-Relationship Diagram (ERD) from SQLAlchemy
and onto a plantuml file.
"""

import json
import os
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional

import click
import jinja2

from superset import db

GROUPINGS: dict[str, Iterable[str]] = {
    "Core": [
        "css_templates",
        "dynamic_plugin",
        "favstar",
        "dashboards",
        "slices",
        "user_attribute",
        "embedded_dashboards",
        "annotation",
        "annotation_layer",
        "tag",
        "tagged_object",
    ],
    "System": ["ssh_tunnels", "keyvalue", "cache_keys", "key_value", "logs"],
    "Alerts & Reports": ["report_recipient", "report_execution_log", "report_schedule"],
    "Inherited from Flask App Builder (FAB)": [
        "ab_user",
        "ab_permission",
        "ab_permission_view",
        "ab_view_menu",
        "ab_role",
        "ab_register_user",
    ],
    "SQL Lab": ["query", "saved_query", "tab_state", "table_schema"],
    "Data Assets": [
        "dbs",
        "table_columns",
        "sql_metrics",
        "tables",
        "row_level_security_filters",
        "sl_tables",
        "sl_datasets",
        "sl_columns",
        "database_user_oauth2_tokens",
    ],
}
# Table name to group name mapping (reversing the above one for easy lookup)
TABLE_TO_GROUP_MAP: dict[str, str] = {}
for group, tables in GROUPINGS.items():
    for table in tables:
        TABLE_TO_GROUP_MAP[table] = group


def sort_data_structure(data):  # type: ignore
    sorted_json = json.dumps(data, sort_keys=True)
    sorted_data = json.loads(sorted_json)
    return sorted_data


def introspect_sqla_model(mapper: Any, seen: set[str]) -> dict[str, Any]:
    """
    Introspects a SQLAlchemy model and returns a data structure that
    can be pass to a jinja2 template for instance

    Parameters:
    -----------
    mapper: SQLAlchemy model mapper
    seen: set of model identifiers to avoid duplicates

    Returns:
    --------
    Dict[str, Any]: data structure for jinja2 template
    """
    table_name = mapper.persist_selectable.name
    model_info: dict[str, Any] = {
        "class_name": mapper.class_.__name__,
        "table_name": table_name,
        "fields": [],
        "relationships": [],
    }
    # Collect fields (columns) and their types
    for column in mapper.columns:
        field_info: dict[str, str] = {
            "field_name": column.key,
            "type": str(column.type),
        }
        model_info["fields"].append(field_info)

    # Collect relationships and identify types
    for attr, relationship in mapper.relationships.items():
        related_table = relationship.mapper.persist_selectable.name
        # Create a unique identifier for the relationship to avoid duplicates
        relationship_id = "-".join(sorted([table_name, related_table]))

        if relationship_id not in seen:
            seen.add(relationship_id)
            squiggle = "||--|{"
            if relationship.direction.name == "MANYTOONE":
                squiggle = "}|--||"

            relationship_info: dict[str, str] = {
                "relationship_name": attr,
                "related_model": relationship.mapper.class_.__name__,
                "type": relationship.direction.name,
                "related_table": related_table,
            }
            # Identify many-to-many by checking for secondary table
            if relationship.secondary is not None:
                squiggle = "}|--|{"
                relationship_info["type"] = "many-to-many"
                relationship_info["secondary_table"] = relationship.secondary.name

            relationship_info["squiggle"] = squiggle
            model_info["relationships"].append(relationship_info)
    return sort_data_structure(model_info)  # type: ignore


def introspect_models() -> dict[str, list[dict[str, Any]]]:
    """
    Introspects SQLAlchemy models and returns a data structure that
    can be pass to a jinja2 template for rendering an ERD.

    Returns:
    --------
    Dict[str, List[Dict[str, Any]]]: data structure for jinja2 template
    """
    data: dict[str, list[dict[str, Any]]] = defaultdict(list)
    seen_models: set[str] = set()
    for model in db.Model.registry.mappers:
        group_name = (
            TABLE_TO_GROUP_MAP.get(model.mapper.persist_selectable.name)
            or "Uncategorized Models"
        )
        model_data = introspect_sqla_model(model, seen_models)
        data[group_name].append(model_data)
    return data


def generate_erd(file_path: str) -> None:
    """
    Generates a PlantUML ERD of the models/database

    Parameters:
    -----------
    file_path: str
        File path to write the ERD to
    """
    data = introspect_models()
    templates_path = os.path.dirname(__file__)
    env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_path))

    # Load the template
    template = env.get_template("erd.template.puml")
    rendered = template.render(data=data)
    with open(file_path, "w") as f:
        click.secho(f"Writing to {file_path}...", fg="green")
        f.write(rendered)


@click.command()
@click.option(
    "--output",
    "-o",
    type=click.Path(dir_okay=False, writable=True),
    help="File to write the ERD to",
)
def erd(output: Optional[str] = None) -> None:
    """
    Generates a PlantUML ERD of the models/database

    Parameters:
    -----------
    output: str, optional
        File to write the ERD to, defaults to erd.plantuml if not provided
    """
    path = os.path.dirname(__file__)
    output = output or os.path.join(path, "erd.puml")

    from superset.app import create_app

    app = create_app()
    with app.app_context():
        generate_erd(output)


if __name__ == "__main__":
    erd()