datacoves/dbt-coves

View on GitHub
dbt_coves/tasks/generate/sources.py

Summary

Maintainability
C
1 day
Test Coverage
from __future__ import nested_scopes

import json
from pathlib import Path

import questionary
from questionary import Choice
from rich.console import Console

from dbt_coves.utils.jinja import render_template, render_template_file
from dbt_coves.utils.tracking import trackable

from .base import BaseGenerateTask

console = Console()


class GenerateSourcesTask(BaseGenerateTask):
    """
    Task that generate sources, models and model properties automatically
    """

    @classmethod
    def register_parser(cls, sub_parsers, base_subparser):
        subparser = sub_parsers.add_parser(
            "sources",
            parents=[base_subparser],
            help="Generate source dbt models by inspecting the database schemas and relations.",
        )
        subparser.add_argument(
            "--database",
            type=str,
            help="Database where source relations live, if different than target",
        )
        subparser.add_argument(
            "--schemas",
            type=str,
            help="Comma separated list of schemas where raw data resides, "
            "i.e. 'RAW_SALESFORCE,RAW_HUBSPOT'",
        )
        subparser.add_argument(
            "--select-relations",
            type=str,
            help="Comma separated list of relations where raw data resides, "
            "i.e. 'RAW_HUBSPOT_PRODUCTS,RAW_SALESFORCE_USERS'",
        )
        subparser.add_argument(
            "--exclude-relations",
            type=str,
            help="Filter relation(s) to exclude from source file(s) generation",
        )
        subparser.add_argument(
            "--sources-destination",
            type=str,
            help="Where sources yml files will be generated, default: "
            "'models/staging/{{schema}}/sources.yml'",
        )
        subparser.add_argument(
            "--models-destination",
            type=str,
            help="Where models sql files will be generated, default: "
            "'models/staging/{{schema}}/{{relation}}.sql'",
        )
        subparser.add_argument(
            "--model-props-destination",
            type=str,
            help="Where models yml files will be generated, default: "
            "'models/staging/{{schema}}/{{relation}}.yml'",
        )
        subparser.add_argument(
            "--update-strategy",
            type=str,
            help="Action to perform when a property file already exists: "
            "'update', 'recreate', 'fail', 'ask' (per file)",
        )
        subparser.add_argument(
            "--templates-folder",
            type=str,
            help="Folder with jinja templates that override default "
            "sources generation templates, i.e. 'templates'",
        )
        subparser.add_argument(
            "--metadata",
            type=str,
            help="Path to csv file containing metadata, i.e. 'metadata.csv'",
        )
        subparser.add_argument(
            "--flatten-json-fields",
            type=str,
            choices=["yes", "no", "ask"],
            help="Action to perform when JSON fields exist:" "'yes', 'no', 'ask' (per file)",
        )
        subparser.add_argument(
            "--overwrite-staging-models",
            help="Overwrite existent staging files",
            action="store_true",
            default=False,
        )
        subparser.add_argument(
            "--skip-model-props",
            help="Don't create model's property (yml) files",
            action="store_true",
            default=False,
        )
        subparser.add_argument(
            "--no-prompt",
            help="Silently generate source dbt models",
            action="store_true",
            default=False,
        )
        cls.arg_parser = base_subparser
        subparser.set_defaults(cls=cls, which="sources")
        return subparser

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.db = None

    def get_config_value(self, key):
        return self.coves_config.integrated["generate"]["sources"][key]

    def select_relations(self, rels):
        if self.no_prompt:
            return rels
        else:
            selected_rels = questionary.checkbox(
                "Which sources would you like to generate?",
                choices=[
                    Choice(f"[{rel.schema}] {rel.name}", checked=True, value=rel) for rel in rels
                ],
            ).ask()

            return selected_rels

    def generate(self, rels):
        self.flatten_json = self.get_config_value("flatten_json_fields")
        self.overwrite_sqls = self.get_config_value("overwrite_staging_models")
        models_destination = self.get_config_value("models_destination")
        options = {
            "override_all": "Yes" if self.overwrite_sqls else None,
            "flatten_all": "Yes" if self.flatten_json == "yes" else None,
            "model_prop_update_all": False,
            "model_prop_recreate_all": False,
            "source_prop_update_all": False,
            "source_prop_recreate_all": False,
        }
        if self.no_prompt:
            options["override_all"] = options["override_all"] or "No"
            options["flatten_all"] = options["flatten_all"] or "No"
        for rel in rels:
            model_dest = self.render_path_template(models_destination, rel)
            model_sql = Path(self.config.project_root).joinpath(model_dest)
            if not options["override_all"]:
                if model_sql.exists():
                    overwrite = questionary.select(
                        f"{model_dest} already exists. Would you like to overwrite it?",
                        choices=["No", "Yes", "No for all", "Yes for all"],
                    ).ask()
                    if overwrite == "Yes":
                        self.generate_model(
                            rel,
                            model_sql,
                            options,
                        )
                    elif overwrite == "No for all":
                        options["override_all"] = "No"
                    elif overwrite == "Yes for all":
                        options["override_all"] = "Yes"
                        self.generate_model(
                            rel,
                            model_sql,
                            options,
                        )
                else:
                    self.generate_model(
                        rel,
                        model_sql,
                        options,
                    )
            elif options["override_all"] == "Yes":
                self.generate_model(rel, model_sql, options)
            else:
                if not model_sql.exists():
                    self.generate_model(
                        rel,
                        model_sql,
                        options,
                    )

    def generate_model(self, relation, destination, options):
        destination.parent.mkdir(parents=True, exist_ok=True)
        columns = self.adapter.get_columns_in_relation(relation)
        nested_field_type = self.NESTED_FIELD_TYPES.get(self.adapter.__class__.__name__)
        nested = [col.name for col in columns if col.dtype == nested_field_type]
        if not options["flatten_all"] and self.flatten_json == "ask":
            if nested:
                field_nlg = "field"
                flatten_nlg = "flatten it"
                if len(nested) > 1:
                    field_nlg = "fields"
                    flatten_nlg = "flatten them"
                flatten = questionary.select(
                    f"{relation.name} contains the JSON {field_nlg} {', '.join(nested)}."
                    f" Would you like to {flatten_nlg}?",
                    choices=["No", "Yes", "No for all", "Yes for all"],
                ).ask()
                if flatten == "Yes":
                    self.render_templates(
                        relation,
                        columns,
                        destination,
                        options,
                        json_cols=nested,
                    )
                elif flatten == "No":
                    self.render_templates(relation, columns, destination, options)
                elif flatten == "No for all":
                    options["flatten_all"] = "No"
                    self.render_templates(relation, columns, destination, options)
                elif flatten == "Yes for all":
                    options["flatten_all"] = "Yes"
                    self.render_templates(
                        relation,
                        columns,
                        destination,
                        options,
                        json_cols=nested,
                    )
            else:
                self.render_templates(relation, columns, destination, options)
        elif options["flatten_all"] == "Yes":
            if nested:
                self.render_templates(
                    relation,
                    columns,
                    destination,
                    options,
                    json_cols=nested,
                )
        else:
            self.render_templates(relation, columns, destination, options)

    def render_path_template(self, destination_path, relation):
        template_context = {
            "schema": relation.schema.lower(),
            "relation": relation.name.lower(),
        }
        return render_template(destination_path, template_context)

    def get_templates_context(self, relation, columns, json_cols=None):
        metadata_cols = self.get_metadata_columns(relation, columns)
        context = {
            "relation": relation,
            "columns": metadata_cols,
            "nested": {},
            "adapter_name": self.adapter.__class__.__name__,
        }
        if json_cols:
            context["nested"] = self.get_nested_keys(json_cols, relation)
            # Removing original column with JSON data
            new_cols = []
            for col in metadata_cols:
                if col["id"] not in context["nested"]:
                    new_cols.append(col)
            context["columns"] = new_cols
        config_db = self.get_config_value("database")
        if config_db:
            context["source_database"] = config_db

        return context

    def render_templates_with_context(self, context, sql_destination, options):
        skip_model_props = self.get_config_value("skip_model_props")
        templates_folder = self.get_config_value("templates_folder")
        update_strategy = self.get_config_value("update_strategy")
        model_property_destination = self.get_config_value("model_props_destination")
        source_property_destination = self.get_config_value("sources_destination")
        rel = context["relation"]

        # Render model SQL
        render_template_file(
            "staging_model.sql",
            context,
            sql_destination,
            templates_folder=templates_folder,
        )
        console.print(f"Model [green][b]{sql_destination}[/green][/b] created")

        # Render model and source YMLs
        if not skip_model_props:
            model_yml_dest = self.render_path_template(model_property_destination, rel)
            model_yml_path = Path(self.config.project_root).joinpath(model_yml_dest)
            self.render_property_files(
                context,
                options,
                templates_folder,
                update_strategy,
                "model",
                model_yml_path,
                "staging_model_props.yml",
            )

        source_yml_dest = self.render_path_template(source_property_destination, rel)
        source_yml_path = Path(self.config.project_root).joinpath(source_yml_dest)
        self.render_property_files(
            context,
            options,
            templates_folder,
            update_strategy,
            "source",
            source_yml_path,
            "source_props.yml",
        )

    def get_nested_keys(self, json_cols, relation):
        config_db = self.get_config_value("database")
        if config_db:
            config_db += "."
        else:
            config_db = ""
        _, data = self.adapter.execute(
            f"SELECT {', '.join(json_cols)} FROM {config_db}{relation.schema}.{relation.name} \
                limit 1",
            fetch=True,
        )
        result = dict()
        if len(data.rows) > 0:
            for idx, json_col in enumerate(json_cols):
                value = data.columns[idx]
                try:
                    nested_key_names = list(json.loads(value[0]).keys())
                    result[json_col] = {}
                    for key_name in nested_key_names:
                        result[json_col][key_name] = self.get_default_metadata_item(key_name)
                    self.add_metadata_to_nested(relation, result, json_col)
                except TypeError:
                    console.print(
                        f"Column {json_col} in relation {relation.name} contains invalid JSON.\n"
                    )
        return result

    def add_metadata_to_nested(self, relation, data, col):
        """
        Adds metadata info to each nested field if metadata was provided.
        """
        metadata = self.get_metadata()
        if metadata:
            # Iterate over fields
            for item in data[col].keys():
                metadata_map_key_data = {
                    "database": relation.database,
                    "schema": relation.schema,
                    "relation": relation.name,
                    "column": col,
                    "key": item,
                }
                metadata_key = self.get_metadata_map_key(metadata_map_key_data)
                # Get metadata info or default and assign to the field.
                metadata_info = metadata.get(metadata_key)
                if metadata_info:
                    data[col][item].update(metadata_info)

    @trackable
    def run(self):
        self.no_prompt = self.get_config_value("no_prompt")
        config_database = self.get_config_value("database")
        self.db = config_database or self.config.credentials.database

        self.get_metadata()

        # initiate connection
        with self.adapter.connection_named("master"):
            filtered_schemas = self.get_schemas()
            if not filtered_schemas:
                return 0
            relations = self.get_relations(filtered_schemas)
            if relations:
                selected_relations = self.select_relations(relations)
                if selected_relations:
                    self.raise_duplicate_relations(selected_relations)
                    self.generate(selected_relations)
                else:
                    console.print("No relations selected for sources generation")
                    return 0
            else:
                schema_nlg = f"schema{'s' if len(filtered_schemas) > 1 else ''}"
                console.print(
                    f"No tables/views found in [u]{', '.join(filtered_schemas)}[/u] {schema_nlg}."
                )
        return 0