datacoves/dbt-coves

View on GitHub
dbt_coves/tasks/generate/base.py

Summary

Maintainability
D
2 days
Test Coverage
import csv
import fnmatch
from pathlib import Path

import questionary
from rich.console import Console
from ruamel.yaml import YAML
from slugify import slugify

from dbt_coves.tasks.base import BaseConfiguredTask
from dbt_coves.utils.jinja import get_render_output, render_template_file
from dbt_coves.utils.yaml import open_yaml, save_yaml

console = Console()
yaml = YAML()
yaml.default_flow_style = False
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.preserve_quotes = True


class BaseGeneratorException(Exception):
    pass


class BaseGenerateTask(BaseConfiguredTask):
    """
    Provides common functionality for all "Generate" sub tasks.
    """

    arg_parser = None
    NESTED_FIELD_TYPES = {
        "SnowflakeAdapter": "VARIANT",
        "BigQueryAdapter": "STRUCT",
        "RedshiftAdapter": "SUPER",
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metadata = None
        self.prop_files_created_by_dbtcoves = set()

    def get_schemas(self):
        # get schema names selectors
        schema_name_selectors = [schema for schema in self.get_config_value("schemas")]

        schema_wildcard_selectors = []
        for schema_name in schema_name_selectors:
            if "*" in schema_name:
                schema_wildcard_selectors.append(schema_name.replace("*", ".*"))

        schemas = [
            schema
            for schema in self.adapter.list_schemas(self.db)
            # TODO: fix this for different adapters
            if schema != "INFORMATION_SCHEMA"
        ]

        filtered_schemas = []

        for schema in schemas:
            for selector in schema_wildcard_selectors:
                if fnmatch.fnmatch(schema, selector):
                    filtered_schemas.append(schema)
                    break

        for schema in schemas:
            for selector in schema_name_selectors:
                if fnmatch.fnmatch(schema.lower(), selector.lower()):
                    filtered_schemas.append(schema)
                    break
        if "snowflake" in self.adapter.type().lower():
            filtered_schemas = list({f'"{schema}"' for schema in filtered_schemas})
        else:
            filtered_schemas = list(set(filtered_schemas))

        if not filtered_schemas:
            schema_nlg = f"schema{'s' if len(schema_name_selectors) > 1 else ''}"
            console.print(
                f"{schema_nlg} [u]{', '.join(schema_name_selectors)}[/u] not found in Database.\n"
            )
            filtered_schemas = self.select_schemas(schemas)
            if not filtered_schemas:
                console.print("No schemas selected")
                exit()

        return filtered_schemas

    def select_schemas(self, schemas):
        if self.no_prompt:
            return schemas
        else:
            selected_schemas = questionary.checkbox(
                "Which schemas would you like to inspect?",
                choices=schemas,
            ).ask()

            return selected_schemas

    def get_relations(self, filtered_schemas):
        rel_name_selectors = [relation for relation in self.get_config_value("select_relations")]

        rel_excludes = [relation for relation in self.get_config_value("exclude_relations")]

        rel_wildcard_selectors = []

        for rel_name in rel_name_selectors:
            if "*" in rel_name:
                rel_wildcard_selectors.append(rel_name.replace("*", ".*"))

        listed_relations = []

        for schema in filtered_schemas:
            listed_relations += self.adapter.list_relations(self.db, schema)

        for rel in listed_relations:
            for selector in rel_wildcard_selectors:
                if fnmatch.fnmatch(rel.name.lower(), selector.lower()):
                    rel_name_selectors.append(rel.name)
                    break

        excluded = []
        for rel in listed_relations:
            for selector in rel_excludes:
                if fnmatch.fnmatch(rel.name.lower(), selector.lower()):
                    excluded.append(rel.name)
                    break

        listed_relations = [
            relation for relation in listed_relations if relation.name not in excluded
        ]

        intersected_rels = []
        for rel in listed_relations:
            for selector in rel_name_selectors:
                if fnmatch.fnmatch(rel.name.lower(), selector.lower()):
                    intersected_rels.append(rel)
                    break

        rels = (
            intersected_rels if rel_name_selectors and rel_name_selectors[0] else listed_relations
        )

        return rels

    def run(self) -> int:
        raise NotImplementedError()

    def get_metadata_map_key(self, row):
        map_key = f"{row['database'].lower()}-{row['schema'].lower()}-{row['relation'].lower()}\
            -{row['column'].lower()}-{row.get('key', '').lower()}"
        return map_key

    def get_metadata_map_item(self, row):
        if row["description"] is None:
            row["description"] = ""
        data = {
            "type": row["type"],
            "description": row["description"].strip(),
        }
        return data

    def get_default_metadata_item(self, name, type="varchar", description=""):
        return {
            "name": name,
            "id": slugify(name, separator="_"),
            "type": type,
            "description": description,
        }

    def get_metadata(self):
        """
        If metadata path is configured, returns a dictionary with column keys
        and their corresponding values.
        If metadata is already set, do not load again and return the existing value.
        """
        path = self.get_config_value("metadata")

        if self.metadata:
            return self.metadata

        metadata_map = dict()
        if path:
            metadata_path = Path(self.config.project_root).joinpath(path)
            try:
                with open(metadata_path, "r") as csvfile:
                    rows = csv.DictReader(csvfile, skipinitialspace=True)
                    for row in rows:
                        try:
                            metadata_map[
                                self.get_metadata_map_key(row)
                            ] = self.get_metadata_map_item(row)
                        except KeyError as e:
                            raise Exception(
                                f"Key {e} not found in {path}. Please check this sample metadata"
                                "file: https://raw.githubusercontent.com/datacoves/dbt-coves/main/\
                                    sample_metadata.csv."
                            )
            except FileNotFoundError as e:
                raise Exception(f"Metadata file not found: {e}")

        self.metadata = metadata_map

        return metadata_map

    def get_config_value(self, key):
        return self.coves_config.integrated["generate"][self.args.task][key]

    def render_templates(self, relation, columns, destination, options=None, json_cols=None):
        destination.parent.mkdir(parents=True, exist_ok=True)
        context = self.get_templates_context(relation, columns, json_cols)
        self.render_templates_with_context(context, destination, options)

    def get_templates_context(self, relation, columns, json_cols=None):
        return {
            "relation": relation,
            "columns": self.get_metadata_columns(relation, columns),
            "nested": {},
            "adapter_name": self.adapter.__class__.__name__,
        }

    def get_metadata_columns(self, relation, cols):
        """
        Get metadata col
        """
        metadata = self.get_metadata()
        metadata_cols = []
        for col in cols:
            new_col = None
            if metadata:
                metadata_map_key_data = {
                    "database": relation.database,
                    "schema": relation.schema,
                    "relation": relation.name,
                    "column": col.name,
                }
                metadata_key = self.get_metadata_map_key(metadata_map_key_data)
                new_col = metadata.get(metadata_key)
                if new_col:
                    # FIXME: DRY this
                    new_col["name"] = col.name
                    new_col["id"] = slugify(col.name, separator="_")
            if not new_col:
                if "BigQuery" in self.adapter.__class__.__name__:
                    col_type = col.data_type
                else:
                    col_type = col.dtype
                new_col = self.get_default_metadata_item(col.name, type=col_type)
            metadata_cols.append(new_col)
        return metadata_cols

    def new_object_exists_in_current_yml(
        self,
        current_yml,
        template,
        context,
        templates_folder,
        resource_type,
    ):
        new_yml = yaml.load(
            get_render_output(
                template,
                context,
                templates_folder=templates_folder,
            )
        )
        resource_type_key = f"{resource_type}s"
        for new_obj in new_yml.get(resource_type_key):
            for curr_obj in current_yml.get(resource_type_key):
                if curr_obj.get("name") == new_obj.get("name"):
                    return new_obj
        return False

    def create_property_file(self, template, context, yml_path, templates_folder):
        self.render_property_file(template, context, yml_path, templates_folder)
        self.prop_files_created_by_dbtcoves.add(yml_path)
        console.print(f"Property file [green][b]{yml_path}[/b][/green] created")

    def render_property_files(
        self,
        context,
        options,
        templates_folder,
        update_strategy,
        resource_type,
        yml_path,
        template,
    ):
        strategy_key_update_all = ""
        strategy_key_recreate_all = ""
        rel = context["relation"]

        context["model"] = rel.name
        strategy_key_update_all = f"{resource_type}_prop_update_all"
        strategy_key_recreate_all = f"{resource_type}_prop_recreate_all"

        if self.no_prompt:
            if update_strategy == "ask" or update_strategy == "update":
                options[strategy_key_update_all] = True
            if update_strategy == "recreate":
                options[strategy_key_recreate_all] = True

        if yml_path.exists():
            object_in_yml = False
            current_yml = open_yaml(yml_path)
            if not current_yml:
                # target yml path exists but it's empty -> recreate file
                return self.create_property_file(template, context, yml_path, templates_folder)
            object_in_yml = self.new_object_exists_in_current_yml(
                current_yml,
                template,
                context,
                templates_folder,
                resource_type,
            )
            sel_action = None
            if object_in_yml:
                new_object_id = object_in_yml.get("name")
                if (
                    not options[strategy_key_recreate_all]
                    and not options[strategy_key_update_all]
                    and yml_path not in self.prop_files_created_by_dbtcoves
                ):
                    if update_strategy == "ask":
                        console.print(
                            f"{resource_type} [yellow][b]{new_object_id}[/b][/yellow] "
                            f"already exists in [b][yellow]{yml_path}[/b][/yellow]."
                        )
                        action = questionary.select(
                            "What would you like to do with it?",
                            choices=[
                                "Update",
                                "Update all",
                                "Recreate",
                                "Recreate all",
                                "Skip",
                                "Cancel",
                            ],
                        ).ask()
                        if action == "Recreate":
                            sel_action = "recreate"
                        elif action == "Recreate all":
                            options[strategy_key_recreate_all] = True
                            sel_action = "recreate"
                        elif action == "Update":
                            sel_action = "update"
                        elif action == "Update all":
                            options[strategy_key_update_all] = True
                            sel_action = "update"
                        elif action == "Skip":
                            return
                        elif action == "Cancel":
                            exit()
                    elif update_strategy == "update":
                        sel_action = "update"
                    elif update_strategy == "recreate":
                        sel_action = "recreate"
                    else:
                        console.print(f"Update strategy {update_strategy} not a valid option.")
                        exit()
                elif options[strategy_key_recreate_all]:
                    sel_action = "recreate"
                elif (
                    options[strategy_key_update_all]
                    or yml_path in self.prop_files_created_by_dbtcoves
                ):
                    sel_action = "update"
            else:
                sel_action = "create"
            self.modify_property_file(
                template,
                context,
                yml_path,
                current_yml,
                templates_folder,
                resource_type,
                sel_action,
            )
        else:
            self.create_property_file(template, context, yml_path, templates_folder)

    def update_object_properties(self, current_object, new_object, resource_type):
        if resource_type == "source":
            current_object = self.update_source_properties(current_object, new_object)
        if resource_type == "model":
            current_object = self.update_model_properties(current_object, new_object)
        return current_object

    def modify_property_file(
        self,
        template,
        context,
        yml_path,
        current_yml,
        templates_folder,
        resource_type,
        action,
    ):
        new_yml = yaml.load(
            get_render_output(
                template,
                context,
                templates_folder=templates_folder,
            )
        )
        resource_type_key = resource_type + "s"
        new_object = new_yml.get(resource_type_key)[0]

        if action == "create":
            current_yml[resource_type_key].append(new_object)
        elif action == "recreate" or action == "update":
            for idx, curr_obj in enumerate(current_yml.get(resource_type_key)):
                if curr_obj.get("name") == new_object.get("name"):
                    if action == "recreate":
                        current_yml[resource_type_key][idx] = new_object
                    if action == "update":
                        current_yml[resource_type_key][idx] = self.update_object_properties(
                            curr_obj, new_object, resource_type
                        )

        # "{Model/Source} {name} created/recreated/updated on file {filepath}"
        console.print(
            f"{resource_type.capitalize()} [green][b]{new_object.get('name')}[/b][/green] "
            f"{action}d on file [green][b]{yml_path}[/b][/green]"
        )

        save_yaml(yml_path, current_yml)

    def render_property_file(self, template, context, model_yml, templates_folder):
        model_yml.parent.mkdir(parents=True, exist_ok=True)
        render_template_file(
            template,
            context,
            model_yml,
            templates_folder=templates_folder,
        )

    def update_model_columns(self, columns_a: list, columns_b: list):
        model_a_column_names = [col.get("name") for col in columns_a]
        for new_column in columns_b:
            if new_column.get("name") in model_a_column_names:
                # If column exists in A, update it's description
                # and leave as-is to avoid overriding tests
                for current_column in columns_a:
                    if (current_column.get("name") == new_column.get("name")) and new_column.get(
                        "description"
                    ):
                        current_column["description"] = new_column.get("description")
            else:
                columns_a.append(new_column)

    def update_model_properties(self, model_a: dict, model_b: dict):
        if model_b.get("description"):
            model_a["description"] = model_b.get("description")
        self.update_model_columns(model_a.get("columns"), model_b.get("columns"))
        return model_a

    def update_source_tables(self, tables_a: list, tables_b: list):
        source_a_table_names = [table.get("name") for table in tables_a]
        for new_table in tables_b:
            if new_table.get("name") in source_a_table_names:
                # If table exists in A, update it's description and identifier
                # and leave as-is to avoid overriding tests
                for current_table in tables_a:
                    if current_table.get("name") == new_table.get("name"):
                        if new_table.get("description"):
                            current_table["description"] = new_table.get("description")
                        if new_table.get("identifier"):
                            current_table["identifier"] = new_table.get("identifier")
            else:
                tables_a.append(new_table)

    def update_source_properties(self, source_a: dict, source_b: dict):
        source_a["database"] = source_b.get("database")
        if source_b.get("schema"):
            source_a["schema"] = source_b.get("schema")
        self.update_source_tables(source_a.get("tables"), source_b.get("tables"))
        return source_a

    def raise_duplicate_relations(self, relations):
        if not self.no_prompt:
            relation_names = [f"{rel.schema.lower()}.{rel.name.lower()}" for rel in relations]
            duplicates = {rel for rel in relation_names if relation_names.count(rel) > 1}
            if duplicates:
                raise BaseGeneratorException(
                    "Can't select multiple relations with the exact same name: "
                    f"[red]{', '.join(duplicates)}[/red]"
                )