datacoves/dbt-coves

View on GitHub
dbt_coves/tasks/generate/docs.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import nested_scopes

import json
from pathlib import Path
from typing import Any, Dict

from rich.console import Console

from dbt_coves.tasks.base import BaseConfiguredTask
from dbt_coves.tasks.setup.utils import print_row
from dbt_coves.utils.shell import prepare_cmd, run_and_capture_cwd
from dbt_coves.utils.tracking import trackable

console = Console()


class DbtCovesGenerateDocsException(Exception):
    pass


class GenerateDocsTask(BaseConfiguredTask):
    """
    Task that generates content on local catalog.json
    """

    @classmethod
    def register_parser(cls, sub_parsers, base_subparser):
        subparser = sub_parsers.add_parser(
            "docs",
            parents=[base_subparser],
            help="Merge models from a catalog.json into another one.",
        )
        subparser.add_argument(
            "--merge-deferred",
            action="store_true",
            help="Flag to merge deferred models and sources into local catalog",
            default=False,
        )
        subparser.add_argument(
            "--state",
            type=str,
            help="Catalog.json to use as reference for merging",
        )
        subparser.add_argument(
            "--dbt-args",
            type=str,
            help="A single, double-quoted, string of `dbt docs generate` args",
        )
        cls.arg_parser = base_subparser
        subparser.set_defaults(cls=cls, which="docs")
        return subparser

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_config_value(self, key):
        return self.coves_config.integrated["generate"]["docs"][key]

    def _generate_dbt_docs(self):
        command = prepare_cmd(self, ["docs", "generate"])
        output = run_and_capture_cwd(command, self.config.project_root)
        if output.returncode == 0:
            deps_status = "[green]SUCCESS :heavy_check_mark:[/green]"
        else:
            deps_status = "[red]FAIL :cross_mark:[/red]"
        print_row(
            "dbt docs generate",
            deps_status,
            new_section=True,
        )
        if output.returncode > 0:
            raise Exception("Error when running 'dbt docs generate'. Check logs.")

    def _fix_dbt_docs_links(self, docs_path: Path):
        dbt_docs_index_path = Path(docs_path, "index.html")
        with open(dbt_docs_index_path, "r+") as f:
            html_content = f.read()
            html_content.replace("</head>", "<base target='_blank'></head>")
            f.write(html_content)
        console.print(
            "[green]:heavy_check_mark:[/green] dbt docs updated. "
            "External links will now open in a new tab"
        )

    def _get_catalog_json(self, docs_folder: Path) -> Dict[str, Any]:
        """
        Open json at docs_folder/catalog.json
        """
        catalog_path = Path(docs_folder, "catalog.json")
        try:
            with open(catalog_path, "r") as f:
                return json.load(f)
        except FileNotFoundError:
            raise DbtCovesGenerateDocsException(f"Catalog.json not found at {catalog_path}")

    def _merge_catalogs(
        self, local_catalog: Dict[str, Any], target_catalog: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Merge nodes and sources from the state catalog.json into the local docs.
        """
        nodes_diff = set(target_catalog.get("nodes", {}).keys()) - set(
            local_catalog.get("nodes", {}).keys()
        )
        n_nodes_merged = len(nodes_diff)

        sources_diff = set(target_catalog.get("sources", {}).keys()) - set(
            local_catalog.get("sources", {}).keys()
        )
        n_sources_merged = len(sources_diff)

        local_catalog["nodes"].update(
            (key, value)
            for key, value in target_catalog.get("nodes", {}).items()
            if key not in local_catalog.get("nodes", {})
        )
        local_catalog["sources"].update(
            (key, value)
            for key, value in target_catalog.get("sources", {}).items()
            if key not in local_catalog.get("sources", {})
        )

        console.print(
            f"Merged [green]{n_nodes_merged} nodes[/green] and [green]{n_sources_merged} sources[/green] into",
            "your local catalog.json",
        )

    def _write_catalog_json(self, catalog: Dict[str, Any], docs_folder: Path):
        """
        Write the catalog.json to the docs_folder
        """
        catalog_path = Path(docs_folder, "catalog.json")
        with open(catalog_path, "w") as f:
            json.dump(catalog, f)

    def _merge_dbt_catalogs(self, local_path: Path, state_path: Path):
        """
        Merge the catalog.json from the stsate into the local docs.
        """
        # Get the source and target catalogs
        local_catalog = self._get_catalog_json(local_path)
        target_catalog = self._get_catalog_json(state_path)

        # Merge the catalogs
        self._merge_catalogs(local_catalog, target_catalog)

        # Write the merged catalog to the source
        self._write_catalog_json(local_catalog, local_path)

    @trackable
    def run(self):
        self._generate_dbt_docs()

        local_docs_path = Path(self.config.project_root, "target")
        self._fix_dbt_docs_links(local_docs_path)

        merge_deferred = self.get_config_value("merge_deferred")
        state_location = self.get_config_value("state")

        if merge_deferred:
            target_docs_path = Path(state_location, "catalog.json")
            if not state_location or not target_docs_path.exists():
                raise DbtCovesGenerateDocsException(
                    "A valid [red][i]--state[/i][/red] argument is required "
                    "when using [yellow]--merge-deferred[/yellow]"
                )

            self._merge_dbt_catalogs(local_docs_path, state_location)
        return 0